"""Lookalike model scoring job for ML Jobs in Data Clean Rooms.

This script runs inside a container on a compute pool. It loads the trained
model from the cleanroom table written by the training job, scores the full
publisher population, and writes scored results to a cleanroom table for
activation.

The script is submitted via the ML Jobs code spec and receives configuration
through the --args JSON parameter, which includes source_table references
resolved by the clean room at runtime.

Usage:
  Staged to a Snowflake internal stage and referenced in the ML Jobs code spec.
  Not intended to be run directly.
"""

import argparse
import json

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn import preprocessing
import pickle
import codecs

from snowflake.snowpark.context import get_active_session


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--args", type=str, default="{}")
    parsed = parser.parse_args()

    args = json.loads(parsed.args) if parsed.args else {}
    score_threshold = args.get("score_threshold", 0.5)

    session = get_active_session()

    source_tables = args.get("source_table", [])
    if len(source_tables) < 1:
        raise ValueError(
            f"Expected at least 1 source table (publisher features), got {len(source_tables)}"
        )

    print("Loading trained model from cleanroom.lal_mljobs_model...")
    model_row = session.table("cleanroom.lal_mljobs_model").to_pandas()
    serialized = model_row["MODEL_DATA"].iloc[0]
    model, ohe, feature_cols = pickle.loads(codecs.decode(serialized.encode(), "base64"))
    print("Model loaded successfully")

    print(f"Loading publisher features from: {source_tables[0]}")
    pub_df = session.table(source_tables[0]).to_pandas()
    print(f"Scoring {len(pub_df)} users...")

    user_ids = pub_df["HASHED_EMAIL"].values
    X = pub_df[feature_cols].copy()

    for col in X.columns:
        try:
            X[col] = pd.to_numeric(X[col])
        except (ValueError, TypeError):
            X[col] = X[col].astype(str)

    categorical = X.select_dtypes(include=["object", "string"])
    if not categorical.empty:
        cat_encoded = ohe.transform(categorical).toarray()
    else:
        cat_encoded = np.empty((len(X), 0))

    non_categorical = X.select_dtypes(exclude=["object", "string"])
    X_pred = np.concatenate(
        (cat_encoded, non_categorical.to_numpy(dtype=float)), axis=1
    )

    dmatrix = xgb.DMatrix(X_pred)
    scores = model.predict(dmatrix)

    results_df = pd.DataFrame({
        "HASHED_EMAIL": user_ids,
        "SCORE": scores.astype(float),
    })

    mask = results_df["SCORE"] > score_threshold
    audience_size = int(mask.sum())
    avg_score = float(results_df.loc[mask, "SCORE"].mean()) if audience_size > 0 else 0.0

    print(f"Found {audience_size} users above threshold {score_threshold}")
    print(f"Average score: {avg_score:.4f}")

    sp_results = session.create_dataframe(results_df)
    sp_results.write.save_as_table("cleanroom.lal_mljobs_scored_results", mode="overwrite")
    print("Scored results written to cleanroom.lal_mljobs_scored_results")

    result = {
        "status": "completed",
        "total_scored": len(user_ids),
        "audience_size": audience_size,
        "avg_score": avg_score,
        "score_threshold": score_threshold,
    }
    print(json.dumps(result))


if __name__ == "__main__":
    main()
