"""Lookalike model training job for ML Jobs in Data Clean Rooms.

This script runs inside a container on a compute pool. It trains an XGBoost
binary classifier on the publisher's feature data joined with the advertiser's
seed audience, then saves the serialized model to a cleanroom table.

The script is submitted via the ML Jobs code spec and receives configuration
through the --args JSON parameter, which includes source_table references
resolved by the clean room at runtime.

Usage:
  Staged to a Snowflake internal stage and referenced in the ML Jobs code spec.
  Not intended to be run directly.
"""

import argparse
import json

import numpy as np
import pandas as pd
import xgboost
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import pickle
import codecs

from snowflake.snowpark.context import get_active_session


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--args", type=str, default="{}")
    parsed = parser.parse_args()

    args = json.loads(parsed.args) if parsed.args else {}

    session = get_active_session()

    source_tables = args.get("source_table", [])
    if len(source_tables) < 2:
        raise ValueError(
            f"Expected at least 2 source tables (publisher features + seed audience), got {len(source_tables)}"
        )

    print(f"Loading publisher features from: {source_tables[0]}")
    pub_df = session.table(source_tables[0]).to_pandas()

    print(f"Loading seed audience from: {source_tables[1]}")
    seed_df = session.table(source_tables[1]).to_pandas()

    print(f"Publisher features: {len(pub_df)} rows, Seed audience: {len(seed_df)} rows")

    seed_df["LABEL"] = (seed_df["AUDIENCE_SEGMENT"] == "HIGH_VALUE").astype(int)
    labeled = pub_df.merge(seed_df[["HASHED_EMAIL", "LABEL"]], on="HASHED_EMAIL", how="inner")

    print(f"Joined rows (seed members with features): {len(labeled)}")

    feature_cols = [c for c in pub_df.columns if c not in ("HASHED_EMAIL",)]
    X = labeled[feature_cols].copy()
    y = labeled["LABEL"].values

    for col in X.columns:
        try:
            X[col] = pd.to_numeric(X[col])
        except (ValueError, TypeError):
            X[col] = X[col].astype(str)

    categorical = X.select_dtypes(include=["object", "string"])
    ohe = preprocessing.OneHotEncoder(handle_unknown="ignore")
    if not categorical.empty:
        categorical_ohe = ohe.fit_transform(categorical).toarray()
    else:
        categorical_ohe = np.empty((len(X), 0))
        ohe.fit(np.empty((0, 0)))

    non_categorical = X.select_dtypes(exclude=["object", "string"])
    train_x = np.concatenate(
        (categorical_ohe, non_categorical.to_numpy(dtype=float)), axis=1
    )

    X_train, X_test, y_train, y_test = train_test_split(
        train_x, y, test_size=0.2, random_state=42
    )

    xg_train = xgboost.DMatrix(X_train, label=y_train)
    xg_test = xgboost.DMatrix(X_test, label=y_test)
    params = {
        "objective": "binary:logistic",
        "max_depth": 3,
        "nthread": 1,
        "eval_metric": "auc",
    }
    evals_result = {}
    model = xgboost.train(
        params,
        xg_train,
        20,
        [(xg_train, "train"), (xg_test, "test")],
        evals_result=evals_result,
    )

    train_auc = float(np.max(evals_result["train"]["auc"]))
    test_auc = float(np.max(evals_result["test"]["auc"]))
    print(f"Training complete. Train AUC: {train_auc:.4f}, Test AUC: {test_auc:.4f}")

    model_package = [model, ohe, feature_cols]
    serialized = codecs.encode(pickle.dumps(model_package), "base64").decode()

    model_df = session.create_dataframe(
        [{"MODEL_ID": "lal_model_v1", "MODEL_DATA": serialized}]
    )
    model_df.write.save_as_table("cleanroom.lal_mljobs_model", mode="overwrite")
    print("Model saved to cleanroom.lal_mljobs_model")

    result = {
        "status": "completed",
        "num_samples": len(pub_df),
        "seed_size": len(seed_df),
        "train_auc": train_auc,
        "test_auc": test_auc,
    }
    print(json.dumps(result))


if __name__ == "__main__":
    main()
