"""
SubTraining.py - Train New PyOghma_ML Neural Network Models

This script provides a command-line interface for training new PyOghma_ML neural
network models from scratch. It supports different network architectures with
optimized hyperparameters for organic photovoltaic device analysis.

The script implements network-specific configurations and training strategies
for Residual, Difference, and Point network types, each optimized for different
aspects of OPV device modeling and prediction.

Usage:
    python SubTraining.py <sim_dir> <idx> <type> <limit> <inputs>

Example:
    python SubTraining.py ./training 0 Residual 1000000 2
"""

from argparse import ArgumentParser
import PyOghma_ML as OML
from PyOghma_ML.Training import Model_Settings
import tensorflow as tf

# Parse command-line arguments for training configuration
parser = ArgumentParser(
    description="Train new PyOghma_ML neural network models from scratch",
    formatter_class=ArgumentParser.RawDescriptionHelpFormatter,
    epilog="""
Network Types:
  Residual    - ResNet-style networks with skip connections for deep learning
  Difference  - Networks trained on experimental-simulation differences
  Point       - Single-point prediction networks for device parameters

Examples:
  python SubTraining.py ./training 0 Residual 1000000 2
  python SubTraining.py ./training 5 Point 500000 1
    """
)
parser.add_argument("sim_dir", type=str,
                   help="Path to the simulation directory containing the training data.")
parser.add_argument("idx", type=int,
                   help="Index of the specific network to train (0-based indexing).")
parser.add_argument("type", type=str, choices=['Residual', 'Difference', 'Point'],
                   help="Type of network architecture to train.")
parser.add_argument("limit", type=str,
                   help="Maximum number of training permutations/samples to use.")
parser.add_argument("inputs", type=str,
                   help="Number of input features/branches for the model architecture.")
args = parser.parse_args()


def create_model_settings(network_type: str, limit: int, inputs: int) -> Model_Settings:
    """
    Create optimized model settings for new model training.
    
    Args:
        network_type: Type of network ('Residual', 'Difference', 'Point')
        limit: Training data limit
        inputs: Number of input features
        
    Returns:
        Model_Settings: Configured settings object for new training
    """
    
    if network_type == 'Residual':
        # Residual networks: Optimized for deep learning with skip connections
        return Model_Settings(
            initializer='he_normal',           # Good initialization for ReLU-family
            activation='silu',                 # Smooth activation function
            layer_nodes=[128, 128, 128, 128],  # Balanced architecture
            dropout=[0.05, 0.05, 0.05, 0.05], # Light regularization
            inital_learning_rate=8e-5,         # Higher LR for fresh training
            batch_size=16384,                  # Large batch for stability
            epochs=1024,                       # Moderate training duration
            patience=16,                       # Early stopping patience
            decay_rate=6e-1,                   # Moderate decay rate
            permutations_limit=limit,
            inputs=inputs
        )

    elif network_type == 'Difference':
        # Difference networks: Focus on experimental-simulation discrepancies
        return Model_Settings(
            initializer='he_normal',
            activation='silu',
            layer_nodes=[256, 256, 256, 256],  # Wider for complex patterns
            dropout=[0.05, 0.05, 0.05, 0.05],
            inital_learning_rate=1e-5,         # Moderate learning rate
            batch_size=16384,
            epochs=1024,
            patience=16,
            decay_rate=8e-1,                   # Higher decay rate
            permutations_limit=limit,
            inputs=inputs
        )

    elif network_type == 'Point':
        # Point networks: Single-point device parameter prediction
        return Model_Settings(
            initializer='he_normal',
            activation='silu',
            layer_nodes=[256, 256, 256, 256],  # Wide for parameter mapping
            dropout=[0.05, 0.05, 0.05, 0.05],
            inital_learning_rate=1e-4,         # Higher LR for Point networks
            batch_size=1024,                   # Smaller batch size
            epochs=4096,                       # Extended training
            patience=128,                      # Higher patience
            decay_rate=8e-1,
            permutations_limit=limit,
            inputs=inputs
        )

    else:
        # Default fallback configuration
        print(f"Unknown network type: {network_type}. Using default settings.")
        return Model_Settings(
            initializer='he_normal',
            activation='silu',
            layer_nodes=[128, 128, 128, 128],
            permutations_limit=limit,
            inputs=inputs
        )


# Create model settings for new training
print(f"Configuring new {args.type} network (index {args.idx}) with {args.inputs} inputs...")
m = create_model_settings(args.type, int(args.limit), int(args.inputs))

# Initialize and train the network
A = OML.Networks.initialise(args.sim_dir, network_type=args.type, model_settings=m, )

A.train(args.idx)  # For 'Residual', train specific index
        # For others, train all
exit()

