跳到主要内容

模型注册表教程

在本教程中,您将探索模型注册表的全部功能——从注册模型、检查其结构,到加载特定模型版本以供进一步使用。

模型注册表

在本教程中,我们将为了简单起见,利用本地跟踪服务器和模型注册表。但是,对于生产用例,我们建议使用远程跟踪服务器

步骤 0:安装依赖项

pip install --upgrade mlflow

步骤 1:注册模型

要使用 MLflow 模型注册表,您需要将 MLflow 模型添加到其中。这可以通过以下任一命令注册给定模型来完成:

  • mlflow.<model_flavor>.log_model(registered_model_name=<model_name>):在将模型记录到跟踪服务器**时**注册模型。
  • mlflow.register_model(<model_uri>, <model_name>):在将模型记录到跟踪服务器**后**注册模型。请注意,您必须在运行此命令之前记录模型才能获得模型 URI。

MLflow 有许多模型格式。在下面的示例中,我们将利用 scikit-learn 的 RandomForestRegressor 来演示注册模型的最简单方法,但请注意,您可以利用任何受支持的模型格式。在下面的代码片段中,我们启动一个 mlflow 运行并训练一个随机森林模型。然后,我们记录一些相关的超参数、模型的均方误差 (MSE),最后记录并注册模型本身。

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

import mlflow
import mlflow.sklearn

with mlflow.start_run() as run:
X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

params = {"max_depth": 2, "random_state": 42}
model = RandomForestRegressor(**params)
model.fit(X_train, y_train)

# Log parameters and metrics using the MLflow APIs
mlflow.log_params(params)

y_pred = model.predict(X_test)
mlflow.log_metrics({"mse": mean_squared_error(y_test, y_pred)})

# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model,
name="sklearn-model",
input_example=X_train,
registered_model_name="sk-learn-random-forest-reg-model",
)
示例输出
Successfully registered model 'sk-learn-random-forest-reg-model'.
Created version '1' of model 'sk-learn-random-forest-reg-model'.

太好了!我们已经注册了一个模型。

在继续之前,让我们强调一些重要的实现说明。

  • 要注册模型,您可以使用 mlflow.sklearn.log_model() 中的 registered_model_name 参数,或者在记录模型后调用 mlflow.register_model()。通常,我们建议前者,因为它更简洁。
  • 模型签名提供对我们的模型输入和输出的验证。log_model() 中的 input_example 会自动推断并记录签名。同样,我们建议使用此实现,因为它很简洁。

探索已注册的模型

现在我们已经记录了一个实验并注册了与该实验运行相关的模型,让我们观察一下这些信息实际上是如何存储在 MLflow UI 和我们的本地目录中的。请注意,我们也可以以编程方式获取这些信息,但出于解释目的,我们将使用 MLflow UI。

步骤 1:探索 mlruns 目录

鉴于我们将本地文件系统用作跟踪服务器和模型注册表,让我们观察一下在上一​​步中运行 Python 脚本时创建的目录结构。

在深入研究之前,需要注意的是,MLflow 的设计旨在为用户抽象复杂性,并且此目录结构仅用于说明目的。此外,在生产用例推荐的远程部署中,跟踪服务器将位于对象存储(S3、ADLS、GCS 等)上,模型注册表将位于关系数据库(PostgreSQL、MySQL 等)上。

mlruns/
├── 0/ # Experiment ID
│ ├── bc6dc2a4f38d47b4b0c99d154bbc77ad/ # Run ID
│ │ ├── metrics/
│ │ │ └── mse # Example metric file for mean squared error
│ │ ├── artifacts/ # Artifacts associated with our run
│ │ │ └── sklearn-model/
│ │ │ ├── python_env.yaml
│ │ │ ├── requirements.txt # Python package requirements
│ │ │ ├── MLmodel # MLflow model file with model metadata
│ │ │ ├── model.pkl # Serialized model file
│ │ │ ├── input_example.json
│ │ │ └── conda.yaml
│ │ ├── tags/
│ │ │ ├── mlflow.user
│ │ │ ├── mlflow.source.git.commit
│ │ │ ├── mlflow.runName
│ │ │ ├── mlflow.source.name
│ │ │ ├── mlflow.log-model.history
│ │ │ └── mlflow.source.type
│ │ ├── params/
│ │ │ ├── max_depth
│ │ │ └── random_state
│ │ └── meta.yaml
│ └── meta.yaml
├── models/ # Model Registry Directory
├── sk-learn-random-forest-reg-model/ # Registered model name
│ ├── version-1/ # Model version directory
│ │ └── meta.yaml
│ └── meta.yaml

跟踪服务器按实验 ID运行 ID进行组织,负责存储我们的实验伪影、参数和指标。另一方面,模型注册表仅存储指向我们跟踪服务器的元数据。

如您所见,支持自动日志记录的格式开箱即用地提供了大量额外信息。另请注意,即使我们没有对感兴趣的模型进行自动日志记录,我们也可以通过显式日志记录调用轻松存储此信息。

还有一个有趣的提示是,默认情况下,您有三种方法来管理模型的环境:python_env.yaml(python 虚拟环境)、requirements.txt(PyPi 需求)和 conda.yaml(conda 环境)。

