MLflow 5 分钟跟踪快速入门
本笔记本演示了如何使用本地 MLflow 跟踪服务器来记录、注册,然后将模型加载为通用的 Python 函数 (pyfunc),以便对 Pandas DataFrame 进行推理。
在本笔记本中,我们将使用 MLflow 声明式 API 来执行与 MLflow 跟踪服务器的所有交互。
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mlflow
from mlflow.models import infer_signature
设置 MLflow 跟踪 URI
根据您运行本笔记本的位置,您初始化与 MLflow 跟踪服务器的接口的配置可能会有所不同。
在此示例中,我们使用的是本地运行的跟踪服务器,但也有其他选项可用(最简单的方法是使用 Databricks 免费试用中的免费托管服务)。
请参阅 此处的运行笔记本指南,了解有关设置跟踪服务器 URI 和配置对托管或自托管 MLflow 跟踪服务器的访问的更多信息。
# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the Databricks Managed MLflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
加载训练数据并训练一个简单模型
在我们的快速入门中,我们将使用 scikit-learn 中包含的熟悉的鸢尾花数据集。在数据分割后,我们将使用训练数据训练一个简单的逻辑回归分类器,并计算一些在保留的测试数据上的错误度量。
请注意,此部分中与 MLflow 相关的唯一活动是使用了 param 字典来提供模型的超参数;这是为了在我们准备好记录模型及其相关元数据时,更轻松地记录这些设置。
# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define the model hyperparameters
params = {"solver": "lbfgs", "max_iter": 1000, "multi_class": "auto", "random_state": 8888}
# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
# Predict on the test set
y_pred = lr.predict(X_test)
# Calculate accuracy as a target loss metric
accuracy = accuracy_score(y_test, y_pred)
定义 MLflow 实验
为了将特定项目或想法的任何独立运行分组在一起,我们可以定义一个实验,该实验将把每个迭代(运行)分组。定义一个与我们正在处理的内容相关的唯一名称有助于组织,并减少找到运行所需的工作量(搜索)。
mlflow.set_experiment("MLflow Quickstart")
<Experiment: artifact_location='mlflow-artifacts:/846578415685150448', creation_time=1699374480748, experiment_id='846578415685150448', last_update_time=1699374480748, lifecycle_stage='active', name='MLflow Quickstart', tags={}>
将模型、超参数和损失度量记录到 MLflow。
为了记录我们的模型、拟合模型时使用的超参数以及在保留数据上验证拟合模型的度量,我们启动一个运行上下文,如下所示。在该上下文的作用域内,我们调用的任何声明式 API(如 mlflow.log_params() 或 mlflow.sklearn.log_model())都将与同一运行相关联并一起记录。
# Start an MLflow run
with mlflow.start_run():
# Log the hyperparameters
mlflow.log_params(params)
# Log the loss metric
mlflow.log_metric("accuracy", accuracy)
# Set a tag that we can use to remind ourselves what this run was for
mlflow.set_tag("Training Info", "Basic LR model for iris data")
# Infer the model signature
signature = infer_signature(X_train, lr.predict(X_train))
# Log the model
model_info = mlflow.sklearn.log_model(
sk_model=lr,
name="iris_model",
signature=signature,
input_example=X_train,
registered_model_name="tracking-quickstart",
)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
Registered model 'tracking-quickstart' already exists. Creating a new version of this model...
2023/11/07 12:17:01 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3
Created version '3' of model 'tracking-quickstart'.
将我们的模型加载为 Python 函数
虽然我们可以使用 mlflow.sklearn.load_model() 将模型加载回为原生的 scikit-learn 格式,但下面我们将模型加载为通用的 Python 函数,这也是模型在线模型服务加载模型的方式。尽管如此,我们仍然可以为批量用例使用 pyfunc 表示形式,如下所示。
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Downloading artifacts: 0%| | 0/6 [00:00<?, ?it/s]
使用我们的模型预测 Pandas DataFrame 上的鸢尾花类类型
predictions = loaded_model.predict(X_test)
iris_feature_names = datasets.load_iris().feature_names
# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=iris_feature_names)
# Add the actual classes to the DataFrame
result["actual_class"] = y_test
# Add the model predictions to the DataFrame
result["predicted_class"] = predictions
result[:4]
| 萼片长度 (cm) | 萼片宽度 (cm) | 花瓣长度 (cm) | 花瓣宽度 (cm) | actual_class | predicted_class | |
|---|---|---|---|---|---|---|
| 0 | 6.1 | 2.8 | 4.7 | 1.2 | 1 | 1 |
| 1 | 5.7 | 3.8 | 1.7 | 0.3 | 0 | 0 |
| 2 | 7.7 | 2.6 | 6.9 | 2.3 | 2 | 2 |
| 3 | 6.0 | 2.9 | 4.5 | 1.5 | 1 | 1 |