# Generates sample data for the audience lookalike modeling custom template.
# Creates two tables, one used by the provider, one used by the consumer.
# Run this in both accounts to try out the sample SQL worksheets.
# Upload this python notebook into your provider and consumer accounts to generate sample data to use
# for the inventory forecasting provider and consumer code.
# Set 'Handler = Main' and 'Return type = String' in the worksheet settings.
# For details about the features and use cases of this template, see the [Snowflake documentation
# topic for this template](https://docs.snowflake.com/user-guide/cleanrooms/lookalike-audience-modeling-template.).
# Add values for placeholders where indicated.

import pandas as pd
import numpy as np
import random
import hashlib
import base64

import snowflake.snowpark

# Set DATABASE_NAME and SCHEMA_NAME to a database and schema where you have write privileges.
DATABASE_NAME = "<database_name>"
SCHEMA_NAME = "<schema_name>"
PROVIDER_TABLE = "LAL_PROVIDER_FEATURES"
CONSUMER_TABLE = "LAL_CONSUMER_SEED_AUDIENCE"

NUM_USERS = 50000 
CONSUMER_SEED_SIZE = 2000 
HIGH_VALUE_THRESHOLD = 2500

def generate_hashed_email(unique_id):
    
    salt = "my_lal_salt_v2"
    email_string = f"{salt}_user_{unique_id}@example.com"
    sha_hash = hashlib.sha256(email_string.encode('utf-8')).digest()
    base64_hash = base64.b64encode(sha_hash).decode('utf-8')
    return base64_hash

def generate_data():
        
    print(f"Generating data for {NUM_USERS} total users...")
    
    provider_data = []
    statuses = ['MEMBER', 'FORMER_MEMBER', 'NON_MEMBER']
    age_bands = [1, 2, 3, 4, 5] # e.g., 18-24, 25-34, etc.
    regions = [f'REGION_CODE_{i}' for i in range(1, 11)]
    
    for i in range(NUM_USERS):
        hashed_email = generate_hashed_email(i)
        
        region = random.choice(regions)
        age = random.choice(age_bands)
        
        if region == 'REGION_CODE_1' or age == 3:
            
            days_active = int(np.random.normal(150, 40))
        else:
            days_active = int(np.random.normal(80, 20))

        provider_data.append({
            'RID': i,
            'HASHED_EMAIL': hashed_email,
            'STATUS': random.choice(statuses),
            'AGE_BAND': age,
            'REGION_CODE': region,
            'DAYS_ACTIVE': max(1, days_active) 
        })
        
    provider_df = pd.DataFrame(provider_data)
    print(f"Generated {len(provider_df)} provider feature records.")

    seed_users_df = provider_df.sample(n=CONSUMER_SEED_SIZE)
    consumer_data = []
    
    for _, row in seed_users_df.iterrows():

        sales = int(np.random.normal(500, 100)) 
        
        if row['REGION_CODE'] == 'REGION_CODE_1':
            sales += int(np.random.normal(1500, 200)) 
        if row['AGE_BAND'] == 3:
            sales += int(np.random.normal(1000, 200)) 
        if row['DAYS_ACTIVE'] > 120:
            sales += int(np.random.normal(300, 50))
            
        segment = 'HIGH_VALUE' if sales >= HIGH_VALUE_THRESHOLD else 'LOW_VALUE'

        consumer_data.append({
            'RID': row['RID'],
            'EMAIL_ID': row['RID'] % 100, 
            'HASHED_EMAIL': row['HASHED_EMAIL'],
            'SALES_DLR': max(1, sales), 
            'AUDIENCE_SEGMENT': segment
        })

    consumer_df = pd.DataFrame(consumer_data)
    print(f"Generated {len(consumer_df)} consumer seed audience records.")
    
    return provider_df, consumer_df

def main(session: snowflake.snowpark.Session):
   
    try:
        provider_df, consumer_df = generate_data()
        
        print(f"Writing provider data to {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE}...")
        session.write_pandas(
            provider_df, 
            PROVIDER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {PROVIDER_TABLE}.")
        
        print(f"Writing consumer data to {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE}...")
        session.write_pandas(
            consumer_df, 
            CONSUMER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {CONSUMER_TABLE}.")
        
        return f"Successfully created tables: {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE} and {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE}"

    except Exception as e:
        print(f"\n--- ERROR during Snowpark execution ---")
        print(e)
        return f"Failed with error: {e}"
