# Generates sample data for the lookalike audience modeling collaboration.
# Creates two tables, one used by the publisher (provider features), one used by the advertiser (seed audience).
# Run this in both accounts to try out the sample SQL worksheets.
# Upload this Python notebook into your publisher and advertiser accounts to generate sample data.
# Set 'Handler = Main' and 'Return type = String' in the worksheet settings.
# For details about the features and use cases of this template, see the Snowflake documentation
# topic: https://docs.snowflake.com/user-guide/cleanrooms/collab-lookalike-modeling
# Replace values in angle brackets with your own values.

import pandas as pd
import numpy as np
import random
import hashlib
import base64

import snowflake.snowpark

# Set DATABASE_NAME and SCHEMA_NAME to a database and schema where you have write privileges.
DATABASE_NAME = "<database_name>"
SCHEMA_NAME = "<schema_name>"
PUBLISHER_TABLE = "LAL_PUBLISHER_FEATURES"
ADVERTISER_TABLE = "LAL_ADVERTISER_SEED_AUDIENCE"

NUM_USERS = 50000
SEED_SIZE = 2000
HIGH_VALUE_THRESHOLD = 2500

def generate_hashed_email(unique_id):
    salt = "my_lal_collab_salt_v1"
    email_string = f"{salt}_user_{unique_id}@example.com"
    sha_hash = hashlib.sha256(email_string.encode('utf-8')).digest()
    base64_hash = base64.b64encode(sha_hash).decode('utf-8')
    return base64_hash

def generate_data():

    print(f"Generating data for {NUM_USERS} total users...")

    publisher_data = []
    statuses = ['MEMBER', 'FORMER_MEMBER', 'NON_MEMBER']
    age_bands = [1, 2, 3, 4, 5]
    regions = [f'REGION_CODE_{i}' for i in range(1, 11)]

    for i in range(NUM_USERS):
        hashed_email = generate_hashed_email(i)
        region = random.choice(regions)
        age = random.choice(age_bands)

        if region == 'REGION_CODE_1' or age == 3:
            days_active = int(np.random.normal(150, 40))
        else:
            days_active = int(np.random.normal(80, 20))

        publisher_data.append({
            'RID': i,
            'HASHED_EMAIL': hashed_email,
            'STATUS': random.choice(statuses),
            'AGE_BAND': age,
            'REGION_CODE': region,
            'DAYS_ACTIVE': max(1, days_active)
        })

    publisher_df = pd.DataFrame(publisher_data)
    print(f"Generated {len(publisher_df)} publisher feature records.")

    seed_users_df = publisher_df.sample(n=SEED_SIZE)
    advertiser_data = []

    for _, row in seed_users_df.iterrows():
        sales = int(np.random.normal(500, 100))

        if row['REGION_CODE'] == 'REGION_CODE_1':
            sales += int(np.random.normal(1500, 200))
        if row['AGE_BAND'] == 3:
            sales += int(np.random.normal(1000, 200))
        if row['DAYS_ACTIVE'] > 120:
            sales += int(np.random.normal(300, 50))

        segment = 'HIGH_VALUE' if sales >= HIGH_VALUE_THRESHOLD else 'LOW_VALUE'

        advertiser_data.append({
            'RID': row['RID'],
            'HASHED_EMAIL': row['HASHED_EMAIL'],
            'SALES_DLR': max(1, sales),
            'AUDIENCE_SEGMENT': segment
        })

    advertiser_df = pd.DataFrame(advertiser_data)
    print(f"Generated {len(advertiser_df)} advertiser seed audience records.")

    return publisher_df, advertiser_df

def main(session: snowflake.snowpark.Session):

    try:
        publisher_df, advertiser_df = generate_data()

        print(f"Writing publisher data to {DATABASE_NAME}.{SCHEMA_NAME}.{PUBLISHER_TABLE}...")
        session.write_pandas(
            publisher_df,
            PUBLISHER_TABLE,
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True,
            overwrite=True
        )
        print(f"Success! Created {PUBLISHER_TABLE}.")

        print(f"Writing advertiser data to {DATABASE_NAME}.{SCHEMA_NAME}.{ADVERTISER_TABLE}...")
        session.write_pandas(
            advertiser_df,
            ADVERTISER_TABLE,
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True,
            overwrite=True
        )
        print(f"Success! Created {ADVERTISER_TABLE}.")

        return f"Successfully created tables: {DATABASE_NAME}.{SCHEMA_NAME}.{PUBLISHER_TABLE} and {DATABASE_NAME}.{SCHEMA_NAME}.{ADVERTISER_TABLE}"

    except Exception as e:
        print(f"\n--- ERROR during Snowpark execution ---")
        print(e)
        return f"Failed with error: {e}"
