Python 用户定义的聚合函数

用户定义的聚合函数 (UDAFs) 一行或多行作为输入并生成单行输出。它们可对多行数值进行数学计算,如求和、求平均值、计数、求最小值或最大值、标准偏差和估算,还可进行一些非数学运算。

Python UDAFs 为您提供了一种编写自己的聚合函数的方法,这些聚合函数类似于 Snowflake 系统定义的 SQL 聚合函数

限制

  • aggregate_state 在序列化版本中的最大大小为 8 MB,因此请尝试控制汇总状态的大小。

  • 不能将 UDAF 作为 窗口函数 调用(换言之,使用 OVER 子句)。

  • 目前不支持来自 Python UDAFs 的事件表日志记录。

  • 在聚合函数上,当使用 AGGREGATE 参数时,不支持 IMMUTABLE。因此,所有聚合函数默认都是 VOLATILE。

  • 当使用带有 DISTINCT() 关键字的 UDAF 时,必须使用 UDAF 的完全限定路径。

    例如,以下操作将失败:

    SELECT python_udaf(DISTINCT(x)) FROM my_table;
    
    Copy

    以下操作将成功:

    SELECT testdb.schema.python_udaf(DISTINCT(x)) FROM my_table;
    
    Copy

聚合函数处理程序界面

聚合函数汇总子节点中的状态,然后最终将这些汇总状态序列化并发送到父节点,在那里对它们进行合并以及计算最终结果。

要定义聚合函数,您必须定义一个 Python 类(即函数的处理程序),其中包含 Snowflake 在运行时调用的方法。下表介绍了这些方法。请参阅本主题的其他部分示例。

方法

要求

描述

__init__

必填

将汇总的内部状态初始化。

aggregate_state

必填

返回汇总的当前状态。

  • 此方法必须有一个 @property 装饰器 (https://docs.python.org/3.8/library/functions.html#property)。

  • 汇总状态对象可以是任何可使用 ` Python pickle 库 <https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled (https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled)>`_ 序列化的 Python 数据类型。

  • 对于简单的汇总状态,可使用原始 Python 数据类型。对于更复杂的汇总状态,请使用 ` Python 数据类 <https://docs.python.org/3/library/dataclasses.html (https://docs.python.org/3/library/dataclasses.html)>`_。

accumulate

必填

根据新的输入行累积汇总的状态。

merge

必填

结合两个中间汇总状态。

finish

必填

根据汇总状态生成最终结果。

该图显示了在子节点中累积的输入值,然后被发送到父节点并合并以产生最终结果。

示例:计算总和

以下示例中的代码定义了 python_sum 用户定义的聚合函数 (UDAF) 以返回数值的总和。

  1. 创建 UDAF。

    CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT)
    RETURNS INT
    LANGUAGE PYTHON
    RUNTIME_VERSION=3.8
    handler = 'PythonSum'
    AS $$
    class PythonSum:
      def __init__(self):
        # This aggregate state is a primitive Python data type.
        self._partial_sum = 0
    
      @property
      def aggregate_state(self):
        return self._partial_sum
    
      def accumulate(self, input_value):
        self._partial_sum += input_value
    
      def merge(self, other_partial_sum):
        self._partial_sum += other_partial_sum
    
      def finish(self):
        return self._partial_sum
    $$;
    
    Copy
  2. 创建测试数据表。

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    SELECT * FROM sales;
    
    Copy
  3. 调用 python_sum UDAF。

    SELECT python_sum(price) FROM sales;
    
    Copy
  4. 将结果与 Snowflake 系统定义的 SQL 函数 SUM 进行比较,并看到结果是相同的。

    SELECT sum(col) FROM sales;
    
    Copy
  5. 根据销售表中的商品类型对销售额进行分组。

    SELECT item, python_sum(price) FROM sales GROUP BY item;
    
    Copy

示例:计算平均值

