跳到主要内容

加载已注册模型

要对已注册模型版本执行推理,我们需要将其加载到内存中。有多种方法可以找到我们的模型版本,但最佳方法取决于您拥有的可用信息。然而,本着快速入门的精神,下面的代码片段展示了通过特定模型 URI 从模型注册表加载模型并执行推理的最简单方法。

import mlflow.sklearn
from sklearn.datasets import make_regression

model_name = "sk-learn-random-forest-reg-model"
model_version = "latest"

# Load the model from the Model Registry
model_uri = f"models:/{model_name}/{model_version}"
model = mlflow.sklearn.load_model(model_uri)

# Generate a new dataset for prediction and predict
X_new, _ = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
y_pred_new = model.predict(X_new)

print(y_pred_new)

请注意,如果您不使用 sklearn,并且您的模型 Flavor(风味)受支持,则应使用特定的模型 Flavor 加载方法,例如 mlflow.<flavor>.load_model()。如果模型 Flavor 不受支持,您应利用 mlflow.pyfunc.load_model()。在本教程中,我们全程使用 sklearn 进行演示。

示例 0:通过 Tracking Server(跟踪服务器)加载

模型 URI 是序列化模型的唯一标识符。鉴于模型 Artifact(工件)与实验一起存储在 Tracking Server(跟踪服务器)中,您可以使用以下模型 URI 绕过模型注册表并将 Artifact 加载到内存中。

  1. 绝对本地路径: mlflow.sklearn.load_model("/Users/me/path/to/local/model")
  2. 相对本地路径: mlflow.sklearn.load_model("relative/path/to/local/model")
  3. 运行 ID: mlflow.sklearn.load_model(f"runs:/{mlflow_run_id}/{run_relative_path_to_model}")

然而,除非您与记录模型时处于同一环境中,否则通常不会拥有上述信息。相反,您应该通过利用模型的名称和版本来加载模型。

示例 1:通过名称和版本加载

要通过 model_name 和单调递增的 model_version 将模型加载到内存中,请使用以下方法

model = mlflow.sklearn.load_model(f"models:/{model_name}/{model_version}")

尽管此方法快速简便,但单调递增的模型版本缺乏灵活性。通常,利用模型版本别名会更高效。

示例 2:通过模型版本别名加载

模型版本别名是用户定义的模型版本标识符。鉴于它们在模型注册后是可变的,因此它们将模型版本与使用它们的代码解耦。

例如,假设我们有一个名为 production_model 的模型版本别名,对应于生产模型。当我们的团队构建了一个更好的模型并准备好部署时,我们无需更改服务工作负载代码。相反,在 MLflow 中,我们将 production_model 别名从旧模型版本重新分配给新模型版本。这可以在 UI 中轻松完成。在 API 中,我们使用相同的模型名称、别名名称和模型版本 ID 运行 client.set_registered_model_alias。就是这么简单!

在上一页中,我们为模型添加了模型版本别名,但这里是一个程序化示例。

import mlflow.sklearn
from mlflow import MlflowClient

client = MlflowClient()

# Set model version alias
model_name = "sk-learn-random-forest-reg-model"
model_version_alias = "the_best_model_ever"
client.set_registered_model_alias(
model_name, model_version_alias, "1"
) # Duplicate of step in UI

# Get information about the model
model_info = client.get_model_version_by_alias(model_name, model_version_alias)
model_tags = model_info.tags
print(model_tags)

# Get the model version using a model URI
model_uri = f"models:/{model_name}@{model_version_alias}"
model = mlflow.sklearn.load_model(model_uri)

print(model)
输出
{'problem_type': 'regression'}
RandomForestRegressor(max_depth=2, random_state=42)

模型版本别名高度动态,可以对应于对您的团队有意义的任何事物。最常见的例子是部署状态。例如,假设我们在生产环境中有一个 champion 模型,但正在开发一个有望超越我们生产模型的 challenger 模型。您可以使用 championchallenger 模型版本别名来唯一标识这些模型版本,以便于访问。

就是这样!您现在应该能够轻松地...

  1. 注册模型
  2. 通过 MLflow UI 查找模型并修改标签和模型版本别名
  3. 加载已注册模型进行推理