跳到主要内容

MLflow 插件

MLflow 的插件架构支持与第三方工具和自定义基础设施的无缝集成。作为一个与框架无关的平台,MLflow 提供了开发者 API,用于扩展存储、身份验证、执行后端和模型评估等功能。

快速入门

安装和使用插件

尝试使用内置的测试插件,了解插件的工作原理

# Clone MLflow and install example plugin
git clone https://github.com/mlflow/mlflow
cd mlflow
pip install -e . tests/resources/mlflow-test-plugin
# Use the plugin with custom tracking URI scheme
MLFLOW_TRACKING_URI=file-plugin:$(PWD)/mlruns python examples/quickstart/mlflow_tracking.py

# Launch MLflow UI to view results
mlflow server --backend-store-uri ./mlruns

打开 https://:5000 查看您跟踪的实验

Quickstart UI

插件优势

插件允许您将 MLflow 与现有基础设施集成,而无需修改 MLflow 核心代码,从而确保平稳的升级和维护。

插件类型和用例

MLflow 支持八种类型的插件,每种插件都满足不同的集成需求

存储与持久化

插件类型目的示例用例
跟踪存储自定义实验数据存储企业数据库、云数据仓库
Artifact 存储库自定义 Artifact 存储内部 Blob 存储、专用文件系统
模型注册表存储自定义模型注册表后端企业模型目录、版本控制系统

身份验证与标头

插件类型目的示例用例
请求认证提供程序自定义身份验证OAuth、API 密钥、基于证书的身份验证
请求标头提供程序自定义 HTTP 标头环境识别、合规性标头
运行上下文提供程序自动运行元数据Git 信息、环境详情、自定义标签

执行与评估

插件类型目的示例用例
项目后端自定义执行环境内部集群、作业调度器、云平台
模型评估器自定义评估指标领域特定验证、自定义测试套件
部署自定义服务平台内部服务基础设施、边缘部署

开发自定义插件

插件结构

将插件创建为一个独立的 Python 包

# setup.py
from setuptools import setup

setup(
name="my-mlflow-plugin",
version="0.1.0",
install_requires=["mlflow>=2.0.0"],
entry_points={
# Define plugin entry points
"mlflow.tracking_store": "my-scheme=my_plugin.store:MyTrackingStore",
"mlflow.artifact_repository": "my-scheme=my_plugin.artifacts:MyArtifactRepo",
"mlflow.run_context_provider": "unused=my_plugin.context:MyContextProvider",
"mlflow.request_auth_provider": "unused=my_plugin.auth:MyAuthProvider",
"mlflow.model_evaluator": "my-evaluator=my_plugin.evaluator:MyEvaluator",
"mlflow.project_backend": "my-backend=my_plugin.backend:MyBackend",
"mlflow.deployments": "my-target=my_plugin.deployment",
"mlflow.app": "my-app=my_plugin.app:create_app",
},
)

存储插件

# my_plugin/store.py
from mlflow.store.tracking.abstract_store import AbstractStore


class MyTrackingStore(AbstractStore):
"""Custom tracking store for scheme 'my-scheme://'"""

def __init__(self, store_uri):
super().__init__()
self.store_uri = store_uri
# Initialize your custom storage backend

def create_experiment(self, name, artifact_location=None, tags=None):
# Implement experiment creation logic
pass

def log_metric(self, run_id, metric):
# Implement metric logging logic
pass

def log_param(self, run_id, param):
# Implement parameter logging logic
pass

# Implement other required AbstractStore methods...

身份验证插件

# my_plugin/auth.py
from mlflow.tracking.request_auth.abstract_request_auth_provider import (
RequestAuthProvider,
)


class MyAuthProvider(RequestAuthProvider):
"""Custom authentication provider"""

def get_name(self):
return "my_auth_provider"

def get_auth(self):
# Return authentication object for HTTP requests
# Can be anything that requests.auth accepts
return MyCustomAuth()


class MyCustomAuth:
"""Custom authentication class"""

def __call__(self, request):
# Add authentication headers to request
token = self._get_token()
request.headers["Authorization"] = f"Bearer {token}"
return request

def _get_token(self):
# Implement token retrieval logic
# E.g., read from file, environment, or API call
pass

用法

export MLFLOW_TRACKING_AUTH=my_auth_provider
python your_mlflow_script.py

执行插件

项目后端插件

# my_plugin/backend.py
from mlflow.projects.backend import AbstractBackend
from mlflow.projects.submitted_run import SubmittedRun


class MyBackend(AbstractBackend):
"""Custom execution backend for MLflow Projects"""

def run(
self,
project_uri,
entry_point,
parameters,
version,
backend_config,
tracking_uri,
experiment_id,
):
"""Execute project on custom infrastructure"""

# Parse backend configuration
cluster_config = backend_config.get("cluster_config", {})

# Submit job to your execution system
job_id = self._submit_job(
project_uri=project_uri,
entry_point=entry_point,
parameters=parameters,
cluster_config=cluster_config,
)

# Return SubmittedRun for monitoring
return MySubmittedRun(job_id, tracking_uri)

def _submit_job(self, project_uri, entry_point, parameters, cluster_config):
# Implement job submission to your infrastructure
# Return job ID for monitoring
pass


class MySubmittedRun(SubmittedRun):
"""Handle for submitted run"""

def __init__(self, job_id, tracking_uri):
self.job_id = job_id
self.tracking_uri = tracking_uri
super().__init__()

def wait(self):
# Wait for job completion and return success status
return self._poll_job_status()

def cancel(self):
# Cancel the running job
self._cancel_job()

def get_status(self):
# Get current job status
return self._get_job_status()

模型评估插件

# my_plugin/evaluator.py
from mlflow.models.evaluation import ModelEvaluator
from mlflow.models import EvaluationResult


