评估数据集 SDK 参考
用于以编程方式创建、管理和查询评估数据集的完整 API 参考。
评估数据集需要一个具有 SQL 后端(PostgreSQL、MySQL、SQLite 或 MSSQL)的 MLflow 跟踪服务器。此功能 不适用于 FileStore(基于本地文件系统的跟踪)。
创建数据集
使用 mlflow.genai.datasets.create_dataset() 创建新的评估数据集
from mlflow.genai.datasets import create_dataset
# Create a new dataset
dataset = create_dataset(
name="customer_support_qa",
experiment_id=["0"], # Link to experiments
tags={"version": "1.0", "team": "ml-platform", "status": "active"},
)
print(f"Created dataset: {dataset.dataset_id}")
您也可以使用 mlflow.tracking.MlflowClient() API
from mlflow import MlflowClient
client = MlflowClient()
dataset = client.create_dataset(
name="customer_support_qa",
experiment_id=["0"],
tags={"version": "1.0"},
)
向数据集中添加记录
使用 mlflow.entities.EvaluationDataset.merge_records() 方法向您的数据集添加新记录。记录可以从字典、DataFrame 或跟踪中添加
- 来自字典
- 来自跟踪
- 来自 DataFrame
直接从 Python 字典添加记录
# Add records with inputs and expectations (ground truth)
new_records = [
{
"inputs": {"question": "What are your business hours?"},
"expectations": {
"expected_answer": "We're open Monday-Friday 9am-5pm EST",
"must_mention_hours": True,
"must_include_timezone": True,
},
},
{
"inputs": {"question": "How do I reset my password?"},
"expectations": {
"expected_answer": (
"Click 'Forgot Password' and follow the email instructions"
),
"must_include_steps": True,
},
},
]
dataset.merge_records(new_records)
print(f"Dataset now has {len(dataset.records)} records")
从 MLflow 跟踪添加记录
import mlflow
# Search for traces to add to the dataset
traces = mlflow.search_traces(
experiment_ids=["0"],
filter_string="attributes.name = 'chat_completion'",
max_results=50,
return_type="list",
)
# Add traces directly to the dataset
dataset.merge_records(traces)
从 pandas DataFrame 添加记录
import pandas as pd
# Create DataFrame with structured data (ground truth expectations)
df = pd.DataFrame(
[
{
"inputs": {
"question": "What is MLflow?",
"context": "general",
},
"expectations": {
"expected_answer": "MLflow is an open-source platform for ML lifecycle",
"must_mention": ["tracking", "experiments"],
},
"tags": {"priority": "high"},
},
{
"inputs": {
"question": "How to track experiments?",
"context": "technical",
},
"expectations": {
"expected_answer": "Use mlflow.start_run() and mlflow.log_params()",
"must_mention": ["log_params", "start_run"],
},
"tags": {"priority": "medium"},
},
]
)
dataset.merge_records(df)
更新现有记录
mlflow.entities.EvaluationDataset.merge_records() 方法可以智能地处理更新。记录基于其输入的哈希值进行匹配——如果存在具有相同输入的记录,则会合并其期望值和标签,而不是创建重复项
# Initial record
dataset.merge_records(
[
{
"inputs": {"question": "What is MLflow?"},
"expectations": {
"expected_answer": "MLflow is a platform for ML",
"must_mention_tracking": True,
},
}
]
)
# Update with same inputs but enhanced expectations
dataset.merge_records(
[
{
"inputs": {"question": "What is MLflow?"}, # Same inputs = update
"expectations": {
# Updates existing value
"expected_answer": (
"MLflow is an open-source platform for managing the ML lifecycle"
),
"must_mention_models": True, # Adds new expectation
# Note: "must_mention_tracking": True is preserved
},
}
]
)
# Result: One record with merged expectations
检索数据集
通过 ID 检索现有数据集或搜索它们
- 按 ID 获取
- 搜索数据集
from mlflow.genai.datasets import get_dataset
# Get a specific dataset by ID
dataset = get_dataset(dataset_id="d-7f2e3a9b8c1d4e5f")
# Access dataset properties
print(f"Name: {dataset.name}")
print(f"Records: {len(dataset.records)}")
print(f"Schema: {dataset.schema}")
print(f"Tags: {dataset.tags}")
from mlflow.genai.datasets import search_datasets
# Search for datasets with filters
datasets = search_datasets(
experiment_ids=["0"],
filter_string="tags.status = 'active' AND name LIKE '%support%'",
order_by=["last_update_time DESC"],
max_results=10,
)
for ds in datasets:
print(f"{ds.name} ({ds.dataset_id}): {len(ds.records)} records")
有关筛选器语法的详细信息,请参阅 搜索筛选器参考。
管理标签
向数据集添加、更新或删除标签
from mlflow.genai.datasets import set_dataset_tags, delete_dataset_tag
# Set or update tags
set_dataset_tags(
dataset_id=dataset.dataset_id,
tags={"status": "production", "validated": "true", "version": "2.0"},
)
# Delete a specific tag
delete_dataset_tag(dataset_id=dataset.dataset_id, key="deprecated")
删除数据集
永久删除数据集及其所有记录
from mlflow.genai.datasets import delete_dataset
# Delete the entire dataset
delete_dataset(dataset_id="d-1a2b3c4d5e6f7890")
数据集删除是永久性的,无法撤销。所有记录将被删除。
使用数据集记录
mlflow.entities.EvaluationDataset() 对象提供了多种访问和分析记录的方法
# Access all records
all_records = dataset.records
# Convert to DataFrame for analysis
df = dataset.to_df()
print(df.head())
# View dataset schema (auto-computed from records)
print(dataset.schema)
# View dataset profile (statistics)
print(dataset.profile)
# Get record count
print(f"Total records: {len(dataset.records)}")
高级主题
理解输入唯一性
记录基于其整个输入字典被视为唯一的。即使是微小的差异也会创建单独的记录
# These are treated as different records due to different inputs
record_a = {
"inputs": {"question": "What is MLflow?", "temperature": 0.7},
"expectations": {"expected_answer": "MLflow is an ML platform"},
}
record_b = {
"inputs": {
"question": "What is MLflow?",
"temperature": 0.8,
}, # Different temperature
"expectations": {"expected_answer": "MLflow is an ML platform"},
}
dataset.merge_records([record_a, record_b])
# Results in 2 separate records due to different temperature values
源类型推断
MLflow 在将记录发送到后端之前,会根据以下规则自动分配源类型
自动推断
当未提供显式源时,MLflow 会根据记录的特征自动推断源类型。
客户端处理
源类型推断在记录发送到跟踪后端之前,在 merge_records() 中进行。
手动覆盖
您可以始终指定显式的源信息来覆盖自动推断。
推断规则
- TRACE 源
- HUMAN 源
- CODE 源
来自 MLflow 跟踪的记录会自动分配 TRACE 源类型
# When adding traces directly (automatic TRACE source)
traces = mlflow.search_traces(experiment_ids=["0"], return_type="list")
dataset.merge_records(traces)
# Or when using DataFrame from search_traces
traces_df = mlflow.search_traces(experiment_ids=["0"]) # Returns DataFrame
# Automatically detects traces and assigns TRACE source
dataset.merge_records(traces_df)
带有期望值的记录被推断为 HUMAN 源
# Records with expectations indicate human review/annotation
human_curated = [
{
"inputs": {"question": "What is MLflow?"},
"expectations": {
"expected_answer": "MLflow is an open-source ML platform",
"must_mention": ["tracking", "models", "deployment"],
}
# Automatically inferred as HUMAN source
}
]
dataset.merge_records(human_curated)
仅包含输入(无期望值)的记录被推断为 CODE 源
# Records without expectations are inferred as CODE source
generated_tests = [{"inputs": {"question": f"Test question {i}"}} for i in range(100)]
dataset.merge_records(generated_tests)
手动源覆盖
您可以显式指定任何记录的源类型和元数据
# Specify HUMAN source with metadata
human_curated = {
"inputs": {"question": "What are your business hours?"},
"expectations": {
"expected_answer": "We're open Monday-Friday 9am-5pm EST",
"must_include_timezone": True,
},
"source": {
"source_type": "HUMAN",
"source_data": {"curator": "support_team", "date": "2024-11-01"},
},
}
# Specify DOCUMENT source
from_docs = {
"inputs": {"question": "How to install MLflow?"},
"expectations": {
"expected_answer": "pip install mlflow",
"must_mention_pip": True,
},
"source": {
"source_type": "DOCUMENT",
"source_data": {"document_id": "install_guide", "page": 1},
},
}
dataset.merge_records([human_curated, from_docs])
可用源类型
TRACE
通过 MLflow 跟踪捕获的生产数据 - 添加跟踪时自动分配
HUMAN
主题专家注释 - 为带有期望值的记录推断
CODE
程序生成的测试 - 为没有期望值的记录推断
DOCUMENT
来自文档或规范的测试用例 - 必须显式指定
UNSPECIFIED
源未知或未提供 - 用于旧数据或导入的数据
搜索筛选器参考
可搜索字段
| 字段 | 类型 | 示例 |
|---|---|---|
name | 字符串 | name = 'production_tests' |
tags.<key> | 字符串 | tags.status = 'validated' |
created_by | 字符串 | created_by = 'alice@company.com' |
last_updated_by | 字符串 | last_updated_by = 'bob@company.com' |
created_time | timestamp | created_time > 1698800000000 |
last_update_time | timestamp | last_update_time > 1698800000000 |
筛选器运算符
=,!=: 精确匹配LIKE,ILIKE: 使用%通配符进行模式匹配(ILIKE 区分大小写)>,<,>=,<=: 数值/时间戳比较AND: 组合条件(OR 当前不支持)
常见筛选器示例
| 筛选器表达式 | 描述 | 用例 |
|---|---|---|
name = 'production_qa' | 精确名称匹配 | 查找特定数据集 |
name LIKE '%test%' | 模式匹配 | 查找所有测试数据集 |
tags.status = 'validated' | 标签相等性 | 查找生产就绪数据集 |
tags.version = '2.0' AND tags.team = 'ml' | 多个标签条件 | 查找特定团队的版本 |
created_by = 'alice@company.com' | 创建者筛选器 | 按作者查找数据集 |
created_time > 1698800000000 | 基于时间的筛选器 | 查找最近的数据集 |
# Complex filter example
datasets = search_datasets(
filter_string="""
tags.status = 'production'
AND name LIKE '%customer%'
AND created_time > 1698800000000
""",
order_by=["last_update_time DESC"],
)