跨数据分区训练模型¶
使用多模型训练 (MMT) 可在数据分区间高效训练多个机器学习模型。该工具可以自动处理分布式编排、模型存储及成果持久化。
MMT 根据指定列对 Snowpark DataFrame 进行分区,并在每个分区上并行训练独立模型。您只需专注于模型训练逻辑,MMT 将自动处理基础设施复杂性并实现自动扩展。
您可以利用 MMT 在不同数据分段间高效训练多个模型。该工具适用于以下场景:训练区域性销售预测模型、构建需为不同客户群体定制模型的个性化推荐系统,或创建特定细分市场的预测模型。MMT 自动处理分布式模型训练,免除管理分布式计算基础设施的复杂性。
您可以使用 MMT 基于开源机器学习模型和框架(如 XGBoost、scikit-learn、PyTorch 和 TensorFlow)进行模型训练。MMT 会自动序列化模型工件,确保您在推理时能直接使用。
您还可实现 ModelSerde 接口来训练自定义模型或使用不受支持的 ML 框架。这使您可以将 MMT 与您使用的任何机器学习框架或自定义模型架构无缝集成。
重要
使用 MMT 之前,请确保满足以下条件:
容器运行时环境:MMT 需要 Snowflake ML 容器运行时环境。
暂存区访问权限:MMT 会自动将模型工件存储在 Snowflake 暂存区中。请确保您拥有访问指定命名暂存区的相应权限。
ML 框架支持:内置集成支持 XGBoost、scikit-learn、PyTorch 和 TensorFlow。自定义模型需实现 ModelSerde 接口。
以下部分将通过示例工作流程指导您使用 MMT。
使用 MMT 训练模型¶
本部分通过五个关键步骤演示完整的 MMT 工作流程:
导入数据 – 使用 Snowpark 加载训练数据
定义训练函数 – 定义训练函数
跨分区训练模型 – 使用 MMT 并行训练每个分区的模型
访问经过训练的模型 – 检索和使用每个分区的训练模型
模型持久化存储和检索 – 将模型保存到暂存区,稍后再恢复
该工作流程可自动处理分布式训练、模型序列化及跨数据分区的工件存储。
导入数据¶
使用 Snowpark 会话开始导入数据。多模型训练函数会根据您指定的列将导入的数据拆分为不同分区。
使用 MMT 之前,请创建一个 Snowpark 会话。有关更多信息,请参阅 为 Snowpark Python 创建会话。
以下代码使用 Snowpark 会话导入训练数据。
# Example: sales_data with columns: region, feature1, feature2, feature3, target
sales_data = session.table("SALES_TRAINING_DATA")
定义训练函数¶
获取数据后,您需要定义 MMT 用于跨分区训练模型的训练函数。训练函数接收数据连接器和上下文对象,后者指向当前训练的数据分区。除提供使用 TensorFlow 和 PyTorch 的示例外,本部分还包含用于训练 XGBoost 模型的训练函数示例。
训练函数必须严格遵循以下签名格式:(data_connector, context)
。对于每个数据分区,MMT 将使用以下实参调用 train_xgboost_model
:
data_connector
:选择使用 时默认使用的角色和仓库。提供 MMT 分区数据访问权限的数据连接器。train_xgboost_model
函数将数据框转换为 Pandas。context
:选择使用 时默认使用的角色和仓库。向train_xgboost_model
函数提供partition_id
的对象。此 ID 即为分区列的名称。
您无需自行调用此函数,MMT 将自动处理跨所有分区的执行。
使用以下代码定义训练函数。更改代码以反映数据特征后,即可将其传递给 MMT 函数。
使用 XGBoost 跨数据分区训练模型。XGBoost 能够为结构化数据提供卓越性能,并自动处理缺失值。
def train_xgboost_model(data_connector, context):
df = data_connector.to_pandas()
print(f"Training model for partition: {context.partition_id}")
# Prepare features and target
X = df[['feature1', 'feature2', 'feature3']]
y = df['target']
# Train the model
from xgboost import XGBRegressor
model = XGBRegressor(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
random_state=42
)
model.fit(X, y)
return model
trainer = ManyModelTraining(train_xgboost_model, "model_stage")
使用 PyTorch 跨数据分区训练深度学习模型。PyTorch 提供灵活的神经网络架构和动态计算图。
def train_pytorch_model(data_connector, context):
import torch
import torch.nn as nn
df = data_connector.to_pandas()
# ... prepare data for PyTorch ...
model = nn.Sequential(nn.Linear(10, 1))
# ... training logic ...
return model # Automatically saved as model.pth
from snowflake.ml.modeling.distributors.many_model import TorchSerde
trainer = ManyModelTraining(train_pytorch_model, "models_stage", serde=TorchSerde())
使用 TensorFlow 跨数据分区训练深度学习模型。TensorFlow 为研究及生产部署提供全面工具支持。
def train_tf_model(data_connector, context):
import tensorflow as tf
df = data_connector.to_pandas()
# ... prepare data for TensorFlow ...
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
# ... training logic ...
return model # Automatically saved as model.h5
from snowflake.ml.modeling.distributors.many_model import TensorFlowSerde
trainer = ManyModelTraining(train_tf_model, "models_stage", serde=TensorFlowSerde())
通过实现 ModelSerde 接口使用自定义模型或不支持的 ML 框架。此示例展示带有自定义元数据处理的 scikit-learn。
from snowflake.ml.modeling.distributors.many_model import ModelSerde
import json
class ScikitLearnSerde(ModelSerde):
'''Custom serializer for scikit-learn models with metadata'''
@property
def filename(self) -> str:
return "sklearn_model.joblib"
def write(self, model, file_path: str) -> None:
import joblib
# Save model with metadata
model_data = {
'model': model,
'feature_names': getattr(model, 'feature_names_in_', None),
'model_type': type(model).__name__
}
joblib.dump(model_data, file_path)
def read(self, file_path: str):
import joblib
return joblib.load(file_path)
def train_sklearn_model(data_connector, context):
from sklearn.ensemble import RandomForestRegressor
df = data_connector.to_pandas()
X, y = df[['feature1', 'feature2']], df['target']
model = RandomForestRegressor()
model.fit(X, y)
return model # Automatically saved with metadata
trainer = ManyModelTraining(train_sklearn_model, "models_stage", serde=ScikitLearnSerde())
跨分区训练模型¶
定义训练函数后,即可使用 MMT 跨分区训练模型。指定用于分区的列以及保存模型的暂存区。
以下代码通过 region
列对数据进行分区,并使用 train_xgboost_model
函数并行训练每个区域的独立模型。
例如,如果 region
列包含以下可能值:
North
South
East
West
Central
ManyModelTraining
函数将为上述每个区域创建独立的数据分区,并在各分区上训练模型。
from snowflake.ml.modeling.distributors.many_model import ManyModelTraining
trainer = ManyModelTraining(train_xgboost_model, "model_stage") # Specify the stage to store the models
training_run = trainer.run(
partition_by="region", # Train separate models for each region
snowpark_dataframe=sales_data,
run_id="regional_models_v1" # Specify a unique ID for the training run
)
# Monitor training progress
final_status = training_run.wait()
print(f"Training completed with status: {final_status}")
模型存储在暂存区路径 run_id/{partition_id}
中,其中 partition_id
为分区列值。
访问经过训练的模型¶
MMT 完成后,您将获得存储在指定暂存区中的各数据分区训练模型。每个模型均基于其特定分区数据进行训练。例如,“North”模型仅使用“North”区域数据进行训练。
训练运行对象提供用于访问这些模型及检查各分区训练状态的方法。
以下代码检查训练运行状态并检索各分区的训练模型:
if final_status == RunStatus.SUCCESS:
# Access models for each partition
for partition_id in training_run.partition_details:
trained_model = training_run.get_model(partition_id)
print(f"Model for {partition_id}: {trained_model}")
# You can now use the model for predictions or further analysis
# Example: model.predict(new_data)
else:
# Handle training failures
for partition_id, details in training_run.partition_details.items():
if details.status != "DONE":
print(f"Training failed for {partition_id}")
error_logs = details.logs
模型持久化存储和检索¶
MMT 在训练过程中会自动将训练好的模型持久化存储到指定的 Snowflake 暂存区。每个模型均按照包含运行 ID 和分区标识符的结构化路径存储,便于后续整理与检索模型。
自动持久化存储意味着您无需手动保存模型。MMT 会为您处理序列化和存储操作,消除因会话超时或连接问题导致训练模型丢失的风险。
即使原始会话已结束,您仍可恢复之前运行的训练并访问其模型。该持久化机制使您能够:
在不同会话间恢复工作
与团队成员共享经过训练的模型
构建模型版本控制工作流程
与下游推理管道集成
模型会自动保存至指定暂存区,并可后续检索:
# Restore training run from stage
restored_run = ManyModelTraining.restore_from("regional_models_v1", "model_stage")
# Access models from restored run
north_model = restored_run.get_model("North")
south_model = restored_run.get_model("South")
训练自定义模型¶
对于自定义模型或不支持的 ML ML 框架,请实现 ModelSerde 接口。您可为自定义模型定义自身的序列化与反序列化逻辑这使您可以将 MMT 与您使用的任何机器学习框架或自定义模型架构无缝集成。
from snowflake.ml.modeling.distributors.many_model import ModelSerde
class CustomModelSerde(ModelSerde):
def serialize(self, model, path):
# Custom serialization logic
pass
def deserialize(self, path):
# Custom deserialization logic
pass
def train_custom_model(data_connector, context):
# Your custom training logic
model = your_custom_model_training(data_connector.to_pandas())
return model
trainer = ManyModelTraining(
train_custom_model,
"custom_model_stage",
model_serde=CustomModelSerde()
)
与模型注册表集成¶
MMT 可以与 Snowflake 模型注册表集成,以实现增强的模型管理。模型注册表为您的组织提供集中式的模型版本控制、元数据跟踪和部署管理。当使用 MMT 训练多个模型时,此集成尤为宝贵,因为它可帮助您从单一位置组织、跟踪和管理所有分区专用模型。
通过将模型注册表与 MMT 结合使用后,您能够执行以下操作:
跟踪分区特定模型的不同迭代版本
存储模型性能指标、训练参数和沿袭信息
管理每个分区部署到生产环境的具体模型版本
通过适当的访问控制和文档实现跨团队模型共享
为模型部署实施审批工作流程和合规性跟踪
# Register trained models to Model Registry
for partition_id in training_run.partition_details:
model = training_run.get_model(partition_id)
# Register to Model Registry
model_ref = registry.log_model(
model,
model_name=f"sales_model_{partition_id.lower()}",
version_name="v1"
)