教程:测试 Python Snowpark

简介

本教程介绍测试 Snowpark Python 代码的基础知识。

您将学习的内容

在本教程中,您将学习如何进行以下操作:

  • 在连接到 Snowflake 时测试您的 Snowpark 代码。

    您可以使用标准测试工具(例如 PyTest)来测试您的 Snowpark Python UDFs、DataFrame 转换和存储过程。

  • 使用本地测试框架在本地测试您的 Snowpark Python DataFrames,而无需连接到 Snowflake 账户。

    在部署代码更改之前,您可以使用本地测试框架在您的开发机器上进行本地测试。

先决条件

要使用本地测试框架,请执行以下操作:

  • 您必须使用版本 1.11.1 或更高版本的 Snowpark Python 库。

  • 受支持的 Python 版本包括:

    • 3.8

    • 3.9

    • 3.10

    • 3.11

设置项目

在本部分中,您将克隆项目存储库并设置教程所需的环境。

  1. 克隆项目存储库。

    git clone https://github.com/Snowflake-Labs/sftutorial-snowpark-testing
    
    Copy

    如果没有安装 Git,请访问版本存储库页面,点击 Code » Download Contents 下载内容。

  2. 使用账户凭据设置环境变量。Snowpark API 将使用这些变量对您的 Snowflake 账户进行身份验证。

    # Linux/MacOS
    export SNOWSQL_ACCOUNT=<replace with your account identifer>
    export SNOWSQL_USER=<replace with your username>
    export SNOWSQL_ROLE=<replace with your role>
    export SNOWSQL_PWD=<replace with your password>
    export SNOWSQL_DATABASE=<replace with your database>
    export SNOWSQL_SCHEMA=<replace with your schema>
    export SNOWSQL_WAREHOUSE=<replace with your warehouse>
    
    Copy
    # Windows/PowerShell
    $env:SNOWSQL_ACCOUNT = "<replace with your account identifer>"
    $env:SNOWSQL_USER = "<replace with your username>"
    $env:SNOWSQL_ROLE = "<replace with your role>"
    $env:SNOWSQL_PWD = "<replace with your password>"
    $env:SNOWSQL_DATABASE = "<replace with your database>"
    $env:SNOWSQL_SCHEMA = "<replace with your schema>"
    $env:SNOWSQL_WAREHOUSE = "<replace with your warehouse>"
    
    Copy

    可选项:您可以通过编辑 bash 配置文件 (Linux/MacOS) 或使用 System Properties 菜单 (Windows) 永久设置该环境变量。

  3. 使用 Anaconda 创建并激活 Conda 环境:

    conda env create --file environment.yml
    conda activate snowpark-testing
    
    Copy
  4. 运行 setup/create_table.py,在账户中创建示例表。这个 Python 脚本将创建一个名为 CITIBIKE 的数据库、一个名为 PUBLIC 的架构和一个名为 TRIPS 的小型表。

    python setup/create_table.py
    
    Copy

现在您已准备好进入下一部分。在本部分中,您执行了以下操作:

  • 克隆了教程存储库。

  • 创建了包含账户信息的环境变量。

  • 为项目创建了 Conda 环境。

  • 使用 Snowpark API 连接到 Snowflake,并创建了示例数据库、架构和表。

尝试存储过程

示例项目包括一个存储过程处理程序 (sproc.py) 和三个DataFrames 转换器方法 (transformers.py)。存储过程处理程序使用 UDF 和 DataFrame 转换器读取来源表 CITIBIKE.PUBLIC.TRIPS,并创建两个事实表:MONTH_FACTSBIKE_FACTS

您可以通过运行此命令从命令行执行存储过程。

python project/sproc.py
Copy

现在您已经熟悉了项目,在下一部分中,您将设置测试目录,并为 Snowflake 会话创建 PyTest 夹具。

为 Snowflake 会话创建 PyTest 夹具

