跳到主要内容

MLflow 插件

MLflow 的插件架构能够与第三方工具和自定义基础设施无缝集成。作为一个与框架无关的平台,MLflow 提供了开发者 API,用于扩展存储、身份验证、执行后端和模型评估等功能。

快速入门

安装和使用插件

尝试内置的测试插件,看看插件是如何工作的

# Clone MLflow and install example plugin
git clone https://github.com/mlflow/mlflow
cd mlflow
pip install -e tests/resources/mlflow-test-plugin
# Use the plugin with custom tracking URI scheme
MLFLOW_TRACKING_URI=file-plugin:$(PWD)/mlruns python examples/quickstart/mlflow_tracking.py

# Launch MLflow UI to view results
mlflow server --backend-store-uri ./mlruns

打开 https://:5000 查看你追踪的实验

Quickstart UI

插件优势

插件允许你将 MLflow 与你现有的基础设施集成,而无需修改 MLflow 核心代码,从而确保平滑的升级和维护。

插件类型 & 用例

MLflow 支持八种类型的插件,每种类型都满足不同的集成需求

存储和持久化

插件类型目的用例示例
跟踪存储自定义实验数据存储企业数据库,云数据仓库
工件仓库自定义工件存储内部 blob 存储,专用文件系统
模型注册表存储自定义模型注册表后端企业模型目录,版本控制系统

身份验证和标头

插件类型目的用例示例
请求身份验证提供程序自定义身份验证OAuth,API 密钥,基于证书的身份验证
请求标头提供程序自定义 HTTP 标头环境识别,合规性标头
运行上下文提供程序自动运行元数据Git 信息,环境详细信息,自定义标签

执行和评估

插件类型目的用例示例
项目后端自定义执行环境内部集群,作业调度器,云平台
模型评估器自定义评估指标特定领域的验证,自定义测试套件
部署自定义服务平台内部服务基础设施,边缘部署

开发自定义插件

插件结构

将插件创建为独立的 Python 包

# setup.py
from setuptools import setup

setup(
name="my-mlflow-plugin",
version="0.1.0",
install_requires=["mlflow>=2.0.0"],
entry_points={
# Define plugin entry points
"mlflow.tracking_store": "my-scheme=my_plugin.store:MyTrackingStore",
"mlflow.artifact_repository": "my-scheme=my_plugin.artifacts:MyArtifactRepo",
"mlflow.run_context_provider": "unused=my_plugin.context:MyContextProvider",
"mlflow.request_auth_provider": "unused=my_plugin.auth:MyAuthProvider",
"mlflow.model_evaluator": "my-evaluator=my_plugin.evaluator:MyEvaluator",
"mlflow.project_backend": "my-backend=my_plugin.backend:MyBackend",
"mlflow.deployments": "my-target=my_plugin.deployment",
"mlflow.app": "my-app=my_plugin.app:create_app",
},
)

存储插件

# my_plugin/store.py
from mlflow.store.tracking.abstract_store import AbstractStore


class MyTrackingStore(AbstractStore):
"""Custom tracking store for scheme 'my-scheme://'"""

def __init__(self, store_uri):
super().__init__()
self.store_uri = store_uri
# Initialize your custom storage backend

def create_experiment(self, name, artifact_location=None, tags=None):
# Implement experiment creation logic
pass

def log_metric(self, run_id, metric):
# Implement metric logging logic
pass

def log_param(self, run_id, param):
# Implement parameter logging logic
pass

# Implement other required AbstractStore methods...

身份验证插件

# my_plugin/auth.py
from mlflow.tracking.request_auth.abstract_request_auth_provider import (
RequestAuthProvider,
)


class MyAuthProvider(RequestAuthProvider):
"""Custom authentication provider"""

def get_name(self):
return "my_auth_provider"

def get_auth(self):
# Return authentication object for HTTP requests
# Can be anything that requests.auth accepts
return MyCustomAuth()


class MyCustomAuth:
"""Custom authentication class"""

def __call__(self, request):
# Add authentication headers to request
token = self._get_token()
request.headers["Authorization"] = f"Bearer {token}"
return request

def _get_token(self):
# Implement token retrieval logic
# E.g., read from file, environment, or API call
pass

用法

export MLFLOW_TRACKING_AUTH=my_auth_provider
python your_mlflow_script.py

执行插件

项目后端插件

# my_plugin/backend.py
from mlflow.projects.backend import AbstractBackend
from mlflow.projects.submitted_run import SubmittedRun


class MyBackend(AbstractBackend):
"""Custom execution backend for MLflow Projects"""

def run(
self,
project_uri,
entry_point,
parameters,
version,
backend_config,
tracking_uri,
experiment_id,
):
"""Execute project on custom infrastructure"""

# Parse backend configuration
cluster_config = backend_config.get("cluster_config", {})

# Submit job to your execution system
job_id = self._submit_job(
project_uri=project_uri,
entry_point=entry_point,
parameters=parameters,
cluster_config=cluster_config,
)

# Return SubmittedRun for monitoring
return MySubmittedRun(job_id, tracking_uri)

def _submit_job(self, project_uri, entry_point, parameters, cluster_config):
# Implement job submission to your infrastructure
# Return job ID for monitoring
pass


class MySubmittedRun(SubmittedRun):
"""Handle for submitted run"""

def __init__(self, job_id, tracking_uri):
self.job_id = job_id
self.tracking_uri = tracking_uri
super().__init__()

def wait(self):
# Wait for job completion and return success status
return self._poll_job_status()

def cancel(self):
# Cancel the running job
self._cancel_job()

def get_status(self):
# Get current job status
return self._get_job_status()

模型评估插件

# my_plugin/evaluator.py
from mlflow.models.evaluation import ModelEvaluator
from mlflow.models import EvaluationResult


