"""
Prediction Script for PyOghma ML Models

This script orchestrates the prediction process for experimental data using trained
neural network models. It supports multiple device types, experiments, and network
architectures for comprehensive analysis of photovoltaic device performance.

Usage:
    python Predict.py

The script will:
    1. Load experimental data files for specified device types and experiments
    2. Map device types to appropriate network materials
    3. Find and iterate through trained simulation directories
    4. Generate predictions using SubPredict.py or SubPredict_SunsVoc.py
    5. Save results in organized CSV format

Supported Experiments:
    - light_dark_jv: Light and dark J-V curve analysis
    - sunsvoc: Suns-Voc measurements with varying illumination intensities

Supported Device Types:
    - PM6BTPeC9: PM6:BTP-eC9 organic solar cells
    - PM6Y12_ZnO: PM6:Y12 with ZnO interlayer
    - PM6Y12_PEIZn: PM6:Y12 with PEI-Zn interlayer
    - W108-2: Reference device type

Directory Structure:
    Results/
    ├── Point/
    ├── Difference/
    └── Residual/
        └── device_name.csv

Author: Cai Williams
Email: cai.williams@physik.tu-chemnitz.de
"""

import os
import subprocess
from natsort import natsorted
import json


def get_experimental(device_type, directory, exp):
    """
    Get sorted list of experimental data files for a given device type and experiment.
    
    This function scans the experimental data directory for files matching
    the specified device type and experiment, applying appropriate file
    filtering based on the experiment type.
    
    Args:
        device_type (str): Type of device (e.g., 'PM6Y12_ZnO', 'PM6BTPeC9')
        directory (str): Base directory path containing experimental data
        exp (str): Experiment type ('light_dark_jv' or 'sunsvoc')
        
    Returns:
        list[str]: Sorted list of absolute file paths to experimental data files
        
    Raises:
        FileNotFoundError: If the experiment directory doesn't exist
        
    Example:
        >>> files = get_experimental('PM6Y12_ZnO', '/data/exp', 'light_dark_jv')
        >>> print(files)
        ['/data/exp/PM6Y12_ZnO/light_dark_jv/device1_am15.dat',
         '/data/exp/PM6Y12_ZnO/light_dark_jv/device2_am15.dat']
    """
    exp_dir = os.path.join(directory, device_type, exp)
    
    if not os.path.exists(exp_dir):
        raise FileNotFoundError(f"Experiment directory not found: {exp_dir}")
    
    files = os.listdir(exp_dir)
    
    # Determine file ending based on experiment type
    if exp == "light_dark_jv":
        ending = 'am15.dat'  # AM1.5 illumination files
    elif exp == "sunsvoc":
        ending = '0000000uIllu_IV.dat'  # Base illumination files for Suns-Voc
    else:
        ending = ''  # Accept all files
    
    # Filter files by ending and create full paths
    files = [os.path.join(exp_dir, f) for f in files if f.endswith(ending)]
    files = natsorted(files)  # Natural sorting for proper numerical order
    
    return files


def get_illumination_files(base_file, extra_info):
    """
    Generate illumination file lists for Suns-Voc experiments.
    
    Args:
        base_file (str): Base experimental file path
        extra_info (str): Number of illumination levels ('2', '4', '8')
        
    Returns:
        list[str]: List of illumination file paths
        
    Raises:
        ValueError: If extra_info is not supported
    """
    base_parts = base_file.split('_')[:-2]
    base_pattern = '_'.join(base_parts) + '_'
    
    illumination_patterns = {
        '2': ['0000000uIllu_IV.dat', '1000000uIllu_IV.dat'],
        '4': ['0000000uIllu_IV.dat', '0160000uIllu_IV.dat', 
              '0320000uIllu_IV.dat', '1000000uIllu_IV.dat'],
        '8': ['0000000uIllu_IV.dat', '0100000uIllu_IV.dat', '0160000uIllu_IV.dat', 
              '0200000uIllu_IV.dat', '0250000uIllu_IV.dat', '0320000uIllu_IV.dat', 
              '0630000uIllu_IV.dat', '1000000uIllu_IV.dat']
    }
    
    if extra_info not in illumination_patterns:
        raise ValueError(f"Unsupported illumination configuration: {extra_info}")
    
    return [base_pattern + pattern for pattern in illumination_patterns[extra_info]]


def map_device_to_material(device_type):
    """
    Map device type to corresponding network material identifier.
    
    Args:
        device_type (str): Device type identifier
        
    Returns:
        str: Network material identifier
        
    Raises:
        ValueError: If device type is not recognized
    """
    device_mapping = {
        'PM6BTPeC9': 'pm6ec9',
        'W108-2': 'pm6ec9',
        'PM6Y12_ZnO': 'pm6y12',
        'PM6Y12_PEIZn': 'pm6y12_peizn'
    }
    
    if device_type not in device_mapping:
        raise ValueError(f"Unknown device type: {device_type}")
    
    return device_mapping[device_type]


