"""
This module provides functionality for loading trained neural network models and making predictions.

It includes a class for managing the loading of TensorFlow/Keras models from a specified directory
and using them to generate predictions based on experimental features. The module is designed to
handle models with multiple input branches and supports preprocessing of input data.
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['PYCARET_CUSTOM_LOGGING_LEVEL'] = 'CRITICAL'
import tensorflow as tf
import numpy as np


class Predicting:
    """
    A class for loading trained neural network models and making predictions.

    This class handles loading a TensorFlow/Keras model from a specified
    directory and using it to make predictions on experimental features.
    """

    def __init__(self, network_dir: str, experimental_feature: np.ndarray, inputs: int = 1) -> None:
        """
        Initialize a Predicting instance.

        This method sets up the prediction environment by loading a trained
        neural network model and preparing experimental features for prediction.

        Args:
            network_dir (str): Directory path containing the trained network model.
                Should contain 'model.keras' file or numbered model files.
            experimental_feature (numpy.ndarray): Experimental input features to use 
                for generating predictions. Will be converted to float type.
            inputs (int, optional): Number of input branches for the model. 
                Defaults to 1 for single-input models.

        Note:
            The model is automatically loaded during initialization via load_network().
            Experimental features are converted to float type for compatibility.
        """
        self.network_dir = network_dir
        self.experimental_feature = experimental_feature.astype(float)
        self.load_network()
        self.inputs = inputs

    def load_network(self) -> None:
        """
        Load the trained neural network model from the specified directory.

        This method handles loading of TensorFlow/Keras models, automatically
        detecting whether to load the default model or the latest numbered model
        based on the directory contents.

        Loading Logic:
            - If only one file exists: loads 'model.keras'
            - If multiple files exist: loads 'model{N}.keras' where N is the file count

        Raises:
            FileNotFoundError: If no model files are found in the directory.
            tf.errors.InvalidArgumentError: If the model file is corrupted or incompatible.

        Note:
            The loaded model is stored in self.model for use in predictions.
        """
        files = len(os.listdir(self.network_dir))
        if files == 1:
            dir = os.path.join(self.network_dir, 'model.keras')
        else:
            dir = os.path.join(self.network_dir, 'model'+str(files)+'.keras')
        self.model = tf.keras.models.load_model(dir)

    def predict(self) -> np.ndarray:
        """
        Make predictions using the loaded neural network model.

        This method handles different input configurations for the neural network,
        automatically splitting experimental features for multi-input models or
        using them directly for single-input models.

        Input Processing:
            - For single input (inputs=1): Uses features directly
            - For dual input (inputs=2): Splits features in half
            - For quad input (inputs=4): Splits features into 4 equal parts
            - For octal input (inputs=8): Splits features into 8 equal parts

        Returns:
            numpy.ndarray: Model predictions with shape (n_samples, n_outputs).

        Raises:
            ValueError: If the input configuration doesn't match expected splits.
            tf.errors.InvalidArgumentError: If feature shapes are incompatible with the model.

        Note:
            Feature splitting assumes equal division of the feature vector.
            For example, with inputs=2, the first half goes to input 1 and 
            the second half goes to input 2.
        """
        x = self.experimental_feature
        # print(np.shape(x))
        # xs = int(np.shape(x)[1] / 2)
        # x_1 = x[:, :xs]
        # x_2 = x[:, xs:]

        match self.inputs:
            case 1:
                return self.model.predict(x)
            case 2:
                xs = int(np.shape(x)[1] / 2)
                x_1 = x[:, :xs]
                x_2 = x[:, xs:]
                return self.model.predict([x_1, x_2])
            case 4:
                xs = int(np.shape(x)[1] / 4)
                x_1 = x[:, :xs]
                x_2 = x[:, xs:2*xs]
                x_3 = x[:, 2*xs:3*xs]
                x_4 = x[:, 3*xs:]
                return self.model.predict([x_1, x_2, x_3, x_4])
            case 8:
                xs = int(np.shape(x)[1] / 8)
                x_1 = x[:, :xs]
                x_2 = x[:, xs:2*xs]
                x_3 = x[:, 2*xs:3*xs]
                x_4 = x[:, 3*xs:4*xs]
                x_5 = x[:, 4*xs:5*xs]
                x_6 = x[:, 5*xs:6*xs]
                x_7 = x[:, 6*xs:7*xs]
                x_8 = x[:, 7*xs:]
                return self.model.predict([x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8])
            case _:
                return self.model.predict(x)
    