#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from abc import abstractmethod

import torch
from torch import nn
from torch.nn.init import kaiming_normal_

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
           """

from draugr.torch_utilities import to_tensor

__all__ = ["VAE"]


class VAE(torch.nn.Module):
    """ """

    class View(nn.Module):
        """ """

        def __init__(self, size):
            super().__init__()
            self.size = size

        def forward(self, tensor):
            """

            Args:
              tensor:

            Returns:

            """
            return tensor.reshape(self.size)

    @staticmethod
    def kaiming_init(m):
        """

        Args:
          m:
        """
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            kaiming_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.weight.data.fill_(1)
            if m.bias is not None:
                m.bias.data.fill_(0)

    @staticmethod
    def normal_init(m, mean, std):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            m.weight.data.normal_(mean, std)
            if m.bias.data is not None:
                m.bias.data.zero_()
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            m.weight.data.fill_(1)
            if m.bias.data is not None:
                m.bias.data.zero_()

    def weight_init(self):
        """ """
        for m in self.modules():
            self.kaiming_init(m)

    def __init__(self, latent_size=10):
        super().__init__()
        self._latent_size = latent_size

    @abstractmethod
    def encode(self, *x: torch.Tensor) -> torch.Tensor:
        """

        Args:
          *x:
        """
        raise NotImplementedError

    @abstractmethod
    def decode(self, *x: torch.Tensor) -> torch.Tensor:
        """

        Args:
          *x:
        """
        raise NotImplementedError

    def sample(self, *x, num=1) -> torch.Tensor:
        """

        :param x:
        :type x:
        :param num:
        :type num:
        :return:
        :rtype:"""
        z = torch.randn(num, self._latent_size).to(
            device=next(self.parameters()).device
        )
        return self.decode(z, *x).to("cpu")

    @staticmethod
    def reparameterise(mean, log_var) -> torch.Tensor:
        """

        reparameterisation trick

        :param mean:
        :type mean:
        :param log_var:
        :type log_var:
        :return:
        :rtype:"""
        std = log_var.div(2).exp()  # e^(1/2 * log(std^2))
        eps = torch.randn_like(std)  # random ~ N(0, 1)
        return eps.mul(std).add_(mean)  # Reparameterise distribution

    def sample_from(self, *encoding) -> torch.Tensor:
        """

        Args:
          *encoding:

        Returns:

        """
        sample = to_tensor(*encoding).to(device=next(self.parameters()).device)
        assert sample.shape[-1] == self._latent_size, (
            f"sample.shape[-1]:{sample.shape[-1]} !="
            f" self._encoding_size:{self._latent_size}"
        )
        sample = self.decode(*sample).to("cpu")
        return sample