if __name__ == "__main__":
    # Configuration: Define simulation and experimental data directories
    sim_dir_og = os.path.join('/', 'media', 'cai', 'Big', 'Simulated_Data', 'opkm', 'Networks_Extended')
    base_exp_dir = os.path.join('/', 'media', 'cai', 'Big', 'Experimental_Data', 'Data_From_Chen')
    res_dir = os.path.join(os.getcwd(), 'Results')
    
    # Processing configuration
    device_types = ["PM6Y12_ZnO"]  # Add more device types: ["PM6BTPeC9", "W108-2"]
    experiment_type = "light_dark_jv"  # Options: "light_dark_jv", "sunsvoc"
    
    print(f"Starting prediction pipeline...")
    print(f"Simulation directory: {sim_dir_og}")
    print(f"Experimental directory: {base_exp_dir}")
    print(f"Results directory: {res_dir}")
    print(f"Processing device types: {device_types}")
    print(f"Experiment type: {experiment_type}")

    for device_idx, device_type in enumerate(device_types, 1):
        print(f"\n=== Processing device {device_idx}/{len(device_types)}: {device_type} ===")
        
        try:
            # Map device type to network material
            net_mat = map_device_to_material(device_type)
            print(f"Mapped to network material: {net_mat}")
            
            # Get experimental files for this device type
            exp_files = get_experimental(device_type, base_exp_dir, experiment_type)
            print(f"Found {len(exp_files)} experimental files")
            
            if not exp_files:
                print(f"No experimental files found for {device_type}. Skipping...")
                continue

        except Exception as e:
            print(f"Error processing device {device_type}: {e}")
            continue

        # Process each experimental file
        for file_idx, exp_file in enumerate(exp_files, 1):
            file_name = os.path.basename(exp_file).split('.')[0]
            print(f"\n--- Processing file {file_idx}/{len(exp_files)}: {file_name} ---")
            
            # Find simulation directories for this network material
            try:
                sim_dirs = [d for d in os.listdir(sim_dir_og) 
                           if d.startswith('Residual_') and net_mat in d]
                sim_dirs = natsorted(sim_dirs)
                sim_dirs = [os.path.join(sim_dir_og, d) for d in sim_dirs]
                print(f"Found {len(sim_dirs)} matching simulation directories")
                
            except Exception as e:
                print(f"Error finding simulation directories: {e}")
                continue
            
            # Process each simulation directory
            for sim_idx, sim_dir in enumerate(sim_dirs, 1):
                print(f"    Processing simulation {sim_idx}/{len(sim_dirs)}: {os.path.basename(sim_dir)}")
                
                training_dir = os.path.join(sim_dir, 'training')
                conversion_dir = os.path.join(sim_dir, 'conversion')
                network_type = os.path.basename(sim_dir).split('_')[0]
                extra_info = os.path.basename(sim_dir).split('_')[-1]
                
                # Determine output directory name
                if extra_info not in ['2', '4', '8']:
                    extra_info = '.'
                    output_name = network_type
                else:
                    output_name = f"{network_type}_{extra_info}"
                
                # Check if output already exists
                output_file = os.path.join(res_dir, output_name, f"{file_name}.csv")
                
                if os.path.isfile(output_file):
                    print(f"    ✓ Output {file_name}.csv already exists in {output_name}. Skipping...")
                    continue

                print(f"    → Running prediction for {file_name} with network type {output_name}")
                
                try:
                    if experiment_type == 'light_dark_jv':
                        # Run SubPredict for J-V experiment
                        result = subprocess.run([
                            "python", "SubPredict.py", training_dir, conversion_dir, 
                            exp_file, res_dir, network_type, "JV", "Deibel", extra_info
                        ], check=True, capture_output=True, text=True)
                        print(f"    ✓ Successfully processed {file_name} with {output_name}")
                        
                    elif experiment_type == 'sunsvoc':
                        # Process Suns-Voc experiment with multiple illumination files
                        if extra_info in ['2', '4', '8']:
                            illumination_files = get_illumination_files(exp_file, extra_info)
                            file_list_str = str(illumination_files)
                        else:
                            file_list_str = exp_file
                            
                        result = subprocess.run([
                            "python", "SubPredict_SunsVoc.py", training_dir, conversion_dir,
                            file_list_str, res_dir, network_type, "JV_I4", "Deibel", extra_info
                        ], check=True, capture_output=True, text=True)
                        print(f"    ✓ Successfully processed Suns-Voc for {file_name} with {output_name}")
                        
                    else:
                        raise ValueError(f"Unknown experiment type: {experiment_type}")
                        
                except subprocess.CalledProcessError as e:
                    print(f"    ✗ Failed to process {file_name} with {output_name}: {e}")
                    if e.stderr:
                        print(f"    Error details: {e.stderr}")
                except Exception as e:
                    print(f"    ✗ Unexpected error processing {file_name} with {output_name}: {e}")

    print("\n=== Prediction pipeline completed ===")
    print(f"Results saved to: {res_dir}")
