开始使用 Keras 3.0 + MLflow
本教程是一个端到端的教程,关于如何使用 Keras 3.0 训练 MINIST 分类器并使用 MLflow 记录结果。它将演示如何使用 mlflow.keras.MlflowCallback
以及如何通过继承它来实现自定义日志记录逻辑。
Keras 是一个高级 API,设计宗旨是简单、灵活和强大——让从初学者到高级用户的所有人都能快速构建、训练和评估模型。Keras 3.0,也称为 Keras Core,是对 Keras 代码库的完全重写,它基于模块化后端架构。这使得 Keras 工作流程可以在任意框架上运行——从 TensorFlow、JAX 和 PyTorch 开始。
安装包
pip install -q keras mlflow jax jaxlib torch tensorflow
导入包 / 配置后端
Keras 3.0 本质上是多后端的,因此您需要在导入包之前设置后端环境变量。
import os
# You can use 'tensorflow', 'torch' or 'jax' as backend. Make sure to set the environment variable before importing.
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import mlflow
Using TensorFlow backend
加载数据集
我们将使用 MNIST 数据集。这是一个手写数字数据集,将用于图像分类任务。共有 10 个类别,对应 10 个数字。
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
x_train[0].shape
(28, 28, 1)
# Visualize Dataset
import matplotlib.pyplot as plt
grid = 3
fig, axes = plt.subplots(grid, grid, figsize=(6, 6))
for i in range(grid):
for j in range(grid):
axes[i][j].imshow(x_train[i * grid + j])
axes[i][j].set_title(f"label={y_train[i * grid + j]}")
plt.tight_layout()
构建模型
我们将使用 Keras 3.0 的 Sequential API 构建一个简单的 CNN。
NUM_CLASSES = 10
INPUT_SHAPE = (28, 28, 1)
def initialize_model():
return keras.Sequential(
[
keras.Input(shape=INPUT_SHAPE),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(NUM_CLASSES, activation="softmax"),
]
)
model = initialize_model()
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 26, 26, 32) │ 320 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 24, 24, 32) │ 9,248 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (Conv2D) │ (None, 22, 22, 32) │ 9,248 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_average_pooling2d │ (None, 32) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 10) │ 330 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 19,146 (74.79 KB)
Trainable params: 19,146 (74.79 KB)
Non-trainable params: 0 (0.00 B)
训练模型 (默认回调函数)
我们将在数据集上拟合模型,使用 MLflow 的 mlflow.keras.MlflowCallback
在训练期间记录指标。
BATCH_SIZE = 64 # adjust this based on the memory of your machine
EPOCHS = 3
按 Epoch 记录日志
一个 Epoch 定义为一次通过整个训练数据集的过程。
model = initialize_model()
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)
run = mlflow.start_run()
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[mlflow.keras.MlflowCallback(run)],
)
mlflow.end_run()
Epoch 1/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 34ms/step - accuracy: 0.5922 - loss: 1.2862 - val_accuracy: 0.9427 - val_loss: 0.2075 Epoch 2/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 33ms/step - accuracy: 0.9330 - loss: 0.2286 - val_accuracy: 0.9348 - val_loss: 0.2020 Epoch 3/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 33ms/step - accuracy: 0.9499 - loss: 0.1671 - val_accuracy: 0.9558 - val_loss: 0.1491
记录结果
运行的回调函数会将参数、指标和工件记录到 MLflow 仪表盘。
按 Batch 记录日志
在每个 Epoch 内,训练数据集根据定义的 BATCH_SIZE
被分解成多个 Batch。如果我们将回调函数设置为不按 Epoch 记录日志(log_every_epoch=False
),并每 5 个 Batch 记录一次日志(log_every_n_steps=5
),我们可以调整日志记录方式以基于 Batch 进行。
model = initialize_model()
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)
with mlflow.start_run() as run:
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[mlflow.keras.MlflowCallback(run, log_every_epoch=False, log_every_n_steps=5)],
)
Epoch 1/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 34ms/step - accuracy: 0.6151 - loss: 1.2100 - val_accuracy: 0.9373 - val_loss: 0.2144 Epoch 2/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 34ms/step - accuracy: 0.9274 - loss: 0.2459 - val_accuracy: 0.9608 - val_loss: 0.1338 Epoch 3/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 34ms/step - accuracy: 0.9477 - loss: 0.1738 - val_accuracy: 0.9577 - val_loss: 0.1454
记录结果
如果按 Epoch 记录日志,由于只有 3 个 Epoch,我们将只有三个数据点
通过按 Batch 记录日志,我们可以获得更多数据点,但它们可能更嘈杂
class MlflowCallbackLogPerBatch(mlflow.keras.MlflowCallback):
def on_batch_end(self, batch, logs=None):
if self.log_every_n_steps is None or logs is None:
return
if (batch + 1) % self.log_every_n_steps == 0:
self.metrics_logger.record_metrics(logs, self._log_step)
self._log_step += self.log_every_n_steps
model = initialize_model()
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)
with mlflow.start_run() as run:
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[MlflowCallbackLogPerBatch(run, log_every_epoch=False, log_every_n_steps=5)],
)
Epoch 1/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 34ms/step - accuracy: 0.5645 - loss: 1.4105 - val_accuracy: 0.9187 - val_loss: 0.2826 Epoch 2/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 34ms/step - accuracy: 0.9257 - loss: 0.2615 - val_accuracy: 0.9602 - val_loss: 0.1368 Epoch 3/3 [1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 34ms/step - accuracy: 0.9456 - loss: 0.1800 - val_accuracy: 0.9678 - val_loss: 0.1037
评估
与训练类似,您可以使用回调函数记录评估结果。
with mlflow.start_run() as run:
model.evaluate(x_test, y_test, callbacks=[mlflow.keras_core.MlflowCallback(run)])
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9541 - loss: 0.1487