"""
Code from https://github.com/nrontsis/PILCO
"""

import tensorflow as tf
import gpflow
import numpy as np
from typeguard import typechecked

float_type = gpflow.settings.dtypes.float_type


def randomize(model):
    mean = 1
    sigma = 0.01

    model.kern.lengthscales.assign(
        mean + sigma * np.random.normal(size=model.kern.lengthscales.shape))
    model.kern.variance.assign(
        mean + sigma * np.random.normal(size=model.kern.variance.shape))
    if model.likelihood.variance.trainable:
        model.likelihood.variance.assign(
            mean + sigma * np.random.normal())


class MGPR(gpflow.Parameterized):
    def __init__(self, action_dim, state_dim, x, y, name=None):
        super(MGPR, self).__init__(name)

        self.num_outputs = state_dim
        self.num_dims = action_dim + state_dim
        self.optimizers = []
        self.models = []
        self._create_models(X=x, Y=y)

    def _create_models(self, X: np.ndarray, Y: np.ndarray):
        for i in range(self.num_outputs):
            kern = gpflow.kernels.RBF(input_dim=X.shape[1], ARD=True)
            kern.lengthscales.prior = gpflow.priors.Gamma(1, 10)  # priors have to be included before
            kern.variance.prior = gpflow.priors.Gamma(1.5, 2)  # before the model gets compiled
            # TODO: Maybe fix noise for better conditioning
            model = gpflow.models.GPR(X, Y[:, i:i + 1], kern, name="{}_{}".format(self.name, i))
            model.clear()
            model.compile()
            self.models.append(model)

    def set_XY(self, X: np.ndarray, Y: np.ndarray):
        for i in range(len(self.models)):
            self.models[i].X = X
            self.models[i].Y = Y[:, i:i + 1]

    def optimize(self, restarts=1):
        if len(self.optimizers) == 0:  # This is the first call to optimize();
            for model in self.models:
                # Create an gpflow.train.ScipyOptimizer object for every model embedded in mgpr
                optimizer = gpflow.train.ScipyOptimizer(method='L-BFGS-B')
                optimizer.minimize(model)
                self.optimizers.append(optimizer)
                restarts -= 1

        for model, optimizer in zip(self.models, self.optimizers):
            session = optimizer._model.enquire_session(None)
            best_parameters = model.read_values(session=session)
            best_likelihood = model.compute_log_likelihood()
            for restart in range(restarts):
                randomize(model)
                optimizer._optimizer.minimize(session=session,
                                              feed_dict=optimizer._gen_feed_dict(optimizer._model, None),
                                              step_callback=None)
                likelihood = model.compute_log_likelihood()
                if likelihood > best_likelihood:
                    best_parameters = model.read_values(session=session)
                    best_likelihood = likelihood
            model.assign(best_parameters)

    def predict(self, x):
        means = []
        vars = []
        for i in range(self.num_outputs):
            mean, var = self.models[i].predict_f(x)
            means.append(mean)
            vars.append(var)
        return np.array(means), np.array(vars)

    def predict_on_noisy_inputs(self, m, s):
        iK, beta = self.calculate_factorizations()
        return self.predict_given_factorizations(m, s, iK, beta)

    def calculate_factorizations(self):
        K = self.K(self.X)
        batched_eye = tf.eye(tf.shape(self.X)[0], batch_shape=[self.num_outputs], dtype=float_type)
        L = tf.cholesky(K + self.noise[:, None, None] * batched_eye)
        iK = tf.cholesky_solve(L, batched_eye)
        Y_ = tf.transpose(self.Y)[:, :, None]
        # Why do we transpose Y? Maybe we need to change the definition of self.Y() or beta?
        beta = tf.cholesky_solve(L, Y_)[:, :, 0]
        return iK, beta

    def predict_given_factorizations(self, m, s, iK, beta):
        """
        Approximate GP regression at noisy inputs via moment matching
        IN: mean (m) (row vector) and (s) variance of the state
        OUT: mean (M) (row vector), variance (S) of the action
             and inv(s)*input-ouputcovariance
        """

        s = tf.tile(s[None, None, :, :], [self.num_outputs, self.num_outputs, 1, 1])
        inp = tf.tile(self.centralized_input(m)[None, :, :], [self.num_outputs, 1, 1])

        # Calculate M and V: mean and inv(s) times input-output covariance
        iL = tf.matrix_diag(1 / self.lengthscales)
        iN = inp @ iL
        B = iL @ s[0, ...] @ iL + tf.eye(self.num_dims, dtype=float_type)

        # Redefine iN as in^T and t --> t^T
        # B is symmetric so its the same
        t = tf.linalg.transpose(
            tf.matrix_solve(B, tf.linalg.transpose(iN), adjoint=True),
        )

        lb = tf.exp(-tf.reduce_sum(iN * t, -1) / 2) * beta
        tiL = t @ iL
        c = self.variance / tf.sqrt(tf.linalg.det(B))

        M = (tf.reduce_sum(lb, -1) * c)[:, None]
        V = tf.matmul(tiL, lb[:, :, None], adjoint_a=True)[..., 0] * c[:, None]

        # Calculate S: Predictive Covariance
        R = s @ tf.matrix_diag(
            1 / tf.square(self.lengthscales[None, :, :]) +
            1 / tf.square(self.lengthscales[:, None, :])
        ) + tf.eye(self.num_dims, dtype=float_type)

        # TODO: change this block according to the PR of tensorflow. Maybe move it into a function?
        X = inp[None, :, :, :] / tf.square(self.lengthscales[:, None, None, :])
        X2 = -inp[:, None, :, :] / tf.square(self.lengthscales[None, :, None, :])
        Q = tf.matrix_solve(R, s) / 2
        Xs = tf.reduce_sum(X @ Q * X, -1)
        X2s = tf.reduce_sum(X2 @ Q * X2, -1)
        maha = -2 * tf.matmul(X @ Q, X2, adjoint_b=True) + \
               Xs[:, :, :, None] + X2s[:, :, None, :]
        #
        k = tf.log(self.variance)[:, None] - \
            tf.reduce_sum(tf.square(iN), -1) / 2
        L = tf.exp(k[:, None, :, None] + k[None, :, None, :] + maha)
        S = (tf.tile(beta[:, None, None, :], [1, self.num_outputs, 1, 1])
             @ L @
             tf.tile(beta[None, :, :, None], [self.num_outputs, 1, 1, 1])
             )[:, :, 0, 0]

        diagL = tf.transpose(tf.linalg.diag_part(tf.transpose(L)))
        S = S - tf.diag(tf.reduce_sum(tf.multiply(iK, diagL), [1, 2]))
        S = S / tf.sqrt(tf.linalg.det(R))
        S = S + tf.diag(self.variance)
        S = S - M @ tf.transpose(M)

        return tf.transpose(M), S, tf.transpose(V)

    def centralized_input(self, m):
        return self.X - m

    def K(self, X1, X2=None):
        return tf.stack(
            [model.kern.K(X1, X2) for model in self.models]
        )

    @property
    def Y(self):
        return tf.concat(
            [model.Y.parameter_tensor for model in self.models],
            axis=1
        )

    @property
    def X(self):
        return self.models[0].X.parameter_tensor

    @property
    def lengthscales(self):
        return tf.stack(
            [model.kern.lengthscales.constrained_tensor for model in self.models]
        )

    @property
    def variance(self):
        return tf.stack(
            [model.kern.variance.constrained_tensor for model in self.models]
        )

    @property
    def noise(self):
        return tf.stack(
            [model.likelihood.variance.constrained_tensor for model in self.models]
        )
