snowflake.ml.experiment.ExperimentTracking

class snowflake.ml.experiment.ExperimentTracking(session: Session, *, database_name: Optional[str] = None, schema_name: Optional[str] = None)

Bases: object

Class to manage experiments in Snowflake.

Initializes experiment tracking within a pre-created schema.

Parameters:
  • session – The Snowpark Session to connect with Snowflake.

  • database_name – The name of the database. If None, the current database of the session will be used. Defaults to None.

  • schema_name – The name of the schema. If None, the current schema of the session will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.

Raises:

ValueError – If no database is provided and no active database exists in the session.

Methods

delete_experiment(experiment_name: str) None

Delete an experiment.

Parameters:

experiment_name – The name of the experiment.

delete_run(run_name: str) None

Delete a run.

Parameters:

run_name – The name of the run to be deleted.

Raises:

RuntimeError – If no experiment is set.

download_artifacts(run_name: str, artifact_path: Optional[str] = None, target_path: Optional[str] = None) None

Download artifacts from a run to a local directory.

Parameters:
  • run_name – Name of the run to download artifacts from.

  • artifact_path – Optional path to file or subdirectory within the run’s artifact directory to download. If None, downloads all artifacts from the root of the run’s artifact directory.

  • target_path – Optional local directory to download files into. If None, downloads into the current working directory.

Raises:

RuntimeError – If no experiment is currently set.

end_run(run_name: Optional[str] = None) None

End the current run if no run name is provided. Otherwise, the specified run is ended.

Parameters:

run_name – The name of the run to be ended. If None, the current run is ended.

Raises:

RuntimeError – If no run is active.

list_artifacts(run_name: str, artifact_path: Optional[str] = None) list[snowflake.ml.experiment._client.artifact.ArtifactInfo]

List artifacts for a given run within the current experiment.

Parameters:
  • run_name – Name of the run to list artifacts from.

  • artifact_path – Optional subdirectory within the run’s artifact directory to scope the listing. If None, lists from the root of the run’s artifact directory.

Returns:

A list of artifact entries under the specified path.

Raises:

RuntimeError – If no experiment is currently set.

log_artifact(local_path: str, artifact_path: Optional[str] = None) None

Log an artifact or a directory of artifacts under the current run. If no run is active, this method will create a new run.

Parameters:
  • local_path – The path to the local file or directory to write.

  • artifact_path – The directory within the run directory to write the artifacts to. If None, the artifacts will be logged in the root directory of the run.

log_metric(key: str, value: float, step: int = 0) None

Log a metric under the current run. If no run is active, this method will create a new run.

Parameters:
  • key – The name of the metric.

  • value – The value of the metric.

  • step – The step of the metric. Defaults to 0.

log_metrics(metrics: dict[str, float], step: int = 0) None

Log metrics under the current run. If no run is active, this method will create a new run.

Parameters:
  • metrics – Dictionary containing metric keys and float values.

  • step – The step of the metrics. Defaults to 0.

log_model(model: Union[catboost.CatBoost, lightgbm.LGBMModel, lightgbm.Booster, prophet.Prophet, CustomModel, sklearn.base.BaseEstimator, sklearn.pipeline.Pipeline, xgboost.XGBModel, xgboost.Booster, torch.nn.Module, torch.jit.ScriptModule, tensorflow.Module, keras.Model, base.BaseEstimator, mlflow.pyfunc.PyFuncModel, transformers.Pipeline, sentence_transformers.SentenceTransformer, HuggingFacePipelineModel, ModelVersion], *, model_name: str, version_name: Optional[str] = None, comment: Optional[str] = None, metrics: Optional[dict[str, Any]] = None, conda_dependencies: Optional[list[str]] = None, pip_requirements: Optional[list[str]] = None, artifact_repository_map: Optional[dict[str, str]] = None, resource_constraint: Optional[dict[str, str]] = None, target_platforms: Optional[list[Union[snowflake.ml.model.target_platform.TargetPlatform, str]]] = None, python_version: Optional[str] = None, signatures: Optional[dict[str, ModelSignature]] = None, sample_input_data: DataFrame]] = None, user_files: Optional[dict[str, list[str]]] = None, code_paths: Optional[list[str]] = None, ext_modules: Optional[list[module]] = None, task: Task = Task.UNKNOWN, options: Optional[Union[BaseModelSaveOption, CatBoostModelSaveOptions, CustomModelSaveOption, LGBMModelSaveOptions, ProphetSaveOptions, SKLModelSaveOptions, XGBModelSaveOptions, SNOWModelSaveOptions, PyTorchSaveOptions, TorchScriptSaveOptions, TensorflowSaveOptions, MLFlowSaveOptions, HuggingFaceSaveOptions, SentenceTransformersSaveOptions, KerasSaveOptions]] = None) ModelVersion

Log a model with various parameters and metadata, or a ModelVersion object.

