模型注册表教程
在本教程中探索模型注册表的全部功能 — 从注册模型和检查其结构,到加载特定的模型版本以供进一步使用。
模型注册表
在本教程中,为了简单起见,我们将利用本地跟踪服务器和模型注册表。但是,对于生产用例,我们建议使用远程跟踪服务器。
步骤 0:安装依赖项
pip install --upgrade mlflow
步骤 1:注册模型
要使用 MLflow 模型注册表,您需要将您的 MLflow 模型添加到其中。这可以通过使用以下命令之一注册给定的模型来完成。
mlflow.<model_flavor>.log_model(registered_model_name=<model_name>)
: 在将模型记录到跟踪服务器的同时注册模型。mlflow.register_model(<model_uri>, <model_name>)
: 在将模型记录到跟踪服务器之后注册模型。 请注意,您必须先记录模型才能运行此命令以获取模型 URI。
MLflow 有很多模型风格。 在下面的示例中,我们将利用 scikit-learn 的 RandomForestRegressor 来演示注册模型的最简单方法,但请注意,您可以利用任何受支持的模型风格。 在下面的代码片段中,我们启动一个 mlflow 运行并训练一个随机森林模型。 然后,我们记录一些相关的超参数、模型的均方误差 (MSE),最后记录并注册模型本身。
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.sklearn
with mlflow.start_run() as run:
X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
params = {"max_depth": 2, "random_state": 42}
model = RandomForestRegressor(**params)
model.fit(X_train, y_train)
# Log parameters and metrics using the MLflow APIs
mlflow.log_params(params)
y_pred = model.predict(X_test)
mlflow.log_metrics({"mse": mean_squared_error(y_test, y_pred)})
# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model,
name="sklearn-model",
input_example=X_train,
registered_model_name="sk-learn-random-forest-reg-model",
)
Successfully registered model 'sk-learn-random-forest-reg-model'.
Created version '1' of model 'sk-learn-random-forest-reg-model'.
太棒了! 我们已经注册了一个模型。
在继续之前,让我们强调一些重要的实现说明。
- 要注册模型,您可以利用
mlflow.sklearn.log_model()
中的registered_model_name
参数或在记录模型后调用mlflow.register_model()
。 通常,我们建议前者,因为它更简洁。 - 模型签名为我们的模型输入和输出提供验证。
log_model()
中的input_example
会自动推断并记录签名。 同样,我们建议使用此实现,因为它很简洁。
探索注册的模型
现在我们已经记录了一个实验并注册了与该实验运行关联的模型,让我们观察一下这些信息实际上是如何存储在 MLflow UI 和本地目录中的。 请注意,我们也可以通过编程方式获取此信息,但为了解释的目的,我们将使用 MLflow UI。
步骤 1:探索 mlruns
目录
鉴于我们正在使用本地文件系统作为我们的跟踪服务器和模型注册表,让我们观察一下在先前步骤中运行 python 脚本时创建的目录结构。
在深入研究之前,重要的是要注意 MLflow 旨在从用户那里抽象出复杂性,并且此目录结构仅用于说明目的。 此外,在远程部署中,建议用于生产用例,跟踪服务器将位于对象存储(S3、ADLS、GCS 等)上,而模型注册表将位于关系数据库(PostgreSQL、MySQL 等)上。
mlruns/
├── 0/ # Experiment ID
│ ├── bc6dc2a4f38d47b4b0c99d154bbc77ad/ # Run ID
│ │ ├── metrics/
│ │ │ └── mse # Example metric file for mean squared error
│ │ ├── artifacts/ # Artifacts associated with our run
│ │ │ └── sklearn-model/
│ │ │ ├── python_env.yaml
│ │ │ ├── requirements.txt # Python package requirements
│ │ │ ├── MLmodel # MLflow model file with model metadata
│ │ │ ├── model.pkl # Serialized model file
│ │ │ ├── input_example.json
│ │ │ └── conda.yaml
│ │ ├── tags/
│ │ │ ├── mlflow.user
│ │ │ ├── mlflow.source.git.commit
│ │ │ ├── mlflow.runName
│ │ │ ├── mlflow.source.name
│ │ │ ├── mlflow.log-model.history
│ │ │ └── mlflow.source.type
│ │ ├── params/
│ │ │ ├── max_depth
│ │ │ └── random_state
│ │ └── meta.yaml
│ └── meta.yaml
├── models/ # Model Registry Directory
├── sk-learn-random-forest-reg-model/ # Registered model name
│ ├── version-1/ # Model version directory
│ │ └── meta.yaml
│ └── meta.yaml
跟踪服务器按实验 ID 和运行 ID 组织,负责存储我们的实验工件、参数和指标。 另一方面,模型注册表仅存储元数据,其中包含指向我们的跟踪服务器的指针。
正如您所看到的,支持自动日志记录的风格提供了许多额外的开箱即用的信息。 另请注意,即使我们的感兴趣的模型没有自动日志记录,我们也可以通过显式日志记录调用轻松存储此信息。
另一个有趣的亮点是,默认情况下,您可以通过三种方式来管理模型的环境:python_env.yaml
(python 虚拟环境)、requirements.txt
(PyPi requirements) 和 conda.yaml
(conda 环境)。
好的,现在我们对记录的内容有了非常高的理解,让我们使用 MLflow UI 来查看此信息。
步骤 2:启动跟踪服务器
在与您的 mlruns
文件夹相同的目录中,运行以下命令。
mlflow server --host 127.0.0.1 --port 8080
[2024-02-05 12:43:10 -0500] [26393] [INFO] Starting gunicorn 20.1.0
[2024-02-05 12:43:10 -0500] [26393] [INFO] Listening at: http://127.0.0.1:8080 (26393)
[2024-02-05 12:43:10 -0500] [26393] [INFO] Using worker: sync
[2024-02-05 12:43:10 -0500] [26414] [INFO] Booting worker with pid: 26414
[2024-02-05 12:43:11 -0500] [26416] [INFO] Booting worker with pid: 26416
[2024-02-05 12:43:11 -0500] [26428] [INFO] Booting worker with pid: 26428
[2024-02-05 12:43:11 -0500] [26437] [INFO] Booting worker with pid: 26437
步骤 3:查看跟踪服务器
假设没有错误,您可以转到您的 Web 浏览器并访问 https://:8080
以查看 MLflow UI。
首先,让我们离开实验跟踪选项卡并访问模型注册表。
接下来,让我们添加标签和模型版本别名,以促进模型部署。 您可以通过单击模型版本表中的相应 Add
链接或铅笔图标来添加或编辑标签和别名。 让我们...
- 添加一个模型版本标签,其键为
problem_type
,值为regression
。 - 添加一个模型版本别名
the_best_model_ever
。
加载注册的模型
要对注册的模型版本执行推理,我们需要将其加载到内存中。 有很多方法可以找到我们的模型版本,但最佳方法取决于您拥有的信息。 但是,本着快速入门的精神,下面的代码片段展示了通过特定的模型 URI 从模型注册表加载模型并执行推理的最简单方法。
import mlflow.sklearn
from sklearn.datasets import make_regression
model_name = "sk-learn-random-forest-reg-model"
model_version = "latest"
# Load the model from the Model Registry
model_uri = f"models:/{model_name}/{model_version}"
model = mlflow.sklearn.load_model(model_uri)
# Generate a new dataset for prediction and predict
X_new, _ = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
y_pred_new = model.predict(X_new)
print(y_pred_new)
请注意,如果您没有使用 sklearn,如果支持您的模型风格,您应该使用特定的模型风格加载方法,例如 mlflow.<flavor>.load_model()
。 如果不支持模型风格,您应该利用mlflow.pyfunc.load_model()
。 在本教程中,我们使用 sklearn 作为演示目的。
示例 0:通过跟踪服务器加载
模型 URI 是序列化模型的唯一标识符。 鉴于模型工件与实验一起存储在跟踪服务器中,您可以使用以下模型 URI 绕过模型注册表并将工件加载到内存中。
- 绝对本地路径:
mlflow.sklearn.load_model("/Users/me/path/to/local/model")
- 相对本地路径:
mlflow.sklearn.load_model("relative/path/to/local/model")
- 运行 ID:
mlflow.sklearn.load_model(f"runs:/{mlflow_run_id}/{run_relative_path_to_model}")
但是,除非您与记录模型的环境相同,否则您通常不会拥有以上信息。 相反,您应该通过利用模型的名称和版本来加载模型。
示例 1:通过名称和版本加载
要通过 model_name
和单调递增的 model_version
将模型加载到内存中,请使用以下方法
model = mlflow.sklearn.load_model(f"models:/{model_name}/{model_version}")
虽然此方法快速且简单,但单调递增的模型版本缺乏灵活性。 通常,利用模型版本别名会更有效。
示例 2:通过模型版本别名加载
模型版本别名是模型版本的用户定义标识符。 鉴于它们在模型注册后是可变的,因此它们将模型版本与使用它们的代码分离。
例如,假设我们有一个名为 production_model
的模型版本别名,对应于一个生产模型。 当我们的团队构建了一个更好的模型并准备好部署时,我们不必更改我们的服务工作负载代码。 相反,在 MLflow 中,我们将 production_model
别名从旧模型版本重新分配给新模型版本。 这可以在 UI 中简单地完成。 在 API 中,我们使用相同的模型名称、别名名称和新模型版本 ID 运行client.set_registered_model_alias。 就这么简单!
在上一页中,我们向我们的模型添加了一个模型版本别名,但这是一个编程示例。
import mlflow.sklearn
from mlflow import MlflowClient
client = MlflowClient()
# Set model version alias
model_name = "sk-learn-random-forest-reg-model"
model_version_alias = "the_best_model_ever"
client.set_registered_model_alias(
model_name, model_version_alias, "1"
) # Duplicate of step in UI
# Get information about the model
model_info = client.get_model_version_by_alias(model_name, model_version_alias)
model_tags = model_info.tags
print(model_tags)
# Get the model version using a model URI
model_uri = f"models:/{model_name}@{model_version_alias}"
model = mlflow.sklearn.load_model(model_uri)
print(model)
{'problem_type': 'regression'}
RandomForestRegressor(max_depth=2, random_state=42)
模型版本别名是高度动态的,可以对应于对您的团队有意义的任何内容。 最常见的例子是部署状态。 例如,假设我们在生产中有一个 champion
模型,但正在开发 challenger
模型,该模型有望胜过我们的生产模型。 您可以使用 champion
和 challenger
模型版本别名来唯一标识这些模型版本,以便于访问。
就这样! 您现在应该可以轻松地...
- 注册模型
- 通过 MLflow UI 查找模型并修改标签和模型版本别名
- 加载注册的模型以进行推理