PyTest 夹具 (https://docs.pytest.org/en/6.2.x/fixture.html) 是在测试(或测试的模块)之前执行的函数,通常用于提供数据或连接测试。在本项目中,您将创建一个PyTest 夹具,它将返回 Snowpark Session 对象。您的测试用例将使用此会话连接到 Snowflake。

  1. 在项目根目录下创建 test 目录。

    mkdir test
    
    Copy
  2. test 目录下,新建一个名为 conftest.py 的 Python 文件。在 conftest.py 中,为 Session 对象创建PyTest 夹具:

    import pytest
    from project.utils import get_env_var_config
    from snowflake.snowpark.session import Session
    
    @pytest.fixture
    def session() -> Session:
        return Session.builder.configs(get_env_var_config()).create()
    
    Copy

为 DataFrame 转换器添加单元测试

  1. test 目录中,新建一个名为 test_transformers.py 的 Python 文件。

  2. test_transformers.py 文件中,导入转换器方法。

    # test/test_transformers.py
    
    from project.transformers import add_rider_age, calc_bike_facts, calc_month_facts
    
    Copy
  3. 接下来,为这些转换器创建单元测试。通常的惯例是为每个测试创建一个方法,名称为 test_<name of method>。在我们的案例中,测试内容如下:

    # test/test_transformers.py
    from project.transformers import add_rider_age, calc_bike_facts, calc_month_facts
    def test_add_rider_age(session):
        ...
    
    def test_calc_bike_facts(session):
        ...
    
    
    def test_calc_month_facts(session):
        ...
    
    Copy

    每个测试用例中的 session 参数指的是您在上一部分中创建的 PyTest 夹具。

  4. 现在为每个转换器执行测试用例。使用以下模式。

    1. 创建输入 DataFrame。

    2. 创建预期输出 DataFrame。

    3. 将步骤 1 的输入 DataFrame 传递给转换器方法。

    4. 将步骤 3 的输出与步骤 2 的预计输出进行比较。

    # test/test_transformers.py
    from project.transformers import add_rider_age, calc_bike_facts, calc_month_facts
    from snowflake.snowpark.types import StructType, StructField, IntegerType, FloatType
    
    def test_add_rider_age(session: Session):
        input = session.create_dataframe(
            [
                [1980],
                [1995],
                [2000]
            ],
            schema=StructType([StructField("BIRTH_YEAR", IntegerType())])
        )
    
        expected = session.create_dataframe(
            [
                [1980, 43],
                [1995, 28],
                [2000, 23]
            ],
            schema=StructType([StructField("BIRTH_YEAR", IntegerType()), StructField("RIDER_AGE", IntegerType())])
        )
    
        actual = add_rider_age(input)
        assert expected.collect() == actual.collect()
    
    
    def test_calc_bike_facts(session: Session):
        input = session.create_dataframe([
                [1, 10, 20],
                [1, 5, 30],
                [2, 20, 50],
                [2, 10, 60]
            ],
            schema=StructType([
                StructField("BIKEID", IntegerType()),
                StructField("TRIPDURATION", IntegerType()),
                StructField("RIDER_AGE", IntegerType())
            ])
        )
    
        expected = session.create_dataframe([
                [1, 2, 7.5, 25.0],
                [2, 2, 15.0, 55.0],
            ],
            schema=StructType([
                StructField("BIKEID", IntegerType()),
                StructField("COUNT", IntegerType()),
                StructField("AVG_TRIPDURATION", FloatType()),
                StructField("AVG_RIDER_AGE", FloatType())
            ])
        )
    
        actual = calc_bike_facts(input)
        assert expected.collect() == actual.collect()
    
    
    def test_calc_month_facts(session: Session):
        from patches import patch_to_timestamp
    
        input = session.create_dataframe(
            data=[
                ['2018-03-01 09:47:00.000 +0000', 1, 10,  15],
                ['2018-03-01 09:47:14.000 +0000', 2, 20, 12],
                ['2018-04-01 09:47:04.000 +0000', 3, 6,  30]
            ],
            schema=['STARTTIME', 'BIKE_ID', 'TRIPDURATION', 'RIDER_AGE']
        )
    
        expected = session.create_dataframe(
            data=[
                ['Mar', 2, 15, 13.5],
                ['Apr', 1, 6, 30.0]
            ],
            schema=['MONTH', 'COUNT', 'AVG_TRIPDURATION', 'AVG_RIDER_AGE']
        )
    
        actual = calc_month_facts(input)
    
        assert expected.collect() == actual.collect()
    
    Copy
  5. 现在可以运行 PyTest 来运行所有单元测试。

    pytest test/test_transformers.py
    
    Copy

为存储过程添加集成测试

我们已经对 DataFrame 转换器方法进行了单元测试,现在让我们为存储过程添加一个集成测试。测试用例将遵循该模式:

  1. 创建一个表,显示存储过程的输入数据。

  2. 创建两个 DataFrames,其中包含存储过程的两个输出表的预期内容。

  3. 调用存储过程。

  4. 将实际输出表与步骤 2 中的DataFrames 进行比较。

  5. 清理:删除步骤 1 的输入表和步骤 3 的输出表。

test 目录下创建名为 test_sproc.py 的 Python 文件。

从项目目录中导入存储过程手册,并创建测试用例。

# test/test_sproc.py
from project.sproc import create_fact_tables

def test_create_fact_tables(session):
    ...
Copy

实施测试用例,从创建输入表开始。

# test/test_sproc.py
from project.sproc import create_fact_tables
from snowflake.snowpark.types import *

def test_create_fact_tables(session):
    DB = 'CITIBIKE'
    SCHEMA = 'TEST'

    # Set up source table
    tbl = session.create_dataframe(
        data=[
            [1983, '2018-03-01 09:47:00.000 +0000', 551, 30958],
            [1988, '2018-03-01 09:47:01.000 +0000', 242, 19278],
            [1992, '2018-03-01 09:47:01.000 +0000', 768, 18461],
            [1980, '2018-03-01 09:47:03.000 +0000', 690, 15533],
            [1991, '2018-03-01 09:47:03.000 +0000', 490, 32449],
            [1959, '2018-03-01 09:47:04.000 +0000', 457, 29411],
            [1971, '2018-03-01 09:47:08.000 +0000', 279, 28015],
            [1964, '2018-03-01 09:47:09.000 +0000', 546, 15148],
            [1983, '2018-03-01 09:47:11.000 +0000', 358, 16967],
            [1985, '2018-03-01 09:47:12.000 +0000', 848, 20644],
            [1984, '2018-03-01 09:47:14.000 +0000', 295, 16365]
        ],
        schema=['BIRTH_YEAR', 'STARTTIME', 'TRIPDURATION',    'BIKEID'],
    )

    tbl.write.mode('overwrite').save_as_table([DB, SCHEMA, 'TRIPS_TEST'], mode='overwrite')
Copy

接下来,为预期输出表创建 DataFrames。

# test/test_sproc.py
from project.sproc import create_fact_tables
from snowflake.snowpark.types import *

def test_create_fact_tables(session):
    DB = 'CITIBIKE'
    SCHEMA = 'TEST'

    # Set up source table
    tbl = session.create_dataframe(
        data=[
            [1983, '2018-03-01 09:47:00.000 +0000', 551, 30958],
            [1988, '2018-03-01 09:47:01.000 +0000', 242, 19278],
            [1992, '2018-03-01 09:47:01.000 +0000', 768, 18461],
            [1980, '2018-03-01 09:47:03.000 +0000', 690, 15533],
            [1991, '2018-03-01 09:47:03.000 +0000', 490, 32449],
            [1959, '2018-03-01 09:47:04.000 +0000', 457, 29411],
            [1971, '2018-03-01 09:47:08.000 +0000', 279, 28015],
            [1964, '2018-03-01 09:47:09.000 +0000', 546, 15148],
            [1983, '2018-03-01 09:47:11.000 +0000', 358, 16967],
            [1985, '2018-03-01 09:47:12.000 +0000', 848, 20644],
            [1984, '2018-03-01 09:47:14.000 +0000', 295, 16365]
        ],
        schema=['BIRTH_YEAR', 'STARTTIME', 'TRIPDURATION',    'BIKEID'],
    )

    tbl.write.mode('overwrite').save_as_table([DB, SCHEMA, 'TRIPS_TEST'], mode='overwrite')

    # Expected values
    n_rows_expected = 12
    bike_facts_expected = session.create_dataframe(
        data=[
            [30958, 1, 551.0, 40.0],
            [19278, 1, 242.0, 35.0],
            [18461, 1, 768.0, 31.0],
            [15533, 1, 690.0, 43.0],
            [32449, 1, 490.0, 32.0],
            [29411, 1, 457.0, 64.0],
            [28015, 1, 279.0, 52.0],
            [15148, 1, 546.0, 59.0],
            [16967, 1, 358.0, 40.0],
            [20644, 1, 848.0, 38.0],
            [16365, 1, 295.0, 39.0]
        ],
        schema=StructType([
            StructField("BIKEID", IntegerType()),
            StructField("COUNT", IntegerType()),
            StructField("AVG_TRIPDURATION", FloatType()),
            StructField("AVG_RIDER_AGE", FloatType())
        ])
    ).collect()

    month_facts_expected = session.create_dataframe(
        data=[['Mar', 11, 502.18182, 43.00000]],
        schema=StructType([
            StructField("MONTH", StringType()),
            StructField("COUNT", IntegerType()),
            StructField("AVG_TRIPDURATION", DecimalType()),
            StructField("AVG_RIDER_AGE", DecimalType())
        ])
    ).collect()
Copy

最后,调用存储过程并读取输出表。将实际表格与 DataFrame 内容进行比较。

# test/test_sproc.py
from project.sproc import create_fact_tables
from snowflake.snowpark.types import *

def test_create_fact_tables(session):
    DB = 'CITIBIKE'
    SCHEMA = 'TEST'

    # Set up source table
    tbl = session.create_dataframe(
        data=[
            [1983, '2018-03-01 09:47:00.000 +0000', 551, 30958],
            [1988, '2018-03-01 09:47:01.000 +0000', 242, 19278],
            [1992, '2018-03-01 09:47:01.000 +0000', 768, 18461],
            [1980, '2018-03-01 09:47:03.000 +0000', 690, 15533],
            [1991, '2018-03-01 09:47:03.000 +0000', 490, 32449],
            [1959, '2018-03-01 09:47:04.000 +0000', 457, 29411],
            [1971, '2018-03-01 09:47:08.000 +0000', 279, 28015],
            [1964, '2018-03-01 09:47:09.000 +0000', 546, 15148],
            [1983, '2018-03-01 09:47:11.000 +0000', 358, 16967],
            [1985, '2018-03-01 09:47:12.000 +0000', 848, 20644],
            [1984, '2018-03-01 09:47:14.000 +0000', 295, 16365]
        ],
        schema=['BIRTH_YEAR', 'STARTTIME', 'TRIPDURATION',    'BIKEID'],
    )

    tbl.write.mode('overwrite').save_as_table([DB, SCHEMA, 'TRIPS_TEST'], mode='overwrite')

    # Expected values
    n_rows_expected = 12
    bike_facts_expected = session.create_dataframe(
        data=[
            [30958, 1, 551.0, 40.0],
            [19278, 1, 242.0, 35.0],
            [18461, 1, 768.0, 31.0],
            [15533, 1, 690.0, 43.0],
            [32449, 1, 490.0, 32.0],
            [29411, 1, 457.0, 64.0],
            [28015, 1, 279.0, 52.0],
            [15148, 1, 546.0, 59.0],
            [16967, 1, 358.0, 40.0],
            [20644, 1, 848.0, 38.0],
            [16365, 1, 295.0, 39.0]
        ],
        schema=StructType([
            StructField("BIKEID", IntegerType()),
            StructField("COUNT", IntegerType()),
            StructField("AVG_TRIPDURATION", FloatType()),
            StructField("AVG_RIDER_AGE", FloatType())
        ])
    ).collect()

    month_facts_expected = session.create_dataframe(
        data=[['Mar', 11, 502.18182, 43.00000]],
        schema=StructType([
            StructField("MONTH", StringType()),
            StructField("COUNT", IntegerType()),
            StructField("AVG_TRIPDURATION", DecimalType()),
            StructField("AVG_RIDER_AGE", DecimalType())
        ])
    ).collect()

    # Call sproc, get actual values
    n_rows_actual = create_fact_tables(session, 'TRIPS_TEST')
    bike_facts_actual = session.table([DB, SCHEMA, 'bike_facts']).collect()
    month_facts_actual = session.table([DB, SCHEMA, 'month_facts']).collect()

    # Comparisons
    assert n_rows_expected == n_rows_actual
    assert bike_facts_expected == bike_facts_actual
    assert month_facts_expected ==  month_facts_actual
Copy

要运行测试用例,请从终端运行 pytest

pytest test/test_sproc.py
Copy

要运行项目中的所有测试,请运行 pytest,无需其他选项。

pytest
Copy

配置本地测试

至此,您就拥有了 DataFrame 转换器和存储过程的PyTest 测试套件。在每个测试用例中,Session 夹具用于连接到您的 Snowflake 账户,从 Snowpark Python API 发送SQL,以及检索响应。

或者,您也可以使用本地测试框架在本地运行转换,无需连接到 Snowflake。在大型测试套件中,这可以大大加快测试执行速度。本部分介绍如何更新测试套件以使用本地测试框架功能。

  1. 首先更新 PyTest Session 夹具。我们将为 PyTest 添加一个命令行选项,以便在本地和实时测试模式之间切换。

    # test/conftest.py
    
    import pytest
    from project.utils import get_env_var_config
    from snowflake.snowpark.session import Session
    
    def pytest_addoption(parser):
        parser.addoption("--snowflake-session", action="store", default="live")
    
    @pytest.fixture(scope='module')
    def session(request) -> Session:
        if request.config.getoption('--snowflake-session') == 'local':
            return Session.builder.configs({'local_testing': True}).create()
        else:
            return Session.builder.configs(get_env_var_config()).create()
    
    Copy
  2. 我们必须首先对该方法进行修补,因为并非所有内置函数都支持本地测试框架,例如 calc_month_facts() 转换器中使用的 monthname() 函数。在测试目录下创建名为 patches.py 的文件。在该文件中,粘贴以下代码。

    from snowflake.snowpark.mock.functions import patch
    from snowflake.snowpark.functions import monthname
    from snowflake.snowpark.mock.snowflake_data_type import ColumnEmulator, ColumnType
    from snowflake.snowpark.types import StringType
    import datetime
    import calendar
    
    @patch(monthname)
    def patch_monthname(column: ColumnEmulator) -> ColumnEmulator:
        ret_column = ColumnEmulator(data=[
            calendar.month_abbr[datetime.datetime.strptime(row, '%Y-%m-%d %H:%M:%S.%f %z').month]
            for row in column])
        ret_column.sf_type = ColumnType(StringType(), True)
        return ret_column
    
    Copy

    上面的补丁只接受一个参数 column,它是一个类似 pandas.Series 的对象,包含列内的数据行。然后,我们使用来自 Python 模块 datetimecalendar 的方法组合来模拟内置 monthname() 列的功能。最后,我们将返回类型设置为 String,因为内置方法返回的是与月份相对应的字符串(“1 月”、“2 月”、“3 月”等)。

  3. 接下来,将此方法导入 DataFrame 转换器和存储过程的测试中。

    # test/test_transformers.py
    
    # No changes to the other unit test methods
    
    def test_calc_month_facts(request, session):
        # Add conditional to include the patch if local testing is being used
        if request.config.getoption('--snowflake-session') == 'local':
            from patches import patch_monthname
    
        # No other changes
    
    Copy
  4. 使用本地旗标重新运行 pytest

    pytest test/test_transformers.py --snowflake-session local
    
    Copy
  5. 现在将相同的补丁应用到存储过程测试中。

    #test/test_sproc.py
    
    def test_create_fact_tables(request, session):
        # Add conditional to include the patch if local testing is being used
        if request.config.getoption('--snowflake-session') == 'local':
            from patches import patch_monthname
    
        # No other changes required
    
    Copy
  6. 使用本地旗标重新运行 pytest。

    pytest test/test_sproc.py --snowflake-session local
    
    Copy
  7. 最后,让我们比较一下在本地和使用实时连接运行完整测试套件所需的时间。我们将使用 time 命令来测量这两个命令所需的时间。让我们从实时连接开始。

    time pytest
    
    Copy

    在这种情况下,测试套件的运行时间为 7.89 秒。(具体时间可能因电脑、网络连接和其他因素而异)。

    =================================== test session starts ==========================
    platform darwin -- Python 3.9.18, pytest-7.4.3, pluggy-1.3.0
    rootdir: /Users/jfreeberg/Desktop/snowpark-testing-tutorial
    configfile: pytest.ini
    collected 4 items
    
    test/test_sproc.py .                                                             [ 25%]
    test/test_transformers.py ...                                                    [100%]
    
    =================================== 4 passed in 6.86s =================================
    pytest  1.63s user 1.86s system 44% cpu 7.893 total
    

    现在,让我们试试本地测试框架:

    time pytest --snowflake-session local
    
    Copy

    在本地测试框架下,执行测试套件仅需 1 秒钟!

    ================================== test session starts ================================
    platform darwin -- Python 3.9.18, pytest-7.4.3, pluggy-1.3.0
    rootdir: /Users/jfreeberg/Desktop/snowpark-testing-tutorial
    configfile: pytest.ini
    collected 4 items
    
    test/test_sproc.py .                                                             [ 25%]
    test/test_transformers.py ...                                                    [100%]
    
    =================================== 4 passed in 0.10s ==================================
    pytest --snowflake-session local  1.37s user 1.70s system 281% cpu 1.093 total
    

了解详情

您已成功完成!做得很好。

在本教程中,您可以从端到端角度了解如何测试 Python Snowpark 代码。在此过程中,您执行了以下操作:

语言: 中文