模型可解释性

训练过程中,机器学习模型会推断输入和输出之间的关系,而不是要求事先明确声明这些关系。这使得 ML 技术无需大量设置即可处理涉及许多变量的复杂情景,特别是在特定结果的因果因素复杂或不明确的情况下,但由此产生的模型可能是黑盒模型。如果一个模型表现不佳,就很难了解原因,也很难知道如何改进其性能。黑盒模型还可能掩盖隐式偏见,无法为决策提供明确的理由。金融和医疗等对可信系统有监管要求的行业,可能需要更有力的证据来证明模型是基于正确的原因得出了正确的结果。

为了帮助解决这些问题,Snowflake Model Registry 包含了一个基于 夏普利值 (https://towardsdatascience.com/the-shapley-value-for-ml-models-f1100bff78d1) 的可解释性函数。夏普利值是一种将机器学习模型的输出归因于其输入特征的方法。通过考虑所有可能的特征组合,夏普利值衡量了每个特征对模型预测的平均边际贡献。该方法确保了归因重要程度的公平性,并为理解复杂模型奠定了坚实基础。虽然计算量很大,但从夏普利值中获得的洞察对于模型的可解释性和调试非常宝贵。

例如,假设我们有一个预测房屋价格的模型,根据房屋的大小、位置、卧室数量以及是否允许养宠物等特征来训练该模型。在这个示例中,房屋的平均价格为 10 万美元,而模型对一所面积为 2000 平方英尺、位于海边、有三间卧室且不允许养宠物的房屋的最终预测价格是 25 万美元。如下表所示,每个特征值都可能对最终的模型预测有所贡献。

特征

对比房屋平均价格的贡献

大小

2000

+$50,000

位置

海边

+$75,000

卧室

3

+$50,000

宠物

-$25,000

综合这些因素,就不难理解为何这套房子的价格会比住宅平均价格高出 15 万美元了。夏普利值会对最终结果产生积极或消极的影响,从而导致结果与平均值有所差别。在本例中,房子不允许养宠物,那么居住意愿就会降低,因此这一特征值的贡献是负 25,000 美元。

平均值是使用后台数据计算得出的,是整个数据集的代表性样本。有关更多信息,请参阅 使用后台数据记录模型

支持的模型类型

本预览版支持以下 Python 原生模型包。

  • XGBoost

  • CatBoost

  • LightGBM

  • Scikit-Learn

支持以下来自 snowflake.ml.modeling 的 Snowpark ML 建模类。

  • XGBoost

  • LightGBM

  • Scikit-learn(管道模型除外)

对于使用 Snowpark ML 1.6.2 及更高版本记录的上述模型,默认支持可解释性。实施使用 SHAP 库 (https://pypi.org/project/shap/)。

使用后台数据记录模型

后台数据通常是具有代表性的数据样本,是基于夏普利值的说明的重要组成部分。后台数据为夏普利算法提供了“平均”输入的概念,可以将单个说明与之进行比较。

夏普利值是通过系统地扰动输入特征并将其替换为后台数据计算得出的。由于它报告的是与后台数据的偏差,因此在比较多个数据集的夏普利值时,必须使用一致的后台数据。

一些基于树的模型在训练过程中会在其结构中隐式编码后台数据,可能不需要明确的后台数据。然而,大多数模型为了进行有用的解释,需要单独提供后台数据,如果提供了后台数据,所有模型(包括基于树的模型)都能得到更准确的解释。

在记录模型时,您可以使用 sample_input_data 参数传递该模型,来提供多达 1,000 行的后台数据,如下所示。

备注

如果模型需要明确的后台数据来计算夏普利值,那么没有这些数据就无法实现可解释性。

mv = reg.log_model(
    catboost_model,
    model_name="diamond_catboost_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs, # xs will be used as background data
)
Copy

您还可以在用签名记录模型的同时提供后台数据,如下所示。

mv = reg.log_model(
    catboost_model,
    model_name="diamond_catboost_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    signatures={"predict": predict_signature, "predict_proba": predict_proba_signature},
    sample_input_data = xs, # xs will be used as background data
    options= {"enable_explainability": True} # you will need to set this flag in order to pass both signatures and background data
)
Copy

检索可解释性值

具有可解释性的模型具有一个名为 explain 的方法,可以返回模型特征的夏普利值。

由于夏普利值是对特定输入所做预测的解释,因此必须将输入数据传递给 explain,才能生成需要解释的预测。

Snowflake 模型版本对象有一个名为 explain 的方法,您可以使用 Python 通过 ModelVersion.run 调用它。

reg = Registry(...)
mv = reg.get_model("Explainable_Catboost_Model").default
explanations = mv.run(input_data, function_name="explain")
Copy

以下是使用 SQL 检索解释的示例。

WITH MV_ALIAS AS MODEL DATABASE.SCHEMA.DIAMOND_CATBOOST_MODEL VERSION EXPLAIN_V0
SELECT *,
      FROM DATABASE.SCHEMA.DIAMOND_DATA,
          TABLE(MV_ALIAS!EXPLAIN(CUT, COLOR, CLARITY, CARAT, DEPTH, TABLE_PCT, X, Y, Z));
Copy

重要

如果您使用的 snowflake-ml-python 版本低于 1.7.0,您可能会在使用 XGBoost 模型时遇到 UnicodeDecodeError: 'utf-8' codec can't decode byte 错误。这是因为 SHAP 库 (https://pypi.org/project/shap/) 的 0.42.1 版本与 Snowflake 支持的最新 XGBoost 版本 (2.1.1) 不兼容。如果无法将 snowflake-ml-python 升级到 1.7.0 或更高版本,请将 XGBoost 版本降级到 2.0.3,并在将 relax_version 选项设置为 False 的情况下记录模型,如以下示例所示。

mv_new = reg.log_model(
    model,
    model_name="model_with_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs,
    options={"relax_version": False}
)
Copy

为现有模型添加可解释性

使用版本低于 1.6.2 的 Snowpark ML 登录注册表的模型不具备可解释性特征。由于模型版本是不可变的,因此必须创建一个新的模型版本才能为现有模型添加可解释性。您可以使用 ModelVersion.load 来检索代表模型实施的 Python 对象,然后将其作为新的模型版本记录到注册表中。请务必以 sample_input_data 的形式传递后台数据。该方法如下所示。

重要

将模型加载到的 Python 环境必须与部署模型的环境完全相同(即 Python 的版本和所有库的版本相同)。有关详细信息,请参阅 加载模型版本

mv_old = reg.get_model("model_without_explain_enabled").default
model = mv_old.load()
mv_new = reg.log_model(
    model,
    model_name="model_with_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs
)
Copy

记录没有可解释性的模型

如果模型支持可解释性,则默认启用可解释性。要在注册表中记录没有可解释性的模型版本,可在记录模型时为 enable_explainability 选项传递 False,如下所示。

mv = reg.log_model(
    catboost_model,
    model_name="diamond_catboost_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs,
    options= {"enable_explainability": False}
)
Copy
语言: 中文