from rtree import index
import numpy as np
import time
import glob
from shapely.geometry import Polygon, Point
import csv
import pandas as pd
import xarray as xr
import geopandas as gpd
import os
from scipy.interpolate import interp1d
import random

# Load in Randall curve
curve_data = pd.read_csv('/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Randall_comp_curves/Amil_cca_curve_SB.csv')
larval_age = curve_data['LarvalAge'].values
P_t = curve_data['P_t'].values
P_t_interpolator = interp1d(larval_age, P_t, fill_value="extrapolate")


def process_file(file_path, gdf_all, idx, Tmin, num_days=30, num_hours_per_day=24):
    print(f"Starting from Tmin Yobama {Tmin}", flush=True)
    data_xarray = xr.open_dataset(file_path, mode='r')
    lon = data_xarray['lon'].values # i have also tried .astype(np.float32) but not significantly faster
    lat = data_xarray['lat'].values

    lon[lon == -999] = np.nan # set beached particles to NAN
    lat[lat == -999] = np.nan # set beached particles to NAN
    
    print(f"NaN count at t=0: {np.isnan(lon[:, 0]).sum()}, {np.isnan(lat[:, 0]).sum()}", flush=True)
    print(f"-999 count at t=0: {(lon[:, 0] == -999).sum()}, {(lat[:, 0] == -999).sum()}", flush=True)

    num_particles, num_timesteps = lon.shape
        
    valid_particles = ~np.isnan(lon[:, 0]) & ~np.isnan(lat[:, 0]) & (lon[:, 0] != -999) & (lat[:, 0] != -999)
    lon = lon[valid_particles, :]
    lat = lat[valid_particles, :]
    num_particles, num_timesteps = lon.shape

    

    ##############################################################################
    # Select only 10000 particles for debugging
    num_particles_to_debug = 10000
    if num_particles > num_particles_to_debug:
       sampled_indices = np.random.choice(num_particles, num_particles_to_debug, replace=False)
       lon = lon[sampled_indices, :]
       lat = lat[sampled_indices, :]
     
    print(f"Debugging with {lon.shape[0]} particles", flush=True)

    num_particles, num_timesteps = lon.shape
    print(f"num of particles {num_particles}", flush=True)
    ###########################################################################

    num_polygons = len(gdf_all)
    transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)
    gdf_all = gdf_all.sort_values(by='id')

    print(f"Finding initial polys", flush=True)
    initial_polygons = []
    for i in range(num_particles):
        point_T0 = Point(lon[i, 0], lat[i, 0])
        initial_polygon = next((j for j in idx.intersection(point_T0.bounds) if point_T0.intersects(gdf_all.geometry[j])), None)
        initial_polygons.append(initial_polygon)
    
    print(f"Valid particles at t=0 w initial polygons: {num_particles}", flush=True)

    # Initialize the 3D numpy array to hold the transition matrices for each day
    all_daily_matrices = []

    # Loop over each timestep -- Section of the code which takes the longest
    for t in range(int(Tmin), num_days * 24):
        if t >= num_timesteps:
            break
        
        if t % 24 == 0:
            print(f"Processing day {t // 24}", flush=True)

        # Initialize a new transitions matrix for this timestep
        transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)      

        # Loop over each particle        
        for i in range(num_particles):
            point_current = Point(lon[i, t], lat[i, t])
            
            if np.isnan(lon[i, t]) or np.isnan(lat[i, t]):  # if lon lat is not nan continue, otherwise it's beached or out of domain
                continue
 

            probability = P_t_interpolator(t / 24) # Check probability of settlement
            
            if np.random.rand() <= probability:
                for k in idx.intersection(point_current.bounds):
                    if point_current.within(gdf_all.geometry[k]):
                        transitions_matrix[initial_polygons[i], k] += 1
                        break  # exit once the sink polygon is found

        
        
        # Check if the transitions_matrix has a consistent shape
        print(f"Shape of transitions_matrix for day {t//24}: {transitions_matrix.shape}")
    
        # Append the transitions matrix for this timestep to the daily 3D array
        all_daily_matrices.append(transitions_matrix)

        # Once all 24 timesteps for a day are processed, combine into a 3D array and output
        if (t + 1) % 24 == 0:  # End of the day
            daily_array = np.stack(all_daily_matrices, axis=2)
            print(f"Shape of daily_array for day {t//24}: {daily_array.shape}")

            # Save the daily 3D numpy array
            daily_filename = f"{pathout}/transitions_matrix_day{t//24}_{year}.npy"
            np.save(daily_filename, daily_array)
            print(f"Saved daily transitions matrix for day {t//24} to {daily_filename}", flush=True)

           
            # Reset for the next day
            all_daily_matrices = []

    return transitions_matrix, num_timesteps



# Get most recent Tmin so can restart the run from the last outputted day
def get_most_recent_Tmin(pathout, year):
    # Find all the transition matrix files for the given year
    file_pattern = os.path.join(pathout, f'transitions_matrix_day*_{year}.npy')
    files = glob.glob(file_pattern)

    if not files:
        return 0  # If no files are found, start from day 0

    # Sort the files based on the day extracted from the filename
    files.sort(key=lambda x: int(x.split('day')[1].split('_')[0]))  # Extract day number and sort

    # Get the most recent file and determine the day (Tmin)
    most_recent_file = files[-1]
    most_recent_day = int(most_recent_file.split('day')[1].split('_')[0])

    # Set Tmin as the most recent day + 1 (to restart from the next day)
    Tmin = (most_recent_day + 1)*24
    print(f"Most recent file: {most_recent_file}, Tmin set to: {Tmin}")

    return Tmin


# Main to run the code
if __name__ == "__main__":
    # Path to netcdf particle tracking files
    base_path = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/SensitivityAnalysis/pout/'
    pathout = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/conmat/'

    # just running for a single year (i have 30 years to run!)
    years = range(1996,1997)
    
    # Reef polygon shapefile
    gdf_all = gpd.read_file("/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Coral_communities/Regions_w_geom_FINAL_crop.shp")
    gdf_all = gdf_all.sort_values(by='id')
    idx = index.Index((j, geom.bounds, None) for j, geom in enumerate(gdf_all.geometry))

    # Process each file sequentially
    for year in years:
        Tmin = get_most_recent_Tmin(pathout, year)
        print(f"Using Tmin: {Tmin} for year {year}", flush = True)
        file_path = f'{base_path}/ParcelsOut_Diffusion_{year}.nc'
        transitions_matrix, num_timesteps = process_file(file_path, gdf_all, idx, Tmin=Tmin)  # Capture both outputs

    print(f"All transition matrices saved individually to {pathout}.")
