Snowflake 多节点 ML 作业

使用 Snowflake 多节点 ML 作业在 Snowflake ML 容器运行时内跨多个计算节点运行分布式机器学习 (ML) 工作流程。将工作分配到多个节点以处理大型数据集和复杂模型,从而提高性能。有关 Snowflake ML 作业的信息,请参阅 Snowflake ML 作业

Snowflake 多节点 ML 作业支持跨多个节点的分布式执行,以扩展 Snowflake ML 作业的功能。您可以获得:

  • 可扩展性能:水平扩展以处理过大而无法容纳在单个节点上的数据集

  • 更短的训练时间:通过并行化加快复杂模型训练

  • 资源效率:优化数据密集型工作负载的资源利用率

  • 框架集成:无缝使用分布式框架,例如 分布式建模类 和 Ray (https://www.ray.io/)。

当您运行具有多个节点的 Snowflake ML 作业时,会发生以下情况:

  • 一个节点用作头节点(协调器)

  • 其他节点用作工作节点(计算资源)

  • 这些节点共同构成了 Snowflake 中的单一逻辑 ML 作业实体

单节点 ML 作业只有一个头节点。具有三个活动节点的多节点作业具有一个头节点和两个工作节点。所有三个节点都参与运行您的工作负载。

先决条件

要使用 Snowflake 多节点 ML 作业,需要满足以下先决条件。

重要

Snowflake 多节点 ML 作业目前仅支持 Python 3.10 客户端。如果您需要其他 Python 版本的支持,请联系您的 Snowflake 账户团队。

要设置多节点作业,请执行以下操作:

  1. 在您的 Python 3.10 环境中安装 Snowflake ML Python 包。

    pip install snowflake-ml-python>=1.9.2
    
    Copy
  2. 创建具有足够节点的计算池以支持您的多节点作业:

    CREATE COMPUTE POOL IF NOT EXISTS MY_COMPUTE_POOL
      MIN_NODES = 1
      MAX_NODES = <NUM_INSTANCES>
      INSTANCE_FAMILY = <INSTANCE_FAMILY>;
    
    Copy

    重要

    您必须将 MAX_NODES 设置为大于或等于用于运行训练作业的目标实例的数量。如果您请求的节点数超过了训练作业的预定用量,则作业可能会失败或出现不可预测的行为。有关运行训练作业的信息,请参阅 运行多节点 ML 作业

为多节点作业编写代码

对于多节点作业,您的代码需要使用 分布式建模类 或 Ray (https://www.ray.io/) 进行设计,以进行分布式处理。

以下是使用分布式建模类或 Ray 时的关键模式和注意事项:

了解节点初始化和可用性

在多节点作业中,工作节点可以在不同的时间异步初始化:

  • 节点可能无法全部同时启动,尤其是计算池资源有限时

  • 有些节点可能会在其他节点启动几秒甚至几分钟后启动

  • ML 作业会自动等待指定的 target_instances 可用,然后再执行您的有效负载。如果预期的节点在超时时间内不可用,则作业将失败并出现错误。有关自定义此行为的更多信息,请参阅 高级配置:使用 min_instances

您可以通过 Ray 检查作业中的可用节点:

import ray
ray.init(address="auto", ignore_reinit_error=True)  # Ray is automatically initialized in multi-node jobs
nodes_info = ray.nodes()
print(f"Available nodes: {len(nodes_info)}")
Copy

分布式处理模式

您可以在多节点作业的有效负载正文中应用多种模式以进行分布式处理。这些模式利用 分布式建模类 和 Ray (https://www.ray.io/):

使用 Snowflake 的分布式训练 API

Snowflake 为常见 ML 框架提供经过优化的训练器:

# Inside the ML Job payload body
from snowflake.ml.modeling.distributors.xgboost import XGBEstimator, XGBScalingConfig

# Configure scaling for distributed execution
scaling_config = XGBScalingConfig()

# Create distributed estimator
estimator = XGBEstimator(
    n_estimators=100,
    params={"objective": "reg:squarederror"},
    scaling_config=scaling_config
)

# Train using distributed resources
# NOTE: data_connector and feature_cols excluded for brevity
model = estimator.fit(data_connector, input_cols=feature_cols, label_col="target")
Copy

有关可用 APIs 的更多信息,请参阅 分布式建模类

使用原生 Ray 任务

另一种方法是使用 Ray 基于任务的编程模型:

# Inside the ML Job payload body
import ray

@ray.remote
def process_chunk(data_chunk):
    # Process a chunk of data
    return processed_result

# Distribute work across available workers
data_chunks = split_data(large_dataset)
futures = [process_chunk.remote(chunk) for chunk in data_chunks]
results = ray.get(futures)
Copy

有关更多信息,请参阅 ` Ray 的任务编程文档<https://docs.ray.io/en/latest/ray-core/tasks.html (https://docs.ray.io/en/latest/ray-core/tasks.html)>`_。

运行多节点 ML 作业

您可以使用与单节点作业相同的方法运行多节点 ML 作业,使用 target_instances 参数:

使用远程装饰器

from snowflake.ml.jobs import remote

@remote(
    "MY_COMPUTE_POOL",
    stage_name="payload_stage",
    session=session,
    target_instances=3  # Specify the number of nodes
)
def distributed_training(data_table: str):

    from snowflake.ml.modeling.distributors.xgboost import XGBEstimator, XGBScalingConfig

    # Configure scaling for distributed execution
    scaling_config = XGBScalingConfig()

    # Create distributed estimator
    estimator = XGBEstimator(
        n_estimators=100,
        params={"objective": "reg:squarederror"},
        scaling_config=scaling_config
    )

    # Train using distributed resources
    # NOTE: data_connector and feature_cols excluded for brevity
    model = estimator.fit(data_connector, input_cols=feature_cols, label_col="target")


job = distributed_training("<my_training_data>")
Copy

运行 Python 文件

from snowflake.ml.jobs import submit_file

job = submit_file(
    "<script_path>",
    "MY_COMPUTE_POOL",
    stage_name="<payload_stage>",
    session=session,
    target_instances=<num_training_nodes>  # Specify the number of nodes
)
Copy

运行目录

from snowflake.ml.jobs import submit_directory

job = submit_directory(
    "<script_directory>",
    "MY_COMPUTE_POOL",
    entrypoint="<script_name>",
    stage_name="<payload_stage>",
    session=session,
    target_instances=<num_training_nodes>  # Specify the number of nodes
)
Copy

高级配置:使用 min_instances

为实现更灵活的资源管理,您可以使用可选 min_instances 参数来指定继续执行作业所需的最小实例数。如果已设置 min_instances,则只要最小节点数可用,即使该数量小于 target_instances,作业有效负载也会立即执行。

这在您需要执行以下操作时很有用:

  • 如果无法立即实现完整目标,则使用更少的节点开始训练

  • 减少计算池资源有限时的等待时间

  • 实施可适应不同资源可用性的容错工作流程

from snowflake.ml.jobs import remote

@remote(
    "MY_COMPUTE_POOL",
    stage_name="payload_stage",
    session=session,
    target_instances=5,  # Prefer 5 nodes
    min_instances=3      # But start with at least 3 nodes
)
def flexible_distributed_training(data_table: str):
    import ray

    # Check how many nodes we actually got
    available_nodes = len(ray.nodes())
    print(f"Training with {available_nodes} nodes")

    # Adapt your training logic based on available resources
    from snowflake.ml.modeling.distributors.xgboost import XGBEstimator, XGBScalingConfig

    scaling_config = XGBScalingConfig(
        num_workers=available_nodes
    )

    estimator = XGBEstimator(
        n_estimators=100,
        params={"objective": "reg:squarederror"},
        scaling_config=scaling_config
    )

    # Train using available distributed resources
    model = estimator.fit(data_connector, input_cols=feature_cols, label_col="target")

job = flexible_distributed_training("<my_training_data>")
Copy

管理多节点作业

监控作业状态

作业状态监控与单节点作业相同:

from snowflake.ml.jobs import MLJob, get_job, list_jobs

# List all jobs
jobs = list_jobs()

# Retrieve an existing job based on ID
job = get_job("<job_id>")  # job is an MLJob instance

# Basic job information
print(f"Job ID: {job.id}")
print(f"Status: {job.status}")  # PENDING, RUNNING, FAILED, DONE

# Wait for completion
job.wait()
Copy

按节点访问日志

在多节点作业中,您可以访问来自特定实例的日志:

# Get logs from the default (head) instance
logs_default = job.get_logs()

# Get logs from specific instances by ID
logs_instance0 = job.get_logs(instance_id=0)
logs_instance1 = job.get_logs(instance_id=1)
logs_instance2 = job.get_logs(instance_id=2)

# Display logs in the notebook/console
job.show_logs()  # Default (head) instance logs
job.show_logs(instance_id=0)  # Instance 0 logs (not necessarily the head node)
Copy

已知问题和限制

使用以下信息来解决您可能遇到的常见问题。

  • 节点连接失败:如果工作节点无法连接到头节点,则头节点可以完成其任务,然后在工作器完成其任务之前自行关闭。为避免连接失败,请在作业中实施结果收集逻辑。

  • 内存耗尽:如果作业由于内存问题而失败,请增加节点大小或使用更多节点,且每个节点的数据量更少。

  • 节点可用性超时:如果所需数量的实例(target_instancesmin_instances)在预定义的超时时间内不可用,则作业将失败。确保您的计算池有足够的容量或调整您的实例要求。

语言: 中文