跳到主要内容

MLflow Keras 3.0 集成

简介

Keras 3.0 是一个高级神经网络 API,可在 TensorFlow、JAX 和 PyTorch 后端上运行。它提供了一个用户友好的界面,用于构建和训练深度学习模型,并且可以在不更改代码的情况下灵活切换后端。

MLflow 的 Keras 集成可为深度学习工作流提供实验跟踪、模型版本管理和部署功能。

为什么选择 MLflow + Keras?

自动日志记录

使用一行代码即可实现全面的实验跟踪:mlflow.tensorflow.autolog() 可自动记录指标、参数和模型。

实验跟踪

跨所有 Keras 实验跟踪训练指标、超参数、模型架构和工件。

模型注册表

使用 MLflow 的模型注册表和提供基础结构来版本化、分阶段和部署 Keras 模型。

多后端支持

跨 TensorFlow、JAX 和 PyTorch 后端一致地跟踪实验。

自动日志记录

通过一行代码实现全面的自动日志记录

python
import mlflow
import numpy as np
from tensorflow import keras

# Enable autologging
mlflow.tensorflow.autolog()

# Prepare sample data
X_train = np.random.rand(1000, 20)
y_train = np.random.randint(0, 2, 1000)

# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1, activation="sigmoid"),
]
)

model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# Training with automatic logging
with mlflow.start_run():
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

自动日志记录可自动捕获训练指标、模型参数、优化器配置和模型工件。

配置自动日志记录行为

python
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)

使用 Keras 回调进行手动日志记录

如需更多控制,请使用 mlflow.tensorflow.MlflowCallback()

python
import mlflow
import numpy as np
from tensorflow import keras

# Prepare sample data
X_train = np.random.rand(100, 20)
y_train = np.random.randint(0, 2, 100)

# Define and compile model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)

model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
X_train,
y_train,
epochs=10,
batch_size=32,
validation_split=0.2,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)

模型日志记录

使用 mlflow.tensorflow.log_model() 保存 Keras 模型

python
import mlflow
from tensorflow import keras

# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)

# Train model (code omitted for brevity)

# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")

# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)

模型注册表集成

注册 Keras 模型以进行版本控制和部署

python
import mlflow
from tensorflow import keras
from mlflow import MlflowClient

with mlflow.start_run():
# Create a simple model for demonstration
model = keras.Sequential(
[
keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D(2),
keras.layers.Flatten(),
keras.layers.Dense(10, activation="softmax"),
]
)

# Log model to registry
model_info = mlflow.tensorflow.log_model(
model, name="keras_model", registered_model_name="ImageClassifier"
)

# Tag for tracking
mlflow.set_tags({"model_type": "cnn", "dataset": "mnist", "framework": "keras"})

# Set alias for production deployment
client = MlflowClient()
client.set_registered_model_alias(
name="ImageClassifier",
alias="champion",
version=model_info.registered_model_version,
)

了解更多