Snowflake 数据集

数据集是专为机器学习工作流程设计的全新 Snowflake 架构级对象。Snowflake 数据集保存按版本组织的数据集合,其中每个版本都保存数据的物化快照,保证了不可变性、高效的数据访问以及与常用的深度学习框架的互操作性。

备注

虽然数据集是 SQL 对象,但只能与 Snowpark ML 一起使用。数据集不会出现在 Snowsight 数据库对象资源管理器中,不能使用 SQL 命令来处理数据集。

在以下情况下,应该使用 Snowflake 数据集:

  • 需要管理大型数据集并对其进行版本控制,以便进行可重复的机器学习模型训练和测试。

  • 希望利用 Snowflake 可扩展且安全的数据存储和处理功能。

  • 分布式训练或数据流式处理需要精细的文件级访问和/或数据混洗。

  • 需要与外部机器学习框架和工具集成。

备注

物化数据集会产生存储成本。为尽量减少这些成本,请删除未使用的数据集。

安装

从 1.5.0 版开始,Dataset Python SDK 已包含在 Snowpark ML(Python 包 snowflake-ml-python)中。有关安装说明,请参阅 在本地使用 Snowflake ML

所需权限

创建数据集需要 CREATE DATASET 架构级权限。修改数据集,例如添加或删除数据集版本需要数据集的 OWNERSHIP。从数据集读取数据只需要数据集的 USAGE 权限(或 OWNERSHIP)。有关在 Snowflake 中授予权限的详细信息,请参阅 GRANT <privileges>

小技巧

使用 setup_feature_store 方法或 权限设置 SQL 脚本 为 Snowflake 特征商店设置权限还会设置数据集权限。如果已经通过上述方法之一设置了特征商店权限,则无需进一步操作。

创建和使用数据集

数据集是通过向 snowflake.ml.dataset.create_from_dataframe 函数传递 Snowpark DataFrame 创建的。

from snowflake import snowpark
from snowflake.ml import dataset

# Create Snowpark Session
# See https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-session
session = snowpark.Session.builder.configs(connection_parameters).create()

# Create a Snowpark DataFrame to serve as a data source
# In this example, we generate a random table with 100 rows and 1 column
df = session.sql(
    "select uniform(0, 10, random(1)) as x, uniform(0, 10, random(2)) as y from table(generator(rowcount => 100))"
)

# Materialize DataFrame contents into a Dataset
ds1 = dataset.create_from_dataframe(
    session,
    "my_dataset",
    "version1",
    input_dataframe=df)
Copy

数据集已进行版本控制。每个版本都是数据集所管理数据的不可变的时间点快照。Python API 包含一个 Dataset.selected_version 属性,用于指示是否选择使用给定数据集。dataset.create_from_dataframedataset.load_dataset 工厂方法会自动设置该属性,因此创建数据集时会自动选择创建的版本。Dataset.select_versionDataset.create_version 方法也可用于显式切换版本。从数据集读取时,会读取活动的选定版本。

# Inspect currently selected version
print(ds1.selected_version) # DatasetVersion(dataset='my_dataset', version='version1')
print(ds1.selected_version.created_on) # Prints creation timestamp

# List all versions in the Dataset
print(ds1.list_versions()) # ["version1"]

# Create a new version
ds2 = ds1.create_version("version2", df)
print(ds1.selected_version.name)  # "version1"
print(ds2.selected_version.name)  # "version2"
print(ds1.list_versions())        # ["version1", "version2"]

# selected_version is immutable, meaning switching versions with
# ds1.select_version() returns a new Dataset object without
# affecting ds1.selected_version
ds3 = ds1.select_version("version2")
print(ds1.selected_version.name)  # "version1"
print(ds3.selected_version.name)  # "version2"
Copy

从数据集读取数据

数据集版本数据以 Apache Parquet 格式存储为大小均匀的文件。Dataset 类提供了与 FileSet 类似的 API,用于从 Snowflake 数据集读取数据,包括 TensorFlow 和 PyTorch 的内置连接器。API 具有可扩展性,支持自定义框架连接器。

从数据集读取数据需要活动的选定版本。

连接至 TensorFlow

数据集可转换为 TensorFlow 的 tf.data.Dataset 并分批流式处理,以实现高效的训练和评估。

import tensorflow as tf

# Convert Snowflake Dataset to TensorFlow Dataset
tf_dataset = ds1.read.to_tf_dataset(batch_size=32)

# Train a TensorFlow model
for batch in tf_dataset:
    # Extract and build tensors as needed
    input_tensor = tf.stack(list(batch.values()), axis=-1)

    # Forward pass (details not included for brevity)
    outputs = model(input_tensor)
Copy

连接至 PyTorch

数据集还支持转换为 PyTorch DataPipes,而且可以分批流式处理,以实现高效的训练和评估。

import torch

# Convert Snowflake Dataset to PyTorch DataPipe
pt_datapipe = ds1.read.to_torch_datapipe(batch_size=32)

# Train a PyTorch model
for batch in pt_datapipe:
    # Extract and build tensors as needed
    input_tensor = torch.stack([torch.from_numpy(v) for v in batch.values()], dim=-1)

    # Forward pass (details not included for brevity)
    outputs = model(input_tensor)
Copy

连接至 Snowpark ML

数据集还可以转换回 Snowpark DataFrames,以便与 Snowpark ML 建模集成。转换后的 Snowpark DataFrame 与数据集创建时提供的 DataFrame 并不相同,而是指向数据集版本中的物化数据。

from snowflake.ml.modeling.ensemble import random_forest_regressor

# Get a Snowpark DataFrame
ds_df = ds1.read.to_snowpark_dataframe()

# Note ds_df != df
ds_df.explain()
df.explain()

# Train a model in Snowpark ML
xgboost_model = random_forest_regressor.RandomForestRegressor(
    n_estimators=100,
    random_state=42,
    input_cols=["X"],
    label_cols=["Y"],
)
xgboost_model.fit(ds_df)
Copy

直接文件访问

数据集 API 还公开了一个 fsspec (https://filesystem-spec.readthedocs.io/en/latest/) 接口,可用于与外部库(如 PyArrow、Dask 或任何其他支持 fsspec 并允许分布式和/或基于流的模型训练的包)构建自定义集成。

print(ds1.read.files()) # ['snow://dataset/my_dataset/versions/version1/data_0_0_0.snappy.parquet']

import pyarrow.parquet as pq
pd_ds = pq.ParquetDataset(ds1.read.files(), filesystem=ds1.read.filesystem())

import dask.dataframe as dd
dd_df = dd.read_parquet(ds1.read.files(), filesystem=ds1.read.filesystem())
Copy

目前的限制和已知问题

  • 数据集名称是 SQL 标识符,须遵守 Snowflake 标识符要求

  • 数据集版本为字符串,最大长度为 128 个字符。某些字符不允许使用,会生成错误消息。

  • 对具有宽架构(超过约 4000 列)的数据集的某些查询操作未全面优化。在即将发布的版本中,这种情况应该会有所改善。

语言: 中文