MLflow 模型服务
MLflow 凭借其全面的服务能力,将您训练好的模型转化为可用于生产环境的推理服务器。通过标准化的 REST API,可在本地、云端或通过托管端点进行部署。
REST API 端点
自动生成标准化的 REST 端点,用于模型推理,并具有一致的请求/响应格式。
多框架支持
通过 MLflow 的 flavor 系统,以统一的部署模式为任何 ML 框架的模型提供服务。
自定义应用程序
构建具有自定义逻辑、预处理和业务规则的复杂服务应用程序。
可扩展部署
部署到各种目标,从本地开发服务器到云平台和 Kubernetes 集群。
快速入门
通过以下简单步骤,在几分钟内启动您的模型服务
- 1. 提供模型服务
- 2. 进行预测
选择您的服务方式
# Serve a logged model
mlflow models serve -m "models:/<model-id>" -p 5000
# Serve a registered model
mlflow models serve -m "models:/<model-name>/<model-version>" -p 5000
# Serve a model from local path
mlflow models serve -m ./path/to/model -p 5000
您的模型将在 https://:5000
可用
通过 HTTP 发送预测请求
curl -X POST https://:5000/invocations \
-H "Content-Type: application/json" \
-d '{"inputs": [[1, 2, 3, 4]]}'
使用 Python
import requests
import json
data = {
"dataframe_split": {
"columns": ["feature1", "feature2", "feature3", "feature4"],
"data": [[1, 2, 3, 4]],
}
}
response = requests.post(
"https://:5000/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
)
print(response.json())
模型服务如何工作
MLflow 通过精心编排的过程将您训练好的模型转换为可用于生产环境的 HTTP 服务器,该过程涵盖了从模型加载到请求处理的所有环节。
服务器启动和模型加载
当您运行 mlflow models serve
时,MLflow 首先分析您的模型元数据以确定如何加载它。每个模型都包含一个 MLmodel
文件,该文件指定它使用的“flavor”——无论是 scikit-learn、PyTorch、TensorFlow 还是自定义 PyFunc 模型。
MLflow 将模型工件下载到本地目录,并创建一个带有标准化端点的 FastAPI 服务器。服务器使用适当的特定于 flavor 的加载逻辑加载您的模型。例如,scikit-learn 模型使用 pickle 加载,而 PyTorch 模型加载其状态字典和模型类。
服务器公开了四个关键端点
POST /invocations
- 主要的预测端点GET /ping
和GET /health
- 用于监控的健康检查GET /version
- 返回服务器和模型信息
请求处理管道
当预测请求到达 /invocations
时,MLflow 会通过几个验证和转换步骤对其进行处理
输入格式检测:MLflow 自动检测您正在使用的输入格式。它支持多种格式以适应不同的用例
dataframe_split
:带有独立列和数据数组的 Pandas DataFramedataframe_records
:表示行的字典列表instances
:用于个体预测的 TensorFlow Serving 格式inputs
:用于更复杂输入的命名张量格式
Schema 验证:如果您的模型包含签名(输入/输出 schema),MLflow 会根据它验证传入数据。这会在数据到达您的模型之前捕获类型不匹配和缺失列。
参数提取:MLflow 将预测数据与可选参数分开。像语言模型的 temperature
或分类器的 threshold
等参数被提取并单独传递给支持它们的模型。
模型预测和响应
一旦输入经过验证和格式化,MLflow 就会调用模型的 predict()
方法。框架会自动检测您的模型是否接受参数并适当地调用它
# For models that accept parameters
raw_predictions = model.predict(data, params=params)
# For traditional models
raw_predictions = model.predict(data)
然后 MLflow 将预测结果序列化回 JSON,处理各种数据类型,包括 NumPy 数组、pandas DataFrames 和 Python 列表。响应格式取决于您的输入格式——传统请求被包装在 predictions
对象中,而 LLM 风格的请求则返回未包装的结果。
Flavor 系统
MLflow 的 flavor 系统是使服务在不同 ML 框架中保持一致的关键。每个 flavor 都实现了特定于框架的加载和预测逻辑,同时公开了统一的接口。
当您使用 mlflow.sklearn.log_model()
或 mlflow.pytorch.log_model()
记录模型时,MLflow 会创建一个特定于 flavor 的表示和一个 PyFunc 包装器。PyFunc 包装器提供了服务层期望的标准化 predict()
接口,而 flavor 则处理特定于框架的细节,如张量操作或数据预处理。
这种架构意味着您可以使用相同的服务命令和 API 来为 scikit-learn、PyTorch、TensorFlow 和自定义模型提供服务。
错误处理和调试
MLflow 的服务基础设施包括全面的错误处理,以帮助您调试问题
- Schema 错误:关于数据类型不匹配或缺失列的详细消息
- 格式错误:当输入格式不正确或不明确时提供清晰的指导
- 模型错误:来自模型预测代码的完整堆栈跟踪
- 服务器错误:超时和资源相关错误处理
服务器记录所有请求和错误,从而更容易诊断生产问题。
输入格式示例
以下是 MLflow 接受的主要输入格式
// dataframe_split format
{
"dataframe_split": {
"columns": ["feature1", "feature2", "feature3"],
"data": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
}
}
// dataframe_records format
{
"dataframe_records": [
{"feature1": 1.0, "feature2": 2.0, "feature3": 3.0},
{"feature1": 4.0, "feature2": 5.0, "feature3": 6.0}
]
}
// instances format (for simple models)
{
"instances": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
}
所有格式都返回一致的响应结构,其中包含您的预测结果以及模型提供的任何附加元数据。
关键实现概念
- 模型准备
- 服务器配置
- 高级模式
准备您的模型以成功提供服务
- 模型签名:定义输入/输出 schema 以进行自动请求验证
- 环境管理:捕获依赖项以确保可重现部署
- 模型注册表:使用别名进行无缝生产更新
- 元数据:包含相关上下文以进行调试和监控
import mlflow
from mlflow.models.signature import infer_signature
from mlflow.tracking import MlflowClient
# Log model with comprehensive serving metadata
signature = infer_signature(X_train, model.predict(X_train))
mlflow.sklearn.log_model(
sk_model=model,
name="my_model",
signature=signature,
registered_model_name="production_model",
input_example=X_train[:5], # Visible example for the MLflow UI
)
# Use aliases for production deployment
client = MlflowClient()
client.set_registered_model_alias(
name="production_model", alias="production", version="1"
)
配置您的服务基础设施以获得最佳性能
- 请求处理:设置适当的超时和批处理大小
- 资源分配:根据模型复杂性和负载配置工作器
- 输入格式:为您的数据类型选择正确的格式
- 错误处理:实施适当的日志记录和监控
# Configure server for production workloads
export MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT=60
export GUNICORN_CMD_ARGS="--timeout 60 --workers 4"
# Serve with optimal settings
mlflow models serve \
--model-uri models:/my_model@production \
--port 5000 \
--env-manager local # For production, use conda or virtualenv
使用自定义 PyFunc 模型实现高级服务模式
- 预处理逻辑:在模型内部处理数据转换
- 多模型集成:组合来自多个模型的预测
- 业务逻辑:集成验证和后处理规则
- 性能优化:批处理和缓存策略
import joblib
import mlflow
import pandas as pd
import numpy as np
from typing import Dict, Any
class EnsembleModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load multiple models for ensemble prediction"""
self.model_a = joblib.load(context.artifacts["model_a"])
self.model_b = joblib.load(context.artifacts["model_b"])
self.preprocessor = joblib.load(context.artifacts["preprocessor"])
# Load ensemble weights from config
self.weights = context.model_config.get("weights", [0.5, 0.5])
def predict(self, model_input: pd.DataFrame) -> np.ndarray:
"""Combine predictions from multiple models"""
# Preprocess input
processed = self.preprocessor.transform(model_input)
# Get predictions from both models
pred_a = self.model_a.predict(processed)
pred_b = self.model_b.predict(processed)
# Weighted average ensemble
ensemble_pred = self.weights[0] * pred_a + self.weights[1] * pred_b
return ensemble_pred
# Log ensemble model with artifacts
artifacts = {
"model_a": "path/to/model_a.pkl",
"model_b": "path/to/model_b.pkl",
"preprocessor": "path/to/preprocessor.pkl",
}
mlflow.pyfunc.log_model(
name="ensemble_model",
python_model=EnsembleModel(),
artifacts=artifacts,
model_config={"weights": [0.6, 0.4]},
pip_requirements=["scikit-learn", "pandas", "numpy"],
)
完整示例:从训练到生产
遵循此分步指南,从模型训练到部署 REST API
- 1. 训练与记录
- 2. 提升模型
- 3. 启动服务器
- 4. 进行预测
使用自动日志记录训练一个简单模型
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
# Load sample data
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Enable sklearn autologging with model registration
mlflow.sklearn.autolog(registered_model_name="iris_classifier")
# Train model - MLflow automatically logs everything
with mlflow.start_run() as run:
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X_train, y_train)
# Autologging automatically captures:
# - Model artifacts
# - Training parameters (n_estimators, random_state, etc.)
# - Training metrics (score on training data)
# - Model signature (inferred from training data)
# - Input example
# Optional: Log additional custom metrics
accuracy = model.score(X_test, y_test)
mlflow.log_metric("test_accuracy", accuracy)
print(f"Run ID: {run.info.run_id}")
print("Model automatically logged and registered!")
为生产设置模型别名
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Get the latest registered version (autologging creates version 1)
model_version = client.get_registered_model("iris_classifier").latest_versions[0]
# Set production alias (replaces deprecated stages)
client.set_registered_model_alias(
name="iris_classifier", alias="production", version=model_version.version
)
print(f"Model version {model_version.version} tagged as 'production'")
# Model URI for serving (using alias)
model_uri = "models:/iris_classifier@production"
print(f"Production model URI: {model_uri}")
提供已注册的模型服务
# Serve using model alias (MLflow 3.x way)
mlflow models serve \
--model-uri "models:/iris_classifier@production" \
--port 5000 \
--env-manager local
# Server will start at https://:5000
# Available endpoints:
# - POST /invocations (predictions)
# - GET /ping (health check)
# - GET /version (model info)
替代服务方法
# Serve by specific version number
mlflow models serve \
--model-uri "models:/iris_classifier/1" \
--port 5000
# Serve from run URI
mlflow models serve \
--model-uri "runs:/<run-id>/model" \
--port 5000
向已提供服务的模型发送请求
import requests
import json
import pandas as pd
# Prepare test data (same format as training)
test_data = {
"dataframe_split": {
"columns": [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)",
],
"data": [
[5.1, 3.5, 1.4, 0.2], # setosa
[6.2, 2.9, 4.3, 1.3], # versicolor
[7.3, 2.9, 6.3, 1.8], # virginica
],
}
}
# Make prediction request
response = requests.post(
"https://:5000/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(test_data),
)
# Parse response
if response.status_code == 200:
predictions = response.json()
print("Predictions:", predictions)
# Output: {"predictions": [0, 1, 2]}
else:
print(f"Error: {response.status_code}, {response.text}")
# Health check
health = requests.get("https://:5000/ping")
print("Health status:", health.status_code) # Should be 200
# Model info
info = requests.get("https://:5000/version")
print("Model version info:", info.json())
后续步骤
准备好构建更高级的服务应用程序了吗?探索这些专业主题
每个部分中的示例都设计为实用且即用。从上面的快速入门开始,然后探索符合您部署需求的用例。