跳到主要内容

MLflow 数据集跟踪

mlflow.data 模块是贯穿整个机器学习生命周期的数据集管理的综合解决方案。它使您能够跟踪、版本控制和管理用于训练、验证和评估的数据集,提供从原始数据到模型预测的完整血缘关系。

为什么数据集跟踪很重要

数据集跟踪对于可重现的机器学习至关重要,并提供以下几个关键优势

  • 数据血缘:跟踪从原始数据源到模型输入的完整历程
  • 可重现性:确保实验可以使用相同的数据集重现
  • 版本控制:随着数据集的演变来管理其不同版本
  • 协作:跨团队共享数据集及其元数据
  • 评估集成:与 MLflow 的评估功能无缝集成
  • 生产监控:跟踪生产推理和评估中使用的 নিরী数据集

核心组件

MLflow 的数据集跟踪围绕两个主要的抽象概念展开

数据集 (Dataset)

Dataset 抽象是一个元数据跟踪对象,其中包含有关已记录数据集的全面信息。Dataset 对象中存储的信息包括

核心属性

  • 名称 (Name):数据集的描述性标识符(如果未指定,默认为 "dataset")
  • 摘要 (Digest):用于识别数据集的唯一哈希/指纹(自动计算)
  • 源 (Source):包含有关原始数据位置的血缘信息的 DatasetSource
  • 架构 (Schema):可选的数据集架构(实现特定,例如 MLflow Schema)
  • 分析 (Profile):可选的汇总统计信息(实现特定,例如行数、列统计信息)

支持的数据集类型

特殊数据集类型

  • EvaluationDataset - 内部数据集类型,专门与 mlflow.models.evaluate() 结合用于模型评估工作流

数据集源 (DatasetSource)

DatasetSource 组件提供与数据原始来源的关联血缘关系,无论它是文件 URL、S3 存储桶、数据库表还是任何其他数据源。这确保您可以始终追溯到数据的起源地。

可以使用 mlflow.data.get_source() API 检索 DatasetSource,该 API 接受 DatasetDatasetEntityDatasetInput 的实例。

快速开始:基本数据集跟踪

以下是开始基本数据集跟踪的方法

python
import mlflow.data
import pandas as pd

# Load your data
dataset_source_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/winequality-white.csv"
raw_data = pd.read_csv(dataset_source_url, delimiter=";")

# Create a Dataset object
dataset = mlflow.data.from_pandas(
raw_data, source=dataset_source_url, name="wine-quality-white", targets="quality"
)

# Log the dataset to an MLflow run
with mlflow.start_run():
mlflow.log_input(dataset, context="training")

# Your training code here
# model = train_model(raw_data)
# mlflow.sklearn.log_model(model, "model")

数据集信息和元数据

创建数据集时,MLflow 会自动捕获丰富的元数据

python
# Access dataset metadata
print(f"Dataset name: {dataset.name}") # Defaults to "dataset" if not specified
print(
f"Dataset digest: {dataset.digest}"
) # Unique hash identifier (computed automatically)
print(f"Dataset source: {dataset.source}") # DatasetSource object
print(
f"Dataset profile: {dataset.profile}"
) # Optional: implementation-specific statistics
print(f"Dataset schema: {dataset.schema}") # Optional: implementation-specific schema

示例输出

text
Dataset name: wine-quality-white
Dataset digest: 2a1e42c4
Dataset profile: {"num_rows": 4898, "num_elements": 58776}
Dataset schema: {"mlflow_colspec": [
{"type": "double", "name": "fixed acidity"},
{"type": "double", "name": "volatile acidity"},
...
{"type": "long", "name": "quality"}
]}
Dataset source: <DatasetSource object>
数据集属性

profileschema 属性是实现特定的,可能因数据集类型(PandasDataset、SparkDataset 等)而异。某些数据集类型可能会为这些属性返回 None

数据集源和血缘

MLflow 支持来自各种来源的数据集

python
# From local file
local_dataset = mlflow.data.from_pandas(
df, source="/path/to/local/file.csv", name="local-data"
)

# From cloud storage
s3_dataset = mlflow.data.from_pandas(
df, source="s3://bucket/data.parquet", name="s3-data"
)

# From database
db_dataset = mlflow.data.from_pandas(
df, source="postgresql://user:pass@host/db", name="db-data"
)

# From URL
url_dataset = mlflow.data.from_pandas(
df, source="https://example.com/data.csv", name="web-data"
)

MLflow UI 中的数据集跟踪

当您将数据集记录到 MLflow 运行中时,它们会出现在 MLflow UI 中,并附带全面的元数据。您可以直接在界面中查看数据集信息、架构和血缘关系。

Dataset in MLflow UI

UI 显示

  • 数据集名称和摘要
  • 带有列类型的架构信息
  • 分析统计信息(行数等)
  • 源血缘信息
  • 使用数据集的上下文

与 MLflow 评估的集成

MLflow 数据集最强大的功能之一是它们与 MLflow 评估功能的无缝集成。在使用 mlflow.models.evaluate() 时,MLflow 会自动在内部将各种数据类型转换为 EvaluationDataset 对象。

EvaluationDataset

MLflow 在处理 mlflow.models.evaluate() 时使用内部 EvaluationDataset 类。此数据集类型会从您的输入数据自动创建,并专门为评估工作流提供优化的哈希和元数据跟踪。

直接将数据集与 MLflow 评估一起使用

python
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Prepare data and train model
data = pd.read_csv("classification_data.csv")
X = data.drop("target", axis=1)
y = data["target"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)

