scikit-learn

注册表支持使用 scikit-learn 创建的模型(从 sklearn.base.BaseEstimatorsklearn.pipeline.Pipeline 派生的模型)。

调用 options 时,可以在 log_model 字典中使用下列附加选项:

选项

描述

target_methods

模型对象上可用方法的名称列表。scikit-learn 模型默认具有以下目标方法(假设方法存在):predicttransformpredict_probapredict_log_probadecision_function

在登记 scikit-learn 模型时,您必须指定 sample_input_datasignatures 参数,以确保注册表了解目标方法的签名。

示例

from sklearn import datasets, ensemble

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
clf = ensemble.RandomForestClassifier(random_state=42)
clf.fit(iris_X, iris_y)
model_ref = registry.log_model(
    clf,
    model_name="RandomForestClassifier",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

管道:

from sklearn import datasets, ensemble, pipeline, preprocessing

iris_X, iris_y = datasets.load_iris(return_X_y=True, as_frame=True)
pipe = pipeline.Pipeline([
    ('scaler', preprocessing.StandardScaler()),
    ('classifier', ensemble.RandomForestClassifier(random_state=42)),
])
pipe.fit(iris_X, iris_y)
model_ref = registry.log_model(
    pipe,
    model_name="Pipeline",
    version_name="v1",
    sample_input_data=iris_X,
    options={
        "method_options": {
            "predict": {"case_sensitive": True},
            "predict_proba": {"case_sensitive": True},
            "predict_log_proba": {"case_sensitive": True},
        }
    },
)
model_ref.run(iris_X[-10:], function_name='"predict_proba"')
Copy

备注

您可以将 scikit-learn 预处理与 XGBoost 模型结合,作为 scikit-learn 管道。

语言: 中文