MLflow 5 分钟跟踪快速入门
本笔记本演示了如何使用本地 MLflow 跟踪服务器来记录、注册模型,然后将模型加载为通用 Python 函数 (pyfunc),以对 Pandas DataFrame 执行推理。
在本笔记本中,我们将使用 MLflow fluent API 来执行与 MLflow 跟踪服务器的所有交互。
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mlflow
from mlflow.models import infer_signature
设置 MLflow 跟踪 URI
根据您运行此笔记本的位置,您初始化与 MLflow 跟踪服务器的接口的配置可能会有所不同。
在此示例中,我们使用本地运行的跟踪服务器,但还有其他选项可用(最简单的方法是使用 Databricks 免费试用版中的免费托管服务)。
请参阅此处的运行笔记本指南,以获取有关设置跟踪服务器 URI 和配置对托管或自托管 MLflow 跟踪服务器的访问权限的更多信息。
# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the Databricks Managed MLflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
加载训练数据并训练一个简单的模型
对于我们的快速入门,我们将使用 scikit-learn 中包含的熟悉的 iris 数据集。在拆分数据后,我们将在训练数据上训练一个简单的逻辑回归分类器,并计算保留测试数据上的一些误差指标。
请注意,此部分中唯一与 MLflow 相关的活动围绕着我们使用 param
字典来提供我们模型的超参数这一事实; 这是为了在我们准备好记录我们的模型及其相关元数据时,更容易记录这些设置。
# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define the model hyperparameters
params = {"solver": "lbfgs", "max_iter": 1000, "multi_class": "auto", "random_state": 8888}
# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
# Predict on the test set
y_pred = lr.predict(X_test)
# Calculate accuracy as a target loss metric
accuracy = accuracy_score(y_test, y_pred)
定义 MLflow 实验
为了将特定项目或想法的任何不同运行分组在一起,我们可以定义一个实验,该实验将每个迭代(运行)分组在一起。 定义与我们正在处理的内容相关的唯一名称有助于组织,并减少了稍后查找运行的工作量(搜索)。
mlflow.set_experiment("MLflow Quickstart")
<Experiment: artifact_location='mlflow-artifacts:/846578415685150448', creation_time=1699374480748, experiment_id='846578415685150448', last_update_time=1699374480748, lifecycle_stage='active', name='MLflow Quickstart', tags={}>
将模型、超参数和损失指标记录到 MLflow。
为了记录我们的模型以及拟合模型时使用的超参数,以及与验证保留数据上拟合模型相关的指标,我们启动一个运行上下文,如下所示。 在该上下文的范围内,我们调用的任何 fluent API(例如 mlflow.log_params()
或 mlflow.sklearn.log_model()
)将被关联并一起记录到相同的运行中。
# Start an MLflow run
with mlflow.start_run():
# Log the hyperparameters
mlflow.log_params(params)
# Log the loss metric
mlflow.log_metric("accuracy", accuracy)
# Set a tag that we can use to remind ourselves what this run was for
mlflow.set_tag("Training Info", "Basic LR model for iris data")
# Infer the model signature
signature = infer_signature(X_train, lr.predict(X_train))
# Log the model
model_info = mlflow.sklearn.log_model(
sk_model=lr,
name="iris_model",
signature=signature,
input_example=X_train,
registered_model_name="tracking-quickstart",
)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.") Registered model 'tracking-quickstart' already exists. Creating a new version of this model... 2023/11/07 12:17:01 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3 Created version '3' of model 'tracking-quickstart'.
将我们保存的模型加载为 Python 函数
虽然我们可以使用 mlflow.sklearn.load_model()
将我们的模型加载回原生 scikit-learn 格式,但下面我们将模型加载为通用 Python 函数,这是在线模型服务加载此模型的方式。 但是,我们仍然可以将 pyfunc
表示形式用于批量用例,如下所示。
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Downloading artifacts: 0%| | 0/6 [00:00<?, ?it/s]
使用我们的模型来预测 Pandas DataFrame 上的 iris 类类型
predictions = loaded_model.predict(X_test)
iris_feature_names = datasets.load_iris().feature_names
# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=iris_feature_names)
# Add the actual classes to the DataFrame
result["actual_class"] = y_test
# Add the model predictions to the DataFrame
result["predicted_class"] = predictions
result[:4]
萼片长度 (cm) | 萼片宽度 (cm) | 花瓣长度 (cm) | 花瓣宽度 (cm) | 实际类别 | 预测类别 | |
---|---|---|---|---|---|---|
0 | 6.1 | 2.8 | 4.7 | 1.2 | 1 | 1 |
1 | 5.7 | 3.8 | 1.7 | 0.3 | 0 | 0 |
2 | 7.7 | 2.6 | 6.9 | 2.3 | 2 | 2 |
3 | 6.0 | 2.9 | 4.5 | 1.5 | 1 | 1 |