跨数据分区训练模型

使用多模型训练 (MMT) 可在数据分区间高效训练多个机器学习模型。该工具可以自动处理分布式编排、模型存储及成果持久化。

MMT 根据指定列对 Snowpark DataFrame 进行分区,并在每个分区上并行训练独立模型。您只需专注于模型训练逻辑,MMT 将自动处理基础设施复杂性并实现自动扩展。

您可以利用 MMT 在不同数据分段间高效训练多个模型。该工具适用于以下场景:训练区域性销售预测模型、构建需为不同客户群体定制模型的个性化推荐系统,或创建特定细分市场的预测模型。MMT 自动处理分布式模型训练,免除管理分布式计算基础设施的复杂性。

您可以使用 MMT 基于开源机器学习模型和框架(如 XGBoost、scikit-learn、PyTorch 和 TensorFlow)进行模型训练。MMT 会自动序列化模型工件,确保您在推理时能直接使用。

您还可实现 ModelSerde 接口来训练自定义模型或使用不受支持的 ML 框架。这使您可以将 MMT 与您使用的任何机器学习框架或自定义模型架构无缝集成。

重要

使用 MMT 之前,请确保满足以下条件:

  • 容器运行时环境:MMT 需要 Snowflake ML 容器运行时环境。

  • 暂存区访问权限:MMT 会自动将模型工件存储在 Snowflake 暂存区中。请确保您拥有访问指定命名暂存区的相应权限。

  • ML 框架支持:内置集成支持 XGBoost、scikit-learn、PyTorch 和 TensorFlow。自定义模型需实现 ModelSerde 接口。

以下部分将通过示例工作流程指导您使用 MMT。

使用 MMT 训练模型

本部分通过五个关键步骤演示完整的 MMT 工作流程:

  1. 导入数据 – 使用 Snowpark 加载训练数据

  2. 定义训练函数 – 定义训练函数

  3. 跨分区训练模型 – 使用 MMT 并行训练每个分区的模型

  4. 访问经过训练的模型 – 检索和使用每个分区的训练模型

  5. 模型持久化存储和检索 – 将模型保存到暂存区,稍后再恢复

该工作流程可自动处理分布式训练、模型序列化及跨数据分区的工件存储。

导入数据

使用 Snowpark 会话开始导入数据。多模型训练函数会根据您指定的列将导入的数据拆分为不同分区。

使用 MMT 之前,请创建一个 Snowpark 会话。有关更多信息,请参阅 为 Snowpark Python 创建会话

以下代码使用 Snowpark 会话导入训练数据。

# Example: sales_data with columns: region, feature1, feature2, feature3, target
sales_data = session.table("SALES_TRAINING_DATA")
Copy

定义训练函数

获取数据后,您需要定义 MMT 用于跨分区训练模型的训练函数。训练函数接收数据连接器和上下文对象,后者指向当前训练的数据分区。除提供使用 TensorFlow 和 PyTorch 的示例外,本部分还包含用于训练 XGBoost 模型的训练函数示例。

训练函数必须严格遵循以下签名格式:(data_connector, context)。对于每个数据分区,MMT 将使用以下实参调用 train_xgboost_model

  • data_connector:选择使用 时默认使用的角色和仓库。提供 MMT 分区数据访问权限的数据连接器。train_xgboost_model 函数将数据框转换为 Pandas。

  • context:选择使用 时默认使用的角色和仓库。向 train_xgboost_model 函数提供 partition_id 的对象。此 ID 即为分区列的名称。

您无需自行调用此函数,MMT 将自动处理跨所有分区的执行。

使用以下代码定义训练函数。更改代码以反映数据特征后,即可将其传递给 MMT 函数。

使用 XGBoost 跨数据分区训练模型。XGBoost 能够为结构化数据提供卓越性能,并自动处理缺失值。

def train_xgboost_model(data_connector, context):
    df = data_connector.to_pandas()
    print(f"Training model for partition: {context.partition_id}")

    # Prepare features and target
    X = df[['feature1', 'feature2', 'feature3']]
    y = df['target']

    # Train the model
    from xgboost import XGBRegressor
    model = XGBRegressor(
        n_estimators=100,
        max_depth=6,
        learning_rate=0.1,
        random_state=42
    )
    model.fit(X, y)
    return model

trainer = ManyModelTraining(train_xgboost_model, "model_stage")
Copy

跨分区训练模型

定义训练函数后,即可使用 MMT 跨分区训练模型。指定用于分区的列以及保存模型的暂存区。

