"""
Extended Training Script for Existing PyOghma ML Models

This script extends the training of existing neural network models with additional
data or modified configurations. Unlike Training.py, this script is designed to
work with models that have already been trained and need further refinement.

Usage:
    python Training_Existing.py

The script will:
    1. Scan extended simulation directories for network configurations
    2. Process ALL models (including existing ones) for extended training
    3. Launch extended training processes using SubTraining_Existing.py
    4. Support continued training with larger datasets or modified parameters

Directory Structure:
    sim_dir/
    ├── Point_pm6y12/training/faster/nets.json
    ├── Difference_pm6y12/training/faster/nets.json
    └── Residual_pm6y12/training/faster/nets.json

Author: Cai Williams
Email: cai.williams@physik.tu-chemnitz.de
"""

import os
import json
import subprocess
import rainbow_tqdm
from tqdm import tqdm


def process_dirs(dirs, sim_dir_og, network_type, train_size, extra_arg="1"):
    """
    Process training directories for extended/continued training.
    
    Unlike the regular training script, this function processes ALL models
    in the directory, regardless of whether they already exist. This allows
    for extended training with additional data or modified parameters.
    
    Args:
        dirs (list[str]): List of directory names to process (e.g., ['Point_pm6y12'])
        sim_dir_og (str): Base simulation directory path
        network_type (str): Type of network architecture ('Residual', 'Difference', 'Point')
        train_size (int): Number of training samples to use (often larger than initial training)
        extra_arg (str, optional): Additional argument for SubTraining_Existing.py. Defaults to "1".
        
    Raises:
        FileNotFoundError: If nets.json file is not found in the expected location
        subprocess.CalledProcessError: If SubTraining_Existing.py execution fails
        
    Note:
        This function will process existing models for continued training,
        unlike process_dirs in Training.py which skips existing models.
        
    Example:
        >>> process_dirs(['Point_pm6y12'], '/path/to/extended_sims', 'Point', 20e6)
        Processing simulation directory: /path/to/extended_sims/Point_pm6y12/training
        Found 10 network configurations to process
        Extending training for model 1/10: net_123
        ✓ Successfully extended training for net_123
    """
    for d in dirs:
        sim_dir = os.path.join(sim_dir_og, d, 'training')
        print(f'Processing extended training directory: {sim_dir}')
        
        # Check if the directory exists
        if not os.path.exists(sim_dir):
            print(f"Warning: Directory {sim_dir} does not exist. Skipping...")
            continue
            
        net_dir = os.path.join(sim_dir, 'faster', 'nets.json')
        
        try:
            # Load the network configuration from the JSON file
            with open(net_dir, 'r') as f:
                nets = json.load(f)
            files = list(nets['sims'].keys())  # Extract the list of network keys
            print(f"Found {len(files)} network configurations for extended training")
            
        except FileNotFoundError:
            print(f"Error: Configuration file {net_dir} not found. Skipping directory...")
            continue
        except json.JSONDecodeError as e:
            print(f"Error: Invalid JSON in {net_dir}: {e}. Skipping directory...")
            continue

        # Iterate through each network for extended training
        for idx, net_key in enumerate(tqdm(files, desc=f"Extending training in {d}")):
            print(f"Extending training for model {idx + 1}/{len(files)}: {net_key}")
            
            try:
                # Run the SubTraining_Existing script for extended training
                result = subprocess.run([
                    "python", "SubTraining_Existing.py", sim_dir, str(idx),
                    network_type, str(int(train_size)), str(extra_arg)
                ], check=True, capture_output=True, text=True)
                print(f"✓ Successfully extended training for {net_key}")
            except subprocess.CalledProcessError as e:
                print(f"✗ Failed to extend training for {net_key}: {e}")
                print(f"Error output: {e.stderr}")
            except Exception as e:
                print(f"✗ Unexpected error extending training for {net_key}: {e}")


