import pandas as pd
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Conv1D, Dense, Flatten, Input, MaxPooling1D
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

from conmo.conf import Label, RandomSeed
from conmo.algorithms.algorithm import PretrainedAlgorithm


class PretrainedCNN1D(PretrainedAlgorithm):

    def __init__(self, pretrained: bool, input_len: int, random_seed: int = None, path: str = None) -> None:
        super().__init__(pretrained, path)
        self.input_len = input_len
        if not self.pretrained:
            if random_seed != None:
                self.random_seed = random_seed
            else:
                self.random_seed = RandomSeed.RANDOM_SEED

    def fit_predict(self, data_train: pd.DataFrame, data_test: pd.DataFrame, labels_train: pd.DataFrame, labels_test: pd.DataFrame) -> pd.DataFrame:
        self.model = None
        if not self.pretrained:
            # Set TensorFlow random seed
            tf.random.set_seed(self.random_seed)

            # Create new Multilayer Perceptron
            self.model = self.build_cnn_1d(self)

            # Compile model
            self.model.compile(loss='mse', optimizer=Adam(
                learning_rate=0.001), metrics=['mse'])

            # Train model with only train data
            callbacks = [ModelCheckpoint(filepath='./checkpoints/checkpoint', monitor='val_loss',
                                         mode='min', verbose=1, save_best_only=True, save_weights_only=True)]
            self.model.fit(data_train.to_numpy(), labels_train.to_numpy(
            ), epochs=50, batch_size=32, validation_split=0.2, verbose=2, callbacks=callbacks, shuffle=True)
        else:
            # If there is a pretrained model saved the is no reason to train
            # Load weights from disk
            self.load_weights()

        # Reshape of data needed
        data_test = data_test.to_numpy().reshape(
            data_test.shape[0], data_test.shape[1], 1)

        # Predict with the data test
        pred = self.model.predict(data_test)

        # Generate output dataframe
        return pd.DataFrame(pred, index=labels_test.index, columns=Label.BATTERIES_DEG_TYPES)

    def load_weights(self) -> None:
        self.model = tf.keras.models.load_model(self.path, compile=False)

    def build_cnn_1d(self) -> Sequential:
        """
        Auxiliary method for building the 1D convolutional neural network.

        Returns
        -------
        model: tf.keras.Model
            Keras model built.
        """
        model = Sequential([
            # input layer
            Input(shape=(self.input_len, 1)),
            # 1D convolutional layer
            Conv1D(filters=32, kernel_size=4, strides=2, activation='relu'),
            # Max pooling layer
            MaxPooling1D(pool_size=2, strides=2),
            # 1D convolutional layer
            Conv1D(filters=32, kernel_size=4, strides=2, activation='relu'),
            # Max pooling layer
            MaxPooling1D(pool_size=2, strides=2),
            # Three fully connected layers of sizes 128, 64 and 32
            Flatten(),
            Dense(128, activation='relu'),
            Dense(64, activation='relu'),
            Dense(3, activation='sigmoid')
        ])
        return model
