import torch
import torch.nn as nn
from torch.distributions import Normal, OneHotCategorical
from torch.nn import functional as F


class MixtureDensityNetwork(nn.Module):
    """
    Mixture density network; implementation mostly from https://github.com/tonyduan/mdn.
    Used as the final/output block of a network.
    [ Bishop, 1994 ]
    Parameters
    ----------
    dim_in: int; dimensionality of the covariates
    dim_out: int; dimensionality of the response variable
    n_components: int; number of components in the mixture model
    """

    def __init__(self, dim_in, dim_out, n_components):
        super().__init__()
        self.pi_network = CategoricalNetwork(dim_in, n_components)
        self.normal_network = MixtureDiagNormalNetwork(dim_in, dim_out,
                                                       n_components)

    def forward(self, x):
        return self.pi_network(x), self.normal_network(x)

    @staticmethod
    def loss(pi, normal, y):
        loglik = normal.log_prob(y.unsqueeze(1).expand_as(normal.loc))
        loglik = torch.sum(loglik, dim=2)
        loss = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=1)
        return loss

    @staticmethod
    def sample(pi, normal):
        samples = torch.sum(pi.sample().unsqueeze(2) * normal.sample(), dim=1)
        return samples

    @staticmethod
    def mean(pi, normal):
        means = torch.sum(pi.mean.unsqueeze(2) * normal.mean, dim=1)
        return means

    @staticmethod
    def stddev(pi, normal):
        # assuming independent
        var = pi.mean.unsqueeze(2) ** 2 * normal.variance + pi.variance.unsqueeze(
            2) * normal.mean ** 2 + pi.variance.unsqueeze(2) * normal.variance
        stddev = torch.sqrt(torch.sum(var, dim=1))
        return stddev


class MixtureDiagNormalNetwork(nn.Module):

    def __init__(self, in_dim, out_dim, n_components, hidden_dim=None):
        super().__init__()
        self.n_components = n_components
        if hidden_dim is None:
            hidden_dim = in_dim
        self.network = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * out_dim * n_components),
        )

    def forward(self, x):
        params = self.network(x)
        mean, sd = torch.split(params, params.shape[1] // 2, dim=1)
        mean = torch.stack(mean.split(mean.shape[1] // self.n_components, 1))
        sd = torch.stack(sd.split(sd.shape[1] // self.n_components, 1))
        # improve stability by using elu instead of exp
        sd = F.elu(sd) + 1 + 1e-15
        return Normal(mean.transpose(0, 1), sd.transpose(0, 1))


class CategoricalNetwork(nn.Module):

    def __init__(self, in_dim, out_dim, hidden_dim=None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = in_dim
        self.network = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        params = self.network(x)
        return OneHotCategorical(logits=params)
