Snowflake 数据集

数据集是专为机器学习工作流程设计的全新 Snowflake 架构级对象。Snowflake 数据集保存按版本组织的数据集合,其中每个版本都保存数据的物化快照,保证了不可变性、高效的数据访问以及与常用的深度学习框架的互操作性。

在以下情况下使用 Snowflake 数据集:

  • 需要管理大型数据集并对其进行版本控制,以便进行可重复的机器学习模型训练和测试。

  • 分布式训练或数据流式处理需要精细的文件级访问和/或数据混洗。

  • 需要与外部机器学习框架和工具集成。

  • 您需要跟踪用于创建 ML 模型的沿袭。

数据集是物化的数据对象。您可以使用 Snowflake ML 或 SQL 命令与它们交互。它们不会出现在 Snowsight 数据数据库对象资源管理器中。

备注

安装

从 1.7.5 版开始,Dataset Python SDK 已包含在 Snowpark ML(Python 包 snowflake-ml-python)中。有关安装说明,请参阅 在本地使用 Snowflake ML

所需权限

创建数据集需要 CREATE DATASET 架构级权限。修改数据集,例如添加或删除数据集版本需要数据集的 OWNERSHIP。从数据集读取数据只需要数据集的 USAGE 权限(或 OWNERSHIP)。有关在 Snowflake 中授予权限的详细信息,请参阅 GRANT <privileges>

小技巧

使用 setup_feature_store 方法或 权限设置 SQL 脚本 为 Snowflake 特征商店设置权限还会设置数据集权限。如果已经通过上述方法之一设置了特征商店权限,则无需进一步操作。

创建和使用数据集

您可以使用 SQL 或 Python 创建和管理数据集。有关使用 SQL 命令的更多信息,请参阅 SQL 命令。有关使用 Python API 的信息,请参阅 snowflake.ml.dataset

通过将 Snowpark DataFrame 传递给 snowflake.ml.dataset.create_from_dataframe 函数来创建数据集。

from snowflake import snowpark
from snowflake.ml import dataset

# Create Snowpark Session
# See https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-session
session = snowpark.Session.builder.configs(connection_parameters).create()

# Create a Snowpark DataFrame to serve as a data source
# In this example, we generate a random table with 100 rows and 1 column
df = session.sql(
  "select uniform(0, 10, random(1)) as x, uniform(0, 10, random(2)) as y from table(generator(rowcount => 100))"
)

# Materialize DataFrame contents into a Dataset
ds1 = dataset.create_from_dataframe(
    session,
    "my_dataset",
    "version1",
    input_dataframe=df)
Copy

数据集已进行版本控制。每个版本都是数据集所管理数据的不可变的时间点快照。Python API 包含一个 Dataset.selected_version 属性,用于指示是否选择使用给定数据集。dataset.create_from_dataframedataset.load_dataset 工厂方法会自动设置该属性,因此创建数据集时会自动选择创建的版本。Dataset.select_versionDataset.create_version 方法也可用于显式切换版本。从数据集读取时,会读取活动的选定版本。

# Inspect currently selected version
print(ds1.selected_version) # DatasetVersion(dataset='my_dataset', version='version1')
print(ds1.selected_version.created_on) # Prints creation timestamp

# List all versions in the Dataset
print(ds1.list_versions()) # ["version1"]

# Create a new version
ds2 = ds1.create_version("version2", df)
print(ds1.selected_version.name)  # "version1"
print(ds2.selected_version.name)  # "version2"
print(ds1.list_versions())        # ["version1", "version2"]

# selected_version is immutable, meaning switching versions with
# ds1.select_version() returns a new Dataset object without
# affecting ds1.selected_version
ds3 = ds1.select_version("version2")
print(ds1.selected_version.name)  # "version1"
print(ds3.selected_version.name)  # "version2"
Copy

从数据集读取数据

数据集版本数据以 Apache Parquet 格式存储为大小均匀的文件。Dataset 类提供了与 FileSet 类似的 API,用于从 Snowflake 数据集读取数据,包括 TensorFlow 和 PyTorch 的内置连接器。API 具有可扩展性,支持自定义框架连接器。

从数据集读取数据需要活动的选定版本。

连接至 TensorFlow

数据集可转换为 TensorFlow 的 tf.data.Dataset 并分批流式处理,以实现高效的训练和评估。

import tensorflow as tf

# Convert Snowflake Dataset to TensorFlow Dataset
tf_dataset = ds1.read.to_tf_dataset(batch_size=32)

