# Generates sample data for the inventory forecasting custom template.
# Creates two tables, one used by the provider, one used by the consumer.
# Run this in both accounts to try out the sample SQL worksheets.
# Upload this python notebook into your provider and consumer accounts to generate sample data to use
# for the inventory forecasting provider and consumer code.
# Set 'Handler = Main' and 'Return type = String' in the worksheet settings.
# For details about the features and use cases of this template, see the [Snowflake documentation
# topic for this template](https://docs.snowflake.com/user-guide/cleanrooms/inventory-forecasting-template).
# Add values for placeholders where indicated.

import pandas as pd
import numpy as np
from datetime import datetime, timedelta

import snowflake.snowpark

# Set DATABASE_NAME and SCHEMA_NAME to a database and schema where you have write privileges.
DATABASE_NAME = "<database_name>"
SCHEMA_NAME = "<schema_name>"
PROVIDER_TABLE = "INVENTORY_PROVIDER_SALES_HISTORY"
CONSUMER_TABLE = "INVENTORY_CONSUMER_STOCK_LEVELS"

NUM_PRODUCTS = 500
NUM_STORES = 10
HISTORY_DAYS = 365 * 2 
BASE_SALES_MEAN = 50
PROMO_LIFT_MULTIPLIER = 3.5 
WEEKEND_LIFT_MULTIPLIER = 1.5 

def generate_data():
    """Generates related provider sales history and consumer stock levels."""
    
    print(f"Generating for {NUM_PRODUCTS} products across {NUM_STORES} stores...")
    
    # Create the master list of product/store combinations
    product_ids = [f'PROD_{i:04d}' for i in range(1, NUM_PRODUCTS + 1)]
    store_ids = [f'STORE_{i:03d}' for i in range(1, NUM_STORES + 1)]
    
    master_list = pd.DataFrame(
        [(p, s) for p in product_ids for s in store_ids],
        columns=['PRODUCT_ID', 'STORE_ID']
    )
    
    print(f"Generating {HISTORY_DAYS} days of sales history...")
    provider_dfs = []
    base_date = datetime.now().date() - timedelta(days=HISTORY_DAYS)
    date_range = [base_date + timedelta(days=x) for x in range(HISTORY_DAYS)]
    
    for _, row in master_list.iterrows():
        daily_data = []
        for d in date_range:
            base_sales = max(0, int(np.random.normal(BASE_SALES_MEAN, 10)))
            
            was_on_promo = (d.day % 20 == 0) 
            is_weekend = d.weekday() >= 5 
            
            if was_on_promo:
                base_sales *= PROMO_LIFT_MULTIPLIER
            elif is_weekend:
                base_sales *= WEEKEND_LIFT_MULTIPLIER
                
            daily_data.append({
                'PRODUCT_ID': row['PRODUCT_ID'],
                'STORE_ID': row['STORE_ID'],
                'SALES_DATE': d.isoformat(),
                'UNITS_SOLD': int(base_sales),
                'WAS_ON_PROMOTION': was_on_promo
            })
        provider_dfs.append(pd.DataFrame(daily_data))
        
    provider_df = pd.concat(provider_dfs, ignore_index=True)
    print(f"Generated {len(provider_df)} historical sales records.")

    print("Generating current consumer inventory and promo plan...")
    consumer_data = []
    promo_date = datetime.now().date() + timedelta(days=14) 
    
    for _, row in master_list.iterrows():
        consumer_data.append({
            'PRODUCT_ID': row['PRODUCT_ID'],
            'STORE_ID': row['STORE_ID'],
            'CURRENT_INVENTORY': int(np.random.normal(BASE_SALES_MEAN * 7, 20)), 
            'UPCOMING_PROMOTION_DATE': promo_date.isoformat()
        })
        
    consumer_df = pd.DataFrame(consumer_data)
    print(f"Generated {len(consumer_df)} consumer inventory records.")
    
    return provider_df, consumer_df

def main(session: snowflake.snowpark.Session):
    
    try:
        provider_df, consumer_df = generate_data()
        
        print(f"Writing provider data to {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE}...")
        session.write_pandas(
            provider_df, 
            PROVIDER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {PROVIDER_TABLE}.")
        
        print(f"Writing consumer data to {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE}...")
        session.write_pandas(
            consumer_df, 
            CONSUMER_TABLE, 
            database=DATABASE_NAME,
            schema=SCHEMA_NAME,
            auto_create_table=True, 
            overwrite=True
        )
        print(f"Success! Created {CONSUMER_TABLE}.")
        
        return f"Successfully created tables: {DATABASE_NAME}.{SCHEMA_NAME}.{PROVIDER_TABLE} and {DATABASE_NAME}.{SCHEMA_NAME}.{CONSUMER_TABLE}"

    except Exception as e:
        print(f"\n--- ERROR during Snowpark execution ---")
        print(e)
        return f"Failed with error: {e}"
