MLflow 5分钟追踪快速入门
本notebook演示了如何使用本地MLflow Tracking服务器来记录、注册,然后将模型作为通用Python函数(pyfunc)加载,以在Pandas DataFrame上执行推理。
在本notebook中,我们将使用MLflow fluent API与MLflow Tracking服务器进行所有交互。
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mlflow
from mlflow.models import infer_signature
设置MLflow Tracking URI
根据您运行此notebook的位置,您初始化与MLflow Tracking服务器接口的配置可能会有所不同。
在此示例中,我们使用本地运行的追踪服务器,但也有其他选项可用(最简单的是使用Databricks免费试用版中的免费托管服务)。
有关设置追踪服务器URI以及配置对托管或自管理MLflow追踪服务器的访问的更多信息,请参阅此处运行notebook的指南。
# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the Databricks Managed MLflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
加载训练数据并训练一个简单模型
对于我们的快速入门,我们将使用scikit-learn中包含的熟悉的iris数据集。在数据分割之后,我们将在训练数据上训练一个简单的逻辑回归分类器,并计算我们的保留测试数据的一些错误指标。
请注意,此部分中唯一与MLflow相关的活动是我们使用`param`字典来提供模型的超参数;这是为了在我们准备记录模型及其相关元数据时,更容易记录这些设置。
# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define the model hyperparameters
params = {"solver": "lbfgs", "max_iter": 1000, "multi_class": "auto", "random_state": 8888}
# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
# Predict on the test set
y_pred = lr.predict(X_test)
# Calculate accuracy as a target loss metric
accuracy = accuracy_score(y_test, y_pred)
定义MLflow实验
为了将特定项目或想法的任何不同运行分组在一起,我们可以定义一个实验,它将每个迭代(运行)组合在一起。定义一个与我们正在进行的工作相关的唯一名称有助于组织,并减少以后查找运行的工作量(搜索)。
mlflow.set_experiment("MLflow Quickstart")
<Experiment: artifact_location='mlflow-artifacts:/846578415685150448', creation_time=1699374480748, experiment_id='846578415685150448', last_update_time=1699374480748, lifecycle_stage='active', name='MLflow Quickstart', tags={}>
将模型、超参数和损失指标记录到MLflow。
为了记录我们的模型以及拟合模型时使用的超参数,以及与在保留数据上验证拟合模型相关的指标,我们启动了一个运行上下文,如下所示。在该上下文的范围内,我们调用的任何fluent API(例如`mlflow.log_params()`或`mlflow.sklearn.log_model()`)都将关联并一起记录到同一运行中。
# Start an MLflow run
with mlflow.start_run():
# Log the hyperparameters
mlflow.log_params(params)
# Log the loss metric
mlflow.log_metric("accuracy", accuracy)
# Set a tag that we can use to remind ourselves what this run was for
mlflow.set_tag("Training Info", "Basic LR model for iris data")
# Infer the model signature
signature = infer_signature(X_train, lr.predict(X_train))
# Log the model
model_info = mlflow.sklearn.log_model(
sk_model=lr,
name="iris_model",
signature=signature,
input_example=X_train,
registered_model_name="tracking-quickstart",
)
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.") Registered model 'tracking-quickstart' already exists. Creating a new version of this model... 2023/11/07 12:17:01 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3 Created version '3' of model 'tracking-quickstart'.
将我们保存的模型加载为Python函数
尽管我们可以使用`mlflow.sklearn.load_model()`将模型作为原生scikit-learn格式加载回来,但下面我们正在将模型作为通用Python函数加载,这是该模型用于在线模型服务的加载方式。不过,我们仍然可以使用`pyfunc`表示进行批处理用例,如下所示。
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Downloading artifacts: 0%| | 0/6 [00:00<?, ?it/s]
使用我们的模型预测Pandas DataFrame上的鸢尾花类别
predictions = loaded_model.predict(X_test)
iris_feature_names = datasets.load_iris().feature_names
# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=iris_feature_names)
# Add the actual classes to the DataFrame
result["actual_class"] = y_test
# Add the model predictions to the DataFrame
result["predicted_class"] = predictions
result[:4]
萼片长度 (cm) | 萼片宽度 (cm) | 花瓣长度 (cm) | 花瓣宽度 (cm) | 实际类别 | 预测类别 | |
---|---|---|---|---|---|---|
0 | 6.1 | 2.8 | 4.7 | 1.2 | 1 | 1 |
1 | 5.7 | 3.8 | 1.7 | 0.3 | 0 | 0 |
2 | 7.7 | 2.6 | 6.9 | 2.3 | 2 | 2 |
3 | 6.0 | 2.9 | 4.5 | 1.5 | 1 | 1 |