使用 MLflow 进行深度学习(第 2 部分)
在深度学习领域,对私有数据集上的预训练大语言模型(LLM)进行微调是一种出色的定制化选项,可以提高模型在特定任务上的相关性。这种做法不仅常见,而且对于开发专门化模型至关重要,尤其是在文本分类和摘要等任务中。
在这种情况下,像 MLflow 这样的工具是无价的。MLflow 这类跟踪工具有助于确保训练过程的每个方面——指标、参数和构件——都能被可复现地跟踪和记录,从而方便对调优迭代进行分析、比较和共享。
在这篇博文中,我们将使用 MLflow 2.12 和最近推出的 MLflow 深度学习功能来跟踪微调大语言模型进行文本分类的所有重要方面,包括使用自动记录训练检查点来简化恢复训练的过程。
用例:为文本分类微调 Transformer 模型
我们在这篇博文中的示例场景使用了 unfair-TOS 数据集。
在当今世界,很难找到一个不附带具有法律约束力的服务条款的服务、平台,甚至是消费品。这些百科全书般大小的协议,充满了密集的法律术语和有时令人费解的细节,以至于大多数人只是不经阅读就接受了。然而,有报告指出,有时其中会嵌入一些可疑的不公平条款。
通过机器学习(ML)解决服务条款(TOS)协议中的不公平条款尤为重要,因为这关系到影响消费者的法律协议的透明度和公平性的迫切需求。请看一个 TOS 协议示例中的条款:“我们可能随时修订这些条款。这些变更不具追溯力,且最新版本的条款将始终……” 该条款规定,服务提供商可以随时以任何理由暂停或终止服务,无论是否通知。大多数人会认为这相当不公平。
虽然这句话深埋在一份相当冗长的文件中,但 ML 算法不会像人类那样因费力梳理文本、识别可能显得有些不公平的条款而感到疲惫。通过自动识别潜在的不公平条款,基于 Transformer 的深度学习(DL)模型可以帮助保护消费者免受剥削性行为的侵害,确保更好地遵守法律标准,并培养服务提供商和用户之间的信任。
一个基础的预训练 Transformer 模型,在没有经过专门微调的情况下,在准确识别不公平服务条款方面面临几个挑战。首先,它缺乏理解复杂法律语言所必需的领域特定知识。其次,其训练目标过于宽泛,无法捕捉法律分析所需的细微解释。最后,它可能无法有效识别决定合同条款公平性的微妙上下文含义,使其在这种专业化任务中效果较差。
使用提示工程来处理识别不公平服务条款的问题,若采用闭源大语言模型,其成本可能会高得令人望而却步。这种方法需要大量的试错来完善提示,却无法调整底层模型机制。每次迭代都会消耗大量计算资源,尤其是在使用小样本提示时,这会导致成本不断攀升,却无法保证准确性或有效性相应提高。
在这种背景下,使用 RoBERTa-base 模型特别有效,前提是经过微调。该模型足够强大,可以处理像辨别文本中嵌入指令这样的复杂任务,同时又足够紧凑,可以在适度的硬件(如 Nvidia T4 GPU)上进行微调。
什么是 PEFT?
参数高效微调(PEFT)方法的优势在于,它们保持了预训练模型的大部分参数固定,而只训练少数附加层,或者在与模型权重交互时修改所使用的参数。这种方法不仅在训练期间节省了内存,还显著减少了总训练时间。与为了针对特定任务定制性能而微调基础模型权重的替代方案相比,PEFT 方法可以在时间和金钱上节省大量成本,同时能以比全面微调训练任务所需更少的数据,提供同等或更好的性能结果。
集成 Hugging-Face 模型和 PyTorch Lightning 框架
PyTorch Lightning 与 Hugging Face 的 Transformers 库无缝集成,实现了简化的模型训练工作流,充分利用了 Lightning 易于使用的高级 API 和 HF 的最先进预训练模型。Lightning 与 Transformers 的 PEFT 模块相结合,通过降低代码复杂性并允许使用高质量的预优化模型来处理各种 NLP 任务,从而提高了生产力和可扩展性。
以下是一个使用 PyTorch Lightning 和 HuggingFace 的 peft
模块配置基于 PEFT 的基础模型微调的示例。
from typing import List
from lightning import LightningModule
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForSequenceClassification
class TransformerModule(LightningModule):
def __init__(
self,
pretrained_model: str,
num_classes: int = 2,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
r: int = 8,
lr: float = 2e-4
):
super().__init__()
self.model = self.create_model(pretrained_model, num_classes, lora_alpha, lora_dropout, r)
self.lr = lr
self.save_hyperparameters("pretrained_model")
def create_model(self, pretrained_model, num_classes, lora_alpha, lora_dropout, r):
"""Create and return the PEFT model with the given configuration.
Args:
pretrained_model: The path or identifier for the pretrained model.
num_classes: The number of classes for the sequence classification.
lora_alpha: The alpha parameter for LoRA.
lora_dropout: The dropout rate for LoRA.
r: The rank of LoRA adaptations.
Returns:
Model: A model configured with PEFT.
"""
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=pretrained_model,
num_labels=num_classes
)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout
)
return get_peft_model(model, peft_config)
def forward(self, input_ids: List[int], attention_mask: List[int], label: List[int]):
"""Calculate the loss by passing inputs to the model and comparing against ground truth labels.
Args:
input_ids: List of token indices to be fed to the model.
attention_mask: List to indicate to the model which tokens should be attended to, and which should not.
label: List of ground truth labels associated with the input data.
Returns:
torch.Tensor: The computed loss from the model as a tensor.
"""
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=label
)
有关完整实现的其他参考资料,可以在配套的仓库中查看。
为基于 PEFT 的微调配置 MLflow
在启动训练过程之前,配置 MLflow 至关重要,以便为训练运行记录所有系统指标、损失指标和参数。从 MLflow 2.12 开始,TensorFlow 和 PyTorch 的自动记录功能现已支持在训练期间对模型权重进行检查点设置,从而在定义的周期频率下提供模型权重的快照,以便在发生错误或计算环境丢失时恢复训练。以下是如何启用此功能的示例。
import mlflow
mlflow.enable_system_metrics_logging()
mlflow.pytorch.autolog(checkpoint_save_best_only = False, checkpoint_save_freq='epoch')
在上面的代码中,我们正在做以下事情:
- 启用系统指标记录:系统资源将被记录到 MLflow,以便了解在整个训练过程中内存、CPU、GPU、磁盘使用和网络流量的瓶颈所在。
- 配置自动记录以记录所有周期的参数、指标和检查点:深度学习涉及试验各种模型架构和超参数设置。自动记录在系统地记录这些实验中扮演着至关重要的角色,使得比较不同的运行并确定哪些配置能产生最佳结果变得更加容易。检查点在每个周期都会被记录,从而能够在项目的初步探索阶段对所有中间周期进行详细评估。然而,通常不建议在后期开发阶段记录所有周期,以避免在最终训练阶段出现过多的数据写入和延迟。
自动记录的检查点指标和模型构件将在模型训练时在 MLflow UI 中可见,如下所示。
记录和早停的重要性
Pytorch Lightning Trainer
回调与 MLflow 的集成在此次训练练习中至关重要。该集成可以在模型微调期间全面跟踪和记录指标、参数和构件,而无需显式调用 MLflow 的记录 API。此外,自动记录 API 允许修改默认的记录行为,可以更改记录频率,允许在每个周期、指定数量的周期后或在明确定义的步骤中进行记录。
早停
早停是神经网络训练中一项关键的正则化技术,旨在通过在验证性能停滞不前时停止训练来帮助防止过拟合。Pytorch Lightning 包含的 API 允许对训练停止进行简单的高级控制,如下所示。
配置 Pytorch Trainer 回调与早停
以下示例展示了在 Lightning
中配置 Trainer
对象,以利用早停来防止过拟合。配置完成后,通过在 Trainer
对象上调用 fit
来执行训练。通过提供 EarlyStopping
回调,并结合 MLflow 的自动记录功能,将使用、记录和跟踪适当的周期数,而无需任何额外的工作。
from dataclasses import dataclass, field
import os
from data import LexGlueDataModule
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping
import mlflow
@dataclass
class TrainConfig:
pretrained_model: str = "bert-base-uncased"
num_classes: int = 2
lr: float = 2e-4
max_length: int = 128
batch_size: int = 256
num_workers: int = os.cpu_count()
max_epochs: int = 10
debug_mode_sample: int | None = None
max_time: dict[str, float] = field(default_factory=lambda: {"hours": 3})
model_checkpoint_dir: str = "/local_disk0/tmp/model-checkpoints"
min_delta: float = 0.005
patience: int = 4
train_config = TrainConfig()
# Instantiate the custom Transformer class for PEFT training
nlp_model = TransformerModule(
pretrained_model=train_config.pretrained_model,
num_classes=train_config.num_classes,
lr=train_config.lr,
)
datamodule = LexGlueDataModule(
pretrained_model=train_config.pretrained_model,
max_length=train_config.max_length,
batch_size=train_config.batch_size,
num_workers=train_config.num_workers,
debug_mode_sample=train_config.debug_mode_sample,
)
# Log system metrics while training loop is running
mlflow.enable_system_metrics_logging()
# Automatically log per-epoch parameters, metrics, and checkpoint weights
mlflow.pytorch.autolog(checkpoint_save_best_only = False)
# Define the Trainer configuration
trainer = Trainer(
callbacks=[
EarlyStopping(
monitor="Val_F1_Score",
min_delta=train_config.min_delta,
patience=train_config.patience,
verbose=True,
mode="max",
)
],
default_root_dir=train_config.model_checkpoint_dir,
fast_dev_run=bool(train_config.debug_mode_sample),
max_epochs=train_config.max_epochs,
max_time=train_config.max_time,
precision="32-true"
)
# Execute the training run
trainer.fit(model=nlp_model, datamodule=datamodule)
MLflow 中的可视化与共享功能
MLflow 2.12 中新引入的针对深度学习的可视化功能,使您能够比较不同运行和构件在各个周期(epoch)中的表现。在比较训练运行时,MLflow 能够生成有用的可视化图表,这些图表可以集成到仪表板中,便于分享。此外,指标与参数的集中存储,可以有效分析训练效果,如下图所示。
何时停止训练?
在训练深度学习模型时,了解何时停止至关重要。高效的训练(为了最小化进行训练的总成本)和最佳的模型性能在很大程度上依赖于防止模型在训练数据上过拟合。训练时间过长的模型将不可避免地变得非常善于“记住”训练数据,导致模型在面对新数据时性能下降。评估这种行为的一个直接方法是确保在训练循环中捕获验证数据集指标(在不在训练数据集中的数据上评估损失指标)。将 MLflow 回调集成到 PyTorch Lightning Trainer 中,可以在可配置的迭代中迭代记录损失指标,从而能够对训练性能进行易于调试的评估,确保可以在适当的时间强制执行停止标准以防止过拟合。
使用 MLflow 评估微调模型的周期检查点
通过 MLflow 对您的训练过程进行细致的跟踪和记录,您可以灵活地在任何任意检查点检索和测试您的模型。为此,您可以使用 mlflow.pytorch.load_model() API 从特定的运行中加载模型,并使用 `predict()` 方法进行评估。
在下面的示例中,我们将从第 3 个周期的模型检查点加载模型,并使用 Lightning
训练模块根据已保存的训练周期的检查点状态生成预测。
import mlflow
mlflow.pytorch.autolog(disable = True)
run_id = '<Add the run ID>'
model = mlflow.pytorch.load_checkpoint(TransformerModule, run_id, 3)
examples_to_test = ["We reserve the right to modify the service price at any time and retroactively apply the adjusted price to historical service usage."]
train_module = Trainer()
tokenizer = AutoTokenizer.from_pretrained(train_config.pretrained_model)
tokens = tokenizer(examples_to_test,
max_length=train_config.max_length,
padding="max_length",
truncation=True)
ds = Dataset.from_dict(dict(tokens))
ds.set_format(
type="torch", columns=["input_ids", "attention_mask"]
)
train_module.predict(model, dataloaders = DataLoader(ds))
总结
将 MLflow 集成到预训练语言模型的微调过程中,尤其是在自定义命名实体识别、文本分类和指令遵循等应用中,代表了在管理和优化深度学习工作流方面的一大进步。在这些工作流中利用 MLflow 的自动记录和跟踪功能,不仅增强了模型开发的可复现性和效率,还促进了一个协作环境,使得见解和改进可以轻松共享和实施。
随着我们不断推动这些模型所能达到的极限,像 MLflow 这样的工具将在发挥其全部潜力方面发挥关键作用。
如果您有兴趣完整地查看整个示例,请随时查看完整的示例实现。
查看代码
我们提供的代码将深入探讨其他方面,如从检查点进行训练、集成 MLflow 和 TensorBoard,以及使用 Pyfunc 进行模型包装等。这些资源专为在 Databricks 社区版上实现而定制。完整示例仓库中的主运行笔记本可以在这里找到。
立即开始使用 MLflow 2.12
立即深入了解最新的 MLflow 更新,提升您管理机器学习项目的方式!凭借我们最新的增强功能,包括高级指标聚合、自动捕获系统指标、直观的特征分组和简化的搜索功能,MLflow 将把您的机器学习工作流提升到新的高度。立即开始使用 MLflow 的前沿工具和功能。
反馈
我们重视您的意见!我们的功能优先级是根据 2023 年末 MLflow 调查的反馈来确定的。请填写我们的2024 年春季调查,通过参与,您可以帮助确保您最想要的功能在 MLflow 中得到实现。