在 Python 中为 DataFrames 创建用户定义的表函数 (UDTFs)

Snowpark API 提供了一些方法,您可以使用这些方法通过用 Python 编写的处理程序创建用户定义的表函数。本主题说明如何创建这些类型的函数。

本主题内容:

简介

您可以使用 Snowpark API 创建用户定义的表函数 (UDTF)。

执行此操作的方式与使用 API 创建标量用户定义函数 (UDF) 类似,如 在 Python 中为 DataFrames 创建用户定义函数 (UDFs) 中所述。主要区别包括注册 UDTF 时所需的 UDF 处理程序要求和参数值。

要在 Snowpark 中创建和注册 UDTF,必须执行以下操作:

  • 实施 UDTF 处理程序。

    该处理程序包含 UDTF 的逻辑。UDTF 处理程序必须实施 Snowflake 在调用 UDTF 时将在运行时调用的函数。有关更多信息,请参阅 实施 UDTF 处理程序

  • 在 Snowflake 数据库中注册 UDTF 及其处理程序。

    您可以使用 Snowpark API 注册 UDTF 及其处理程序。注册 UDTF 后,您可以通过 SQL 或使用 Snowpark API 进行调用。有关注册的更多信息,请参阅 注册 UDTF

有关调用 UDTF 的信息,请参阅 调用用户定义的表函数 (UDTFs)

实施 UDTF 处理程序

正如 用 Python 编写 UDTF 中详细描述的那样,UDTF 处理程序类必须实施 Snowflake 在调用 UDTF 时调用的方法。无论您是使用 Snowpark API 注册 UDTF,还是使用 CREATE FUNCTION 语句通过 SQL 创建,都可以使用编写的类作为处理程序。

处理程序类的方法旨在处理 UDTF 接收的行和分区。

UDTF 处理程序类实施了以下方法,Snowflake 在运行时会调用这些方法:

  • __init__ 方法。可选。调用以初始化输入分区的有状态处理。

  • process 方法。必填。为每个输入行调用。该方法以元组的形式返回表格值。

  • end_partition 方法。可选。调用以完成输入分区的处理。

    虽然 Snowflake 支持大型分区,会调整超时以成功处理分区,但特别大的分区可能导致处理超时(例如 end_partition 需要太长时间才能完成)。如果您需要针对特定使用场景调整超时阈值,请联系 Snowflake 支持部门 (https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge)。

有关处理程序的详细信息和示例,请参阅 用 Python 编写 UDTF

注册 UDTF

实施 UDTF 处理程序后,可以使用 Snowpark API 在 Snowflake 数据库上注册 UDTF。注册 UDTF 将创建 UDTF,以便可以调用它。

您可以像注册标量 UDF 一样,将 UDTF 注册为命名函数或匿名函数。有关注册标量 UDF 的相关信息,请参阅 创建匿名 UDF创建和注册命名的 UDF

注册 UDTF 时,指定 Snowflake 创建 UDTF 所需的参数值。(其中许多参数在功能上对应 SQL 中 CREATE FUNCTION 语句的子句。有关更多信息,请参阅 CREATE FUNCTION。)

这些参数中的大多数与创建标量 UDF 时指定的参数相同(有关更多信息,请参阅 在 Python 中为 DataFrames 创建用户定义函数 (UDFs))。主要区别在于 UDTF 返回表格值,以及它的处理程序是类,而不是函数。有关参数的完整列表,请参阅下面链接的 APIs 文档。

要在 Snowpark 中注册 UDTF,可以使用以下方法之一,指定在数据库中创建 UDTF 所需的参数值。有关区分这些选项的信息,请参阅 UDFRegistration,其中描述了注册标量 UDF 的类似选项。

定义 UDTF 的输入类型和输出架构

注册 UDTF 时,需要指定有关该函数的参数和输出值的详细信息。这样做是为了使函数本身声明的类型与该函数的底层处理程序的类型精确对应。

有关示例,请参阅 示例 (在本主题和 snowflake.snowpark.udtf.UDTFRegistration 参考中)。

