MLflow Spark MLlib 集成
Apache Spark MLlib 提供分布式机器学习算法,用于在集群上处理大规模数据集。MLflow 与 Spark MLlib 集成,用于跟踪分布式机器学习管道、管理模型,并支持从集群训练到独立推理的灵活部署。
为什么选择 MLflow + Spark MLlib?
管道跟踪
自动记录带有所有阶段、转换器和估计器的 Spark ML 管道。跟踪每个管道组件的参数并维护完整的血缘关系。
格式灵活性
以原生的 Spark 格式保存模型以用于分布式批处理,或以 PyFunc 格式保存模型以便在 Spark 集群外部进行推理,并自动进行 DataFrame 转换。
数据源自动记录
自动跟踪数据源的路径、格式和版本。为分布式机器学习工作流维护完整的数据血缘关系。
跨平台部署
使用 PyFunc 包装器部署 Spark 模型以用于 REST API 和边缘计算,或转换为 ONNX 以实现平台无关的推理。
基本模型记录
使用 mlflow.spark.log_model() 记录 Spark MLlib 模型
python
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()
# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)
# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)
# Log the entire pipeline
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark-pipeline"
)
# Log parameters manually
mlflow.log_params(
{
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
}
)
print(f"Model logged with URI: {model_info.model_uri}")
自动以 Spark 原生和 PyFunc 格式记录完整的管道、所有阶段、参数和模型。
模型格式和加载
- 原生 Spark 格式
- PyFunc 格式
保留完整的 Spark ML 功能以进行分布式处理
python
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)
# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)
predictions = spark_model.transform(test_data)
predictions.show()
支持在 Spark 集群外部进行推理
python
import pandas as pd
# Load as PyFunc model
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
# Use with pandas DataFrame
test_data = pd.DataFrame(
{"text": ["spark machine learning", "hadoop distributed computing"]}
)
predictions = pyfunc_model.predict(test_data)
print(predictions)
PyFunc 会自动将 pandas DataFrames 转换为 Spark 格式,并为推理创建本地 Spark 会话。请注意,Apache Spark 库仍需要作为依赖项。
数据源自动记录
在模型训练期间自动跟踪数据源
python
import mlflow.spark
mlflow.spark.autolog()
with mlflow.start_run():
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
model = pipeline.fit(raw_data)
mlflow.spark.log_model(model, artifact_path="model")
需要 Spark 3.0+、MLflow-Spark JAR 配置,并且不支持在 Databricks 共享/无服务器集群上使用。记录所有数据源读取的路径、格式和版本。
模型签名
为 Spark ML 模型自动推断签名
python
from mlflow.models import infer_signature
from pyspark.ml.functions import array_to_vector
vector_data = spark.createDataFrame(
[([3.0, 4.0], 0.0), ([5.0, 6.0], 1.0)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)
predictions = model.transform(vector_data)
# Infer signature from pandas DataFrames
signature = infer_signature(
vector_data.limit(2).toPandas(),
predictions.select("prediction").limit(2).toPandas(),
)
with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
)
ONNX 转换
将 Spark 模型转换为 ONNX(实验性)
python
import onnxmltools
with mlflow.start_run():
model = pipeline.fit(training_data)
mlflow.spark.log_model(spark_model=model, artifact_path="spark_model")
onnx_model = onnxmltools.convert_sparkml(model, name="SparkMLPipeline")
onnxmltools.utils.save_model(onnx_model, "model.onnx")
mlflow.log_artifact("model.onnx")
模型注册表
注册和提升 Spark 模型
python
from mlflow import MlflowClient
client = MlflowClient()
with mlflow.start_run():
model = pipeline.fit(train_data)
mlflow.spark.log_model(
spark_model=model,
artifact_path="production_candidate",
registered_model_name="CustomerSegmentationModel",
)
mlflow.set_tags(
{
"validation_passed": "true",
"deployment_target": "batch_scoring",
}
)
model_version = client.get_latest_versions(
"CustomerSegmentationModel", stages=["None"]
)[0]
client.transition_model_version_stage(
name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)