# Train a TensorFlow model
for batch in tf_dataset:
    # Extract and build tensors as needed
    input_tensor = tf.stack(list(batch.values()), axis=-1)

    # Forward pass (details not included for brevity)
    outputs = model(input_tensor)
Copy

连接至 PyTorch

数据集还支持转换为 PyTorch DataPipes,而且可以分批流式处理,以实现高效的训练和评估。

import torch

# Convert Snowflake Dataset to PyTorch DataPipe
pt_datapipe = ds1.read.to_torch_datapipe(batch_size=32)

# Train a PyTorch model
for batch in pt_datapipe:
    # Extract and build tensors as needed
    input_tensor = torch.stack([torch.from_numpy(v) for v in batch.values()], dim=-1)

    # Forward pass (details not included for brevity)
    outputs = model(input_tensor)
Copy

连接至 Snowpark ML

数据集还可以转换回 Snowpark DataFrames,以便与 Snowpark ML 建模集成。转换后的 Snowpark DataFrame 与数据集创建时提供的 DataFrame 并不相同,而是指向数据集版本中的物化数据。

from snowflake.ml.modeling.ensemble import random_forest_regressor

# Get a Snowpark DataFrame
ds_df = ds1.read.to_snowpark_dataframe()

# Note ds_df != df
ds_df.explain()
df.explain()

# Train a model in Snowpark ML
xgboost_model = random_forest_regressor.RandomForestRegressor(
    n_estimators=100,
    random_state=42,
    input_cols=["X"],
    label_cols=["Y"],
)
xgboost_model.fit(ds_df)
Copy

直接文件访问

数据集 API 还公开了一个 fsspec (https://filesystem-spec.readthedocs.io/en/latest/) 接口,可用于与外部库(如 PyArrow、Dask 或任何其他支持 fsspec 并允许分布式和/或基于流的模型训练的包)构建自定义集成。

print(ds1.read.files()) # ['snow://dataset/my_dataset/versions/version1/data_0_0_0.snappy.parquet']

import pyarrow.parquet as pq
pd_ds = pq.ParquetDataset(ds1.read.files(), filesystem=ds1.read.filesystem())

import dask.dataframe as dd
dd_df = dd.read_parquet(ds1.read.files(), filesystem=ds1.read.filesystem())
Copy

使用 SQL 从数据集版本读取

您可以使用标准 Snowflake SQL 命令从数据集版本读取数据。您可以使用 SQL 命令执行以下操作:

  • 列出文件

  • 推断架构

  • 直接从暂存区查询数据。

重要

您必须拥有对数据集的 USAGE 或 OWNERSHIP 权限才能从中读取。

列出数据集版本中的文件

使用 LIST snow_url 命令列出数据集版本中的文件。使用以下 SQL 语法列出数据集版本内的所有文件:

LIST 'snow://dataset/<dataset_name>/versions/<dataset_version>'
Copy

分析文件并获取列定义

使用 INFER_SCHEMA 函数分析数据集版本中的文件并检索列定义。使用以下 SQL 示例列出数据集版本内的所有文件:

INFER_SCHEMA(
  LOCATION => 'snow://dataset/<dataset_name>/versions/<dataset_version>',
  FILE_FORMAT => '<file_format_name>'
)
Copy

您必须使用示例中指定的模式来获取数据集版本的位置。

对于 FILE_FORMAT,请指定 PARQUET

以下示例创建文件格式并运行 INFER_SCHEMA 函数:

CREATE FILE FORMAT my_parquet_format TYPE = PARQUET;

SELECT *
FROM TABLE(
    INFER_SCHEMA(
        FILE_FORMAT => 'snow://dataset/MYDS/versions/v1,
        FILE_FORMAT => 'my_parquet_format'
    )
);
Copy

暂存区查询

直接从存储在数据集版本中的文件查询数据,方法与查询外部表类似。使用以下 SQL 示例帮助您入门:

SELECT $1
FROM 'snow://dataset/foo/versions/V1'
( FILE_FORMAT => 'my_parquet_format',
PATTERN => '.*data.*' ) t;
Copy

SQL 命令

您可以使用 SQL 命令创建和管理数据集。有关更多信息,请参阅:

目前的限制和已知问题

  • 数据集名称是 SQL 标识符,须遵守 Snowflake 标识符要求

  • 数据集版本为字符串,最大长度为 128 个字符。某些字符不允许使用,会生成错误消息。

  • 对具有宽架构(超过约 4000 列)的数据集的某些查询操作未全面优化。在即将发布的版本中,这种情况应该会有所改善。

语言: 中文