注册模型
在本教程中,为简单起见,我们将使用本地追踪服务器和模型注册表。然而,对于生产用例,我们推荐使用远程追踪服务器。
步骤 0: 安装依赖
pip install --upgrade mlflow
步骤 1: 注册模型
要使用 MLflow 模型注册表,您需要将 MLflow 模型添加到其中。这可以通过以下任一命令注册给定模型来完成:
mlflow.<model_flavor>.log_model(registered_model_name=<model_name>)
: 在将模型记录到追踪服务器的**同时**注册模型。mlflow.register_model(<model_uri>, <model_name>)
: 在将模型记录到追踪服务器**之后**注册模型。请注意,在运行此命令获取模型 URI 之前,您必须先记录模型。
MLflow 有多种模型风格(model flavors)。在下面的示例中,我们将利用 scikit-learn 的 RandomForestRegressor 来演示注册模型的最简单方法,但请注意,您可以使用任何支持的模型风格。在下面的代码片段中,我们启动一个 MLflow 运行并训练一个随机森林模型。然后我们记录一些相关的超参数、模型的均方误差 (MSE),最后记录并注册模型本身。
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.sklearn
with mlflow.start_run() as run:
X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
params = {"max_depth": 2, "random_state": 42}
model = RandomForestRegressor(**params)
model.fit(X_train, y_train)
# Log parameters and metrics using the MLflow APIs
mlflow.log_params(params)
y_pred = model.predict(X_test)
mlflow.log_metrics({"mse": mean_squared_error(y_test, y_pred)})
# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model,
artifact_path="sklearn-model",
input_example=X_train,
registered_model_name="sk-learn-random-forest-reg-model",
)
示例输出
Successfully registered model 'sk-learn-random-forest-reg-model'.
Created version '1' of model 'sk-learn-random-forest-reg-model'.
太好了!我们已经注册了一个模型。
继续之前,让我们强调一些重要的实现注意事项。
- 要注册模型,您可以在
mlflow.sklearn.log_model()
中使用registered_model_name
参数,或者在记录模型后调用mlflow.register_model()
。通常,我们建议使用前者,因为它更简洁。 - 模型签名 (Model Signatures) 为我们的模型输入和输出提供验证。
log_model()
中的input_example
会自动推断并记录签名。同样,我们建议使用这种实现方式,因为它很简洁。