跳到主要内容

MLflow 与 OpenAI Whisper 简介

下载此 Notebook

在本教程中,了解如何将 OpenAI 的 Whisper(一个 ASR 系统)与 MLflow 集成。

您将在此教程中学到什么

  • 使用 Whisper 模型建立音频转录管道
  • 使用 MLflow 记录和管理 Whisper 模型。
  • 推断和理解 Whisper 模型签名
  • 加载并与存储在 MLflow 中的 Whisper 模型交互。
  • 利用 MLflow 的 pyfunc 进行 Whisper 模型服务和转录任务。

什么是 Whisper?

Whisper 由 OpenAI 开发,是一个多功能的 ASR 模型,经过训练可实现高准确度的语音转文本转换。其突出之处在于它在各种口音和环境上进行训练,并通过 Transformers 库提供,易于使用。

为什么将 MLflow 与 Whisper 结合使用?

将 MLflow 与 Whisper 集成增强了 ASR 模型管理

  • 实验跟踪:有助于跟踪模型配置和性能,以获得最佳结果。
  • 模型管理:集中管理不同版本的 Whisper 模型,增强组织性和可访问性。
  • 可复现性:通过跟踪重现模型行为所需的所有组件,确保转录的一致性。
  • 部署:简化 Whisper 模型在各种生产环境中的部署,确保高效应用。

有兴趣了解更多关于 Whisper 的信息吗?要阅读更多关于 Whisper 为 ASR 领域带来的转录能力的重大突破,您可以阅读白皮书,并在 OpenAI 的研究网站上查看更多关于活跃开发和阅读更多关于进展的信息。

准备好增强您的语音转文本能力了吗?让我们使用 MLflow 和 Whisper 探索自动语音识别!

# Disable tokenizers warnings when constructing pipelines
%env TOKENIZERS_PARALLELISM=false

import warnings

# Disable a few less-than-useful UserWarnings from setuptools and pydantic
warnings.filterwarnings("ignore", category=UserWarning)
env: TOKENIZERS_PARALLELISM=false

设置环境和获取音频数据

使用 Whisper 进行转录的初步步骤:获取音频并设置 MLflow。

在使用 OpenAI 的 Whisper 进行音频转录过程之前,需要进行一些准备步骤,以确保一切就绪,获得顺畅高效的转录体验。

音频获取

第一步是获取一个音频文件进行处理。在本教程中,我们使用一个来自 NASA 的公开可用音频文件。此样本音频提供了一个实际示例,以展示 Whisper 的转录能力。

模型和管道初始化

我们从 Transformers 库加载 Whisper 模型及其分词器和特征提取器。这些组件对于处理音频数据并将其转换为 Whisper 模型能够理解和转录的格式至关重要。接下来,我们使用 Whisper 模型创建一个转录管道。此管道简化了将音频数据馈送给模型并获取转录的过程。

MLflow 环境设置

除了模型和音频数据设置外,我们还初始化了 MLflow 环境。MLflow 用于跟踪和管理我们的实验,提供了一种有组织的方式来记录转录过程和结果。

以下代码块涵盖了这些初始设置步骤,为使用 Whisper 模型进行音频转录任务奠定了基础。

import requests
import transformers

import mlflow

# Acquire an audio file that is in the public domain
resp = requests.get(
"https://www.nasa.gov/wp-content/uploads/2015/01/590325main_ringtone_kennedy_WeChoose.mp3"
)
resp.raise_for_status()
audio = resp.content

# Set the task that our pipeline implementation will be using
task = "automatic-speech-recognition"

# Define the model instance
architecture = "openai/whisper-large-v3"

# Load the components and necessary configuration for Whisper ASR from the Hugging Face Hub
model = transformers.WhisperForConditionalGeneration.from_pretrained(architecture)
tokenizer = transformers.WhisperTokenizer.from_pretrained(architecture)
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(architecture)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]