# Main extended training configuration
if __name__ == "__main__":
    # Define the extended simulation directory
    sim_dir_og = os.path.join('/', 'media', 'cai', 'Big', 'Simulated_Data', 'opkm', 'Networks_Extended')

    print("Starting extended training...")
    
    # Extended training configurations (currently commented out)
    # Uncomment and modify as needed for your extended training requirements
    
    # Example: Extended training for Point networks with larger dataset
    # process_dirs(['Point_pm6y12'], sim_dir_og, 'Point', 20e6)
    
    # Active extended training configurations
    # Uncomment specific lines below to enable extended training for different network types
    
    # Basic network type extended training
    # process_dirs(['Point_pm6y12'], sim_dir_og, 'Point', 20e6)
    # process_dirs(['Difference_pm6y12'], sim_dir_og, 'Difference', 10e6)
    # process_dirs(['Residual_pm6y12'], sim_dir_og, 'Residual', 10e6)

    # PM6EC9 material extended training
    # process_dirs(['Point_pm6ec9'], sim_dir_og, 'Point', 10e6)
    # process_dirs(['Difference_pm6ec9'], sim_dir_og, 'Difference', 10e6)
    process_dirs(['Residual_pm6ec9'], sim_dir_og, 'Residual', 10e6)
    
    print("Extended training completed!")
    
    
    # Additional extended training configurations (currently disabled)
    # Uncomment blocks below to enable additional extended training scenarios
    
    # Extended training with PEI-Zn interlayer (limits are broken - needs investigation)
    # process_dirs(['Point_pm6y12_peizn'], sim_dir_og, 'Point', 5e6)
    # process_dirs(['Difference_pm6y12_peizn'], sim_dir_og, 'Difference', 5e6)
    # process_dirs(['Residual_pm6y12_peizn'], sim_dir_og, 'Residual', 5e6)
    
    # Multi-material extended training configurations
    # process_dirs(['Point_pm6ec9', 'Point_pm6y12'], sim_dir_og, 'Point', 5e6)
    # process_dirs(['Difference_pm6ec9', 'Difference_pm6y12'], sim_dir_og, 'Difference', 5e6)
    
    # Extended residual network training with multiple materials
    # process_dirs([
    #     'Residual_pm6ec9', 'Residual_pm6y12', 'Residual_pm6y12_peizn', 
    #     'Residual_ptq10y12_pm6y12'
    # ], sim_dir_og, 'Residual', 5e6)
    
    # Extended training size sensitivity analysis
    # for size_exp in [1, 2, 3, 4, 5, 6]:
    #     train_size = 5 * (10 ** size_exp)
    #     process_dirs([f'Residual_pm6y12_5e{size_exp}'], sim_dir_og, 'Residual', train_size)
    
    # Extended training with layer multiplication factors
    # for factor in [2, 4, 8]:
    #     process_dirs([f'Residual_pm6y12_{factor}'], sim_dir_og, 'Residual', 
    #                 5e6 if factor != 8 else 4e6, extra_arg=str(factor))


    # Manual extended training configuration example (for advanced users)
    # Uncomment and modify as needed for custom extended training settings
    """
    import PyOghma_ML_Private as OML
    from PyOghma_ML_Private.Training import Model_Settings
    
    # Custom extended model configuration
    m = Model_Settings()
    m.activation = 'silu'
    m.initializer = 'he_normal'
    m.batch_size = 2048
    m.layer_nodes = [64, 64, 64, 64]
    m.initial_learning_rate = 0.05
    m.gamma_learning_rate = 1e-4
    m.power_learning_rate = 2
    m.epochs = 128
    m.inputs = 4
    m.permutations_limit = int(5e6)
    
    # Manual extended training execution
    sim_dir = os.path.join(sim_dir_og, 'Residual_pm6y12_4', 'training')
    A = OML.Networks.initialise(sim_dir, network_type='Residual', model_settings=m)
    A.train()
    """
