跳到主要内容

社区模型风格

MLflow 活跃的社区已为专业 ML 框架和用例开发了多种风格,将 MLflow 的功能扩展到内置风格之外。这些由社区维护的包可以与特定领域工具无缝集成,用于时间序列预测、异常检测、可视化等等。

快速入门

安装社区风格

大多数社区风格可通过 PyPI 获取

# Time series forecasting
pip install mlflow[sktime]
pip install mlflavors

# Visualization and plotting
pip install mlflow-vizmod

# Big data and cloud platforms
pip install bigmlflow
pip install mlflow[aliyun-oss]

基本使用模式

所有社区风格都遵循 MLflow 的标准接口

import mlflow
import community_flavor # Replace with actual flavor

# Train your model
model = SomeModel()
model.fit(data)

# Log with MLflow
with mlflow.start_run():
community_flavor.log_model(model, "model_path")

# Load for inference
loaded_model = community_flavor.load_model("model_uri")
predictions = loaded_model.predict(new_data)

Sktime

用于时间序列预测、分类和转换的统一接口。

pip install sktime[mlflow]
import pandas as pd
from sktime.datasets import load_airline
from sktime.forecasting.arima import AutoARIMA
from sktime.utils import mlflow_sktime

# Load data and train model
airline = load_airline()
model = AutoARIMA(sp=12, d=0, max_p=2, max_q=2, suppress_warnings=True)
model.fit(airline, fh=[1, 2, 3])

# Save and load with MLflow
mlflow_sktime.save_model(sktime_model=model, path="model")
loaded_model = mlflow_sktime.load_model(model_uri="model")

# Make predictions
predictions = loaded_model.predict()
print(predictions)

# Load as PyFunc for serving
loaded_pyfunc = mlflow_sktime.pyfunc.load_model(model_uri="model")
pyfunc_predictions = loaded_pyfunc.predict(pd.DataFrame())

MLflavors 包

在一个包中支持多种时间序列和 ML 框架。

pip install mlflavors

支持的框架

框架类别示例用例
Orbit时间序列贝叶斯预测
StatsForecast时间序列统计模型
PyOD异常检测离群值检测
SDV合成数据隐私保护数据生成

PyOD 异常检测示例

import mlflow
from pyod.models.knn import KNN
from pyod.utils.data import generate_data
import mlflavors

# Generate synthetic data
contamination = 0.1
n_train, n_test = 200, 100
X_train, X_test, _, y_test = generate_data(
n_train=n_train, n_test=n_test, contamination=contamination
)

with mlflow.start_run():
# Train KNN detector
clf = KNN()
clf.fit(X_train)

# Log model
mlflavors.pyod.log_model(
pyod_model=clf, artifact_path="anomaly_detector", serialization_format="pickle"
)

# Evaluate
scores = clf.decision_function(X_test)
mlflow.log_metric("mean_anomaly_score", scores.mean())

服务 PyOD 模型

# Load as PyFunc
loaded_pyfunc = mlflavors.pyod.pyfunc.load_model(model_uri="model_uri")

# Create configuration for inference
import pandas as pd

predict_conf = pd.DataFrame([{"X": X_test, "predict_method": "decision_function"}])

anomaly_scores = loaded_pyfunc.predict(predict_conf)[0]

框架支持矩阵

按用例

用例框架安装主要特性
时间序列预测Sktime, Orbit, StatsForecastpip install sktime[mlflow]统一 API,多种算法
异常检测PyODpip install mlflavors40+ 检测算法
可视化通过 VizMod 支持 Altair, Plotlypip install mlflow-vizmod交互式图表作为模型
合成数据SDVpip install mlflavors隐私保护数据生成
大数据机器学习BigMLpip install bigmlflow基于云的监督学习

集成模式

大多数社区风格遵循此模式

import mlflow
import community_flavor

# 1. Train your model
model = SomeFrameworkModel()
model.fit(training_data)

# 2. Log with MLflow
with mlflow.start_run():
community_flavor.log_model(
model=model,
artifact_path="model",
# Framework-specific parameters
serialization_format="pickle",
custom_config={"param": "value"},
)

# 3. Load for inference
loaded_model = community_flavor.load_model(model_uri)
predictions = loaded_model.predict(new_data)

# 4. Load as PyFunc for generic serving
pyfunc_model = community_flavor.pyfunc.load_model(model_uri)
generic_predictions = pyfunc_model.predict(input_dataframe)

最佳实践

开发指南

  1. 遵循 MLflow 约定 - 实现 save_model()log_model()load_model() 函数。为通用推理添加 PyFunc 风格。包含全面的错误处理。

  2. 配置管理 - 对复杂推理参数使用单行 DataFrames。使所有参数 JSON 可序列化,以便进行 REST API 服务。为可选参数提供合理的默认值。

  3. 测试策略 - 测试保存/加载往返功能。验证 PyFunc 兼容性。使用示例请求测试模型服务。

性能优化

# Efficient serialization for large models
def save_model(model, path, serialization_format="pickle"):
if serialization_format == "joblib":
# Use joblib for sklearn-compatible models
import joblib

joblib.dump(model, os.path.join(path, "model.joblib"))
elif serialization_format == "cloudpickle":
# Use cloudpickle for complex models with custom objects
import cloudpickle

with open(os.path.join(path, "model.pkl"), "wb") as f:
cloudpickle.dump(model, f)

错误处理

def load_model(model_uri):
try:
# Attempt to load model
return _load_model_internal(model_uri)
except Exception as e:
raise mlflow.exceptions.MlflowException(
f"Failed to load {FLAVOR_NAME} model. "
f"Ensure model was saved with compatible version. Error: {str(e)}"
)

社区资源

贡献新风格

  1. 创建 GitHub 仓库 - 遵循命名约定:mlflow-{framework}。包含全面的文档。添加示例笔记本。

  2. 包结构

    mlflow-myframework/
    ├── setup.py
    ├── README.md
    ├── mlflow_myframework/
    │ ├── __init__.py
    │ └── flavor.py
    ├── examples/
    │ └── example_usage.ipynb
    └── tests/
    └── test_flavor.py
  3. 文档要求 - 安装说明。基本使用示例。API 参考。模型服务示例。

获取帮助

MLflow 讨论: GitHub 讨论社区 Slack:加入 MLflow 社区工作区。 Stack Overflow:mlflow 和框架名称标记问题。 特定于框架:检查各个风格仓库以获取问题。


准备好扩展 MLflow 了吗? 从探索现有社区风格开始,然后考虑为尚未支持的框架贡献您自己的风格!