# Instantiate our pipeline for ASR using the Whisper model
audio_transcription_pipeline = transformers.pipeline(
task=task, model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

格式化转录输出

在本节中,我们介绍一个实用函数,该函数仅用于增强此 Jupyter Notebook 演示中转录输出的可读性。请务必注意,此函数是为演示目的而设计的,不应包含在生产代码中,也不应用于本教程之外的任何其他目的。

format_transcription 函数接受一个长字符串的转录文本,并通过将其分割成句子并插入换行符来对其进行格式化。这使得在 Notebook 环境中打印输出时更易于阅读。

def format_transcription(transcription):
"""
Function for formatting a long string by splitting into sentences and adding newlines.
"""
# Split the transcription into sentences, ensuring we don't split on abbreviations or initials
sentences = [
sentence.strip() + ("." if not sentence.endswith(".") else "")
for sentence in transcription.split(". ")
if sentence
]

# Join the sentences with a newline character
return "
".join(sentences)

执行转录管道

使用 Whisper 管道执行音频转录并查看输出。

设置好 Whisper 模型和音频转录管道后,我们的下一步是处理音频文件以提取其转录文本。本教程的这一部分至关重要,因为它展示了 Whisper 模型在将口语转换为书面文本方面的实际应用。

转录过程

下面的代码块将音频文件馈送给管道,然后生成转录文本。前面定义的 format_transcription 函数通过使用句子分割和换行符格式化输出来增强可读性。

保存前测试的重要性

在将模型保存在 MLflow 中之前测试转录管道至关重要。此步骤验证模型是否按预期工作,确保准确性和可靠性。这种验证可以避免部署后出现问题,并确认模型与其训练数据保持一致的性能。它还提供了一个基准,用于与从 MLflow 加载模型后的输出进行比较,确保性能的一致性。

执行以下代码以转录音频并评估 Whisper 模型提供的转录文本的质量和准确性。

# Verify that our pipeline is capable of processing an audio file and transcribing it
transcription = audio_transcription_pipeline(audio)

print(format_transcription(transcription["text"]))
We choose to go to the moon in this decade and do the other things.
Not because they are easy, but because they are hard.
3, 2, 1, 0.
All engines running.
Liftoff.
We have a liftoff.
32 minutes past the hour.
Liftoff on Apollo 11.

模型签名和配置

为 Whisper 生成模型签名,以了解其输入和输出数据要求。

模型签名对于定义 Whisper 模型输入和输出的模式至关重要,明确了预期的数据类型和结构。此步骤确保模型正确处理输入并输出结构化数据。

处理不同的音频格式

虽然默认签名涵盖了二进制音频数据,但 transformers flavor 支持多种格式,包括 numpy 数组和基于 URL 的输入。这种灵活性使得 Whisper 可以从各种来源进行转录,尽管此处未演示基于 URL 的转录。

模型配置

设置模型配置包括音频处理的步长等参数。这些设置可以调整以适应不同的转录需求,从而增强 Whisper 在特定场景下的性能。

运行下一个代码块,推断模型的签名并配置关键参数,使 Whisper 的功能符合您项目的要求。

# Specify parameters and their defaults that we would like to be exposed for manipulation during inference time
model_config = {
"chunk_length_s": 20,
"stride_length_s": [5, 3],
}

# Define the model signature by using the input and output of our pipeline, as well as specifying our inference parameters that will allow for those parameters to
# be overridden at inference time.
signature = mlflow.models.infer_signature(
audio,
mlflow.transformers.generate_signature_output(audio_transcription_pipeline, audio),
params=model_config,
)

# Visualize the signature
signature
inputs: 
[binary]
outputs: 
[string]
params: 
['chunk_length_s': long (default: 20), 'stride_length_s': long (default: [5, 3]) (shape: (-1,))]

创建实验

我们创建一个新的 MLflow 实验,以便我们要记录模型的运行不会记录到默认实验,而是拥有自己的上下文相关条目。

# If you are running this tutorial in local mode, leave the next line commented out.
# Otherwise, uncomment the following line and set your tracking uri to your local or remote tracking server.

# mlflow.set_tracking_uri("http://127.0.0.1:8080")

mlflow.set_experiment("Whisper Transcription ASR")
<Experiment: artifact_location='file:///Users/benjamin.wilson/repos/mlflow-fork/mlflow/docs/source/llms/transformers/tutorials/audio-transcription/mlruns/864092483920291025', creation_time=1701294423466, experiment_id='864092483920291025', last_update_time=1701294423466, lifecycle_stage='active', name='Whisper Transcription ASR', tags={}>

使用 MLflow 记录模型

学习如何使用 MLflow 记录 Whisper 模型及其配置。

在 MLflow 中记录 Whisper 模型是捕获模型复现、共享和部署所需关键信息的重要步骤。此过程包括

模型记录的关键组件

  • 模型信息:包括模型、其签名和输入示例。
  • 模型配置:为模型设置的任何特定参数,例如块长度步长

使用 MLflow 的 log_model 函数

此函数在 MLflow 运行中使用,用于记录模型及其配置。它确保记录了模型使用所需的所有必要组件。

执行下一个单元格中的代码将在当前的 MLflow 实验中记录 Whisper 模型。这包括将模型存储在指定的 artifact 路径中,并记录在推断过程中将应用的默认配置。

# Log the pipeline
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=audio_transcription_pipeline,
artifact_path="whisper_transcriber",
signature=signature,
input_example=audio,
model_config=model_config,
# Since MLflow 2.11.0, you can save the model in 'reference-only' mode to reduce storage usage by not saving
# the base model weights but only the reference to the HuggingFace model hub. To enable this, uncomment the
# following line:
# save_pretrained=False,
)

