使用 MLflow 记录可视化
在本指南的这一部分,我们将强调使用 MLflow 进行可视化日志记录的重要性。将可视化与训练后的模型一起保留可以增强模型的可解释性、审计和溯源,从而确保机器学习生命周期的稳健性和透明度。
我们在做什么?
- 存储可视化工件: 我们正在将各种绘图作为可视化工件记录在 MLflow 中,以确保它们始终可访问并与相应的模型和运行数据对齐。
- 增强模型可解释性: 这些可视化有助于理解和解释模型行为,从而提高模型的透明度和责任性。
它如何应用于 MLflow?
- 集成可视化日志记录: MLflow 无缝集成了日志记录和访问可视化工件的工具,从而提高了处理视觉上下文和见解的简便性和效率。
- 便捷访问: 记录的图形可在 MLflow UI 的运行视图窗格中显示,从而确保快速便捷地进行分析和审查。
注意
虽然 MLflow 为可视化日志记录提供了简洁性和便利性,但至关重要的是要确保可视化工件与相应的模型数据的一致性和相关性,从而保持模型信息的完整性和全面性。
为什么一致的日志记录很重要?
- 审计和溯源: 可视化的持续和全面日志记录对于审计目的至关重要,可确保每个模型都附带相关的视觉见解,以进行彻底的分析和审查。
- 增强模型理解: 适当的视觉上下文增强了对模型行为的理解,有助于有效的模型评估和验证。
总之,MLflow 的可视化日志记录功能在确保全面、透明和高效的机器学习生命周期方面发挥着宝贵的作用,从而加强了模型的可解释性、审计和溯源。
生成合成苹果销售数据
在下一节中,我们将深入探讨使用 generate_apple_sales_data_with_promo_adjustment
函数生成用于预测苹果销售需求的合成数据。此函数模拟了与苹果销售相关的各种特征,从而提供了丰富的数据集以供探索和建模。
我们在做什么?
- 模拟真实数据: 生成具有日期、平均温度、降雨量、周末标志等特征的数据集,以模拟苹果销售的真实场景。
- 纳入各种影响: 该函数纳入了促销调整、季节性和竞争对手定价等影响,从而影响了“需求”目标变量。
它如何应用于数据生成?
- 全面数据集: 合成数据集提供了一组全面的特征和交互,非常适合探索需求预测的各个方面和维度。
- 自由和灵活性: 合成性质允许不受约束的探索和分析,免受真实世界数据的敏感性和约束。
注意
虽然合成数据为探索和学习提供了许多优势,但至关重要的是要承认其在捕获真实世界复杂性和细微差别方面的局限性。
为什么承认局限性很重要?
- 真实世界的复杂性: 合成数据可能无法捕获真实世界数据中存在的所有复杂模式和异常,从而可能导致过于简化的模型和见解。
- 可转移到真实世界场景: 确保从合成数据得出的见解和模型可转移到真实世界场景需要仔细考虑和验证。
总之,generate_apple_sales_data_with_promo_adjustment
函数提供了一个强大的工具,用于生成用于预测苹果销售需求的全面合成数据集,从而促进广泛的探索和分析,同时承认合成数据的局限性。
import math
import pathlib
from datetime import datetime, timedelta
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
mean_squared_log_error,
median_absolute_error,
r2_score,
)
from sklearn.model_selection import train_test_split
import mlflow
def generate_apple_sales_data_with_promo_adjustment(
base_demand: int = 1000,
n_rows: int = 5000,
competitor_price_effect: float = -50.0,
):
"""
Generates a synthetic dataset for predicting apple sales demand with multiple
influencing factors.
This function creates a pandas DataFrame with features relevant to apple sales.
The features include date, average_temperature, rainfall, weekend flag, holiday flag,
promotional flag, price_per_kg, competitor's price, marketing intensity, stock availability,
and the previous day's demand. The target variable, 'demand', is generated based on a
combination of these features with some added noise.
Args:
base_demand (int, optional): Base demand for apples. Defaults to 1000.
n_rows (int, optional): Number of rows (days) of data to generate. Defaults to 5000.
competitor_price_effect (float, optional): Effect of competitor's price being lower
on our sales. Defaults to -50.
Returns:
pd.DataFrame: DataFrame with features and target variable for apple sales prediction.
Example:
>>> df = generate_apple_sales_data_with_promo_adjustment(base_demand=1200, n_rows=6000)
>>> df.head()
"""
# Set seed for reproducibility
np.random.seed(9999)
# Create date range
dates = [datetime.now() - timedelta(days=i) for i in range(n_rows)]
dates.reverse()
# Generate features
df = pd.DataFrame(
{
"date": dates,
"average_temperature": np.random.uniform(10, 35, n_rows),
"rainfall": np.random.exponential(5, n_rows),
"weekend": [(date.weekday() >= 5) * 1 for date in dates],
"holiday": np.random.choice([0, 1], n_rows, p=[0.97, 0.03]),
"price_per_kg": np.random.uniform(0.5, 3, n_rows),
"month": [date.month for date in dates],
}
)
# Introduce inflation over time (years)
df["inflation_multiplier"] = 1 + (df["date"].dt.year - df["date"].dt.year.min()) * 0.03
# Incorporate seasonality due to apple harvests
df["harvest_effect"] = np.sin(2 * np.pi * (df["month"] - 3) / 12) + np.sin(
2 * np.pi * (df["month"] - 9) / 12
)
# Modify the price_per_kg based on harvest effect
df["price_per_kg"] = df["price_per_kg"] - df["harvest_effect"] * 0.5
# Adjust promo periods to coincide with periods lagging peak harvest by 1 month
peak_months = [4, 10] # months following the peak availability
df["promo"] = np.where(
df["month"].isin(peak_months),
1,
np.random.choice([0, 1], n_rows, p=[0.85, 0.15]),
)
# Generate target variable based on features
base_price_effect = -df["price_per_kg"] * 50
seasonality_effect = df["harvest_effect"] * 50
promo_effect = df["promo"] * 200
df["demand"] = (
base_demand
+ base_price_effect
+ seasonality_effect
+ promo_effect
+ df["weekend"] * 300
+ np.random.normal(0, 50, n_rows)
) * df["inflation_multiplier"] # adding random noise
# Add previous day's demand
df["previous_days_demand"] = df["demand"].shift(1)
df["previous_days_demand"].fillna(method="bfill", inplace=True) # fill the first row
# Introduce competitor pricing
df["competitor_price_per_kg"] = np.random.uniform(0.5, 3, n_rows)
df["competitor_price_effect"] = (
df["competitor_price_per_kg"] < df["price_per_kg"]
) * competitor_price_effect
# Stock availability based on past sales price (3 days lag with logarithmic decay)
log_decay = -np.log(df["price_per_kg"].shift(3) + 1) + 2
df["stock_available"] = np.clip(log_decay, 0.7, 1)
# Marketing intensity based on stock availability
# Identify where stock is above threshold
high_stock_indices = df[df["stock_available"] > 0.95].index
# For each high stock day, increase marketing intensity for the next week
for idx in high_stock_indices:
df.loc[idx : min(idx + 7, n_rows - 1), "marketing_intensity"] = np.random.uniform(0.7, 1)
# If the marketing_intensity column already has values, this will preserve them;
# if not, it sets default values
fill_values = pd.Series(np.random.uniform(0, 0.5, n_rows), index=df.index)
df["marketing_intensity"].fillna(fill_values, inplace=True)
# Adjust demand with new factors
df["demand"] = df["demand"] + df["competitor_price_effect"] + df["marketing_intensity"]
# Drop temporary columns
df.drop(
columns=[
"inflation_multiplier",
"harvest_effect",
"month",
"competitor_price_effect",
"stock_available",
],
inplace=True,
)
return df
生成苹果销售数据
在此单元格中,我们调用 generate_apple_sales_data_with_promo_adjustment
函数来生成苹果销售数据集。
使用的参数:
base_demand
:设置为 1000,表示苹果的基本需求。n_rows
:设置为 10,000,用于确定生成的数据集中的行数或数据点数。competitor_price_effect
:设置为 -25.0,表示当竞争对手的价格较低时对我们销售的影响。
通过运行此单元格,我们将获得一个数据集 my_data
,其中包含具有上述配置的合成苹果销售数据。此数据集将用于本笔记本后续步骤中的进一步探索和分析。
您可以在生成单元格之后的单元格中查看数据。
my_data = generate_apple_sales_data_with_promo_adjustment(
base_demand=1000, n_rows=10_000, competitor_price_effect=-25.0
)
my_data
日期 | 平均温度 | 降雨量 | 周末 | 假期 | 每公斤价格 | 促销 | 需求 | 前几天的需求 | 每公斤竞争对手价格 | 营销力度 | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1996-05-11 13:10:40.689999 | 30.584727 | 1.831006 | 1 | 0 | 1.578387 | 1 | 1301.647352 | 1326.324266 | 0.755725 | 0.323086 |
1 | 1996-05-12 13:10:40.689999 | 15.465069 | 0.761303 | 1 | 0 | 1.965125 | 0 | 1143.972638 | 1326.324266 | 0.913934 | 0.030371 |
2 | 1996-05-13 13:10:40.689998 | 10.786525 | 1.427338 | 0 | 0 | 1.497623 | 0 | 890.319248 | 1168.942267 | 2.879262 | 0.354226 |
3 | 1996-05-14 13:10:40.689997 | 23.648154 | 3.737435 | 0 | 0 | 1.952936 | 0 | 811.206168 | 889.965021 | 0.826015 | 0.953000 |
4 | 1996-05-15 13:10:40.689997 | 13.861391 | 5.598549 | 0 | 0 | 2.059993 | 0 | 822.279469 | 835.253168 | 1.130145 | 0.953000 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
9995 | 2023-09-22 13:10:40.682895 | 23.358868 | 7.061220 | 0 | 0 | 1.556829 | 1 | 1981.195884 | 2089.644454 | 0.560507 | 0.889971 |
9996 | 2023-09-23 13:10:40.682895 | 14.859048 | 0.868655 | 1 | 0 | 1.632918 | 0 | 2180.698138 | 2005.305913 | 2.460766 | 0.884467 |
9997 | 2023-09-24 13:10:40.682894 | 17.941035 | 13.739986 | 1 | 0 | 0.827723 | 1 | 2675.093671 | 2179.813671 | 1.321922 | 0.884467 |
9998 | 2023-09-25 13:10:40.682893 | 14.533862 | 1.610512 | 0 | 0 | 0.589172 | 0 | 1703.287285 | 2674.209204 | 2.604095 | 0.812706 |
9999 | 2023-09-26 13:10:40.682889 | 13.048549 | 5.287508 | 0 | 0 | 1.794122 | 1 | 1971.029266 | 1702.474579 | 1.261635 | 0.750458 |
10000 行 × 11 列
需求的时序可视化
在本节中,我们将创建一个时序图,以可视化需求数据及其滚动平均值。
为什么这很重要?
可视化时序数据对于识别模式、了解可变性和做出更明智的决策至关重要。通过同时绘制滚动平均值,我们可以消除短期波动并突出长期趋势或周期。这种视觉辅助工具对于理解数据和做出更准确和明智的预测和决策至关重要。
代码结构:
- 输入验证:代码首先确保数据是 pandas DataFrame。
- 日期转换:它将“日期”列转换为 datetime 格式,以便进行准确的绘图。
- 滚动平均值计算:它计算“需求”的滚动平均值,并指定窗口大小(
window_size
),默认为 7 天。 - 绘图:它将原始需求数据和计算出的滚动平均值绘制在同一图上进行比较。原始需求数据以低 alpha 值绘制,使其显示为“幽灵”,以确保滚动平均值突出显示。
- 标签和图例:添加足够的标签和图例以提高清晰度。
为什么要返回图形?
我们返回图形对象 (fig
) 而不是直接呈现它,以便模型训练事件的每次迭代都可以将图形作为已记录的工件提供给 MLflow。这种方法允许我们将数据可视化的状态与用于训练的数据的状态完全持久化。MLflow 可以存储此图形对象,从而可以在 MLflow UI 中轻松检索和呈现,确保可视化始终可访问并与相关的模型和数据信息配对。
def plot_time_series_demand(data, window_size=7, style="seaborn", plot_size=(16, 12)):
if not isinstance(data, pd.DataFrame):
raise TypeError("df must be a pandas DataFrame.")
df = data.copy()
df["date"] = pd.to_datetime(df["date"])
# Calculate the rolling average
df["rolling_avg"] = df["demand"].rolling(window=window_size).mean()
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Plot the original time series data with low alpha (transparency)
ax.plot(df["date"], df["demand"], "b-o", label="Original Demand", alpha=0.15)
# Plot the rolling average
ax.plot(
df["date"],
df["rolling_avg"],
"r",
label=f"{window_size}-Day Rolling Average",
)
# Set labels and title
ax.set_title(
f"Time Series Plot of Demand with {window_size} day Rolling Average",
fontsize=14,
)
ax.set_xlabel("Date", fontsize=12)
ax.set_ylabel("Demand", fontsize=12)
# Add legend to explain the lines
ax.legend()
plt.tight_layout()
plt.close(fig)
return fig
使用箱形图可视化周末与工作日的需求
在本节中,我们将使用箱形图来可视化周末与工作日的需求分布。此可视化有助于了解基于星期几的需求的可变性和集中趋势。
为什么这很重要?
了解周末和工作日之间的需求差异对于在库存、人员配备和其他运营方面做出明智的决策至关重要。它有助于识别需求较高的时期,从而可以更好地进行资源分配和计划。
代码结构:
- 箱形图:代码使用 Seaborn 创建一个箱形图,该图显示了周末 (1) 和工作日 (0) 的需求分布。箱形图提供了有关两个类别的需求数据的中位数、四分位数和可能的异常值的见解。
- 添加单个数据点:为了提供更多上下文,单个数据点将作为条带图叠加在箱形图上。它们会抖动以实现更好的可视化效果,并根据日期类型进行颜色编码。
- 样式设置:对绘图进行样式设置以提高清晰度,并删除不必要的图例以增强可读性。
为什么要返回图形?
与时序图一样,此函数还会返回图形对象 (fig
) 而不是直接显示它。
def plot_box_weekend(df, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
sns.boxplot(data=df, x="weekend", y="demand", ax=ax, color="lightgray")
sns.stripplot(
data=df,
x="weekend",
y="demand",
ax=ax,
hue="weekend",
palette={0: "blue", 1: "green"},
alpha=0.15,
jitter=0.3,
size=5,
)
ax.set_title("Box Plot of Demand on Weekends vs. Weekdays", fontsize=14)
ax.set_xlabel("Weekend (0: No, 1: Yes)", fontsize=12)
ax.set_ylabel("Demand", fontsize=12)
for i in ax.get_xticklabels() + ax.get_yticklabels():
i.set_fontsize(10)
ax.legend_.remove()
plt.tight_layout()
plt.close(fig)
return fig
探索需求与每公斤价格之间的关系
在此可视化中,我们将创建一个散点图来调查 demand
和 price_per_kg
之间的关系。了解这种关系对于定价策略和需求预测至关重要。
为什么这很重要?
- 洞察定价策略: 此可视化有助于揭示需求如何随每公斤价格变化,从而为设定价格以优化销售额和收入提供有价值的见解。
- 了解需求弹性: 它有助于了解关于价格的需求弹性,从而有助于做出有关促销和折扣的明智和数据驱动的决策。
代码结构:
- 散点图: 代码生成一个散点图,其中每个点的位置由
price_per_kg
和demand
决定,颜色表示当天是周末还是工作日。这种颜色编码有助于快速识别特定于周末或工作日的模式。 - 透明度和抖动: 点以透明度 (
alpha=0.15
) 绘制以处理过度绘图,从而可以可视化点的密度。 - 回归线: 对于每个子组(周末和工作日),将单独的回归线拟合并绘制在同一轴上。这些线清楚地表明了每个组的需求趋势,涉及每公斤价格。
def plot_scatter_demand_price(df, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Scatter plot with jitter, transparency, and color-coded based on weekend
sns.scatterplot(
data=df,
x="price_per_kg",
y="demand",
hue="weekend",
palette={0: "blue", 1: "green"},
alpha=0.15,
ax=ax,
)
# Fit a simple regression line for each subgroup
sns.regplot(
data=df[df["weekend"] == 0],
x="price_per_kg",
y="demand",
scatter=False,
color="blue",
ax=ax,
)
sns.regplot(
data=df[df["weekend"] == 1],
x="price_per_kg",
y="demand",
scatter=False,
color="green",
ax=ax,
)
ax.set_title("Scatter Plot of Demand vs Price per kg with Regression Line", fontsize=14)
ax.set_xlabel("Price per kg", fontsize=12)
ax.set_ylabel("Demand", fontsize=12)
for i in ax.get_xticklabels() + ax.get_yticklabels():
i.set_fontsize(10)
plt.tight_layout()
plt.close(fig)
return fig
可视化需求密度:工作日与周末
此可视化使我们可以分别观察工作日和周末的 demand
分布。
为什么这很重要?
- 需求分布洞察: 了解工作日与周末的需求分布可以为库存管理和人员配备需求提供信息。
- 为业务战略提供信息: 这种洞察对于做出关于促销、折扣和其他可能在特定日期更有效的策略的数据驱动决策至关重要。
代码结构:
- 密度图: 代码生成
demand
的密度图,分为工作日和周末。 - 颜色编码组: 两个组(工作日和周末)采用颜色编码(分别为蓝色和绿色),从而可以轻松区分它们。
- 透明度和填充: 密度曲线下的区域填充有浅色、透明的颜色 (
alpha=0.15
),以便轻松可视化,同时避免视觉混乱。
视觉元素是什么?
- 两条密度曲线: 绘图包含两条密度曲线,一条用于工作日,另一条用于周末。这些曲线提供了每个组的需求分布的清晰视觉表示。
- 图例: 添加了图例以帮助识别哪条曲线对应于哪个组(工作日或周末)。
def plot_density_weekday_weekend(df, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Plot density for weekdays
sns.kdeplot(
df[df["weekend"] == 0]["demand"],
color="blue",
label="Weekday",
ax=ax,
fill=True,
alpha=0.15,
)
# Plot density for weekends
sns.kdeplot(
df[df["weekend"] == 1]["demand"],
color="green",
label="Weekend",
ax=ax,
fill=True,
alpha=0.15,
)
ax.set_title("Density Plot of Demand by Weekday/Weekend", fontsize=14)
ax.set_xlabel("Demand", fontsize=12)
ax.legend(fontsize=12)
for i in ax.get_xticklabels() + ax.get_yticklabels():
i.set_fontsize(10)
plt.tight_layout()
plt.close(fig)
return fig
模型系数的可视化
在本节中,我们将使用条形图来可视化训练模型中特征的系数。
为什么这很重要?
了解系数的大小和方向对于解释模型至关重要。它有助于识别影响预测的最重要特征。这种洞察对于特征选择、工程设计,并最终提高模型性能至关重要。
代码结构:
- 上下文设置:代码首先将绘图样式设置为“seaborn”,以增强美观性。
- 图形初始化:它创建图形和轴以进行绘图。
- 条形图:它使用水平条形图 (
barh
) 来可视化每个特征的系数。y 轴表示特征名称,x 轴表示系数值。此可视化使您可以轻松地比较系数,从而深入了解它们对于目标变量的相对重要性和影响。 - 标题和标签:它为 x(“系数值”)和 y(“特征”)轴设置适当的标题(“系数图”)和标签,以确保清晰度和可理解性。
通过可视化系数,我们可以更深入地了解模型,从而更轻松地解释模型的预测并做出关于特征重要性和影响的更明智的决策。
def plot_coefficients(model, feature_names, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
ax.barh(feature_names, model.coef_)
ax.set_title("Coefficient Plot", fontsize=14)
ax.set_xlabel("Coefficient Value", fontsize=12)
ax.set_ylabel("Features", fontsize=12)
plt.tight_layout()
plt.close(fig)
return fig
残差的可视化
在本节中,我们将创建一个图来可视化模型的残差,残差是观察值和预测值之间的差异。
为什么这很重要?
残差图是回归分析中的一种基本诊断工具,用于调查预测变量和响应变量之间关系的不可预测性。它有助于识别非线性、异方差和异常值。此图有助于验证误差呈正态分布且具有恒定方差的假设,这对于回归模型预测的可靠性至关重要。
代码结构:
- 残差计算:代码首先通过计算实际值 (
y_test
) 和预测值 (y_pred
) 之间的差来计算残差。 - 上下文设置:代码将绘图样式设置为“seaborn”,以获得视觉上吸引人的绘图。
- 图形初始化:它创建图形和轴以进行绘图。
- 残差绘图:它利用 Seaborn 中的
residplot
创建残差图,并使用 lowess(局部加权散点图平滑)线来突出显示残差中的趋势。 - 零线:它在零处添加一条虚线,以用作观察残差的参考。线上的残差表示欠预测,而线下的残差表示过度预测。
- 标题和标签:它为 x(“预测值”)和 y(“残差”)轴设置适当的标题(“残差图”)和标签,以确保清晰度和可理解性。
通过检查残差图,我们可以就模型的充分性以及可能需要进一步改进或增加复杂性做出更明智的决策。
def plot_residuals(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
residuals = y_test - y_pred
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
sns.residplot(
x=y_pred,
y=residuals,
lowess=True,
ax=ax,
line_kws={"color": "red", "lw": 1},
)
ax.axhline(y=0, color="black", linestyle="--")
ax.set_title("Residual Plot", fontsize=14)
ax.set_xlabel("Predicted values", fontsize=12)
ax.set_ylabel("Residuals", fontsize=12)
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontsize(10)
plt.tight_layout()
plt.close(fig)
return fig
预测错误的可视化
在本节中,我们将创建一个图来可视化预测错误,展示来自我们模型的实际值和预测值之间的差异。
为什么这很重要?
了解预测错误对于评估模型的性能至关重要。预测错误图提供了有关错误分布的见解,并有助于识别趋势、偏差或异常值。此可视化是模型评估的关键组成部分,有助于识别模型可能需要改进的领域,并确保模型可以很好地推广到新数据。
代码结构:
- 上下文设置:代码将绘图样式设置为“seaborn”,以获得干净且有吸引力的绘图。
- 图形初始化:它初始化用于绘图的图形和轴。
- 散点图:代码根据误差(实际值 - 预测值)绘制预测值。绘图上的每个点表示一个特定的观察值,其在 y 轴上的位置表示误差的大小和方向(高于零表示欠预测,低于零表示过度预测)。
- 零线:一条红色虚线在 y=0 处绘制为参考,有助于轻松识别误差。此线上方的点是欠预测,下方的点是过度预测。
- 标题和标签:它为 x(“预测值”)和 y(“误差”)轴添加标题(“预测误差图”)和标签,以提高清晰度和理解性。
通过分析预测误差图,从业者可以深入了解模型的性能,从而有助于进一步改进和增强模型,以获得更好和更可靠的预测。
def plot_prediction_error(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
ax.scatter(y_pred, y_test - y_pred)
ax.axhline(y=0, color="red", linestyle="--")
ax.set_title("Prediction Error Plot", fontsize=14)
ax.set_xlabel("Predicted Values", fontsize=12)
ax.set_ylabel("Errors", fontsize=12)
plt.tight_layout()
plt.close(fig)
return fig
分位数-分位数图 (QQ 图) 的可视化
在本节中,我们将生成一个 QQ 图,以可视化模型预测的残差分布。
为什么这很重要?
QQ 图对于评估模型中的残差是否遵循正态分布至关重要,这是线性回归模型中的基本假设。如果 QQ 图中的点没有紧密地遵循该线并显示出一种模式,则表明残差可能不是正态分布,这可能意味着模型存在异方差或非线性等问题。
代码结构:
- 残差计算:代码首先通过从实际测试值中减去预测值来计算残差。
- 上下文设置:绘图样式设置为“seaborn”,以增强美感。
- 图形初始化:初始化图形和轴以进行绘图。
- QQ 图生成:
stats.probplot
函数用于生成 QQ 图。它绘制残差的分位数与正态分布的分位数。 - 标题添加:将标题(“QQ 图”)添加到绘图中,以提高清晰度。
通过仔细分析 QQ 图,我们可以确保我们模型的残差满足正态性假设。如果不是,则可能需要探索其他模型类型或转换,以提高模型的性能和可靠性。
def plot_qq(y_test, y_pred, style="seaborn", plot_size=(10, 8)):
residuals = y_test - y_pred
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
stats.probplot(residuals, dist="norm", plot=ax)
ax.set_title("QQ Plot", fontsize=14)
plt.tight_layout()
plt.close(fig)
return fig
特征相关矩阵
在本节中,我们将生成一个特征相关矩阵,以可视化数据集中不同特征之间的关系。
注意: 与本笔记本中的其他绘图不同,我们将绘图的本地副本保存到磁盘,以显示任意文件的备用日志记录机制,即 log_artifact()
API。在下面的主模型训练和日志记录部分中,您将看到此绘图如何添加到 MLflow 运行中。
为什么这很重要?
了解不同特征之间的相关性至关重要,可以
- 识别多重共线性,这可能会影响模型性能和可解释性。
- 深入了解变量之间的关系,从而可以为特征工程和选择提供信息。
- 发现不同特征之间潜在的因果关系或交互作用,从而可以为领域理解和进一步分析提供信息。
代码结构:
- 相关性计算:代码首先计算所提供 DataFrame 的相关矩阵。
- 屏蔽:为相关矩阵的上三角形创建一个屏蔽,因为该矩阵是对称的,我们不需要可视化重复信息。
- 热图生成:生成热图以可视化相关系数。颜色梯度和注释提供了有关变量之间关系的清晰见解。
- 标题添加:添加标题以明确标识绘图。
通过分析相关矩阵,我们可以做出更明智的关于特征选择的决策,并更好地了解我们数据集中的关系。
def plot_correlation_matrix_and_save(
df, style="seaborn", plot_size=(10, 8), path="/tmp/corr_plot.png"
):
with plt.style.context(style=style):
fig, ax = plt.subplots(figsize=plot_size)
# Calculate the correlation matrix
corr = df.corr()
# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(
corr,
mask=mask,
cmap="coolwarm",
vmax=0.3,
center=0,
square=True,
linewidths=0.5,
annot=True,
fmt=".2f",
)
ax.set_title("Feature Correlation Matrix", fontsize=14)
plt.tight_layout()
plt.close(fig)
# convert to filesystem path spec for os compatibility
save_path = pathlib.Path(path)
fig.savefig(save_path)
模型训练和可视化的主要执行的详细概述
本节深入探讨了为模型训练、预测、误差计算和可视化执行的全面工作流程。每个步骤的重要性以及特定选择的原因都将得到彻底讨论。
结构化执行的优势
以结构化方式执行模型训练和评估的所有关键步骤至关重要。它提供了一个框架,可确保考虑建模过程的每个方面,从而提供更可靠和更强大的模型。这种简化的执行有助于避免被忽视的错误或偏差,并保证模型在所有必要的方面都得到评估。
将可视化记录到 MLflow 的重要性
将可视化记录到 MLflow 具有以下几个主要优势
-
永久性:与笔记本的临时状态(其中单元格可能会以无序的方式运行,从而导致潜在的误解)不同,将绘图记录到 MLflow 可确保可视化与特定运行永久存储。这种永久性确保保留了模型训练和评估的视觉上下文,消除了混乱并确保了解释的清晰性。
-
溯源:通过记录可视化,可以捕获模型训练时数据的确切状态和关系。此做法对于很久以前训练的模型至关重要。它提供了一个可靠的参考点,可以了解模型在训练时的行为和数据特征,从而确保见解和解释在一段时间内保持有效和可靠。
-
可访问性:将可视化存储在 MLflow 中使所有团队成员或相关人员都可以轻松访问它们。这种可视化的集中存储增强了协作,使不同的团队成员可以轻松地查看、分析和解释可视化,从而导致更明智和集体的决策。
代码的详细结构:
-
设置 MLflow:
- 定义 MLflow 的跟踪 URI。
- 设置名为“可视化演示”的实验,所有运行和日志都将存储在该实验下。
-
数据准备:
X
和y
分别定义为特征和目标变量。- 数据集分为训练集和测试集,以确保在未见过的数据上评估模型的性能。
-
初始绘图生成:
- 生成初始绘图,包括时序、箱形图、散点图和密度图。
- 这些绘图初步了解了数据及其特征。
-
模型定义和训练:
- 使用
alpha
1.0 定义 Ridge 回归模型。 - 在训练数据上训练模型,学习数据中的关系和模式。
- 使用
-
预测和误差计算:
- 训练后的模型用于对测试数据进行预测。
- 计算各种误差指标,包括 MSE、RMSE、MAE、R2、MSLE 和 MedAE,以评估模型的性能。
-
其他绘图生成:
- 生成其他绘图,包括残差图、系数图、预测误差图和 QQ 图。
- 这些绘图进一步深入了解了模型的性能、残差行为和误差分布。
-
记录到 MLflow:
- 训练后的模型、计算的指标、定义的参数 (
alpha
) 和所有生成的绘图都记录到 MLflow 中。 - 此日志记录确保与模型相关的所有信息和可视化都存储在集中的、可访问的位置。
- 训练后的模型、计算的指标、定义的参数 (
结论:
通过执行此全面且结构化的代码,我们可以确保涵盖模型训练、评估和解释的每个方面。将所有相关信息和可视化记录到 MLflow 的做法进一步提高了模型及其性能的可靠性、可访问性和可解释性,从而有助于进行更明智和可靠的模型部署和利用。
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("Visualizations Demo")
X = my_data.drop(columns=["demand", "date"])
y = my_data["demand"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
fig1 = plot_time_series_demand(my_data, window_size=28)
fig2 = plot_box_weekend(my_data)
fig3 = plot_scatter_demand_price(my_data)
fig4 = plot_density_weekday_weekend(my_data)
# Execute the correlation plot, saving the plot to a local temporary directory
plot_correlation_matrix_and_save(my_data)
# Define our Ridge model
model = Ridge(alpha=1.0)
# Train the model
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Calculate error metrics
mse = mean_squared_error(y_test, y_pred)
rmse = math.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
msle = mean_squared_log_error(y_test, y_pred)
medae = median_absolute_error(y_test, y_pred)
# Generate prediction-dependent plots
fig5 = plot_residuals(y_test, y_pred)
fig6 = plot_coefficients(model, X_test.columns)
fig7 = plot_prediction_error(y_test, y_pred)
fig8 = plot_qq(y_test, y_pred)
# Start an MLflow run for logging metrics, parameters, the model, and our figures
with mlflow.start_run() as run:
# Log the model
mlflow.sklearn.log_model(sk_model=model, input_example=X_test, name="model")
# Log the metrics
mlflow.log_metrics(
{"mse": mse, "rmse": rmse, "mae": mae, "r2": r2, "msle": msle, "medae": medae}
)
# Log the hyperparameter
mlflow.log_param("alpha", 1.0)
# Log plots
mlflow.log_figure(fig1, "time_series_demand.png")
mlflow.log_figure(fig2, "box_weekend.png")
mlflow.log_figure(fig3, "scatter_demand_price.png")
mlflow.log_figure(fig4, "density_weekday_weekend.png")
mlflow.log_figure(fig5, "residuals_plot.png")
mlflow.log_figure(fig6, "coefficients_plot.png")
mlflow.log_figure(fig7, "prediction_errors.png")
mlflow.log_figure(fig8, "qq_plot.png")
# Log the saved correlation matrix plot by referring to the local file system location
mlflow.log_artifact("/tmp/corr_plot.png")
2023/09/26 13:10:41 INFO mlflow.tracking.fluent: Experiment with name 'Visualizations Demo' does not exist. Creating a new experiment. /Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/mlflow/models/signature.py:333: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values <https://www.mlflow.org/docs/latest/models.html#handling-integers-with-missing-values>`_ for more details. input_schema = _infer_schema(input_ex) /Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.")