Parameters:
  • model – Supported model or ModelVersion object. - Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML, PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers, or Custom Model. - ModelVersion: Source ModelVersion object used to create the new ModelVersion object.

  • model_name – Name to identify the model. This must be a valid Snowflake SQL Identifier. Alphanumeric characters and underscores are permitted. See https://docs.snowflake.cn/en/sql-reference/identifiers-syntax for more.

  • version_name – Version identifier for the model. Combination of model_name and version_name must be unique. If not specified, a random name will be generated.

  • comment – Comment associated with the model version. Defaults to None.

  • metrics – A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.

  • conda_dependencies – List of Conda package specifications. Use “[channel::]package [operator version]” syntax to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is not specified, Snowflake Anaconda Channel will be used. Defaults to None.

  • pip_requirements – List of Pip package specifications. Defaults to None. Models running in a Snowflake Warehouse must also specify a pip artifact repository (see artifact_repository_map). Otherwise, models with pip requirements are runnable only in Snowpark Container Services. See https://docs.snowflake.cn/en/developer-guide/snowflake-ml/model-registry/container for more.

  • artifact_repository_map

    Specifies a mapping of package channels or platforms to custom artifact repositories. Defaults to None. Currently, the mapping applies only to Warehouse execution. Note : This feature is currently in Public Preview. Format: {channel_name: artifact_repository_name}, where:

    • channel_name: Currently must be ‘pip’.

    • artifact_repository_name: The identifier of the artifact repository to fetch packages from, e.g. snowflake.snowpark.pypi_shared_repository.

  • resource_constraint – Mapping of resource constraint keys and values, e.g. {“architecture”: “x86”}.

  • target_platforms

    List of target platforms to run the model. The only acceptable inputs are a combination of “WAREHOUSE” and “SNOWPARK_CONTAINER_SERVICES”, or a target platform constant: - [“WAREHOUSE”] or snowflake.ml.model.target_platform.WAREHOUSE_ONLY (Warehouse only) - [“SNOWPARK_CONTAINER_SERVICES”] or

    snowflake.ml.model.target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY (Snowpark Container Services only)

    • [“WAREHOUSE”, “SNOWPARK_CONTAINER_SERVICES”] or snowflake.ml.model.target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES (Both)

    Defaults to None. When None, the target platforms will be both.

  • python_version – Python version in which the model is run. Defaults to None.

  • signatures – Model data signatures for inputs and outputs for various target methods. If it is None, sample_input_data would be used to infer the signatures for those models that cannot automatically infer the signature. If not None, sample_input_data should not be specified. Defaults to None.

  • sample_input_data – Sample input data to infer model signatures from. It would also be used as background data in explanation and to capture data lineage. Defaults to None.

  • user_files – Dictionary where the keys are subdirectories, and values are lists of local file name strings. The local file name strings can include wildcards (? or *) for matching multiple files.

  • code_paths – List of directories containing code to import. Defaults to None.

  • ext_modules – List of external modules to pickle with the model object. Only supported when logging the following types of model: Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.

  • task – The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION, TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default, it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object.

  • options (Dict[str, Any], optional) –

    Additional model saving options.

    Model Saving Options include:

    • embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.

      Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda Channel. Otherwise, defaults to False

    • relax_version: Whether to relax the version constraints of the dependencies when running in the

      Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.

    • function_type: Set the method function type globally. To set method function types individually see function_type in model_options.

    • volatility: Set the volatility for all model methods globally (use Volatility.VOLATILE or Volatility.IMMUTABLE). Volatility.VOLATILE functions may return different results for the same arguments, while Volatility.IMMUTABLE functions always return the same result for the same arguments. Defaults are set automatically based on model type: supported models (sklearn, xgboost, pytorch, huggingface_pipeline, mlflow, etc.) default to IMMUTABLE, while custom models default to VOLATILE. Individual method volatility can be set in method_options and will override this global setting.

    • target_methods: List of target methods to register when logging the model. This option is not used in MLFlow models. Defaults to None, in which case the model handler’s default target methods will be used.

    • save_location: Location to save the model and metadata.

    • method_options: Per-method saving options. This dictionary has method names as keys and dictionary

      values with the desired options. See the example below.

      The following are the available method options:

      • case_sensitive: Indicates whether the method and its signature should be case sensitive.

        This means when you refer the method in the SQL, you need to double quote it. This will be helpful if you need case to tell apart your methods or features, or you have non-alphabetic characters in your method or feature name. Defaults to False.

      • max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.

        Defaults to None, determined automatically by Snowflake.

      • function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).

      • volatility: Volatility level for the function (use Volatility.VOLATILE or Volatility.IMMUTABLE).

        Volatility.VOLATILE functions may return different results for the same arguments, while Volatility.IMMUTABLE functions always return the same result for the same arguments. This per-method setting overrides any global volatility setting. Defaults to None (no volatility specified).

Raises:
  • ValueError – If extra arguments are specified ModelVersion is provided.

  • Exception – If the model logging fails.

Returns:

ModelVersion object corresponding to the model just logged.

Return type:

ModelVersion

Example:

from snowflake.ml.registry import Registry

# create a session
session = ...

registry = Registry(session=session)

# Define `method_options` for each inference method if needed.
method_options={
  "predict": {
    "case_sensitive": True
  }
}

registry.log_model(
  model=model,
  model_name="my_model",
  options={"method_options": method_options},
)
Copy
log_param(key: str, value: Any) None

Log a parameter under the current run. If no run is active, this method will create a new run.

Parameters:
  • key – The name of the parameter.

  • value – The value of the parameter. Values can be of any type, but will be converted to string.

log_params(params: dict[str, Any]) None

Log parameters under the current run. If no run is active, this method will create a new run.

Parameters:

params – Dictionary containing parameter keys and values. Values can be of any type, but will be converted to string.

set_experiment(experiment_name: str) Experiment

Set the experiment in context. Creates a new experiment if it doesn’t exist.

Parameters:

experiment_name – The name of the experiment.

Returns:

The experiment that was set.

Return type:

Experiment

start_run(run_name: Optional[str] = None) Run

Start a new run. If a run name of an existing run is provided, resumes the run if it is running.

Parameters:

run_name – The name of the run. If None, a default name will be generated.

Returns:

The run that was started or resumed.

Return type:

Run

Raises:

RuntimeError – If a run is already active. If a run with the same name exists but is not running.

Language: English