MLflow 跟踪
MLflow 追踪是一个 API 和 UI,用于在运行机器学习代码时记录参数、代码版本、指标和输出文件,并用于后续结果的可视化。MLflow 追踪提供了 Python、REST、R 和 Java API。
快速入门
如果您以前没有使用过 MLflow 追踪,我们强烈建议您阅读以下快速入门教程。
概念
运行
MLflow 追踪围绕运行的概念进行组织,运行是数据科学代码的执行过程,例如,一次 python train.py
执行。每次运行都会记录元数据(关于您的运行的各种信息,例如指标、参数、开始和结束时间)和产物(运行的输出文件,例如模型权重、图像等)。
模型
模型表示在您的运行期间生成的经过训练的机器学习产物。记录的模型包含其自己的元数据和产物,与运行类似。
实验
一个实验将特定任务的运行和模型组合在一起。您可以使用 CLI、API 或 UI 创建实验。MLflow API 和 UI 还允许您创建和搜索实验。有关如何将运行组织到实验中的更多详细信息,请参阅将运行组织到实验中。
追踪运行
MLflow 追踪 API 提供了一组函数来追踪您的运行。例如,您可以调用 mlflow.start_run()
来开始新的运行,然后调用日志记录函数,例如 mlflow.log_param()
和 mlflow.log_metric()
分别记录参数和指标。请访问追踪 API 文档以获取有关使用这些 API 的更多详细信息。
import mlflow
with mlflow.start_run():
mlflow.log_param("lr", 0.001)
# Your ml code
...
mlflow.log_metric("val_loss", val_loss)
或者,自动日志记录提供了启动 MLflow 追踪的超快速设置。这个强大的功能允许您记录指标、参数和模型,而无需显式的日志语句——您所需要做的就是在训练代码之前调用 mlflow.autolog()
。自动日志记录支持流行的库,如 Scikit-learn、XGBoost、PyTorch、Keras、Spark 等。请参阅自动日志记录文档,了解支持的库以及如何将自动日志记录 API 与它们一起使用。
import mlflow
mlflow.autolog()
# Your training code...
默认情况下,在没有任何特定服务器/数据库配置的情况下,MLflow 追踪会将数据记录到本地的 mlruns
目录。如果您想将运行记录到其他位置,例如远程数据库和云存储,以便与团队共享结果,请按照设置 MLflow 追踪环境部分中的说明进行操作。
以编程方式搜索已记录的模型
MLflow 3 通过 mlflow.search_logged_models()
引入了强大的模型搜索功能。此 API 允许您使用类似 SQL 的语法,根据性能指标、参数和模型属性在您的实验中查找特定模型。
import mlflow
# Find high-performing models across experiments
top_models = mlflow.search_logged_models(
experiment_ids=["1", "2"],
filter_string="metrics.accuracy > 0.95 AND params.model_type = 'RandomForest'",
order_by=[{"field_name": "metrics.f1_score", "ascending": False}],
max_results=5,
)
# Get the best model for deployment
best_model = mlflow.search_logged_models(
experiment_ids=["1"],
filter_string="metrics.accuracy > 0.9",
max_results=1,
order_by=[{"field_name": "metrics.accuracy", "ascending": False}],
output_format="list",
)[0]
# Load the best model directly
loaded_model = mlflow.pyfunc.load_model(f"models:/{best_model.model_id}")
主要特性
- 类 SQL 过滤:使用
metrics.
、params.
和属性前缀构建复杂查询 - 数据集感知搜索:根据特定数据集过滤指标,以便进行公平的模型比较
- 灵活排序:按多个条件排序以找到最佳模型
- 直接模型加载:使用新的
models:/<model_id>
URI 格式直接访问模型
有关全面的示例和高级搜索模式,请参阅搜索已记录模型指南。
以编程方式查询运行
您还可以使用 MlflowClient 以编程方式访问追踪 UI 中的所有功能。
例如,以下代码片段搜索实验中具有最佳验证损失的运行。
client = mlflow.tracking.MlflowClient()
experiment_id = "0"
best_run = client.search_runs(
experiment_id, order_by=["metrics.val_loss ASC"], max_results=1
)[0]
print(best_run.info)
# {'run_id': '...', 'metrics': {'val_loss': 0.123}, ...}
追踪模型
MLflow 3 引入了增强的模型追踪功能,允许您在单个运行中记录多个模型检查点,并追踪它们在不同数据集上的性能。这对于深度学习工作流程特别有用,您可以在不同的训练阶段保存和比较模型检查点。
记录模型检查点
您可以在训练期间的不同步骤使用模型日志记录函数中的 step
参数记录模型检查点。每个记录的模型都会获得一个唯一的模型 ID,您可以使用它稍后引用该模型。
import mlflow
import mlflow.pytorch
with mlflow.start_run() as run:
for epoch in range(100):
# Train your model
train_model(model, epoch)
# Log model checkpoint every 10 epochs
if epoch % 10 == 0:
model_info = mlflow.pytorch.log_model(
pytorch_model=model,
name=f"checkpoint-epoch-{epoch}",
step=epoch,
input_example=sample_input,
)
# Log metrics linked to this specific model checkpoint
accuracy = evaluate_model(model, validation_data)
mlflow.log_metric(
key="accuracy",
value=accuracy,
step=epoch,
model_id=model_info.model_id, # Link metric to specific model
dataset=validation_dataset,
)
将指标链接到模型和数据集
MLflow 3 允许您将指标链接到特定的模型检查点和数据集,从而提供更好的模型性能可追溯性
# Create a dataset reference
train_dataset = mlflow.data.from_pandas(train_df, name="training_data")
# Log metric with model and dataset links
mlflow.log_metric(
key="f1_score",
value=0.95,
step=epoch,
model_id=model_info.model_id, # Links to specific model checkpoint
dataset=train_dataset, # Links to specific dataset
)
搜索和排名模型检查点
使用 mlflow.search_logged_models()
根据其性能指标搜索和排名模型检查点
# Search for all models in a run, ordered by accuracy
ranked_models = mlflow.search_logged_models(
filter_string=f"source_run_id='{run.info.run_id}'",
order_by=[{"field_name": "metrics.accuracy", "ascending": False}],
output_format="list",
)
# Get the best performing model
best_model = ranked_models[0]
print(f"Best model: {best_model.name}")
print(f"Accuracy: {best_model.metrics[0].value}")
# Load the best model for inference
loaded_model = mlflow.pyfunc.load_model(f"models:/{best_model.model_id}")
MLflow 3 中的模型 URI
MLflow 3 引入了一种新的模型 URI 格式,它使用模型 ID 而不是运行 ID,提供更直接的模型引用
# New MLflow 3 model URI format
model_uri = f"models:/{model_info.model_id}"
loaded_model = mlflow.pyfunc.load_model(model_uri)
# This replaces the older run-based URI format:
# model_uri = f"runs:/{run_id}/model_path"
这种新方法具有多项优势
- 直接模型引用:无需知道运行 ID 和产物路径
- 更好的模型生命周期管理:每个模型检查点都有其唯一的标识符
- 改进的模型比较:轻松比较同一运行中的不同检查点
- 增强的可追溯性:模型、指标和数据集之间有清晰的链接
追踪数据集
MLflow 提供了追踪与模型训练事件相关联的数据集的能力。这些与数据集相关的元数据可以通过使用 mlflow.log_input()
API 进行存储。要了解更多信息,请访问MLflow 数据文档,查看此 API 中可用的功能。
探索运行、模型和结果
追踪 UI
追踪 UI 允许您直观地探索您的实验、运行和模型,如本页面顶部所示。
- 基于实验的运行列表和比较(包括跨多个实验的运行比较)
- 按参数或指标值搜索运行
- 可视化运行指标
- 下载运行结果(产物和元数据)
这些功能也适用于模型,如下所示。
如果您将运行记录到本地 mlruns
目录,请在该目录的上一级目录中运行以下命令,然后在浏览器中访问 http://127.0.0.1:5000。
mlflow ui --port 5000
或者,MLflow 追踪服务器提供相同的 UI 并支持运行产物的远程存储。在这种情况下,您可以从任何可以连接到您的追踪服务器的机器上,通过 http://<您的 MLflow 追踪服务器的 IP 地址>:5000
查看 UI。
设置 MLflow 追踪环境
如果您只想将实验数据和模型记录到本地文件,则可以跳过此部分。
MLflow 追踪支持您的开发工作流程中的许多不同场景。本节将指导您如何为您的特定用例设置 MLflow 追踪环境。从鸟瞰的角度来看,MLflow 追踪环境由以下组件组成。
组件
MLflow 追踪 API
您可以在您的 ML 代码中调用 MLflow 追踪 API 来记录运行,并在必要时与 MLflow 追踪服务器通信。
后端存储
后端存储为每个运行持久化各种元数据,例如运行 ID、开始和结束时间、参数、指标等。MLflow 支持两种类型的后端存储:基于文件系统的(如本地文件)和基于数据库的(如 PostgreSQL)。
此外,如果您正在与托管服务(例如 Databricks 或 Azure 机器学习)交互,您将与一个外部管理且无法直接访问的基于 REST 的后端存储进行交互。
产物存储
产物存储为每个运行持久化(通常较大的)产物,例如模型权重(例如,一个 pickle 化的 scikit-learn 模型)、图像(例如 PNG 文件)、模型和数据文件(例如 Parquet 文件)。MLflow 默认将产物存储在本地文件(mlruns
)中,但也支持不同的存储选项,例如 Amazon S3 和 Azure Blob Storage。
对于作为 MLflow 产物记录的模型,您可以通过 models:/<model_id>
格式的模型 URI 引用该模型,其中 'model_id' 是分配给已记录模型的唯一标识符。这取代了旧的 runs:/<run_id>/<artifact_path>
格式,并提供了更直接的模型引用。
如果模型已在 MLflow 模型注册表中注册,您还可以通过 models:/<model-name>/<model-version>
格式的模型 URI 引用该模型,详情请参阅MLflow 模型注册表。
MLflow 追踪服务器(可选)
MLflow 追踪服务器是一个独立的 HTTP 服务器,提供 REST API 以访问后端和/或产物存储。追踪服务器还提供了配置要服务的数据、管理访问控制、版本控制等方面的灵活性。请阅读MLflow 追踪服务器文档了解更多详细信息。
常见设置
通过正确配置这些组件,您可以创建适合您团队开发工作流程的 MLflow 追踪环境。以下图表和表格展示了 MLflow 追踪环境的几种常见设置。
1. 本地主机(默认) | 2. 使用本地数据库进行本地追踪 | 3. 使用MLflow 追踪服务器进行远程追踪 | |
---|---|---|---|
场景 | 个人开发 | 个人开发 | 团队开发 |
用例 | 默认情况下,MLflow 将每次运行的元数据和产物记录到本地目录 mlruns 。这是开始使用 MLflow 追踪的最简单方式,无需设置任何外部服务器、数据库和存储。 | MLflow 客户端可以与 SQLAlchemy 兼容的数据库(例如 SQLite、PostgreSQL、MySQL)作为后端进行交互。将元数据保存到数据库可以更清晰地管理您的实验数据,同时省去了设置服务器的精力。 | MLflow 追踪服务器可以配置产物 HTTP 代理,通过追踪服务器传递产物请求,从而存储和检索产物,而无需与底层对象存储服务交互。这对于团队开发场景特别有用,在这种场景下,您希望将产物和实验元数据存储在具有适当访问控制的共享位置。 |
教程 | 快速入门 | 使用本地数据库追踪实验 | 使用 MLflow 追踪服务器进行远程实验追踪 |
使用MLflow 追踪服务器的其他配置
MLflow 追踪服务器为其他特殊用例提供了可定制性。请按照使用 MLflow 追踪服务器进行远程实验追踪了解基本设置,并继续阅读以下材料以获取满足您需求的高级配置。
- 本地追踪服务器
- 仅产物模式
- 直接访问产物
在本地使用 MLflow 追踪服务器
您当然可以在本地运行 MLflow 追踪服务器。虽然这与直接使用本地文件或数据库相比没有太多额外好处,但对于在本地测试您的团队开发工作流程或在容器环境中运行您的机器学习代码可能很有用。
在仅产物模式下运行 MLflow 追踪服务器
MLflow 追踪服务器有一个 --artifacts-only
选项,允许服务器专门处理(代理)产物,而不允许处理元数据。当您身处大型组织或训练超大型模型时,这尤其有用。在这些场景中,您可能具有高产物传输量,并且可以通过分离产物服务流量来避免影响追踪功能。请阅读可选地仅使用追踪服务器实例进行产物处理以获取有关如何使用此模式的更多详细信息。
禁用产物代理以允许直接访问产物
MLflow 追踪服务器默认同时提供产物和元数据。然而,在某些情况下,您可能希望允许直接访问远程产物存储,以避免代理开销,同时保留元数据追踪的功能。这可以通过使用 --no-serve-artifacts
选项启动服务器来禁用产物代理实现。有关如何设置此项,请参阅不通过代理访问产物来使用追踪服务器。
常见问题
我可以并行启动多个运行吗?
是的,MLflow 支持并行启动多个运行,例如多进程/多线程。有关更多详细信息,请参阅在一个程序中启动多个运行。
我如何整齐地组织许多 MLflow 运行?
MLflow 提供了几种组织运行的方式
- 将运行组织到实验中 - 实验是您运行的逻辑容器。您可以使用 CLI、API 或 UI 创建实验。
- 创建子运行 - 您可以在单个父运行下创建子运行以将其分组。例如,您可以为交叉验证实验中的每个折叠创建一个子运行。
- 为运行添加标签 - 您可以将任意标签与每次运行关联起来,这允许您根据标签过滤和搜索运行。
我可以在不运行追踪服务器的情况下直接访问远程存储吗?
是的,虽然在团队开发工作流程中将 MLflow 追踪服务器作为产物访问的代理是最佳实践,但如果您将其用于个人项目或测试,则可能不需要这样做。您可以通过以下变通方法实现这一点
- 设置产物配置,例如凭据和端点,就像您为 MLflow 追踪服务器所做的那样。有关更多详细信息,请参阅配置产物存储。
- 创建一个具有显式产物位置的实验,
experiment_name = "your_experiment_name"
mlflow.create_experiment(experiment_name, artifact_location="s3://your-bucket")
mlflow.set_experiment(experiment_name)
此实验下的运行将直接将产物记录到远程存储。
如何将 MLflow 追踪与模型注册表集成?
要将模型注册表功能与 MLflow 追踪结合使用,您必须使用数据库支持的存储(例如 PostgreSQL),并使用相应模型风格的 log_model
方法记录模型。一旦模型被记录,您就可以通过 UI 或 API 在模型注册表中添加、修改、更新或删除模型。有关如何为您的工作流程正确配置后端存储,请参阅后端存储和常见设置。
如何包含关于运行的额外描述文本?
系统标签 mlflow.note.content
可用于添加关于此运行的描述性备注。虽然其他系统标签是自动设置的,但此标签默认未设置,用户可以覆盖它以包含有关运行的额外信息。内容将显示在运行页面的“备注”部分下。