"""
SubTraining_Existing.py - Continue Training Existing Neural Network Models

This script provides command-line interface for continuing training of existing
PyOghma_ML neural network models. It supports different network architectures
(Residual, Difference, Point) with optimized hyperparameters for each type.

The script implements selective training based on network indices and applies
different learning configurations depending on the network type and index.
This allows for fine-tuned training continuation with architecture-specific
optimization strategies.

Usage:
    python SubTraining_Existing.py <sim_dir> <idx> <type> <limit> <inputs>

Example:
    python SubTraining_Existing.py ./training 8 Residual 1000000 2
"""

from argparse import ArgumentParser
import PyOghma_ML as OML
from PyOghma_ML.Training import Model_Settings
import tensorflow as tf

# Parse command-line arguments for training configuration
parser = ArgumentParser(
    description="Continue training existing PyOghma_ML neural network models",
    formatter_class=ArgumentParser.RawDescriptionHelpFormatter,
    epilog="""
Network Types:
  Residual    - ResNet-style networks with skip connections
  Difference  - Networks trained on experimental-simulation differences  
  Point       - Single-point prediction networks

Examples:
  python SubTraining_Existing.py ./training 8 Residual 1000000 2
  python SubTraining_Existing.py ./training 12 Point 500000 1
    """
)
parser.add_argument("sim_dir", type=str, 
                   help="Path to the simulation directory containing the training data and existing models.")
parser.add_argument("idx", type=int, 
                   help="Index of the specific network to continue training (0-based indexing).")
parser.add_argument("type", type=str, choices=['Residual', 'Difference', 'Point'],
                   help="Type of network architecture to train.")
parser.add_argument("limit", type=str, 
                   help="Maximum number of training permutations/samples to use.")
parser.add_argument("inputs", type=str, 
                   help="Number of input features/branches for the model architecture.")
args = parser.parse_args()

# Configure model settings based on network type and index
# This section defines optimized hyperparameters for different network architectures
# and training scenarios based on empirical performance analysis.

def create_model_settings(network_type: str, idx: int, limit: int, inputs: int) -> Model_Settings:
    """
    Create optimized model settings based on network type and index.
    
    Args:
        network_type: Type of network ('Residual', 'Difference', 'Point')
        idx: Network index for selective configuration
        limit: Training data limit
        inputs: Number of input features
        
    Returns:
        Model_Settings: Configured settings object
    """
    
    if network_type == 'Residual':
        # Residual networks: Skip connections for deeper architectures
        # Only train specific indices that have shown good convergence
        allowed_indices = [8, 14, 15, 16, 17, 18, 19, 20, 24, 29, 30, 31, 32, 33, 34]
        if idx not in allowed_indices:
            print(f"Index {idx} not in allowed list for Residual networks: {allowed_indices}")
            exit()

        # Optimized configurations for different index ranges
        if idx in [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]:
            # Standard configuration for early-stage networks
            return Model_Settings(
                initializer='he_normal',         # Good for ReLU-family activations
                activation='silu',               # Smooth activation for stable training
                layer_nodes=[128, 128, 128, 128], # Balanced depth vs. width
                dropout=[0.05, 0.05, 0.05, 0.05], # Light regularization
                inital_learning_rate=8e-6,       # Conservative learning rate
                batch_size=16384,                # Large batch for stable gradients
                epochs=2048,                     # Extended training
                patience=32,                     # Allow for learning plateaus
                decay_rate=9.8e-1,              # Slow decay for fine-tuning
                permutations_limit=limit,
                inputs=inputs
            )
        elif idx in [19, 20]:
            # Same configuration as above - these indices perform similarly
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[128, 128, 128, 128],
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=8e-6,
                batch_size=16384,
                epochs=2048,
                patience=32,
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )
        elif idx in [24]:
            # Lower learning rate for more sensitive networks
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[128, 128, 128, 128],
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=3e-6,       # Reduced LR for stability
                batch_size=16384,
                epochs=2048,
                patience=32,
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )
        elif idx in [29, 30, 31, 32, 33, 34]:
            # Higher learning rate for later-stage networks
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[128, 128, 128, 128],
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=1e-5,       # Slightly higher LR
                batch_size=16384,
                epochs=2048,
                patience=32,
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )
        else:
            # Default configuration for other indices
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[128, 128, 128, 128],
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=5e-6,
                batch_size=16384,
                epochs=1024,                     # Shorter training
                patience=16,                     # Less patience
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )

    elif network_type == 'Difference':
        # Difference networks: Focus on experimental-simulation discrepancies
        allowed_indices = [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
        if idx not in allowed_indices:
            print(f"Index {idx} not in allowed list for Difference networks: {allowed_indices}")
            exit()

        if idx in allowed_indices:
            # Larger networks for complex difference patterns
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[256, 256, 256, 256], # Wider networks for complexity
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=1e-5,       # Moderate learning rate
                batch_size=16384,
                epochs=2048,
                patience=256,                    # High patience for convergence
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )
        else:
            # Alternative configuration (currently unreachable due to exit above)
            return Model_Settings(
                initializer='he_normal',
                activation='silu',
                layer_nodes=[256, 256, 256, 256],
                dropout=[0.05, 0.05, 0.05, 0.05],
                inital_learning_rate=1e-7,       # Very conservative LR
                batch_size=16384,
                epochs=4096,                     # Extended training
                patience=128,
                decay_rate=9.8e-1,
                permutations_limit=limit,
                inputs=inputs
            )

    elif network_type == 'Point':
        # Point networks: Single-point device parameter prediction
        allowed_indices = [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
        if idx not in allowed_indices:
            print(f"Index {idx} not in allowed list for Point networks: {allowed_indices}")
            exit()

        return Model_Settings(
            initializer='he_normal',
            activation='silu',
            layer_nodes=[256, 256, 256, 256],     # Wide networks for parameter mapping
            dropout=[0.05, 0.05, 0.05, 0.05],
            inital_learning_rate=2e-6,           # Conservative for fine-tuning
            batch_size=1024,                     # Smaller batch for Point networks
            epochs=2048,
            patience=128,
            decay_rate=9.8e-1,
            permutations_limit=limit,
            inputs=inputs
        )

    else:
        # Default fallback configuration
        print(f"Unknown network type: {network_type}. Using default settings.")
        return Model_Settings(
            initializer='he_normal',
            activation='silu',
            layer_nodes=[128, 128, 128, 128],
            permutations_limit=limit,
            inputs=inputs
        )


# Create model settings using the configuration function
print(f"Configuring {args.type} network (index {args.idx}) with {args.inputs} inputs...")
m = create_model_settings(args.type, args.idx, int(args.limit), int(args.inputs))

# Initialize network and continue training
print(f"Initializing {args.type} network from directory: {args.sim_dir}")
try:
    # Create network instance with configured settings
    network = OML.Networks.initialise(args.sim_dir, network_type=args.type, model_settings=m)
    
    # Continue training the existing model at the specified index
    print(f"Continuing training for network index {args.idx}...")
    network.train_existing(args.idx)
    
    print(f"Training completed successfully for {args.type} network (index {args.idx})")
    
except Exception as e:
    print(f"Error during training: {e}")
    exit(1)

print("Script execution completed.")
exit(0)

