Sentence Transformer¶
Snowflake Model Registry 支持使用 Sentence Transformers (sentence_transformers.SentenceTransformer) 的模型。有关更多信息,请参阅 Sentence Transformer 文档 (https://sbert.net/)。
要使注册表知道目标方法的签名,您必须指定示例输入数据或用于定义模型方法的输入和输出架构的签名。
对于示例输入数据,请指定 Snowpark DataFrame 作为 sample_input_data 参数的值。例如,您可以指定一个值,例如 sample_input = pd.DataFrame(["This is a sample sentence."], columns=["TEXT"])。
如果您使用的是签名参数,请指定一个字典作为 signatures 参数的值。该字典定义模型的输入和输出方法。例如,以下代码定义模型的 encode 方法的输入和输出架构:
from snowflake.ml.model.model_signature import ModelSignature, FeatureSpec, DataType
signatures = {
"encode": ModelSignature(
inputs=[FeatureSpec(dtype=DataType.STRING, name='TEXT')],
outputs=[FeatureSpec(dtype=DataType.FLOAT, name='EMBEDDINGS', shape=(-1,))]
)
}
调用 log_model 时,您可以使用 options 字典中的以下附加选项:
选项 |
描述 |
|---|---|
|
可在模型对象上使用的方法的名称列表。默认情况下,Sentence Transformer 模型具有以下目标方法(假设方法存在): |
|
部署到具有 GPU 的平台时使用的 CUDA 运行时版本;默认值为 11.8。如果手动设置为 |
以下示例:
加载预先训练的 Sentence Transformer 模型。
将其记录到 Snowflake ML Model Registry。
使用记录的模型进行推理。
备注
在示例中,reg 是 snowflake.ml.registry.Registry 的实例。有关创建注册表对象的信息,请参阅 Snowflake Model Registry。
from sentence_transformers import SentenceTransformer
import pandas as pd
# 1. Initialize the model
# This example uses the 'all-MiniLM-L6-v2' model, which is a popular
# and efficient model for generating sentence embeddings.
model = SentenceTransformer('all-MiniLM-L6-v2')
# 2. Prepare sample input data
# Sentence Transformers expect a single column of text data for the 'encode' method.
sentences = ["This is an example sentence", "Each sentence is converted into a vector"]
sample_input = pd.DataFrame(sentences, columns=["TEXT"])
# 3. Log the model
# Provide the model object, a name, and a version.
# Including sample_input_data allows the registry to infer the input/output signatures.
model_ref = reg.log_model(
model=model,
model_name="my_sentence_transformer",
version_name="v1",
sample_input_data=sample_input,
)
# 4. Use the model for inference
# The 'run' method executes the default 'encode' function on the input data.
result_df = model_ref.run(sample_input, function_name="encode")
# The result is a DataFrame where the output column (usually named 'outputs')
# contains the embeddings as arrays of floats.
print(result_df)