跳到主要内容

MLflow 插件

MLflow 的插件架构可实现与第三方工具和自定义基础设施的无缝集成。作为一个与框架无关的平台,MLflow 提供开发人员 API,用于扩展存储、身份验证、执行后端和模型评估等功能。

快速入门

安装和使用插件

尝试内置的测试插件以了解插件的工作原理

# Clone MLflow and install example plugin
git clone https://github.com/mlflow/mlflow
cd mlflow
pip install -e tests/resources/mlflow-test-plugin
# Use the plugin with custom tracking URI scheme
MLFLOW_TRACKING_URI=file-plugin:$(PWD)/mlruns python examples/quickstart/mlflow_tracking.py

# Launch MLflow UI to view results
mlflow server --backend-store-uri ./mlruns

打开 https://:5000 查看您跟踪的实验

Quickstart UI

插件优势

插件允许您将 MLflow 与现有基础设施集成,而无需修改核心 MLflow 代码,从而确保平稳升级和维护。

插件类型和用例

MLflow 支持八种类型的插件,每种类型都满足不同的集成需求

存储与持久化

插件类型目的示例用例
跟踪存储自定义实验数据存储企业数据库,云数据仓库
工件仓库自定义工件存储内部 blob 存储,专用文件系统
模型注册表存储自定义模型注册表后端企业模型目录,版本控制系统

身份验证和请求头

插件类型目的示例用例
请求身份验证提供程序自定义身份验证OAuth、API 密钥、基于证书的身份验证
请求头提供程序自定义 HTTP 请求头环境标识,合规性请求头
运行上下文提供程序自动运行元数据Git 信息、环境详情、自定义标签

执行与评估

插件类型目的示例用例
项目后端自定义执行环境内部集群、作业调度程序、云平台
模型评估器自定义评估指标领域特定验证,自定义测试套件
部署自定义服务平台内部服务基础设施,边缘部署

开发自定义插件

插件结构

将插件创建为独立的 Python 包

# setup.py
from setuptools import setup

setup(
name="my-mlflow-plugin",
version="0.1.0",
install_requires=["mlflow>=2.0.0"],
entry_points={
# Define plugin entry points
"mlflow.tracking_store": "my-scheme=my_plugin.store:MyTrackingStore",
"mlflow.artifact_repository": "my-scheme=my_plugin.artifacts:MyArtifactRepo",
"mlflow.run_context_provider": "unused=my_plugin.context:MyContextProvider",
"mlflow.request_auth_provider": "unused=my_plugin.auth:MyAuthProvider",
"mlflow.model_evaluator": "my-evaluator=my_plugin.evaluator:MyEvaluator",
"mlflow.project_backend": "my-backend=my_plugin.backend:MyBackend",
"mlflow.deployments": "my-target=my_plugin.deployment",
"mlflow.app": "my-app=my_plugin.app:create_app",
},
)

存储插件

# my_plugin/store.py
from mlflow.store.tracking.abstract_store import AbstractStore


class MyTrackingStore(AbstractStore):
"""Custom tracking store for scheme 'my-scheme://'"""

def __init__(self, store_uri):
super().__init__()
self.store_uri = store_uri
# Initialize your custom storage backend

def create_experiment(self, name, artifact_location=None, tags=None):
# Implement experiment creation logic
pass

def log_metric(self, run_id, metric):
# Implement metric logging logic
pass

def log_param(self, run_id, param):
# Implement parameter logging logic
pass

# Implement other required AbstractStore methods...

身份验证插件

# my_plugin/auth.py
from mlflow.tracking.request_auth.abstract_request_auth_provider import (
RequestAuthProvider,
)


class MyAuthProvider(RequestAuthProvider):
"""Custom authentication provider"""

def get_name(self):
return "my_auth_provider"

def get_auth(self):
# Return authentication object for HTTP requests
# Can be anything that requests.auth accepts
return MyCustomAuth()


class MyCustomAuth:
"""Custom authentication class"""

def __call__(self, request):
# Add authentication headers to request
token = self._get_token()
request.headers["Authorization"] = f"Bearer {token}"
return request

def _get_token(self):
# Implement token retrieval logic
# E.g., read from file, environment, or API call
pass

用法

export MLFLOW_TRACKING_AUTH=my_auth_provider
python your_mlflow_script.py

执行插件

项目后端插件

# my_plugin/backend.py
from mlflow.projects.backend import AbstractBackend
from mlflow.projects.submitted_run import SubmittedRun


class MyBackend(AbstractBackend):
"""Custom execution backend for MLflow Projects"""

def run(
self,
project_uri,
entry_point,
parameters,
version,
backend_config,
tracking_uri,
experiment_id,
):
"""Execute project on custom infrastructure"""

# Parse backend configuration
cluster_config = backend_config.get("cluster_config", {})

# Submit job to your execution system
job_id = self._submit_job(
project_uri=project_uri,
entry_point=entry_point,
parameters=parameters,
cluster_config=cluster_config,
)

# Return SubmittedRun for monitoring
return MySubmittedRun(job_id, tracking_uri)

def _submit_job(self, project_uri, entry_point, parameters, cluster_config):
# Implement job submission to your infrastructure
# Return job ID for monitoring
pass


class MySubmittedRun(SubmittedRun):
"""Handle for submitted run"""

def __init__(self, job_id, tracking_uri):
self.job_id = job_id
self.tracking_uri = tracking_uri
super().__init__()

def wait(self):
# Wait for job completion and return success status
return self._poll_job_status()

def cancel(self):
# Cancel the running job
self._cancel_job()

def get_status(self):
# Get current job status
return self._get_job_status()

模型评估插件

# my_plugin/evaluator.py
from mlflow.models.evaluation import ModelEvaluator
from mlflow.models import EvaluationResult