class MyEvaluator(ModelEvaluator):
"""Custom model evaluator"""

def can_evaluate(self, *, model_type, evaluator_config, **kwargs):
"""Check if this evaluator can handle the model type"""
supported_types = ["classifier", "regressor"]
return model_type in supported_types

def evaluate(
self, *, model, model_type, dataset, run_id, evaluator_config, **kwargs
):
"""Perform custom evaluation"""

# Get predictions
predictions = model.predict(dataset.features_data)

# Compute custom metrics
metrics = self._compute_custom_metrics(
predictions, dataset.labels_data, evaluator_config
)

# Generate custom artifacts
artifacts = self._generate_artifacts(predictions, dataset, evaluator_config)

return EvaluationResult(metrics=metrics, artifacts=artifacts)

def _compute_custom_metrics(self, predictions, labels, config):
# Implement domain-specific metrics
return {
"custom_score": self._calculate_custom_score(predictions, labels),
"business_metric": self._calculate_business_metric(predictions, labels),
}

def _generate_artifacts(self, predictions, dataset, config):
# Generate custom plots, reports, etc.
return {}

SQL Server 插件

将 Artifact 直接存储在 SQL Server 数据库中

pip install mlflow[sqlserver]
import mlflow

# Use SQL Server as artifact store
db_uri = "mssql+pyodbc://user:pass@host:port/db?driver=ODBC+Driver+17+for+SQL+Server"
mlflow.create_experiment("sql_experiment", artifact_location=db_uri)

with mlflow.start_run():
mlflow.onnx.log_model(model, name="model") # Stored as BLOB in SQL Server

阿里云 OSS 插件

与阿里云对象存储服务集成

pip install mlflow[aliyun-oss]
import os
import mlflow

# Configure OSS credentials
os.environ["MLFLOW_OSS_ENDPOINT_URL"] = "https://oss-region.aliyuncs.com"
os.environ["MLFLOW_OSS_KEY_ID"] = "your_access_key"
os.environ["MLFLOW_OSS_KEY_SECRET"] = "your_secret_key"

# Use OSS as artifact store
mlflow.create_experiment("oss_experiment", artifact_location="oss://bucket/path")

XetHub 插件

使用 XetHub 进行版本化 Artifact 存储

pip install mlflow[xethub]
import mlflow

# Authenticate with XetHub (via CLI or environment variables)
mlflow.create_experiment(
"xet_experiment", artifact_location="xet://username/repo/branch"
)

Elasticsearch 插件

使用 Elasticsearch 进行实验跟踪

pip install mlflow-elasticsearchstore

测试您的插件

# tests/test_my_plugin.py
import pytest
import mlflow
from my_plugin.store import MyTrackingStore


class TestMyTrackingStore:
def setup_method(self):
self.store = MyTrackingStore("my-scheme://test")

def test_create_experiment(self):
experiment_id = self.store.create_experiment("test_exp")
assert experiment_id is not None

def test_log_metric(self):
experiment_id = self.store.create_experiment("test_exp")
run = self.store.create_run(experiment_id, "user", "test_run")

metric = mlflow.entities.Metric("accuracy", 0.95, 12345, 0)
self.store.log_metric(run.info.run_id, metric)

# Verify metric was logged correctly
stored_run = self.store.get_run(run.info.run_id)
assert "accuracy" in stored_run.data.metrics
assert stored_run.data.metrics["accuracy"] == 0.95

def test_log_artifact(self):
# Test artifact logging functionality
pass

分发与发布

包结构

my-mlflow-plugin/
├── setup.py # Package configuration
├── README.md # Plugin documentation
├── my_plugin/
│ ├── __init__.py
│ ├── store.py # Tracking store implementation
│ ├── artifacts.py # Artifact repository implementation
│ ├── auth.py # Authentication provider
│ └── evaluator.py # Model evaluator
├── tests/
│ ├── test_store.py
│ ├── test_artifacts.py
│ └── test_integration.py
└── examples/
└── example_usage.py

发布到 PyPI

# Build distribution packages
python setup.py sdist bdist_wheel

# Upload to PyPI
pip install twine
twine upload dist/*

文档模板

# My MLflow Plugin

## Installation
```bash
pip install my-mlflow-plugin

配置

export MY_PLUGIN_CONFIG="value"

用法

import mlflow

mlflow.set_tracking_uri("my-scheme://config")

功能

  • 功能 1
  • 功能 2

示例

请参阅 examples/ 目录以获取完整的用法示例。

最佳实践

插件开发

  • 遵循 MLflow 接口 - 实现所有必需的抽象方法
  • 优雅处理错误 - 为配置问题提供清晰的错误消息
  • 支持身份验证 - 与现有凭证系统集成
  • 添加全面的日志记录 - 帮助用户调试配置问题
  • 版本兼容性 - 测试多个 MLflow 版本

性能优化

  • 批量操作 - 在可能的情况下实现高效的批量日志记录
  • 连接池 - 重用到外部系统的连接
  • 异步操作 - 在有益时使用异步 I/O 进行存储操作
  • 缓存 - 缓存经常访问的元数据

安全注意事项

  • 凭证管理 - 绝不记录或暴露敏感凭证
  • 输入验证 - 验证所有用户输入和 URI
  • 访问控制 - 遵守现有的身份验证和授权
  • 安全通信 - 使用 TLS/SSL 进行网络通信

测试策略

  • 单元测试 - 测试单个插件组件
  • 集成测试 - 与 MLflow 测试完整工作流程
  • 性能测试 - 验证可接受的性能特征
  • 兼容性测试 - 与不同 MLflow 版本进行测试

准备好扩展 MLflow 了吗?示例测试插件 开始,查看所有插件类型的实际应用,然后构建您的自定义集成!