跳到主要内容

MLflow 5分钟追踪快速入门

下载此笔记本

本notebook演示了如何使用本地MLflow Tracking服务器来记录、注册,然后将模型作为通用Python函数(pyfunc)加载,以在Pandas DataFrame上执行推理。

在本notebook中,我们将使用MLflow fluent API与MLflow Tracking服务器进行所有交互。

import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import mlflow
from mlflow.models import infer_signature

设置MLflow Tracking URI

根据您运行此notebook的位置,您初始化与MLflow Tracking服务器接口的配置可能会有所不同。

在此示例中,我们使用本地运行的追踪服务器,但也有其他选项可用(最简单的是使用Databricks免费试用版中的免费托管服务)。

有关设置追踪服务器URI以及配置对托管或自管理MLflow追踪服务器的访问的更多信息,请参阅此处运行notebook的指南

# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the Databricks Managed MLflow

mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")

加载训练数据并训练一个简单模型

对于我们的快速入门,我们将使用scikit-learn中包含的熟悉的iris数据集。在数据分割之后,我们将在训练数据上训练一个简单的逻辑回归分类器,并计算我们的保留测试数据的一些错误指标。

请注意,此部分中唯一与MLflow相关的活动是我们使用`param`字典来提供模型的超参数;这是为了在我们准备记录模型及其相关元数据时,更容易记录这些设置。

# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the model hyperparameters
params = {"solver": "lbfgs", "max_iter": 1000, "multi_class": "auto", "random_state": 8888}

# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)

# Predict on the test set
y_pred = lr.predict(X_test)

# Calculate accuracy as a target loss metric
accuracy = accuracy_score(y_test, y_pred)

定义MLflow实验

为了将特定项目或想法的任何不同运行分组在一起,我们可以定义一个实验,它将每个迭代(运行)组合在一起。定义一个与我们正在进行的工作相关的唯一名称有助于组织,并减少以后查找运行的工作量(搜索)。

mlflow.set_experiment("MLflow Quickstart")
<Experiment: artifact_location='mlflow-artifacts:/846578415685150448', creation_time=1699374480748, experiment_id='846578415685150448', last_update_time=1699374480748, lifecycle_stage='active', name='MLflow Quickstart', tags={}>

将模型、超参数和损失指标记录到MLflow。

为了记录我们的模型以及拟合模型时使用的超参数,以及与在保留数据上验证拟合模型相关的指标,我们启动了一个运行上下文,如下所示。在该上下文的范围内,我们调用的任何fluent API(例如`mlflow.log_params()`或`mlflow.sklearn.log_model()`)都将关联并一起记录到同一运行中。

# Start an MLflow run
with mlflow.start_run():
# Log the hyperparameters
mlflow.log_params(params)

# Log the loss metric
mlflow.log_metric("accuracy", accuracy)

# Set a tag that we can use to remind ourselves what this run was for
mlflow.set_tag("Training Info", "Basic LR model for iris data")

# Infer the model signature
signature = infer_signature(X_train, lr.predict(X_train))

# Log the model
model_info = mlflow.sklearn.log_model(
sk_model=lr,
name="iris_model",
signature=signature,
input_example=X_train,
registered_model_name="tracking-quickstart",
)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
Registered model 'tracking-quickstart' already exists. Creating a new version of this model...
2023/11/07 12:17:01 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3
Created version '3' of model 'tracking-quickstart'.

将我们保存的模型加载为Python函数

尽管我们可以使用`mlflow.sklearn.load_model()`将模型作为原生scikit-learn格式加载回来,但下面我们正在将模型作为通用Python函数加载,这是该模型用于在线模型服务的加载方式。不过,我们仍然可以使用`pyfunc`表示进行批处理用例,如下所示。

loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

使用我们的模型预测Pandas DataFrame上的鸢尾花类别

predictions = loaded_model.predict(X_test)

iris_feature_names = datasets.load_iris().feature_names

# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=iris_feature_names)

# Add the actual classes to the DataFrame
result["actual_class"] = y_test

# Add the model predictions to the DataFrame
result["predicted_class"] = predictions

result[:4]
萼片长度 (cm) 萼片宽度 (cm) 花瓣长度 (cm) 花瓣宽度 (cm) 实际类别 预测类别
0 6.1 2.8 4.7 1.2 1 1
1 5.7 3.8 1.7 0.3 0 0
2 7.7 2.6 6.9 2.3 2 2
3 6.0 2.9 4.5 1.5 1 1