以下代码通过 region 列对数据进行分区,并使用 train_xgboost_model 函数并行训练每个区域的独立模型。

例如,如果 region 列包含以下可能值:

  • North

  • South

  • East

  • West

  • Central

ManyModelTraining 函数将为上述每个区域创建独立的数据分区,并在各分区上训练模型。

from snowflake.ml.modeling.distributors.many_model import ManyModelTraining

trainer = ManyModelTraining(train_xgboost_model, "model_stage") # Specify the stage to store the models
training_run = trainer.run(
    partition_by="region",  # Train separate models for each region
    snowpark_dataframe=sales_data,
    run_id="regional_models_v1" # Specify a unique ID for the training run
)

# Monitor training progress
final_status = training_run.wait()
print(f"Training completed with status: {final_status}")
Copy

模型存储在暂存区路径 run_id/{partition_id} 中,其中 partition_id 为分区列值。

访问经过训练的模型

MMT 完成后,您将获得存储在指定暂存区中的各数据分区训练模型。每个模型均基于其特定分区数据进行训练。例如,“North”模型仅使用“North”区域数据进行训练。

训练运行对象提供用于访问这些模型及检查各分区训练状态的方法。

以下代码检查训练运行状态并检索各分区的训练模型:

if final_status == RunStatus.SUCCESS:
    # Access models for each partition
    for partition_id in training_run.partition_details:
        trained_model = training_run.get_model(partition_id)
        print(f"Model for {partition_id}: {trained_model}")

        # You can now use the model for predictions or further analysis
        # Example: model.predict(new_data)
else:
    # Handle training failures
    for partition_id, details in training_run.partition_details.items():
        if details.status != "DONE":
            print(f"Training failed for {partition_id}")
            error_logs = details.logs
Copy

模型持久化存储和检索

MMT 在训练过程中会自动将训练好的模型持久化存储到指定的 Snowflake 暂存区。每个模型均按照包含运行 ID 和分区标识符的结构化路径存储,便于后续整理与检索模型。

自动持久化存储意味着您无需手动保存模型。MMT 会为您处理序列化和存储操作,消除因会话超时或连接问题导致训练模型丢失的风险。

即使原始会话已结束,您仍可恢复之前运行的训练并访问其模型。该持久化机制使您能够:

  • 在不同会话间恢复工作

  • 与团队成员共享经过训练的模型

  • 构建模型版本控制工作流程

  • 与下游推理管道集成

模型会自动保存至指定暂存区,并可后续检索:

# Restore training run from stage
restored_run = ManyModelTraining.restore_from("regional_models_v1", "model_stage")

# Access models from restored run
north_model = restored_run.get_model("North")
south_model = restored_run.get_model("South")
Copy

训练自定义模型

对于自定义模型或不支持的 ML ML 框架,请实现 ModelSerde 接口。您可为自定义模型定义自身的序列化与反序列化逻辑这使您可以将 MMT 与您使用的任何机器学习框架或自定义模型架构无缝集成。

from snowflake.ml.modeling.distributors.many_model import ModelSerde

class CustomModelSerde(ModelSerde):
    def serialize(self, model, path):
        # Custom serialization logic
        pass

    def deserialize(self, path):
        # Custom deserialization logic
        pass

def train_custom_model(data_connector, context):
    # Your custom training logic
    model = your_custom_model_training(data_connector.to_pandas())
    return model

trainer = ManyModelTraining(
    train_custom_model,
    "custom_model_stage",
    model_serde=CustomModelSerde()
)
Copy

与模型注册表集成

MMT 可以与 Snowflake 模型注册表集成,以实现增强的模型管理。模型注册表为您的组织提供集中式的模型版本控制、元数据跟踪和部署管理。当使用 MMT 训练多个模型时,此集成尤为宝贵,因为它可帮助您从单一位置组织、跟踪和管理所有分区专用模型。

通过将模型注册表与 MMT 结合使用后,您能够执行以下操作:

  • 跟踪分区特定模型的不同迭代版本

  • 存储模型性能指标、训练参数和沿袭信息

  • 管理每个分区部署到生产环境的具体模型版本

  • 通过适当的访问控制和文档实现跨团队模型共享

  • 为模型部署实施审批工作流程和合规性跟踪

# Register trained models to Model Registry
for partition_id in training_run.partition_details:
    model = training_run.get_model(partition_id)

    # Register to Model Registry
    model_ref = registry.log_model(
        model,
        model_name=f"sales_model_{partition_id.lower()}",
        version_name="v1"
    )
Copy
语言: 中文