跳到主要内容

MLflow 5 分钟跟踪快速入门

下载此笔记本

本笔记本演示了如何使用本地 MLflow 跟踪服务器来记录、注册模型,然后将模型加载为通用 Python 函数 (pyfunc),以对 Pandas DataFrame 执行推理。

在本笔记本中,我们将使用 MLflow fluent API 来执行与 MLflow 跟踪服务器的所有交互。

import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import mlflow
from mlflow.models import infer_signature

设置 MLflow 跟踪 URI

根据您运行此笔记本的位置,您初始化与 MLflow 跟踪服务器的接口的配置可能会有所不同。

在此示例中,我们使用本地运行的跟踪服务器,但还有其他选项可用(最简单的方法是使用 Databricks 免费试用版中的免费托管服务)。

请参阅此处的运行笔记本指南,以获取有关设置跟踪服务器 URI 和配置对托管或自托管 MLflow 跟踪服务器的访问权限的更多信息。

# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the Databricks Managed MLflow

mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")

加载训练数据并训练一个简单的模型

对于我们的快速入门,我们将使用 scikit-learn 中包含的熟悉的 iris 数据集。在拆分数据后,我们将在训练数据上训练一个简单的逻辑回归分类器,并计算保留测试数据上的一些误差指标。

请注意,此部分中唯一与 MLflow 相关的活动围绕着我们使用 param 字典来提供我们模型的超参数这一事实; 这是为了在我们准备好记录我们的模型及其相关元数据时,更容易记录这些设置。

# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the model hyperparameters
params = {"solver": "lbfgs", "max_iter": 1000, "multi_class": "auto", "random_state": 8888}

# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)

# Predict on the test set
y_pred = lr.predict(X_test)

# Calculate accuracy as a target loss metric
accuracy = accuracy_score(y_test, y_pred)

定义 MLflow 实验

为了将特定项目或想法的任何不同运行分组在一起,我们可以定义一个实验,该实验将每个迭代(运行)分组在一起。 定义与我们正在处理的内容相关的唯一名称有助于组织,并减少了稍后查找运行的工作量(搜索)。

mlflow.set_experiment("MLflow Quickstart")
<Experiment: artifact_location='mlflow-artifacts:/846578415685150448', creation_time=1699374480748, experiment_id='846578415685150448', last_update_time=1699374480748, lifecycle_stage='active', name='MLflow Quickstart', tags={}>

将模型、超参数和损失指标记录到 MLflow。

为了记录我们的模型以及拟合模型时使用的超参数,以及与验证保留数据上拟合模型相关的指标,我们启动一个运行上下文,如下所示。 在该上下文的范围内,我们调用的任何 fluent API(例如 mlflow.log_params()mlflow.sklearn.log_model())将被关联并一起记录到相同的运行中。

# Start an MLflow run
with mlflow.start_run():
# Log the hyperparameters
mlflow.log_params(params)

# Log the loss metric
mlflow.log_metric("accuracy", accuracy)

# Set a tag that we can use to remind ourselves what this run was for
mlflow.set_tag("Training Info", "Basic LR model for iris data")

# Infer the model signature
signature = infer_signature(X_train, lr.predict(X_train))

# Log the model
model_info = mlflow.sklearn.log_model(
sk_model=lr,
name="iris_model",
signature=signature,
input_example=X_train,
registered_model_name="tracking-quickstart",
)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
Registered model 'tracking-quickstart' already exists. Creating a new version of this model...
2023/11/07 12:17:01 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3
Created version '3' of model 'tracking-quickstart'.

将我们保存的模型加载为 Python 函数

虽然我们可以使用 mlflow.sklearn.load_model() 将我们的模型加载回原生 scikit-learn 格式,但下面我们将模型加载为通用 Python 函数,这是在线模型服务加载此模型的方式。 但是,我们仍然可以将 pyfunc 表示形式用于批量用例,如下所示。

loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

使用我们的模型来预测 Pandas DataFrame 上的 iris 类类型

predictions = loaded_model.predict(X_test)

iris_feature_names = datasets.load_iris().feature_names

# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=iris_feature_names)

# Add the actual classes to the DataFrame
result["actual_class"] = y_test

# Add the model predictions to the DataFrame
result["predicted_class"] = predictions

result[:4]
萼片长度 (cm) 萼片宽度 (cm) 花瓣长度 (cm) 花瓣宽度 (cm) 实际类别 预测类别
0 6.1 2.8 4.7 1.2 1 1
1 5.7 3.8 1.7 0.3 0 0
2 7.7 2.6 6.9 2.3 2 2
3 6.0 2.9 4.5 1.5 1 1