加载和使用模型管道

探索如何从 MLflow 加载和使用 Whisper 模型管道。

在 MLflow 中记录 Whisper 模型后,下一个关键步骤是加载并使用它进行推断。此过程确保我们记录的模型按预期运行,并可以有效地用于音频转录等任务。

加载模型

使用 MLflow 的 load_model 函数以其原生格式加载模型。此步骤验证模型在 MLflow 中记录后可以被无缝检索和使用。

使用加载的模型

加载后,模型即可进行推断。我们通过将 MP3 音频文件传递给模型并获取其转录文本来演示这一点。此测试是对模型记录后能力的实际演示。

此步骤是在转向更复杂的部署场景之前的一种验证形式。确保模型在其原生格式下正常工作有助于故障排除并简化部署过程,特别是对于像 Whisper 这样大型且复杂的模型。

# Load the pipeline in its native format
loaded_transcriber = mlflow.transformers.load_model(model_uri=model_info.model_uri)

# Perform transcription with the native pipeline implementation
transcription = loaded_transcriber(audio)

print(f"
Whisper native output transcription:
{format_transcription(transcription['text'])}")
2023/11/30 12:51:43 INFO mlflow.transformers: 'runs:/f7503a09d20f4fb481544968b5ed28dd/whisper_transcriber' resolved as 'file:///Users/benjamin.wilson/repos/mlflow-fork/mlflow/docs/source/llms/transformers/tutorials/audio-transcription/mlruns/864092483920291025/f7503a09d20f4fb481544968b5ed28dd/artifacts/whisper_transcriber'
Loading checkpoint shards:   0%|          | 0/13 [00:00<?, ?it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Whisper native output transcription:
We choose to go to the moon in this decade and do the other things.
Not because they are easy, but because they are hard.
3, 2, 1, 0.
All engines running.
Liftoff.
We have a liftoff.
32 minutes past the hour.
Liftoff on Apollo 11.

使用 Pyfunc Flavor 进行推断

了解 MLflow 的 pyfunc flavor 如何促进灵活的模型部署。

MLflow 的 pyfunc flavor 提供了一个用于模型推断的通用接口,可在各种机器学习框架和部署环境中提供灵活性。此功能对于部署原始框架可能不可用或需要更具适应性接口的模型非常有益。

使用 Pyfunc 加载和预测

下面的代码演示了如何将 Whisper 模型加载为 pyfunc 并使用它进行预测。此方法突出了 MLflow 在不同场景中适应和部署模型的能力。

输出格式注意事项

请注意使用 pyfunc 与原生格式相比输出格式的差异。pyfunc 输出符合标准的 pyfunc 输出签名,通常表示为 List[str] 类型,与 MLflow 更广泛的模型输出标准保持一致。

# Load the saved transcription pipeline as a generic python function
pyfunc_transcriber = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)

# Ensure that the pyfunc wrapper is capable of transcribing passed-in audio
pyfunc_transcription = pyfunc_transcriber.predict([audio])

# Note: the pyfunc return type if `return_timestamps` is set is a JSON encoded string.
print(f"
Pyfunc output transcription:
{format_transcription(pyfunc_transcription[0])}")
Loading checkpoint shards:   0%|          | 0/13 [00:00<?, ?it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2023/11/30 12:52:02 WARNING mlflow.transformers: params provided to the `predict` method will override the inference configuration saved with the model. If the params provided are not valid for the pipeline, MlflowException will be raised.
Pyfunc output transcription:
We choose to go to the moon in this decade and do the other things.
Not because they are easy, but because they are hard.
3, 2, 1, 0.
All engines running.
Liftoff.
We have a liftoff.
32 minutes past the hour.
Liftoff on Apollo 11.

教程总结

在本教程中,我们探讨了如何

  • 使用 OpenAI Whisper 模型设置音频转录管道。
  • 格式化和准备音频数据进行转录。
  • 使用 MLflow 记录、加载和使用模型,利用原生和 pyfunc 两种 flavors 进行推断。
  • 格式化输出,以便在 Jupyter Notebook 环境中提高可读性和实用性。

我们已经看到了使用 MLflow 管理机器学习生命周期的好处,包括实验跟踪、模型版本控制、可复现性和部署。通过将 MLflow 与 Transformers 库集成,我们简化了使用最先进 NLP 模型的过程,使得跟踪、管理和部署尖端 NLP 应用变得更加容易。