以下示例中的代码定义了 python_avg 用户定义的聚合函数以返回数值的总和。

  1. 创建函数。

    CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
    RETURNS FLOAT
    LANGUAGE PYTHON
    RUNTIME_VERSION = 3.8
    HANDLER = 'PythonAvg'
    AS $$
    from dataclasses import dataclass
    
    @dataclass
    class AvgAggState:
        count: int
        sum: int
    
    class PythonAvg:
        def __init__(self):
            # This aggregate state is an object data type.
            self._agg_state = AvgAggState(0, 0)
    
        @property
        def aggregate_state(self):
            return self._agg_state
    
        def accumulate(self, input_value):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            self._agg_state.sum = sum + input_value
            self._agg_state.count = count + 1
    
        def merge(self, other_agg_state):
            sum = self._agg_state.sum
            count = self._agg_state.count
    
            other_sum = other_agg_state.sum
            other_count = other_agg_state.count
    
            self._agg_state.sum = sum + other_sum
            self._agg_state.count = count + other_count
    
        def finish(self):
            sum = self._agg_state.sum
            count = self._agg_state.count
            return sum / count
    $$;
    
    Copy
  2. 创建测试数据表。

    CREATE OR REPLACE TABLE sales(item STRING, price INT);
    INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
    
    Copy
  3. 调用 python_avg 用户定义的函数。

    SELECT python_avg(price) FROM sales;
    
    Copy
  4. 将结果与 Snowflake 系统定义的 SQL 函数 AVG 进行比较,并看到结果是相同的。

    SELECT avg(price) FROM sales;
    
    Copy
  5. 在销售表中按商品类型对平均值进行分组。

    SELECT item, python_avg(price) FROM sales GROUP BY item;
    
    Copy

示例:仅返回唯一值

以下示例中的代码采用一个数组,并返回一个仅包含唯一值的数组。

CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
    def __init__(self):
        self._agg_state = set()

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        self._agg_state.update(input)

    def merge(self, other_agg_state):
        self._agg_state.update(other_agg_state)

    def finish(self):
        return list(self._agg_state)
$$;


CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');

SELECT * FROM array_table;

SELECT pythonGetUniqueValues(x) FROM array_table;
Copy

示例:返回字符串计数

以下示例中的代码返回一个对象中所有字符串实例的计数。

CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict

class PythonMapCount:
    def __init__(self):
        self._agg_state = defaultdict(int)

    @property
    def aggregate_state(self):
        return self._agg_state

    def accumulate(self, input):
        # Increment count of lowercase input
        self._agg_state[input.lower()] += 1

    def merge(self, other_agg_state):
        for item, count in other_agg_state.items():
            self._agg_state[item] += count

    def finish(self):
        return dict(self._agg_state)
$$;

CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));

SELECT pythonMapCount(x) FROM string_table;
Copy

示例:返回前 k 个最大值

以下示例中的代码返回 k 的前几个最大值的列表。此代码在最小堆上累积否定的输入值,然后返回前几个 k 最大值。

CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List

@dataclass
class AggState:
    minheap: List[int]
    k: int

class PythonTopK:
    def __init__(self):
        self._agg_state = AggState([], 0)

    @property
    def aggregate_state(self):
        return self._agg_state

    @staticmethod
    def get_top_k_items(minheap, k):
      # Return k smallest elements if there are more than k elements on the min heap.
      if (len(minheap) > k):
        return [heapq.heappop(minheap) for i in range(k)]
      return minheap

    def accumulate(self, input, k):
        self._agg_state.k = k

        # Store the input as negative value, as heapq is a min heap.
        heapq.heappush(self._agg_state.minheap, -input)

        # Store only top k items on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def merge(self, other_agg_state):
        k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k

        # Merge two min heaps by popping off elements from one and pushing them onto another.
        while(len(other_agg_state.minheap) > 0):
            heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))

        # Store only k elements on the min heap.
        self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)

    def finish(self):
        return [-x for x in self._agg_state.minheap]
$$;


CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));

-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;
Copy
语言: 中文