Python 用户定义的聚合函数¶
用户定义的聚合函数 (UDAFs) 一行或多行作为输入并生成单行输出。它们可对多行数值进行数学计算,如求和、求平均值、计数、求最小值或最大值、标准偏差和估算,还可进行一些非数学运算。
Python UDAFs 为您提供了一种编写自己的聚合函数的方法,这些聚合函数类似于 Snowflake 系统定义的 SQL 聚合函数。
限制¶
aggregate_state
在序列化版本中的最大大小为 8 MB,因此请尝试控制汇总状态的大小。不能将 UDAF 作为 窗口函数 调用(换言之,使用 OVER 子句)。
目前不支持来自 Python UDAFs 的事件表日志记录。
在聚合函数上,当使用 AGGREGATE 参数时,不支持 IMMUTABLE。因此,所有聚合函数默认都是 VOLATILE。
当使用带有 DISTINCT() 关键字的 UDAF 时,必须使用 UDAF 的完全限定路径。
例如,以下操作将失败:
SELECT python_udaf(DISTINCT(x)) FROM my_table;
以下操作将成功:
SELECT testdb.schema.python_udaf(DISTINCT(x)) FROM my_table;
聚合函数处理程序界面¶
聚合函数汇总子节点中的状态,然后最终将这些汇总状态序列化并发送到父节点,在那里对它们进行合并以及计算最终结果。
要定义聚合函数,您必须定义一个 Python 类(即函数的处理程序),其中包含 Snowflake 在运行时调用的方法。下表介绍了这些方法。请参阅本主题的其他部分示例。
方法 |
要求 |
描述 |
---|---|---|
|
必填 |
将汇总的内部状态初始化。 |
|
必填 |
返回汇总的当前状态。
|
|
必填 |
根据新的输入行累积汇总的状态。 |
|
必填 |
结合两个中间汇总状态。 |
|
必填 |
根据汇总状态生成最终结果。 |
示例:计算总和¶
以下示例中的代码定义了 python_sum
用户定义的聚合函数 (UDAF) 以返回数值的总和。
创建 UDAF。
CREATE OR REPLACE AGGREGATE FUNCTION PYTHON_SUM(a INT) RETURNS INT LANGUAGE PYTHON RUNTIME_VERSION=3.8 handler = 'PythonSum' AS $$ class PythonSum: def __init__(self): # This aggregate state is a primitive Python data type. self._partial_sum = 0 @property def aggregate_state(self): return self._partial_sum def accumulate(self, input_value): self._partial_sum += input_value def merge(self, other_partial_sum): self._partial_sum += other_partial_sum def finish(self): return self._partial_sum $$;
创建测试数据表。
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000); SELECT * FROM sales;
调用
python_sum
UDAF。SELECT python_sum(price) FROM sales;
将结果与 Snowflake 系统定义的 SQL 函数 SUM 进行比较,并看到结果是相同的。
SELECT sum(col) FROM sales;
根据销售表中的商品类型对销售额进行分组。
SELECT item, python_sum(price) FROM sales GROUP BY item;
示例:计算平均值¶
以下示例中的代码定义了 python_avg
用户定义的聚合函数以返回数值的总和。
创建函数。
CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT) RETURNS FLOAT LANGUAGE PYTHON RUNTIME_VERSION = 3.8 HANDLER = 'PythonAvg' AS $$ from dataclasses import dataclass @dataclass class AvgAggState: count: int sum: int class PythonAvg: def __init__(self): # This aggregate state is an object data type. self._agg_state = AvgAggState(0, 0) @property def aggregate_state(self): return self._agg_state def accumulate(self, input_value): sum = self._agg_state.sum count = self._agg_state.count self._agg_state.sum = sum + input_value self._agg_state.count = count + 1 def merge(self, other_agg_state): sum = self._agg_state.sum count = self._agg_state.count other_sum = other_agg_state.sum other_count = other_agg_state.count self._agg_state.sum = sum + other_sum self._agg_state.count = count + other_count def finish(self): sum = self._agg_state.sum count = self._agg_state.count return sum / count $$;
创建测试数据表。
CREATE OR REPLACE TABLE sales(item STRING, price INT); INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
调用
python_avg
用户定义的函数。SELECT python_avg(price) FROM sales;
将结果与 Snowflake 系统定义的 SQL 函数 AVG 进行比较,并看到结果是相同的。
SELECT avg(price) FROM sales;
在销售表中按商品类型对平均值进行分组。
SELECT item, python_avg(price) FROM sales GROUP BY item;
示例:仅返回唯一值¶
以下示例中的代码采用一个数组,并返回一个仅包含唯一值的数组。
CREATE OR REPLACE AGGREGATE FUNCTION pythonGetUniqueValues(input ARRAY)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonGetUniqueValues'
AS $$
class PythonGetUniqueValues:
def __init__(self):
self._agg_state = set()
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
self._agg_state.update(input)
def merge(self, other_agg_state):
self._agg_state.update(other_agg_state)
def finish(self):
return list(self._agg_state)
$$;
CREATE OR REPLACE TABLE array_table(x array) AS
SELECT ARRAY_CONSTRUCT(0, 1, 2, 3, 4, 'foo', 'bar', 'snowflake') UNION ALL
SELECT ARRAY_CONSTRUCT(1, 3, 5, 7, 9, 'foo', 'barbar', 'snowpark') UNION ALL
SELECT ARRAY_CONSTRUCT(0, 2, 4, 6, 8, 'snow');
SELECT * FROM array_table;
SELECT pythonGetUniqueValues(x) FROM array_table;
示例:返回字符串计数¶
以下示例中的代码返回一个对象中所有字符串实例的计数。
CREATE OR REPLACE AGGREGATE FUNCTION pythonMapCount(input STRING)
RETURNS OBJECT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonMapCount'
AS $$
from collections import defaultdict
class PythonMapCount:
def __init__(self):
self._agg_state = defaultdict(int)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input):
# Increment count of lowercase input
self._agg_state[input.lower()] += 1
def merge(self, other_agg_state):
for item, count in other_agg_state.items():
self._agg_state[item] += count
def finish(self):
return dict(self._agg_state)
$$;
CREATE OR REPLACE TABLE string_table(x STRING);
INSERT INTO string_table SELECT 'foo' FROM TABLE(GENERATOR(ROWCOUNT => 1000));
INSERT INTO string_table SELECT 'bar' FROM TABLE(GENERATOR(ROWCOUNT => 2000));
INSERT INTO string_table SELECT 'snowflake' FROM TABLE(GENERATOR(ROWCOUNT => 50));
INSERT INTO string_table SELECT 'snowpark' FROM TABLE(GENERATOR(ROWCOUNT => 123));
INSERT INTO string_table SELECT 'SnOw' FROM TABLE(GENERATOR(ROWCOUNT => 1));
INSERT INTO string_table SELECT 'snow' FROM TABLE(GENERATOR(ROWCOUNT => 4));
SELECT pythonMapCount(x) FROM string_table;
示例:返回前 k 个最大值¶
以下示例中的代码返回 k
的前几个最大值的列表。此代码在最小堆上累积否定的输入值,然后返回前几个 k
最大值。
CREATE OR REPLACE AGGREGATE FUNCTION pythonTopK(input INT, k INT)
RETURNS ARRAY
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonTopK'
AS $$
import heapq
from dataclasses import dataclass
import itertools
from typing import List
@dataclass
class AggState:
minheap: List[int]
k: int
class PythonTopK:
def __init__(self):
self._agg_state = AggState([], 0)
@property
def aggregate_state(self):
return self._agg_state
@staticmethod
def get_top_k_items(minheap, k):
# Return k smallest elements if there are more than k elements on the min heap.
if (len(minheap) > k):
return [heapq.heappop(minheap) for i in range(k)]
return minheap
def accumulate(self, input, k):
self._agg_state.k = k
# Store the input as negative value, as heapq is a min heap.
heapq.heappush(self._agg_state.minheap, -input)
# Store only top k items on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def merge(self, other_agg_state):
k = self._agg_state.k if self._agg_state.k > 0 else other_agg_state.k
# Merge two min heaps by popping off elements from one and pushing them onto another.
while(len(other_agg_state.minheap) > 0):
heapq.heappush(self._agg_state.minheap, heapq.heappop(other_agg_state.minheap))
# Store only k elements on the min heap.
self._agg_state.minheap = self.get_top_k_items(self._agg_state.minheap, k)
def finish(self):
return [-x for x in self._agg_state.minheap]
$$;
CREATE OR REPLACE TABLE numbers_table(num_column INT);
INSERT INTO numbers_table SELECT 5 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 1 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 9 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 7 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 10 FROM TABLE(GENERATOR(ROWCOUNT => 10));
INSERT INTO numbers_table SELECT 3 FROM TABLE(GENERATOR(ROWCOUNT => 10));
-- Return top 15 largest values from numbers_table.
SELECT pythonTopK(num_column, 15) FROM numbers_table;