class MyEvaluator(ModelEvaluator):
"""Custom model evaluator"""

def can_evaluate(self, *, model_type, evaluator_config, **kwargs):
"""Check if this evaluator can handle the model type"""
supported_types = ["classifier", "regressor"]
return model_type in supported_types

def evaluate(
self, *, model, model_type, dataset, run_id, evaluator_config, **kwargs
):
"""Perform custom evaluation"""

# Get predictions
predictions = model.predict(dataset.features_data)

# Compute custom metrics
metrics = self._compute_custom_metrics(
predictions, dataset.labels_data, evaluator_config
)

# Generate custom artifacts
artifacts = self._generate_artifacts(predictions, dataset, evaluator_config)

return EvaluationResult(metrics=metrics, artifacts=artifacts)

def _compute_custom_metrics(self, predictions, labels, config):
# Implement domain-specific metrics
return {
"custom_score": self._calculate_custom_score(predictions, labels),
"business_metric": self._calculate_business_metric(predictions, labels),
}

def _generate_artifacts(self, predictions, dataset, config):
# Generate custom plots, reports, etc.
return {}

SQL Server 插件

将工件直接存储在 SQL Server 数据库中

pip install mlflow[sqlserver]
import mlflow

# Use SQL Server as artifact store
db_uri = "mssql+pyodbc://user:pass@host:port/db?driver=ODBC+Driver+17+for+SQL+Server"
mlflow.create_experiment("sql_experiment", artifact_location=db_uri)

with mlflow.start_run():
mlflow.onnx.log_model(model, name="model") # Stored as BLOB in SQL Server

阿里云 OSS 插件

与阿里云对象存储服务集成

pip install mlflow[aliyun-oss]
import os
import mlflow

# Configure OSS credentials
os.environ["MLFLOW_OSS_ENDPOINT_URL"] = "https://oss-region.aliyuncs.com"
os.environ["MLFLOW_OSS_KEY_ID"] = "your_access_key"
os.environ["MLFLOW_OSS_KEY_SECRET"] = "your_secret_key"

# Use OSS as artifact store
mlflow.create_experiment("oss_experiment", artifact_location="oss://bucket/path")

XetHub 插件

使用 XetHub 进行版本化的工件存储

pip install mlflow[xethub]
import mlflow

# Authenticate with XetHub (via CLI or environment variables)
mlflow.create_experiment(
"xet_experiment", artifact_location="xet://username/repo/branch"
)

Elasticsearch 插件

使用 Elasticsearch 进行实验跟踪

pip install mlflow-elasticsearchstore

测试你的插件

# tests/test_my_plugin.py
import pytest
import mlflow
from my_plugin.store import MyTrackingStore


class TestMyTrackingStore:
def setup_method(self):
self.store = MyTrackingStore("my-scheme://test")

def test_create_experiment(self):
experiment_id = self.store.create_experiment("test_exp")
assert experiment_id is not None

def test_log_metric(self):
experiment_id = self.store.create_experiment("test_exp")
run = self.store.create_run(experiment_id, "user", "test_run")

metric = mlflow.entities.Metric("accuracy", 0.95, 12345, 0)
self.store.log_metric(run.info.run_id, metric)

# Verify metric was logged correctly
stored_run = self.store.get_run(run.info.run_id)
assert "accuracy" in stored_run.data.metrics
assert stored_run.data.metrics["accuracy"] == 0.95

def test_log_artifact(self):
# Test artifact logging functionality
pass

分发和发布

包结构

my-mlflow-plugin/
├── setup.py # Package configuration
├── README.md # Plugin documentation
├── my_plugin/
│ ├── __init__.py
│ ├── store.py # Tracking store implementation
│ ├── artifacts.py # Artifact repository implementation
│ ├── auth.py # Authentication provider
│ └── evaluator.py # Model evaluator
├── tests/
│ ├── test_store.py
│ ├── test_artifacts.py
│ └── test_integration.py
└── examples/
└── example_usage.py

发布到 PyPI

# Build distribution packages
python setup.py sdist bdist_wheel

# Upload to PyPI
pip install twine
twine upload dist/*

文档模板

# My MLflow Plugin

## Installation
```bash
pip install my-mlflow-plugin

配置

export MY_PLUGIN_CONFIG="value"

用法

import mlflow

mlflow.set_tracking_uri("my-scheme://config")

功能

  • 功能 1
  • 功能 2

示例

有关完整的用法示例,请参阅 examples/ 目录。

最佳实践

插件开发

  • 遵循 MLflow 接口 - 实现所有必需的抽象方法
  • 优雅地处理错误 - 为配置问题提供清晰的错误消息
  • 支持身份验证 - 与现有的凭证系统集成
  • 添加全面的日志记录 - 帮助用户调试配置问题
  • 版本兼容性 - 针对多个 MLflow 版本进行测试

性能优化

  • 批量操作 - 尽可能实现高效的批量日志记录
  • 连接池 - 重用与外部系统的连接
  • 异步操作 - 在有益的情况下,对存储操作使用异步 I/O
  • 缓存 - 缓存经常访问的元数据

安全注意事项

  • 凭证管理 - 永远不要记录或暴露敏感凭证
  • 输入验证 - 验证所有用户输入和 URI
  • 访问控制 - 尊重现有的身份验证和授权
  • 安全通信 - 对网络通信使用 TLS/SSL

测试策略

  • 单元测试 - 测试各个插件组件
  • 集成测试 - 使用 MLflow 测试完整的工作流程
  • 性能测试 - 验证可接受的性能特征
  • 兼容性测试 - 使用不同的 MLflow 版本进行测试

准备好扩展 MLflow 了吗?示例测试插件开始,了解所有插件类型的使用方法,然后构建你的自定义集成!