利用可视化和 MLflow 进行深入的模型分析
引言
在任何机器学习项目中,理解所开发模型的行为、性能和特征都非常重要。清晰、信息丰富的可视化在此理解中起着至关重要的作用,能够深入洞察模型的模式、错误和效率。
在本指南的这一部分,我们将查看一个笔记本,其中涉及与回归任务相关的常用且有用的图表的生成和存储。
我们将探讨两种主要的方式来记录图表以及我们记录的模型:
- 通过
mlflow.log_figure()
进行直接图表记录:我们将使用生成图表的内存中的图表引用。 - 通过
mlflow.log_artifact()
记录本地图表文件:这将允许我们将本地存储的图像记录到运行中。
可视化在模型分析中的作用
可视化是洞察机器学习模型复杂世界的窗口。它们能够探索各个方面:
- 理解数据:初步的可视化可以深入探索数据,揭示可以指导整个建模过程的模式、异常和关系。
- 模型评估:残差图和预测误差图等图表有助于诊断模型问题并评估其性能。
- 超参数调优:可视化有助于理解不同超参数对模型性能的影响,指导选择过程。
- 错误分析:它们有助于分析模型产生的错误的类型和模式,为可能的改进提供洞察。
关于图表程序化生成的一个警告
在本指南这一小节配套的笔记本中,您将看到图表在函数内部声明。这种方法与机器学习教程和指南中常见的示例有所不同,因此有必要阐明为什么提供的示例选择了这种方法。
核心问题:状态性
笔记本天生会在单元格之间维护状态。虽然此特性可能有所助益,但它在确保代码和输出的可靠性和准确性方面带来了重大挑战,特别是对于可视化而言。
乱序执行的挑战
笔记本环境中最显著的问题之一是乱序执行的可能性。单元格可以按任何顺序运行,从而导致变量或输出无法反映最新的代码更改。这个问题对于可视化来说尤为突出。如果在生成图表后在另一个单元格中显示,乱序运行单元格可能导致显示过时或不正确的可视化。
确保可视化准确性
为了让可视化能够达到传递准确、清晰和可靠信息的目的,它们必须与当前的数据和模型状态相对应。在笔记本环境中确保这种对应关系需要仔细管理单元格执行顺序和状态,这可能既繁琐又容易出错。
为什么使用函数生成图表
为了缓解这些挑战,示例代码选择在函数内部声明图表。这种方法提供了几个优势:
- 封装:通过将图表生成封装在函数内部,代码确保每次调用函数时都会使用当前数据状态生成图表。这种封装避免了乱序单元格执行影响图表准确性的陷阱。
- 灵活性和可重用性:函数提供了使用不同参数和数据生成图表的灵活性,而无需重复代码。这种可重用性提高了代码的可维护性和可读性。
- 与 MLflow 集成:函数与 MLflow 无缝集成,允许将图表与指标、参数和模型一起记录,确保可视化与特定的运行和模型状态相对应。这种集成在 MLflow UI 中提供了模型、指标和图表的可靠且整合的视图,避免了笔记本中可能出现的视图不一致问题。
- 避免在 Stdout 中显示:基于函数的方法避免了将图表直接打印到笔记本的标准输出(stdout)。直接打印会使笔记本变得混乱,增加保存的笔记本文件大小,并导致笔记本中显示多个图表时产生混淆。通过在 MLflow 中直接记录图表,示例代码保持了笔记本的整洁,确保图表与特定的模型运行相对应,并利用 MLflow 的 UI 来查看和比较图表。
通过将图表生成封装并限定在训练上下文(即 mlflow.start_run()
内部),我们可以获得笔记本带来的所有命令式迭代代码开发的灵活性、易用性和优势,同时避免记录反映旧的、无效的或不准确的数据或模型状态的图表。
将可视化与 MLflow 集成的优势
将可视化与 MLflow 集成带来了几个显著的优势:
- 持久存储:将可视化与模型一起存储在 MLflow 中,确保它们可供未来参考,防止因会话终止或其他问题而丢失。
- 溯源:它为可视化提供了清晰的溯源信息,确保它们提供的洞察始终可以追溯到精确的模型版本和数据集。
- 一致性:确保可视化与正确的模型版本相对应,防止混淆和错误。
- 可访问性:使所有团队成员都能轻松访问可视化,增强协作和洞察共享。
生成图表
在本指南本节的配套笔记本中,提供了许多与回归相关的图表示例。其中一些,例如相关矩阵图,与特征数据集相关;而另一些,例如系数图,仅在我们拥有训练好的模型后才相关。
无论我们是否使用训练好的模型,记录这些图像工件的方法都是类似的。
定义图表
在复杂的数据可视化世界中,图表的结构化和有组织呈现至关重要。下面是一个生成箱线图的示例,它比较了一个连续变量与一个分类(有序)变量。该示例使用了典型的 matplotlib
实现,并用 seaborn
进行了增强,以获得更精美的视觉效果。这种结构对于确保建模代码的清晰度和可读性至关重要。通过将图表生成定义为一个独立的、可调用的函数,我们保持了干净和有组织的代码库。这种方法尤其在笔记本环境中至关重要,以确保每次训练迭代都有一个特定且明确的图表生成引用,直接链接到训练迭代中使用的数据的精确状态。这种方法缓解了与声明式定义和具体化图表相关的风险,这些图表如果在数据修改后未重新生成,可能导致数据表示不一致和错误。
def plot_box_weekend(df, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
sns.boxplot(data=df, x="weekend", y="demand", ax=ax, color="lightgray")
sns.stripplot(
data=df,
x="weekend",
y="demand",
ax=ax,
hue="weekend",
palette={0: "blue", 1: "green"},
alpha=0.15,
jitter=0.3,
size=5,
)
ax.set_title("Box Plot of Demand on Weekends vs. Weekdays", fontsize=14)
ax.set_xlabel("Weekend (0: No, 1: Yes)", fontsize=12)
ax.set_ylabel("Demand", fontsize=12)
for i in ax.get_xticklabels() + ax.get_yticklabels():
i.set_fontsize(10)
ax.legend_.remove()
plt.tight_layout()
plt.close(fig)
return fig
关键要素
- 应用标题:在图表中包含标题不仅仅是一种形式,它对于确保清晰度和易理解性至关重要,尤其是在 MLflow UI 中。精心设计的标题提供了全面的概览,有助于立即理解并消除任何歧义或困惑。
- 覆盖默认大小:调整字体和图表大小等各种元素的默认大小对于确保图表在 MLflow UI 中的可读性和视觉吸引力至关重要。它确保图表无论在何种查看平台或屏幕尺寸下都保持清晰可读。
- 轴标签:正确标记的轴是易于理解和自洽图表的基础。它们提供了关于数据维度的清晰信息,使得图表无需外部参考或解释即可理解。
- 关闭图表:在返回图表之前将其关闭,确保笔记本环境干净整洁。这可以防止图表在笔记本的标准输出中无意中显示,避免混淆并维护笔记本的组织结构。
- 移除图例:移除图表中自动生成的图例可增强视觉清晰度和可读性。它防止了不必要的混乱,使图表更简洁明了,确保焦点集中在重要的数据表示上。
定义保存到本地的图表
在某些情况下,先将图表保存到本地再记录到 MLflow 会更有优势。下面的示例展示了如何生成相关矩阵图,并在调用时保存图像,而不是返回内存中的引用。这种方法虽然不同,但仍与 MLflow 无缝兼容,确保了相同的组织和访问级别,并在图表的访问和使用方面提供了额外的灵活性。
def plot_correlation_matrix_and_save(
df, style="seaborn", plot_size=(10, 8), path="/tmp/corr_plot.png"
):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Calculate the correlation matrix
corr = df.corr()
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(
corr,
mask=mask,
cmap="coolwarm",
vmax=0.3,
center=0,
square=True,
linewidths=0.5,
annot=True,
fmt=".2f",
)
ax.set_title("Feature Correlation Matrix", fontsize=14)
plt.tight_layout()
plt.close(fig)
# convert to filesystem path spec for os compatibility
save_path = pathlib.Path(path)
fig.savefig(path)
关键洞察
- 相关性热力图:在这种情况下使用热力图提供了一种视觉直观且有效的特征相关性表示方法。它使得识别不同特征之间的关系变得容易,从而增强了理解性和分析深度。
- 标题和布局调整:包含清晰且描述性的标题以及布局调整可确保清晰紧凑的展示,提高了图表的使用性和解释便捷性。
- 图表本地保存:在本地保存图表提供了轻松访问和参考的方式,确保它不依赖于笔记本的执行状态。它提供了访问的灵活性,并确保图表独立可用,有助于更条理化和高效的数据分析和模型评估过程。
记录图表图像
在下面主笔记本中的代码片段中,我们将训练和图表生成作为单个原子操作执行。如前所述,这有助于确保无论笔记本中任何其他单元格的状态如何,生成的图表都将引用用于训练和评估模型的训练数据状态。
对于除相关矩阵之外的所有图表,我们在调用 mlflow.log_figure()
时,都使用了图表的直接 matplotlib
Figure
对象引用。对于相关矩阵,我们操作的是本地保存的 .png
图像文件。这需要使用更通用的工件写入器(它支持任何文件类型)mlflow.log_artifact()
。
为了简化,如果您想将大量图表记录到模型中,建议使用目录范围的 mlflow.log_artifacts()
。这个 API 会记录给定本地目录路径中的所有文件,而无需显式命名每个文件并进行大量 log_artifact()
调用。如果使用基于目录的 log_artifacts()
,请确保您的本地文件名相关且具有足够的说明性,以便在 MLflow UI 中区分图表内容。虽然 log_artifact()
允许您在记录到 MLflow 时重命名给定文件的名称,但批量处理的 log_artifacts()
API 不允许(文件名将原样传输)。
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("Visualizations Demo")
X = my_data.drop(columns=["demand", "date"])
y = my_data["demand"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
fig1 = plot_time_series_demand(my_data, window_size=28)
fig2 = plot_box_weekend(my_data)
fig3 = plot_scatter_demand_price(my_data)
fig4 = plot_density_weekday_weekend(my_data)
# Execute the correlation plot, saving the plot to a local temporary directory
plot_correlation_matrix_and_save(my_data)
# Define our Ridge model
model = Ridge(alpha=1.0)
# Train the model
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Calculate error metrics
mse = mean_squared_error(y_test, y_pred)
rmse = math.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
msle = mean_squared_log_error(y_test, y_pred)
medae = median_absolute_error(y_test, y_pred)
# Generate prediction-dependent plots
fig5 = plot_residuals(y_test, y_pred)
fig6 = plot_coefficients(model, X_test.columns)
fig7 = plot_prediction_error(y_test, y_pred)
fig8 = plot_qq(y_test, y_pred)
# Start an MLflow run for logging metrics, parameters, the model, and our figures
with mlflow.start_run() as run:
# Log the model
mlflow.sklearn.log_model(
sk_model=model, input_example=X_test, artifact_path="model"
)
# Log the metrics
mlflow.log_metrics(
{"mse": mse, "rmse": rmse, "mae": mae, "r2": r2, "msle": msle, "medae": medae}
)
# Log the hyperparameter
mlflow.log_param("alpha", 1.0)
# Log plots
mlflow.log_figure(fig1, "time_series_demand.png")
mlflow.log_figure(fig2, "box_weekend.png")
mlflow.log_figure(fig3, "scatter_demand_price.png")
mlflow.log_figure(fig4, "density_weekday_weekend.png")
mlflow.log_figure(fig5, "residuals_plot.png")
mlflow.log_figure(fig6, "coefficients_plot.png")
mlflow.log_figure(fig7, "prediction_errors.png")
mlflow.log_figure(fig8, "qq_plot.png")
# Log the saved correlation matrix plot by referring to the local file system location
mlflow.log_artifact("/tmp/corr_plot.png")
在 UI 中查看图表
在执行此训练单元格后,如果我们前往 MLflow UI,可以在工件查看器窗格中看到所有已定义的图表。无论图表是使用 log_figure()
API 记录的,还是从本地文件系统获取并通过 log_artifacts()
记录的,我们都能够看到与我们的数据和训练模型相关的运行图表,捕获了运行进行时的状态。
挑战
您能想到一些与数据验证、回归建模或总体预测质量相关的其他图表吗?
如果您感兴趣,请点击下方按钮获取笔记本的副本,并按照说明进行操作。
下载笔记本并使用 Jupyter 打开后:
- 实现更多代表您在此类模型训练(或再训练)时希望看到的可视化的图表。
- 将每个图表保存到一个公共目录,而不是返回图形对象。
- 确保所有图表文件名唯一且能指示图表内容。
- 使用
mlflow.log_artifacts()
(而不是mlflow.log_artifact()
)将目录内容记录到运行中。 - 验证图表在 MLflow UI 中的渲染效果。
log_artifacts()
API 有一个可选的 artifact_path
参数,可以覆盖默认值 None
,以便在 MLflow 工件存储(和 UI)中将这些额外的图表隔离到自己的目录中。如果您记录数十个具有不同类别分组的图表,这会非常有用,无需用主根目录中的大量文件填充工件查看器中的 UI 显示窗格。
总结
可视化是构建高质量模型的关键部分。MLflow 原生集成了记录图形、图表和图像的功能,使得整合可视化变得非常简单,不仅适用于训练数据,也适用于训练事件的结果。
借助可在模型训练上下文中进行范围界定的简单、高级 API,可以消除状态不一致,确保每个图表都能准确反映训练时的数据和模型状态。