class MyEvaluator(ModelEvaluator):
"""Custom model evaluator"""

def can_evaluate(self, *, model_type, evaluator_config, **kwargs):
"""Check if this evaluator can handle the model type"""
supported_types = ["classifier", "regressor"]
return model_type in supported_types

def evaluate(
self, *, model, model_type, dataset, run_id, evaluator_config, **kwargs
):
"""Perform custom evaluation"""

# Get predictions
predictions = model.predict(dataset.features_data)

# Compute custom metrics
metrics = self._compute_custom_metrics(
predictions, dataset.labels_data, evaluator_config
)

# Generate custom artifacts
artifacts = self._generate_artifacts(predictions, dataset, evaluator_config)

return EvaluationResult(metrics=metrics, artifacts=artifacts)

def _compute_custom_metrics(self, predictions, labels, config):
# Implement domain-specific metrics
return {
"custom_score": self._calculate_custom_score(predictions, labels),
"business_metric": self._calculate_business_metric(predictions, labels),
}

def _generate_artifacts(self, predictions, dataset, config):
# Generate custom plots, reports, etc.
return {}

SQL Server 插件

直接将工件存储在 SQL Server 数据库中

pip install mlflow[sqlserver]
import mlflow

# Use SQL Server as artifact store
db_uri = "mssql+pyodbc://user:pass@host:port/db?driver=ODBC+Driver+17+for+SQL+Server"
mlflow.create_experiment("sql_experiment", artifact_location=db_uri)

with mlflow.start_run():
mlflow.onnx.log_model(model, name="model") # Stored as BLOB in SQL Server

阿里云 OSS 插件

与阿里云对象存储服务集成

pip install mlflow[aliyun-oss]
import os
import mlflow

# Configure OSS credentials
os.environ["MLFLOW_OSS_ENDPOINT_URL"] = "https://oss-region.aliyuncs.com"
os.environ["MLFLOW_OSS_KEY_ID"] = "your_access_key"
os.environ["MLFLOW_OSS_KEY_SECRET"] = "your_secret_key"

# Use OSS as artifact store
mlflow.create_experiment("oss_experiment", artifact_location="oss://bucket/path")

XetHub 插件

使用 XetHub 进行版本控制的工件存储

pip install mlflow[xethub]
import mlflow

# Authenticate with XetHub (via CLI or environment variables)
mlflow.create_experiment(
"xet_experiment", artifact_location="xet://username/repo/branch"
)

Elasticsearch 插件

使用 Elasticsearch 进行实验跟踪

pip install mlflow-elasticsearchstore

测试您的插件

# tests/test_my_plugin.py
import pytest
import mlflow
from my_plugin.store import MyTrackingStore


class TestMyTrackingStore:
def setup_method(self):
self.store = MyTrackingStore("my-scheme://test")

def test_create_experiment(self):
experiment_id = self.store.create_experiment("test_exp")
assert experiment_id is not None

def test_log_metric(self):
experiment_id = self.store.create_experiment("test_exp")
run = self.store.create_run(experiment_id, "user", "test_run")

metric = mlflow.entities.Metric("accuracy", 0.95, 12345, 0)
self.store.log_metric(run.info.run_id, metric)

# Verify metric was logged correctly
stored_run = self.store.get_run(run.info.run_id)
assert "accuracy" in stored_run.data.metrics
assert stored_run.data.metrics["accuracy"] == 0.95

def test_log_artifact(self):
# Test artifact logging functionality
pass

分发与发布

包结构

my-mlflow-plugin/
├── setup.py # Package configuration
├── README.md # Plugin documentation
├── my_plugin/
│ ├── __init__.py
│ ├── store.py # Tracking store implementation
│ ├── artifacts.py # Artifact repository implementation
│ ├── auth.py # Authentication provider
│ └── evaluator.py # Model evaluator
├── tests/
│ ├── test_store.py
│ ├── test_artifacts.py
│ └── test_integration.py
└── examples/
└── example_usage.py

发布到 PyPI

# Build distribution packages
python setup.py sdist bdist_wheel

# Upload to PyPI
pip install twine
twine upload dist/*

文档模板

# My MLflow Plugin

## Installation
```bash
pip install my-mlflow-plugin

配置

export MY_PLUGIN_CONFIG="value"

用法

import mlflow

mlflow.set_tracking_uri("my-scheme://config")

功能

  • 功能 1
  • 功能 2

示例

有关完整的用法示例,请参阅 examples/ 目录。

最佳实践

插件开发

  • 遵循 MLflow 接口 - 实现所有必需的抽象方法
  • 优雅地处理错误 - 为配置问题提供清晰的错误消息
  • 支持身份验证 - 与现有凭证系统集成
  • 添加全面的日志记录 - 帮助用户调试配置问题
  • 版本兼容性 - 针对多个 MLflow 版本进行测试

性能优化

  • 批量操作 - 尽可能实现高效的批量日志记录
  • 连接池 - 重用与外部系统的连接
  • 异步操作 - 在有利时,对存储操作使用异步 I/O
  • 缓存 - 缓存频繁访问的元数据

安全注意事项

  • 凭证管理 - 绝不记录或暴露敏感凭证
  • 输入验证 - 验证所有用户输入和 URI
  • 访问控制 - 尊重现有的身份验证和授权
  • 安全通信 - 对网络通信使用 TLS/SSL

测试策略

  • 单元测试 - 测试单个插件组件
  • 集成测试 - 测试与 MLflow 的完整工作流
  • 性能测试 - 验证可接受的性能特征
  • 兼容性测试 - 测试与不同 MLflow 版本的兼容性

准备好扩展 MLflow 了吗?示例测试插件开始,查看所有插件类型的实际运行情况,然后构建您的自定义集成!