import tensorflow as tf

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import unicodedata
import re
import numpy as np
import os
import io
import time

from dlex.configs import Params
from dlex.datasets import TensorflowDataset as Dataset
from dlex.tf.models.base import BaseModel, Batch


class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            self.enc_units,
            return_sequences=True,
            return_state=True,
            recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state = hidden)
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))


class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values):
        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(query, 1)

        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        score = self.V(tf.nn.tanh(
            self.W1(values) + self.W2(hidden_with_time_axis)))

        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights


class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            self.dec_units,
            return_sequences=True,
            return_state=True,
            recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)

        # used for attention
        self.attention = BahdanauAttention(self.dec_units)

    def call(self, x, hidden, enc_output):
        # enc_output shape == (batch_size, max_length, hidden_size)
        context_vector, attention_weights = self.attention(hidden, enc_output)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(x)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))

        # output shape == (batch_size, vocab)
        x = self.fc(output)

        return x, state, attention_weights


class Attention(BaseModel):
    def __init__(self, params: Params, dataset: Dataset):
        super().__init__(params, dataset)
        print(type(params))

        self._encoder = self._build_encoder()
        self._decoder = self._build_decoder()

        self._loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

    def _build_encoder(self) -> Encoder:
        cfg = self.params.model
        return Encoder(
            self.dataset.input_size,
            cfg.encoder.input_size,
            cfg.encoder.hidden_size,
            self.params.train.batch_size)

    def _build_decoder(self) -> Decoder:
        cfg = self.params.model
        return Decoder(
            self.dataset.output_size,
            cfg.decoder.input_size,
            cfg.decoder.hidden_size,
            self.params.train.batch_size)

    def get_loss(self):
        optimizer = tf.keras.optimizers.Adam()

    def loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = self._loss_fn(real, pred)

        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask

        return tf.reduce_mean(loss_)

    def call(self, batch: Batch):
        enc_hidden = self._encoder.initialize_hidden_state()
        enc_output, enc_hidden = self._encoder(batch.X, enc_hidden)

        dec_hidden = enc_hidden

        batch_size = batch.X.shape[0]
        dec_input = tf.expand_dims([self.dataset.sos_token_idx] * batch_size, 1)

        # Teacher forcing - feeding the target as the next input
        loss = 0
        for t in range(1, batch.Y.shape[1]):
            # passing enc_output to the decoder
            predictions, dec_hidden, _ = self._decoder(dec_input, dec_hidden, enc_output)

            loss += self.loss_function(batch.Y[:, t], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(batch.Y[:, t], 1)
        return loss

    @property
    def trainable_variables(self):
        return self._encoder.trainable_variables + self._decoder.trainable_variables


class NMT(Attention):
    def __init__(self, params: Params, dataset: Dataset):
        super().__init__(params, dataset)