分布式训练¶
Snowflake Container Runtime 提供了灵活的训练环境,可用于在 Snowflake 的基础设施上训练模型。您可以使用开源包,也可以使用 Snowflake ML 分布式训练器进行多节点和多设备训练。
分布式训练器可自动在多个节点和 GPUs 上扩展您的机器学习工作负载。Snowflake 分发器无需复杂配置,即可以智能方式管理集群资源,从而使分布式训练易于访问且高效。
当您处于以下情况时,请使用标准的开源库
在单节点环境中处理小型数据集
使用模型快速进行原型设计和试验
提升和转移工作流程,无需进行分布式处理
使用 Snowflake 分布式训练器执行以下操作:
利用超过单个计算节点内存的数据集训练模型
高效利用多个 GPUs
自动利用所有计算多节点 MLJobs 或扩展笔记本集群
Snowflake ML 分布式训练¶
Snowflake ML 为常用的机器学习框架提供分布式训练器,包括 XGBoost、LightGBM 和 PyTorch。这些训练器经过优化,可在 Snowflake 的基础设施上运行,并且可以自动扩展到多个节点和 GPUs。
自动资源管理 - Snowflake 自动发现并使用所有可用的集群资源
简化设置 - Container Runtime 环境由 Snowflake 提供的 Ray 集群提供支持,无需用户配置
无缝 Snowflake 集成 - 与 Snowflake 数据连接器和暂存区直接兼容
可选扩展配置 - 高级用户可以在需要时进行微调
数据加载¶
对于开源和 Snowflake 分布式训练器来说,引入数据最有效的方式是使用 Snowflake Data Connector:
训练方法¶
开源训练¶
当您需要最大程度的灵活性和对训练过程的控制时,请使用标准的开源库。通过开源训练,只需进行极少的修改,您就可以直接使用常用的 ML 框架(如 XGBoost、LightGBM 和 PyTorch),同时仍然可以受益于 Snowflake 的基础设施和数据连接。
以下示例使用 XGBoost 和 LightGBM 训练模型。
要使用开源 XGBoost 进行训练,请在使用数据连接器加载数据后,将其转换为 Pandas 数据帧并直接使用 XGB 库:
分布式训练¶
分布式 XGBEstimator 类有类似的 API,但有一些关键区别:
在类初始化期间,XGBoost 训练参数通过“params”参数传递给
XGBEstimator。DataConnector 对象可以直接传递给估算器的
fit函数,以及定义特征的输入列和定义目标的标签列。在实例化
XGBEstimator类时,您可以提供扩展配置。但是,Snowflake 默认使用所有可用资源。
评估模型¶
可以通过传递 eval_set 和使用 verbose_eval 将评估数据打印到控制台来评估模型。此外,推理可以作为第二步来完成。为方便起见,分布式估算器提供了 predict 方法,但它不会以分布式方式进行推理。我们建议在训练后将拟合模型转换为 OSS xgboost 估算器,以便进行推理并记录到模型注册表。
注册模型¶
要将模型注册到 Snowflake 模型注册表,请使用由 estimator.get_booster 提供并从 estimator.fit 返回的开源提升器。有关更多信息,请参阅 XGBoost。
PyTorch¶
Snowflake Distributor PyTorch 原生支持 Snowflake 后端的 Distributed Data Parallel 模型。要在 Snowflake 上使用 DDP,请利用开源 PyTorch 模块,并进行一些特定于 Snowflake 的修改:
使用
ShardedDataConnector加载数据,以自动将数据分片到与分布式训练器的world_size匹配的数量的分区中。在 Snowflake 训练上下文中调用get_shard,以检索与该工作器进程相关的分片。在训练函数中,使用
context对象获取特定于进程的信息,例如排名、本地排名和训练所需的数据。使用上下文的
get_model_dir保存模型,以查找模型的存储位置。这将在本地存储模型以进行单节点训练,并将模型同步到 Snowflake 暂存区以进行分布式训练。如果未提供暂存区位置,则默认使用您的用户暂存区。
加载数据¶
训练模型¶
检索模型¶
如果您使用多节点 DDP,则该模型将作为共享的永久存储自动同步到 Snowflake 暂存区。
以下代码会从暂存区获取模型。它使用 artifact_stage_location 参数来指定存储模型工件的暂存区的位置。
保存在 stage_location 变量中的函数会获取训练完成后模型在暂存区中的位置。模型对象保存在 "DB_NAME.SCHEMA_NAME.STAGE_NAME/model/{request_id}" 下。