利用可视化和 MLflow 进行深入的模型分析
简介
在任何机器学习项目中,理解所开发模型的行为、性能和特征都非常重要。清晰、信息丰富的可视化在理解模型方面起着至关重要的作用,可以深入了解模型的模式、错误和效率。
在本指南的这一部分,我们将介绍一个与回归任务相关的常用且有用的图表生成和存储的笔记本。
我们将主要通过以下两种方式与我们记录的模型一起记录图表:
- 直接图表记录:通过
mlflow.log_figure(),我们将使用内存中的图表引用来记录生成的图表。 - 记录本地图表文件:通过
mlflow.log_artifact(),允许我们将本地存储的图像记录到运行中。
可视化在模型分析中的作用
可视化充当了深入了解机器学习模型复杂世界的窗口。它们使得探索各种方面成为可能:
- 理解数据:初始可视化可以深入了解数据,揭示模式、异常和关系,这些信息可以为整个建模过程提供依据。
- 模型评估:残差图和预测误差图等图表有助于诊断模型问题和评估其性能。
- 超参数调优:可视化有助于理解不同超参数对模型性能的影响,指导选择过程。
- 错误分析:它们有助于分析模型所犯错误的类型和模式,从而深入了解可能的改进之处。
关于程序化生成图表的警告
在本指南此小节的配套笔记本中,您将看到在函数中声明图表。这种方法不同于机器学习教程和指南中常见的示例,因此有必要说明选择此方法的原因。
核心问题:状态性
笔记本在单元格之间固有地维护状态。虽然此功能可能很有益,但它对确保代码和输出的可靠性和准确性构成了重大挑战,尤其是对于可视化。
乱序执行的挑战
笔记本环境中最大的问题之一是乱序执行的可能性。单元格可以按任何顺序运行,导致变量或输出的状态无法反映最新的代码更改。对于可视化而言,这个问题尤为严重。如果生成了一个图表,然后在单独的单元格中显示它,那么乱序执行单元格可能会导致显示过时或不正确的可视化。
确保准确的可视化
为了使可视化能够达到传达准确、清晰和可靠信息的目的,它们必须与当前数据和模型的状态相对应。在笔记本环境中确保这种对应关系需要仔细管理单元格的执行顺序和状态,这可能很繁琐且容易出错。
为什么使用函数来生成图表
为了缓解这些挑战,示例代码选择在函数中声明图表。这种方法提供了几个优点:
- 封装:通过将图表生成封装在函数中,代码确保每次调用函数时,都会使用当前的数据状态生成图表。这种封装避免了乱序单元格执行影响图表准确性的陷阱。
- 灵活性和可重用性:函数提供在不复制代码的情况下使用不同参数和数据生成图表的灵活性。这种可重用性增强了代码的可维护性和可读性。
- 与 MLflow 集成:函数与 MLflow 无缝集成,允许将图表与指标、参数和模型一起记录,确保可视化与特定的运行和模型状态对应。这种集成在 MLflow UI 中提供了模型、指标和图表的可靠且统一的视图,避免了笔记本中可能出现的零散视图。
- 避免输出到标准输出:基于函数的方法避免了将图表直接打印到笔记本的标准输出。直接打印会弄乱笔记本,增加保存的笔记本大小,并可能导致笔记本中显示多个图表时出现混淆。通过直接将图表记录到 MLflow,示例代码使笔记本保持整洁,确保图表与特定模型运行相对应,并利用 MLflow 的 UI 查看和比较图表。
通过将图表的生成封装并限定在训练上下文(在 mlflow.start_run() 中)内,我们可以获得笔记本提供的命令式迭代代码开发的所有灵活性、易用性和优势,而无需承担记录过时、无效或不准确且不反映实际数据或模型状态的图表的风险。
将可视化与 MLflow 集成的优势
将可视化与 MLflow 集成具有多项显著优势:
- 持久存储:将可视化与模型一起存储在 MLflow 中,确保它们可供将来参考,防止因会话终止或其他问题而丢失。
- 出处:它为可视化提供了清晰的出处,确保它们提供的见解始终可以追溯到确切的模型版本和数据集。
- 一致性:确保可视化与正确的模型版本相对应,避免混淆和错误。
- 可访问性:使所有团队成员都能轻松访问可视化,增强协作和见解共享。
生成图表
在本指南本节的配套笔记本中,有许多与回归相关的图表样本。有些,如相关矩阵图,与特征数据集相关,而另一些,如系数图,仅在我们训练好模型后才相关。
无论我们是否使用训练好的模型,记录这些图像构件的方法都是相似的。
定义图表
在数据可视化的复杂世界中,图表的结构化和有组织地呈现至关重要。下面是一个生成箱线图的示例,该图将连续变量与分类(有序)变量进行比较。该示例采用了典型的 matplotlib 实现,并使用 seaborn 进行了增强,以获得更精炼的视觉效果。这种结构对于确保我们建模代码的清晰度和可读性至关重要。通过将图表生成定义为一个单独的可调用函数,我们可以维护一个干净有序的代码库。这种方法至关重要,尤其是在笔记本环境中,以确保每次训练迭代都有一个特定且明确的图表生成引用,该引用直接链接到训练迭代中使用的确切数据状态。这种方法可以减轻与声明性定义和具体化图表相关的风险,因为如果数据修改后未重新生成,这些图表可能会导致数据表示不一致和错误。
def plot_box_weekend(df, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
sns.boxplot(data=df, x="weekend", y="demand", ax=ax, color="lightgray")
sns.stripplot(
data=df,
x="weekend",
y="demand",
ax=ax,
hue="weekend",
palette={0: "blue", 1: "green"},
alpha=0.15,
jitter=0.3,
size=5,
)
ax.set_title("Box Plot of Demand on Weekends vs. Weekdays", fontsize=14)
ax.set_xlabel("Weekend (0: No, 1: Yes)", fontsize=12)
ax.set_ylabel("Demand", fontsize=12)
for i in ax.get_xticklabels() + ax.get_yticklabels():
i.set_fontsize(10)
ax.legend_.remove()
plt.tight_layout()
plt.close(fig)
return fig
关键要素
- 标题应用:在图表中包含标题不仅是形式,更是确保清晰度和可理解性的必要条件,尤其是在 MLflow UI 中。精心制作的标题提供了全面的概述,有助于立即理解,消除任何歧义或混淆。
- 覆盖默认大小:调整字体大小和图表大小等各种元素的默认大小,对于确保 MLflow UI 中图表的可读性和视觉吸引力至关重要。它确保图表在查看平台或屏幕尺寸方面保持可读性和清晰度。
- 轴标签:正确标记的轴是可理解的、自给自足的图表的支柱。它们提供关于数据维度的清晰信息,使得图表无需外部参考或解释即可理解。
- 图形关闭:在返回图形之前关闭图形可确保干净整洁的笔记本环境。它防止图表意外显示在笔记本的标准输出中,避免混淆并保持笔记本的组织结构。
- 图例移除:移除图表中自动生成的图例可提高视觉清晰度和可读性。它消除了不必要的混乱,使图表更加简洁明了,确保重点放在重要的数据表示上。
定义要本地保存的图表
在将图表记录到 MLflow 之前将其本地保存有时更有优势。下面的示例说明了相关矩阵图的生成,在调用时保存图像,而不是返回内存引用。这种方法虽然不同,但仍然与 MLflow 无缝兼容,确保了相同的组织和访问级别,并且在图表访问和使用方面具有额外的灵活性。
def plot_correlation_matrix_and_save(
df, style="seaborn", plot_size=(10, 8), path="/tmp/corr_plot.png"
):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Calculate the correlation matrix
corr = df.corr()
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(
corr,
mask=mask,
cmap="coolwarm",
vmax=0.3,
center=0,
square=True,
linewidths=0.5,
annot=True,
fmt=".2f",
)
ax.set_title("Feature Correlation Matrix", fontsize=14)
plt.tight_layout()
plt.close(fig)
# convert to filesystem path spec for os compatibility
save_path = pathlib.Path(path)
fig.savefig(path)
关键见解
- 相关性热力图:在此背景下使用热力图提供了直观有效的特征相关性表示。它允许轻松识别不同特征之间的关系,增强了可理解性和分析深度。
- 标题和布局调整:包含清晰描述性的标题以及布局调整可确保清晰紧凑的呈现,从而增强图表的使用性和解释的便捷性。
- 本地保存图表:将图形本地保存可提供便捷的访问和参考,确保它不与笔记本的执行状态相关联。它提供了访问的灵活性,并确保图表独立可用,从而有助于更组织和高效的数据分析和模型评估过程。
记录图表图像
在主笔记本下面的代码片段中,我们将训练和图表生成执行为单个原子操作。如前所述,这有助于确保,无论笔记本中任何其他单元格的状态如何,生成的图表都将引用用于训练和评估模型的训练数据状态。
对于除相关矩阵以外的所有图表,我们在调用 mlflow.log_figure() 时使用 matplotlib 的直接 Figure 对象引用。对于相关矩阵,我们正在处理本地保存的 .png 图像文件。这需要使用更通用的伪影写入器(支持任何文件类型)mlflow.log_artifact()。
为了简单起见,如果您想将大量图表记录到模型中,建议使用作用域为目录的 mlflow.log_artifacts()。此 API 将记录给定本地目录路径中的所有文件,而无需显式命名每个文件并进行大量 log_artifact() 调用。如果使用基于目录的 log_artifacts(),请确保您的本地文件名足够相关且富有表现力,以便在 MLflow UI 中区分图表的内容。虽然 log_artifact() 允许您在将文件记录到 MLflow 时重命名文件,但批量处理 log_artifacts() API 则不允许(文件名将原样传输)。
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("Visualizations Demo")
X = my_data.drop(columns=["demand", "date"])
y = my_data["demand"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
fig1 = plot_time_series_demand(my_data, window_size=28)
fig2 = plot_box_weekend(my_data)
fig3 = plot_scatter_demand_price(my_data)
fig4 = plot_density_weekday_weekend(my_data)
# Execute the correlation plot, saving the plot to a local temporary directory
plot_correlation_matrix_and_save(my_data)
# Define our Ridge model
model = Ridge(alpha=1.0)
# Train the model
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Calculate error metrics
mse = mean_squared_error(y_test, y_pred)
rmse = math.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
msle = mean_squared_log_error(y_test, y_pred)
medae = median_absolute_error(y_test, y_pred)
# Generate prediction-dependent plots
fig5 = plot_residuals(y_test, y_pred)
fig6 = plot_coefficients(model, X_test.columns)
fig7 = plot_prediction_error(y_test, y_pred)
fig8 = plot_qq(y_test, y_pred)
# Start an MLflow run for logging metrics, parameters, the model, and our figures
with mlflow.start_run() as run:
# Log the model
mlflow.sklearn.log_model(sk_model=model, input_example=X_test, name="model")
# Log the metrics
mlflow.log_metrics(
{"mse": mse, "rmse": rmse, "mae": mae, "r2": r2, "msle": msle, "medae": medae}
)
# Log the hyperparameter
mlflow.log_param("alpha", 1.0)
# Log plots
mlflow.log_figure(fig1, "time_series_demand.png")
mlflow.log_figure(fig2, "box_weekend.png")
mlflow.log_figure(fig3, "scatter_demand_price.png")
mlflow.log_figure(fig4, "density_weekday_weekend.png")
mlflow.log_figure(fig5, "residuals_plot.png")
mlflow.log_figure(fig6, "coefficients_plot.png")
mlflow.log_figure(fig7, "prediction_errors.png")
mlflow.log_figure(fig8, "qq_plot.png")
# Log the saved correlation matrix plot by referring to the local file system location
mlflow.log_artifact("/tmp/corr_plot.png")
在 UI 中查看图表
在执行此训练单元格后,如果我们转到 MLflow UI,我们可以在伪影查看器窗格中看到所有已定义的图表。无论图表是使用 log_figure() API 记录的,还是从本地文件系统获取并通过 log_artifacts() 记录的,我们都可以看到与我们的数据和训练模型相关的、捕获运行进行时状态的、与运行相关的图表。
挑战
您能否想到一些额外的图表,对于数据验证、回归建模或一般的预测质量来说是相关的?
如果您有兴趣,请单击下面的按钮获取笔记本副本,并按照说明进行操作。
下载笔记本并用 Jupyter 打开后
- 实现一些额外的图表,这些图表代表了您在训练(或重新训练)此类模型时希望看到的视图。
- 不要返回图形,而是将每个图表保存到一个公共目录。
- 确保所有图表文件名都是唯一的,并且能指示图表内容。
- 使用
mlflow.log_artifacts()(而不是mlflow.log_artifact())来记录目录内容到运行中。 - 验证 MLflow UI 中图表的渲染。
log_artifacts() API 有一个可选的 artifact_path 参数,可以从默认值 None 覆盖,以便将这些附加图表隔离到 MLflow 伪影存储(以及 UI)中的自己的目录中。如果您要记录数十个具有明显类别分组的图表,这会非常有用,而无需在伪影查看器的主要根目录中填充大量文件。
结论
可视化是构建高质量模型的重要组成部分。MLflow 通过其记录图形、图表和图像的原生集成,使得不仅可以轻松地将用于训练的数据可视化,还可以轻松地将训练事件的结果可视化。
通过可以在模型正在训练的上下文范围内使用的简单、高级 API,可以消除状态不一致,确保每个图表都精确反映训练时数据和模型的状态。