社区模型风格
MLflow 活跃的社区已为专业 ML 框架和用例开发了多种风格,将 MLflow 的功能扩展到内置风格之外。这些由社区维护的包可以与特定领域工具无缝集成,用于时间序列预测、异常检测、可视化等等。
快速入门
安装社区风格
大多数社区风格可通过 PyPI 获取
# Time series forecasting
pip install mlflow[sktime]
pip install mlflavors
# Visualization and plotting
pip install mlflow-vizmod
# Big data and cloud platforms
pip install bigmlflow
pip install mlflow[aliyun-oss]
基本使用模式
所有社区风格都遵循 MLflow 的标准接口
import mlflow
import community_flavor # Replace with actual flavor
# Train your model
model = SomeModel()
model.fit(data)
# Log with MLflow
with mlflow.start_run():
community_flavor.log_model(model, "model_path")
# Load for inference
loaded_model = community_flavor.load_model("model_uri")
predictions = loaded_model.predict(new_data)
特色社区风格
- 时间序列
- 可视化
- 大数据与云
- MLflow Go
- 自定义风格
Sktime
用于时间序列预测、分类和转换的统一接口。
pip install sktime[mlflow]
import pandas as pd
from sktime.datasets import load_airline
from sktime.forecasting.arima import AutoARIMA
from sktime.utils import mlflow_sktime
# Load data and train model
airline = load_airline()
model = AutoARIMA(sp=12, d=0, max_p=2, max_q=2, suppress_warnings=True)
model.fit(airline, fh=[1, 2, 3])
# Save and load with MLflow
mlflow_sktime.save_model(sktime_model=model, path="model")
loaded_model = mlflow_sktime.load_model(model_uri="model")
# Make predictions
predictions = loaded_model.predict()
print(predictions)
# Load as PyFunc for serving
loaded_pyfunc = mlflow_sktime.pyfunc.load_model(model_uri="model")
pyfunc_predictions = loaded_pyfunc.predict(pd.DataFrame())
MLflavors 包
在一个包中支持多种时间序列和 ML 框架。
pip install mlflavors
支持的框架
框架 | 类别 | 示例用例 |
---|---|---|
Orbit | 时间序列 | 贝叶斯预测 |
StatsForecast | 时间序列 | 统计模型 |
PyOD | 异常检测 | 离群值检测 |
SDV | 合成数据 | 隐私保护数据生成 |
PyOD 异常检测示例
import mlflow
from pyod.models.knn import KNN
from pyod.utils.data import generate_data
import mlflavors
# Generate synthetic data
contamination = 0.1
n_train, n_test = 200, 100
X_train, X_test, _, y_test = generate_data(
n_train=n_train, n_test=n_test, contamination=contamination
)
with mlflow.start_run():
# Train KNN detector
clf = KNN()
clf.fit(X_train)
# Log model
mlflavors.pyod.log_model(
pyod_model=clf, artifact_path="anomaly_detector", serialization_format="pickle"
)
# Evaluate
scores = clf.decision_function(X_test)
mlflow.log_metric("mean_anomaly_score", scores.mean())
服务 PyOD 模型
# Load as PyFunc
loaded_pyfunc = mlflavors.pyod.pyfunc.load_model(model_uri="model_uri")
# Create configuration for inference
import pandas as pd
predict_conf = pd.DataFrame([{"X": X_test, "predict_method": "decision_function"}])
anomaly_scores = loaded_pyfunc.predict(predict_conf)[0]
MLflow VizMod
将可视化视为模型,用于版本控制、跟踪和部署。
pip install mlflow-vizmod
创建并记录交互式可视化
from sklearn.datasets import load_iris
import altair as alt
import mlflow_vismod
# Load data
df_iris = load_iris(as_frame=True)
# Create Altair visualization
viz_iris = (
alt.Chart(df_iris)
.mark_circle(size=60)
.encode(x="sepal_length:Q", y="sepal_width:Q", color="target:N")
.properties(height=375, width=575)
.interactive()
)
# Log visualization as a model
with mlflow.start_run():
mlflow_vismod.log_model(
model=viz_iris,
artifact_path="iris_viz",
style="vegalite",
input_example=df_iris.head(5),
)
优点:版本控制:跟踪可视化随时间的变化。可复现性:用相同数据精确复现可视化。部署:将交互式可视化作为 Web 服务提供。协作:共享具有一致元数据的可视化。
BigML 集成
通过 MLflow 部署和管理 BigML 监督模型。
pip install bigmlflow
import json
import mlflow
import bigmlflow
# Load BigML model from JSON
MODEL_FILE = "logistic_regression.json"
with mlflow.start_run():
with open(MODEL_FILE) as handler:
model = json.load(handler)
# Log BigML model
bigmlflow.log_model(
model,
artifact_path="bigml_model",
registered_model_name="production_classifier",
)
# Load and use for inference
loaded_model = bigmlflow.load_model("model_uri")
predictions = loaded_model.predict(test_dataframe)
# Load as PyFunc
pyfunc_model = mlflow.pyfunc.load_model("model_uri")
pyfunc_predictions = pyfunc_model.predict(test_dataframe)
主要特点:BigML 集成:直接支持 BigML 的监督模型。PyFunc 兼容:与 MLflow 的通用 Python 函数接口一起使用。模型注册表:注册 BigML 模型以进行生产部署。DataFrame 推理:标准的 pandas DataFrame 输入/输出。
MLflow Go 后端
MLflow 跟踪服务器的高性能 Go 实现,以提高可伸缩性和性能。
pip install mlflow-go-backend
性能优势
- 更快的 API 调用用于关键跟踪操作
- 更高的并发性 - 处理更多同时请求
- 提高吞吐量适用于高容量 ML 工作负载
- 即插即用替代适用于现有 MLflow 部署
服务器使用
替换您现有的 MLflow 服务器命令
# Traditional MLflow server
mlflow server --backend-store-uri postgresql://user:pass@localhost:5432/mlflow
# High-performance Go backend
mlflow-go server --backend-store-uri postgresql://user:pass@localhost:5432/mlflow
支持所有现有的 MLflow 服务器选项
mlflow-go server \
--backend-store-uri postgresql://user:pass@localhost:5432/mlflow \
--artifacts-destination s3://my-mlflow-artifacts \
--host 0.0.0.0 \
--port 5000 \
--workers 4
客户端使用
在您的 Python 代码中启用 Go 后端
import mlflow
import mlflow_go_backend
# Enable the Go client implementation
mlflow_go_backend.enable_go()
# Set tracking URI (database required)
mlflow.set_tracking_uri("postgresql://user:pass@localhost:5432/mlflow")
# Use MLflow as normal - all operations now use Go backend
mlflow.set_experiment("high-performance-experiment")
with mlflow.start_run():
mlflow.log_param("algorithm", "xgboost")
mlflow.log_metric("accuracy", 0.95)
mlflow.log_artifact("model.pkl")
直接存储使用
直接将 Go 后端与 MLflow 存储配合使用
import logging
import mlflow
import mlflow_go_backend
# Enable debug logging to see Go backend in action
logging.basicConfig()
logging.getLogger("mlflow_go_backend").setLevel(logging.DEBUG)
# Enable Go implementation
mlflow_go_backend.enable_go()
# Get high-performance tracking store
tracking_store = mlflow.tracking._tracking_service.utils._get_store(
"postgresql://user:pass@localhost:5432/mlflow"
)
# All operations now use Go backend
experiment = tracking_store.get_experiment(0)
runs = tracking_store.search_runs([experiment.experiment_id], "")
# Get high-performance model registry store
model_registry_store = mlflow.tracking._model_registry.utils._get_store(
"postgresql://user:pass@localhost:5432/mlflow"
)
# Model registry operations also use Go backend
latest_versions = model_registry_store.get_latest_versions("production_model")
性能基准
初步基准测试显示性能显著提升
API 响应时间
- 搜索运行:比 Python 实现快 60%
- 记录指标:批量记录速度提高 45%
- 获取实验:检索速度提高 70%
并发性
- 支持2 倍的并发请求
- 负载下更好的资源利用率
- 用于服务器操作的减少内存占用
要求与限制
需要数据库: Go 后端目前需要一个数据库后端存储。不支持基于文件的存储。
# Supported database URIs
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_tracking_uri("postgresql://user:pass@host:5432/db")
mlflow.set_tracking_uri("mysql://user:pass@host:3306/db")
# Not supported yet
# mlflow.set_tracking_uri("file:///local/mlruns")
端点覆盖: 大多数 MLflow REST API 端点都在 Go 中实现。缺失的端点会自动回退到 Python 实现,以实现无缝兼容性。
迁移指南
步骤 1:安装 Go 后端
pip install mlflow-go-backend
步骤 2:更新服务器命令
# Old command
mlflow server --backend-store-uri postgresql://...
# New command
mlflow-go server --backend-store-uri postgresql://...
步骤 3:在客户端代码中启用
import mlflow_go_backend
mlflow_go_backend.enable_go()
步骤 4:验证性能
import time
import mlflow
# Benchmark your workload
start_time = time.time()
with mlflow.start_run():
for i in range(1000):
mlflow.log_metric(f"metric_{i}", i)
duration = time.time() - start_time
print(f"Logged 1000 metrics in {duration:.2f} seconds")
为 Go 后端做贡献
MLflow Go 后端正在积极寻求贡献者
缺失端点:帮助在 Go 中实现剩余的 REST API 端点 性能优化:改进现有 Go 实现 测试:为边缘情况和性能场景添加测试覆盖 文档:改进设置和使用文档
Go 入门:如果您是 Go 新手,项目维护者已整理了学习资源,帮助您有效地做出贡献。
社区支持:在官方 MLflow Slack 上加入 #mlflow-go
频道,进行提问和协作。
创建您自己的风格
为现有选项未涵盖的专业 ML 框架构建自定义风格。
风格结构要求
每个自定义风格都必须实现这些核心函数
# Required functions for any custom flavor
def save_model(model, path, **kwargs):
"""Save model to specified path with MLflow format"""
pass
def log_model(model, artifact_path, **kwargs):
"""Log model to current MLflow run"""
pass
def load_model(model_uri):
"""Load model from MLflow format"""
pass
def _load_pyfunc(path):
"""Load model as PyFunc for generic inference"""
pass
示例:自定义 Sktime 风格实现
import os
import pickle
import mlflow
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.utils.environment import _CONDA_ENV_FILE_NAME, _PYTHON_ENV_FILE_NAME
FLAVOR_NAME = "sktime"
SERIALIZATION_FORMAT_PICKLE = "pickle"
def save_model(
sktime_model, path, conda_env=None, serialization_format=SERIALIZATION_FORMAT_PICKLE
):
"""Save sktime model in MLflow format"""
# Validate and prepare save path
os.makedirs(path, exist_ok=True)
# Create MLflow model configuration
mlflow_model = Model()
# Save the actual model
model_data_subpath = "model.pkl"
model_data_path = os.path.join(path, model_data_subpath)
with open(model_data_path, "wb") as f:
pickle.dump(sktime_model, f)
# Add PyFunc flavor for generic inference
pyfunc.add_to_model(
mlflow_model,
loader_module="custom_sktime_flavor", # Your module name
model_path=model_data_subpath,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
)
# Add custom flavor configuration
mlflow_model.add_flavor(
FLAVOR_NAME,
pickled_model=model_data_subpath,
sktime_version=sktime.__version__,
serialization_format=serialization_format,
)
# Save MLmodel configuration file
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
def log_model(sktime_model, artifact_path, **kwargs):
"""Log sktime model to current MLflow run"""
return Model.log(
artifact_path=artifact_path,
flavor=custom_sktime_flavor, # Your module reference
sktime_model=sktime_model,
**kwargs,
)
def load_model(model_uri):
"""Load sktime model from MLflow format"""
local_model_path = mlflow.artifacts.download_artifacts(model_uri)
# Read flavor configuration
model_config = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME))
flavor_conf = model_config.flavors[FLAVOR_NAME]
# Load the pickled model
model_path = os.path.join(local_model_path, flavor_conf["pickled_model"])
with open(model_path, "rb") as f:
return pickle.load(f)
class SktimeModelWrapper:
"""PyFunc wrapper for sktime models"""
def __init__(self, sktime_model):
self.sktime_model = sktime_model
def predict(self, context, model_input):
"""Predict using configuration DataFrame"""
if len(model_input) != 1:
raise ValueError("Configuration DataFrame must have exactly 1 row")
config = model_input.iloc[0].to_dict()
predict_method = config.get("predict_method", "predict")
if predict_method == "predict":
fh = config.get("fh", None)
X = config.get("X", None)
return self.sktime_model.predict(fh=fh, X=X)
elif predict_method == "predict_interval":
fh = config.get("fh", None)
X = config.get("X", None)
coverage = config.get("coverage", 0.9)
return self.sktime_model.predict_interval(fh=fh, X=X, coverage=coverage)
else:
raise ValueError(f"Unsupported predict_method: {predict_method}")
def _load_pyfunc(path):
"""Load model as PyFunc"""
model = load_model(path)
return SktimeModelWrapper(model)
使用示例
import mlflow
import pandas as pd
from sktime.forecasting.naive import NaiveForecaster
from sktime.datasets import load_longley
# Train model
y, X = load_longley()
forecaster = NaiveForecaster()
forecaster.fit(y, X=X)
# Log with custom flavor
with mlflow.start_run():
custom_sktime_flavor.log_model(
sktime_model=forecaster, artifact_path="custom_forecaster"
)
model_uri = mlflow.get_artifact_uri("custom_forecaster")
# Load and use natively
loaded_model = custom_sktime_flavor.load_model(model_uri)
native_predictions = loaded_model.predict(fh=[1, 2, 3])
# Load as PyFunc for serving
loaded_pyfunc = mlflow.pyfunc.load_model(model_uri)
# Create configuration for PyFunc prediction
config_df = pd.DataFrame(
[
{
"predict_method": "predict_interval",
"fh": [1, 2, 3, 4],
"coverage": [0.9, 0.95],
"X": X.tail(4).values.tolist(), # JSON serializable
}
]
)
pyfunc_predictions = loaded_pyfunc.predict(config_df)
模型服务
# Serve your custom flavor model
mlflow models serve -m runs:/RUN_ID/custom_forecaster --host 127.0.0.1 --port 5000
# Request predictions from served model
import requests
import pandas as pd
config_df = pd.DataFrame([{"predict_method": "predict", "fh": [1, 2, 3, 4]}])
response = requests.post(
"http://127.0.0.1:5000/invocations",
json={"dataframe_split": config_df.to_dict(orient="split")},
)
predictions = response.json()
框架支持矩阵
按用例
用例 | 框架 | 安装 | 主要特性 |
---|---|---|---|
时间序列预测 | Sktime, Orbit, StatsForecast | pip install sktime[mlflow] | 统一 API,多种算法 |
异常检测 | PyOD | pip install mlflavors | 40+ 检测算法 |
可视化 | 通过 VizMod 支持 Altair, Plotly | pip install mlflow-vizmod | 交互式图表作为模型 |
合成数据 | SDV | pip install mlflavors | 隐私保护数据生成 |
大数据机器学习 | BigML | pip install bigmlflow | 基于云的监督学习 |
集成模式
- 标准模式
- 服务模式
- 部署模式
大多数社区风格遵循此模式
import mlflow
import community_flavor
# 1. Train your model
model = SomeFrameworkModel()
model.fit(training_data)
# 2. Log with MLflow
with mlflow.start_run():
community_flavor.log_model(
model=model,
artifact_path="model",
# Framework-specific parameters
serialization_format="pickle",
custom_config={"param": "value"},
)
# 3. Load for inference
loaded_model = community_flavor.load_model(model_uri)
predictions = loaded_model.predict(new_data)
# 4. Load as PyFunc for generic serving
pyfunc_model = community_flavor.pyfunc.load_model(model_uri)
generic_predictions = pyfunc_model.predict(input_dataframe)
基于配置的复杂模型服务
import pandas as pd
# Many community flavors use configuration DataFrames
# for complex inference parameters
config_df = pd.DataFrame(
[
{
"predict_method": "predict_interval", # What type of prediction
"fh": [1, 2, 3, 4], # Forecast horizon
"coverage": [0.9, 0.95], # Confidence intervals
"X": exogenous_data.tolist(), # Additional features
"custom_param": "value", # Framework-specific options
}
]
)
# Use configuration with PyFunc model
pyfunc_model = community_flavor.pyfunc.load_model(model_uri)
predictions = pyfunc_model.predict(config_df)
生产部署工作流
# 1. Register model in MLflow Model Registry
mlflow.register_model(
model_uri="runs:/RUN_ID/model",
name="production_forecaster"
)
# 2. Transition to production stage
client = mlflow.MlflowClient()
client.transition_model_version_stage(
name="production_forecaster",
version=1,
stage="Production"
)
# 3. Serve model
# Option A: Local serving
mlflow models serve \
-m "models:/production_forecaster/Production" \
--host 0.0.0.0 --port 5000
# Option B: Cloud deployment (Azure ML example)
mlflow deployments create \
-t azureml \
-m "models:/production_forecaster/Production" \
--name forecaster-service
最佳实践
开发指南
-
遵循 MLflow 约定 - 实现
save_model()
、log_model()
、load_model()
函数。为通用推理添加 PyFunc 风格。包含全面的错误处理。 -
配置管理 - 对复杂推理参数使用单行 DataFrames。使所有参数 JSON 可序列化,以便进行 REST API 服务。为可选参数提供合理的默认值。
-
测试策略 - 测试保存/加载往返功能。验证 PyFunc 兼容性。使用示例请求测试模型服务。
性能优化
# Efficient serialization for large models
def save_model(model, path, serialization_format="pickle"):
if serialization_format == "joblib":
# Use joblib for sklearn-compatible models
import joblib
joblib.dump(model, os.path.join(path, "model.joblib"))
elif serialization_format == "cloudpickle":
# Use cloudpickle for complex models with custom objects
import cloudpickle
with open(os.path.join(path, "model.pkl"), "wb") as f:
cloudpickle.dump(model, f)
错误处理
def load_model(model_uri):
try:
# Attempt to load model
return _load_model_internal(model_uri)
except Exception as e:
raise mlflow.exceptions.MlflowException(
f"Failed to load {FLAVOR_NAME} model. "
f"Ensure model was saved with compatible version. Error: {str(e)}"
)
社区资源
贡献新风格
-
创建 GitHub 仓库 - 遵循命名约定:
mlflow-{framework}
。包含全面的文档。添加示例笔记本。 -
包结构
mlflow-myframework/
├── setup.py
├── README.md
├── mlflow_myframework/
│ ├── __init__.py
│ └── flavor.py
├── examples/
│ └── example_usage.ipynb
└── tests/
└── test_flavor.py -
文档要求 - 安装说明。基本使用示例。API 参考。模型服务示例。
获取帮助
MLflow 讨论: GitHub 讨论。 社区 Slack:加入 MLflow 社区工作区。 Stack Overflow:用 mlflow
和框架名称标记问题。 特定于框架:检查各个风格仓库以获取问题。
准备好扩展 MLflow 了吗? 从探索现有社区风格开始,然后考虑为尚未支持的框架贡献您自己的风格!