model = RandomForestClassifier()
model.fit(X_train, y_train)

# Create evaluation dataset
eval_data = X_test.copy()
eval_data["target"] = y_test

eval_dataset = mlflow.data.from_pandas(
eval_data, targets="target", name="evaluation-set"
)

with mlflow.start_run():
# Log model
mlflow.sklearn.log_model(model, name="model", input_example=X_test)

# Evaluate using the dataset
result = mlflow.models.evaluate(
model="runs:/{}/model".format(mlflow.active_run().info.run_id),
data=eval_dataset,
model_type="classifier",
)

print(f"Accuracy: {result.metrics['accuracy_score']:.3f}")

MLflow 评估集成示例

这是一个完整的示例,展示了数据集如何与 MLflow 的评估功能集成

Dataset Evaluation in MLflow UI

评估运行显示了数据集、模型、指标和评估工件(如混淆矩阵)是如何一起记录的,从而提供了评估过程的完整视图。

高级数据集管理

跟踪数据集随时间推移的版本演变

python
def create_versioned_dataset(data, version, base_name="customer-data"):
"""Create a versioned dataset with metadata."""

dataset = mlflow.data.from_pandas(
data,
source=f"data_pipeline_v{version}",
name=f"{base_name}-v{version}",
targets="target",
)

with mlflow.start_run(run_name=f"Dataset_Version_{version}"):
mlflow.log_input(dataset, context="versioning")

# Log version metadata
mlflow.log_params(
{
"dataset_version": version,
"data_size": len(data),
"features_count": len(data.columns) - 1,
"target_distribution": data["target"].value_counts().to_dict(),
}
)

# Log data quality metrics
mlflow.log_metrics(
{
"missing_values_pct": (data.isnull().sum().sum() / data.size) * 100,
"duplicate_rows": data.duplicated().sum(),
"target_balance": data["target"].std(),
}
)

return dataset


# Create multiple versions
v1_dataset = create_versioned_dataset(data_v1, "1.0")
v2_dataset = create_versioned_dataset(data_v2, "2.0")
v3_dataset = create_versioned_dataset(data_v3, "3.0")

生产用例

监控生产批量预测中使用的 নিরী数据集

python
def monitor_batch_predictions(batch_data, model_version, date):
"""Monitor production batch prediction datasets."""

# Create dataset for batch predictions
batch_dataset = mlflow.data.from_pandas(
batch_data,
source=f"production_batch_{date}",
name=f"batch_predictions_{date}",
targets="true_label" if "true_label" in batch_data.columns else None,
predictions="prediction" if "prediction" in batch_data.columns else None,
)

with mlflow.start_run(run_name=f"Batch_Monitor_{date}"):
mlflow.log_input(batch_dataset, context="production_batch")

# Log production metadata
mlflow.log_params(
{
"batch_date": date,
"model_version": model_version,
"batch_size": len(batch_data),
"has_ground_truth": "true_label" in batch_data.columns,
}
)

# Monitor prediction distribution
if "prediction" in batch_data.columns:
pred_metrics = {
"prediction_mean": batch_data["prediction"].mean(),
"prediction_std": batch_data["prediction"].std(),
"unique_predictions": batch_data["prediction"].nunique(),
}
mlflow.log_metrics(pred_metrics)

# Evaluate if ground truth is available
if all(col in batch_data.columns for col in ["prediction", "true_label"]):
result = mlflow.models.evaluate(data=batch_dataset, model_type="classifier")
print(f"Batch accuracy: {result.metrics.get('accuracy_score', 'N/A')}")

return batch_dataset


# Usage
batch_dataset = monitor_batch_predictions(daily_batch_data, "v2.1", "2024-01-15")

最佳实践

在使用 MLflow 数据集时,请遵循以下最佳实践

数据质量:在记录数据集之前,始终验证数据质量。检查缺失值、重复项和数据类型。

命名约定:为数据集使用一致、描述性的名称,其中包含版本信息和上下文。

源文档:始终指定有意义的源 URL 或标识符,以便您可以追溯到原始数据。

上下文规范:在记录数据集时使用清晰的上下文标签(例如,“训练”、“验证”、“评估”、“生产”)。

元数据记录:包含有关数据收集、预处理步骤和数据特征的相关元数据。

版本控制:明确跟踪数据集版本,尤其是在数据预处理或收集方法发生变化时。

摘要计算:不同数据集类型的摘要计算方式不同

  • 标准数据集:基于数据内容和结构
  • MetaDataset:仅基于元数据(名称、源、架构)- 不进行实际数据哈希处理
  • EvaluationDataset:使用大数据的样本行进行优化哈希处理

源灵活性DatasetSource 支持各种源类型,包括 HTTP URL、文件路径、数据库连接和云存储位置。

评估集成:通过明确指定目标和预测列来设计数据集,以考虑评估。

主要优势

MLflow 数据集跟踪为 ML 团队提供了几个关键优势

可重现性:确保即使在数据源不断演变的情况下,实验也能使用相同的数据集重现。

血缘跟踪:维护从源到模型预测的完整数据血缘关系,从而实现更好的调试和合规性。

协作:使用一致的界面跨团队成员共享数据集及其元数据。

评估集成:与 MLflow 评估功能无缝集成,实现全面的模型评估。

生产监控:跟踪生产系统中使用的数据集,用于性能监控和数据漂移检测。

质量保证:自动捕获数据质量指标并随时间推移进行监控。

无论您是跟踪训练数据集、管理评估数据,还是监控生产批量预测,MLflow 的数据集跟踪功能都为可靠、可重现的机器学习工作流提供了基础。