注册 UDTF 时需要为其指定以下内容:

  • 其输入参数的类型,作为注册函数的 input_types 参数值。如果您在 process 方法的声明中提供类型提示,则该 input_types 参数是可选的。

    将此值指定为基于 `snowflake.snowpark.types 的类型列表。DataType<https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/types>`_. 例如,可以指定 input_types=[StringType(), IntegerType()]

  • 其表格输出的架构,作为注册函数的 output_schema 参数值。

    output_schema 可以是以下其中一项:

    • UDTF 的返回值中列的名称列表。

      该列表将仅包含列名,因此还必须在 process 方法的声明中提供类型提示。

    • 表示输出表的列名*和*类型的 StructType

      以下示例中的代码将架构作为值分配给 output 变量,然后在注册 UDTF 时使用该变量。

      >>> from snowflake.snowpark.types import StructField, StructType, StringType, IntegerType, FloatType
      >>> from snowflake.snowpark.functions import udtf, table_function
      >>> schema = StructType([
      ...     StructField("symbol", StringType())
      ...     StructField("cost", IntegerType()),
      ... ])
      >>> @udtf(output_schema=schema,input_types=[StringType(), IntegerType(), FloatType()],stage_location="straut_udf",is_permanent=True,name="test_udtf",replace=True)
      ... class StockSale:
      ...     def process(self, symbol, quantity, price):
      ...         cost = quantity * price
      ...         yield (symbol, cost)
      
      Copy

示例

以下是示例的简要列表。有关更多示例,请参阅 snowflake.snowpark.udtf.UDTFRegistration

使用 udtf 函数注册 UDTF

注册该函数。

>>> from snowflake.snowpark.types import IntegerType, StructField, StructType
>>> from snowflake.snowpark.functions import udtf, lit
>>> class GeneratorUDTF:
...     def process(self, n):
...         for i in range(n):
...             yield (i, )
>>> generator_udtf = udtf(GeneratorUDTF, output_schema=StructType([StructField("number", IntegerType())]), input_types=[IntegerType()])
Copy

调用该函数。

>>> session.table_function(generator_udtf(lit(3))).collect()  # Query it by calling it
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
>>> session.table_function(generator_udtf.name, lit(3)).collect()  # Query it by using the name
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
Copy

使用 register 函数注册 UDTF

注册该函数。

>>> from collections import Counter
>>> from typing import Iterable, Tuple
>>> from snowflake.snowpark.functions import lit
>>> class MyWordCount:
...     def __init__(self):
...         self._total_per_partition = 0
...
...     def process(self, s1: str) -> Iterable[Tuple[str, int]]:
...         words = s1.split()
...         self._total_per_partition = len(words)
...         counter = Counter(words)
...         yield from counter.items()
...
...     def end_partition(self):
...         yield ("partition_total", self._total_per_partition)
>>> udtf_name = "word_count_udtf"
>>> word_count_udtf = session.udtf.register(
...     MyWordCount, ["word", "count"], name=udtf_name, is_permanent=False, replace=True
... )
Copy

调用该函数。

>>> # Call it by its name
>>> df1 = session.table_function(udtf_name, lit("w1 w2 w2 w3 w3 w3"))
>>> df1.show()
-----------------------------
|"WORD"           |"COUNT"  |
-----------------------------
|w1               |1        |
|w2               |2        |
|w3               |3        |
|partition_total  |6        |
-----------------------------
Copy

使用 register_from_file 函数注册 UDTF

注册该函数。

>>> from snowflake.snowpark.types import IntegerType, StructField, StructType
>>> from snowflake.snowpark.functions import udtf, lit
>>> _ = session.sql("create or replace temp stage mystage").collect()
>>> _ = session.file.put("tests/resources/test_udtf_dir/test_udtf_file.py", "@mystage", auto_compress=False)
>>> generator_udtf = session.udtf.register_from_file(
...     file_path="@mystage/test_udtf_file.py",
...     handler_name="GeneratorUDTF",
...     output_schema=StructType([StructField("number", IntegerType())]),
...     input_types=[IntegerType()]
... )
Copy

调用该函数。

>>> session.table_function(generator_udtf(lit(3))).collect()
[Row(NUMBER=0), Row(NUMBER=1), Row(NUMBER=2)]
Copy
语言: 中文