MLflow Keras 3.0 集成
简介
Keras 3.0 是一个高级神经网络 API,可在 TensorFlow、JAX 和 PyTorch 后端上运行。它提供了一个用户友好的界面,用于构建和训练深度学习模型,并且可以在不更改代码的情况下灵活切换后端。
MLflow 的 Keras 集成可为深度学习工作流提供实验跟踪、模型版本管理和部署功能。
为什么选择 MLflow + Keras?
自动日志记录
使用一行代码即可实现全面的实验跟踪:mlflow.tensorflow.autolog() 可自动记录指标、参数和模型。
实验跟踪
跨所有 Keras 实验跟踪训练指标、超参数、模型架构和工件。
模型注册表
使用 MLflow 的模型注册表和提供基础结构来版本化、分阶段和部署 Keras 模型。
多后端支持
跨 TensorFlow、JAX 和 PyTorch 后端一致地跟踪实验。
自动日志记录
通过一行代码实现全面的自动日志记录
python
import mlflow
import numpy as np
from tensorflow import keras
# Enable autologging
mlflow.tensorflow.autolog()
# Prepare sample data
X_train = np.random.rand(1000, 20)
y_train = np.random.randint(0, 2, 1000)
# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
# Training with automatic logging
with mlflow.start_run():
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
自动日志记录可自动捕获训练指标、模型参数、优化器配置和模型工件。
配置自动日志记录行为
python
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)
使用 Keras 回调进行手动日志记录
如需更多控制,请使用 mlflow.tensorflow.MlflowCallback()
python
import mlflow
import numpy as np
from tensorflow import keras
# Prepare sample data
X_train = np.random.rand(100, 20)
y_train = np.random.randint(0, 2, 100)
# Define and compile model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
X_train,
y_train,
epochs=10,
batch_size=32,
validation_split=0.2,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)
模型日志记录
使用 mlflow.tensorflow.log_model() 保存 Keras 模型
python
import mlflow
from tensorflow import keras
# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)
# Train model (code omitted for brevity)
# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")
# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)
模型注册表集成
注册 Keras 模型以进行版本控制和部署
python
import mlflow
from tensorflow import keras
from mlflow import MlflowClient
with mlflow.start_run():
# Create a simple model for demonstration
model = keras.Sequential(
[
keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D(2),
keras.layers.Flatten(),
keras.layers.Dense(10, activation="softmax"),
]
)
# Log model to registry
model_info = mlflow.tensorflow.log_model(
model, name="keras_model", registered_model_name="ImageClassifier"
)
# Tag for tracking
mlflow.set_tags({"model_type": "cnn", "dataset": "mnist", "framework": "keras"})
# Set alias for production deployment
client = MlflowClient()
client.set_registered_model_alias(
name="ImageClassifier",
alias="champion",
version=model_info.registered_model_version,
)