"""
Training Script for PyOghma ML Models

This script orchestrates the training of multiple neural network models across
different device configurations and network architectures. It iterates through
simulation directories and trains models using the SubTraining.py script.

Usage:
    python Training.py

The script will:
    1. Scan predefined simulation directories for network configurations
    2. Check for existing trained models to avoid retraining
    3. Launch training processes for missing models using SubTraining.py
    4. Support different network types: Point, Difference, and Residual

Example Directory Structure:
    sim_dir/
    ├── Point_pm6y12/training/faster/nets.json
    ├── Difference_pm6y12/training/faster/nets.json
    └── Residual_pm6y12/training/faster/nets.json

Author: Cai Williams
Email: cai.williams@physik.tu-chemnitz.de
"""

import os
import json
import subprocess
import rainbow_tqdm
from tqdm import tqdm


def process_dirs(dirs, sim_dir_og, network_type, train_size, extra_arg="1"):
    """
    Process training directories and train missing models.
    
    For each directory in the list, this function:
    1. Checks if the training directory exists
    2. Loads the network configuration from nets.json
    3. Iterates through each network configuration
    4. Skips training if model already exists
    5. Launches SubTraining.py for missing models
    
    Args:
        dirs (list[str]): List of directory names to process (e.g., ['Point_pm6y12'])
        sim_dir_og (str): Base simulation directory path
        network_type (str): Type of network architecture ('Residual', 'Difference', 'Point')
        train_size (int): Number of training samples to use
        extra_arg (str, optional): Additional argument for SubTraining.py. Defaults to "1".
        
    Raises:
        FileNotFoundError: If nets.json file is not found in the expected location
        subprocess.CalledProcessError: If SubTraining.py execution fails
        
    Example:
        >>> process_dirs(['Point_pm6y12'], '/path/to/sims', 'Point', 10e6)
        Sim dir: /path/to/sims/Point_pm6y12/training
        0
        1
        net_123 Already Exists!
        2
    """
    for d in dirs:
        sim_dir = os.path.join(sim_dir_og, d, 'training')
        print(f'Processing simulation directory: {sim_dir}')
        
        # Check if the directory exists
        if not os.path.exists(sim_dir):
            print(f"Warning: Directory {sim_dir} does not exist. Skipping...")
            continue
            
        net_dir = os.path.join(sim_dir, 'faster', 'nets.json')
        
        try:
            # Load the network configuration from the JSON file
            with open(net_dir, 'r') as f:
                nets = json.load(f)
            files = list(nets['sims'].keys())  # Extract the list of network keys
            print(f"Found {len(files)} network configurations to process")
            
        except FileNotFoundError:
            print(f"Error: Configuration file {net_dir} not found. Skipping directory...")
            continue
        except json.JSONDecodeError as e:
            print(f"Error: Invalid JSON in {net_dir}: {e}. Skipping directory...")
            continue

        # Iterate through each network and check if the model already exists
        for idx, net_key in enumerate(tqdm(files, desc=f"Training models in {d}")):
            model_path = os.path.join(sim_dir, 'faster', net_key, 'model.keras')
            
            if os.path.isfile(model_path):
                print(f"{net_key}: Model already exists, skipping...")
            else:
                print(f"Training model {idx + 1}/{len(files)}: {net_key}")
                try:
                    # Run the SubTraining script to train the network
                    result = subprocess.run([
                        "python", "SubTraining.py", sim_dir, str(idx),
                        network_type, str(int(train_size)), str(extra_arg)
                    ], check=True, capture_output=True, text=True)
                    print(f"✓ Successfully trained {net_key}")
                except subprocess.CalledProcessError as e:
                    print(f"✗ Failed to train {net_key}: {e}")
                    print(f"Error output: {e.stderr}")
                except Exception as e:
                    print(f"✗ Unexpected error training {net_key}: {e}")


# Main training configuration
if __name__ == "__main__":
    # Define the simulation directory and the path to the network configuration file
    sim_dir_og = os.path.join('/', 'media', 'cai', 'Big', 'Simulated_Data', 'opkm', 'Networks')

    # Active training configurations - uncomment as needed
    print("Starting training for Point networks...")
    process_dirs(['Point_pm6y12'], sim_dir_og, 'Point', 10e6)
    
    print("Starting training for Difference networks...")
    process_dirs(['Difference_pm6y12'], sim_dir_og, 'Difference', 10e6)
    
    print("Starting training for Residual networks...")
    process_dirs(['Residual_pm6y12'], sim_dir_og, 'Residual', 10e6)

    print("Training PM6EC9 networks...")
    process_dirs(['Point_pm6ec9'], sim_dir_og, 'Point', 10e6)
    process_dirs(['Difference_pm6ec9'], sim_dir_og, 'Difference', 10e6)
    process_dirs(['Residual_pm6ec9'], sim_dir_og, 'Residual', 10e6)
    
    print("Training completed!")
    # Additional training configurations (currently disabled)
    # Uncomment blocks below to enable additional training scenarios
    
    # Training with PEI-Zn interlayer (currently has limits issues)
    # process_dirs(['Point_pm6y12_peizn'], sim_dir_og, 'Point', 5e6)
    # process_dirs(['Difference_pm6y12_peizn'], sim_dir_og, 'Difference', 5e6)
    # process_dirs(['Residual_pm6y12_peizn'], sim_dir_og, 'Residual', 5e6)
    
    # Multi-material training configurations
    # process_dirs(['Point_pm6ec9', 'Point_pm6y12'], sim_dir_og, 'Point', 5e6)
    # process_dirs(['Difference_pm6ec9', 'Difference_pm6y12'], sim_dir_og, 'Difference', 5e6)
    
    # Extended residual network training
    # process_dirs([
    #     'Residual_pm6ec9', 'Residual_pm6y12', 'Residual_pm6y12_peizn', 
    #     'Residual_ptq10y12_pm6y12'
    # ], sim_dir_og, 'Residual', 5e6)
    
    # Training size sensitivity analysis
    # for size_exp in [1, 2, 3, 4, 5, 6]:
    #     train_size = 5 * (10 ** size_exp)
    #     process_dirs([f'Residual_pm6y12_5e{size_exp}'], sim_dir_og, 'Residual', train_size)
    
    # Layer multiplication factor experiments
    # for factor in [2, 4, 8]:
    #     process_dirs([f'Residual_pm6y12_{factor}'], sim_dir_og, 'Residual', 
    #                 5e6 if factor != 8 else 4e6, extra_arg=str(factor))


    # Manual training configuration example (for advanced users)
    # Uncomment and modify as needed for custom training settings
    """
    import PyOghma_ML_Private as OML
    from PyOghma_ML_Private.Training import Model_Settings
    
    # Custom model configuration
    m = Model_Settings()
    m.activation = 'silu'
    m.initializer = 'he_normal' 
    m.batch_size = 2048
    m.layer_nodes = [64, 64, 64, 64]
    m.initial_learning_rate = 0.05
    m.gamma_learning_rate = 1e-4
    m.power_learning_rate = 2
    m.epochs = 128
    m.inputs = 4
    m.permutations_limit = int(5e6)
    
    # Manual training execution
    sim_dir = os.path.join(sim_dir_og, 'Residual_pm6y12_4', 'training')
    A = OML.Networks.initialise(sim_dir, network_type='Residual', model_settings=m)
    A.train()
    """