好的,现在我们对记录的内容有一个非常高层次的了解,让我们使用 MLflow UI 来查看这些信息。

步骤 2:启动跟踪服务器

在您的 mlruns 文件夹所在的同一目录中,运行以下命令。

mlflow server --host 127.0.0.1 --port 8080
INFO:     Started server process [26393]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://127.0.0.1:8080 (Press CTRL+C to quit)

步骤 3:查看跟踪服务器

如果没有错误,您可以转到您的网络浏览器并访问 https://:8080 来查看 MLflow UI。

首先,我们离开实验跟踪选项卡,访问模型注册表。

Model information from the mlflow
ui.

接下来,让我们添加标签和模型版本别名以方便模型部署。您可以通过点击模型版本表中的相应 添加 链接或铅笔图标来添加或编辑标签和别名。让我们...

  1. 添加一个键为 problem_type,值为 regression 的模型版本标签。
  2. 添加一个模型版本别名为 the_best_model_ever

Model information from the mlflow
ui.

加载已注册的模型

要对已注册的模型版本执行推理,我们需要将其加载到内存中。有很多方法可以找到我们的模型版本,但最佳方法取决于您可用的信息。然而,为了快速入门,以下代码片段展示了通过特定模型 URI 从模型注册表加载模型并执行推理的最简单方法。

import mlflow.sklearn
from sklearn.datasets import make_regression

model_name = "sk-learn-random-forest-reg-model"
model_version = "latest"

# Load the model from the Model Registry
model_uri = f"models:/{model_name}/{model_version}"
model = mlflow.sklearn.load_model(model_uri)

# Generate a new dataset for prediction and predict
X_new, _ = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
y_pred_new = model.predict(X_new)

print(y_pred_new)

请注意,如果您不使用 sklearn,如果您的模型格式受支持,您应该使用特定的模型格式加载方法,例如 mlflow.<flavor>.load_model()。如果模型格式不受支持,您应该利用mlflow.pyfunc.load_model()。在本教程中,我们以 sklearn 为例进行演示。

示例 0:通过跟踪服务器加载

模型 URI 是序列化模型的唯一标识符。鉴于模型伪影与跟踪服务器中的实验一起存储,您可以使用以下模型 URI 来绕过模型注册表并将伪影加载到内存中。

  1. 绝对本地路径mlflow.sklearn.load_model("/Users/me/path/to/local/model")
  2. 相对本地路径mlflow.sklearn.load_model("relative/path/to/local/model")
  3. 运行 IDmlflow.sklearn.load_model(f"runs:/{mlflow_run_id}/{run_relative_path_to_model}")

但是,除非您处于记录模型的相同环境中,否则通常不会拥有上述信息。相反,您应该通过利用模型的名称和版本来加载模型。

示例 1:通过名称和版本加载

要通过 model_name 和单调递增的 model_version 将模型加载到内存中,请使用以下方法:

model = mlflow.sklearn.load_model(f"models:/{model_name}/{model_version}")

虽然此方法快速简便,但单调递增的模型版本缺乏灵活性。通常,利用模型版本别名更有效。

示例 2:通过模型版本别名加载

模型版本别名是用户定义的模型版本标识符。鉴于它们在模型注册后是可变的,它们将模型版本与使用它们的代码分离。

例如,假设我们有一个名为 production_model 的模型版本别名,它对应一个生产模型。当我们的团队构建一个更适合部署的模型时,我们不必更改我们的服务工作负载代码。相反,在 MLflow 中,我们将 production_model 别名从旧模型版本重新分配给新模型版本。这可以通过 UI 轻松完成。在 API 中,我们运行client.set_registered_model_alias,使用相同的模型名称、别名名称和**新**模型版本 ID。就是这么简单!

在上一​​页中,我们为模型添加了一个模型版本别名,但这是一个以编程方式实现的示例。

import mlflow.sklearn
from mlflow import MlflowClient

client = MlflowClient()

# Set model version alias
model_name = "sk-learn-random-forest-reg-model"
model_version_alias = "the_best_model_ever"
client.set_registered_model_alias(
model_name, model_version_alias, "1"
) # Duplicate of step in UI

# Get information about the model
model_info = client.get_model_version_by_alias(model_name, model_version_alias)
model_tags = model_info.tags
print(model_tags)

# Get the model version using a model URI
model_uri = f"models:/{model_name}@{model_version_alias}"
model = mlflow.sklearn.load_model(model_uri)

print(model)
输出
{'problem_type': 'regression'}
RandomForestRegressor(max_depth=2, random_state=42)

模型版本别名高度动态,可以对应于对您的团队有意义的任何内容。最常见的例子是部署状态。例如,假设我们在生产环境中有一个 champion 模型,但正在开发一个 challenger 模型,该模型有望超越我们的生产模型。您可以使用 championchallenger 模型版本别名来唯一标识这些模型版本,以便于访问。

就是这样!您现在应该可以轻松地...

  1. 注册模型
  2. 通过 MLflow UI 查找模型并修改标签和模型版本别名
  3. 加载已注册的模型以进行推理