"""This is one general script. For different data, you should re-write this and tune."""
from __future__ import print_function, division

from typing import Union, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Linear, BatchNorm1d, Module, ModuleList
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import PairTensor, OptTensor

from pyg_extension.nn.basemodel import BaseCrystalModel
from pyg_extension.nn.general import lift_jump_index_select


class CGConv(MessagePassing):
    r"""The crystal graph convolutional operator from the
    `"Crystal Graph Convolutional Neural Networks for an
    Accurate and Interpretable Prediction of Material Properties"
    <https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`_
    paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)}
        \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right)
        \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s  \right)

    where :math:`\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j,
    \mathbf{e}_{i,j} ]` denotes the concatenation of central node features,
    neighboring node features and edge features.
    In addition, :math:`\sigma` and :math:`g` denote the sigmoid and softplus
    functions, respectively.

    Args:
        channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        dim (int, optional): Edge feature dimensionality. (default: :obj:`0`)
        aggr (string, optional): The aggregation operator to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"add"`)
        batch_norm (bool, optional): If set to :obj:`True`, will make use of
            batch normalization. (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
          node features :math:`(|\mathcal{V}|, F)` or
          :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
          if bipartite,
          edge indices :math:`(2, |\mathcal{E}|)`,
          edge features :math:`(|\mathcal{E}|, D)` *(optional)*
        - **output:** node features :math:`(|\mathcal{V}|, F)` or
          :math:`(|\mathcal{V_t}|, F_{t})` if bipartite
    """

    def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0,
                 aggr: str = 'add', batch_norm: bool = False,
                 bias: bool = True, **kwargs):
        super().__init__(aggr=aggr, **kwargs)
        self.channels = channels
        self.dim = dim
        self.batch_norm = batch_norm

        if isinstance(channels, int):
            channels = (channels, channels)

        self.lin_f = Linear(sum(channels) + dim, channels[1], bias=bias)
        self.lin_s = Linear(sum(channels) + dim, channels[1], bias=bias)
        if batch_norm:
            self.bn = BatchNorm1d(channels[1])
        else:
            self.bn = None

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_f.reset_parameters()
        self.lin_s.reset_parameters()
        if self.bn is not None:
            self.bn.reset_parameters()

    def message(self, x_i, x_j, edge_attr: OptTensor) -> Tensor:
        if edge_attr is None:
            z = torch.cat([x_i, x_j], dim=-1)
        else:
            z = torch.cat([x_i, x_j, edge_attr], dim=-1)
        return self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.channels}, dim={self.dim})'

    def forward(self, x: Union[Tensor, PairTensor], data) -> Tensor:
        """"""
        edge_index = data.edge_index if data.edge_index is not None else data.adj_t
        out = self.propagate(edge_index, x=x, edge_attr=data.edge_attr)
        out = out if self.bn is None else self.bn(out)
        return out

    def __lift__(self, src, edge_index, dim):
        return lift_jump_index_select(self, src, edge_index, dim)


class GeoResPack(Module):
    """This layer just add resnet, and change x shape."""

    def __init__(self, m, nc_node_hidden=64, nc_node_interaction=64, nc_edge_hidden=3,
                 n_res=1, **kwargs):
        super().__init__()

        self.res_lin0 = Linear(nc_node_hidden, nc_node_interaction)
        self.res_layer = ModuleList()

        for _ in range(n_res):
            cg = m(nc_node_interaction, nc_edge_hidden,
                   aggr='mean', bias=True, **kwargs)
            self.res_layer.append(cg)

        self.n_res = n_res

    def reset_parameters(self):
        self.res_lin0.reset_parameters()
        self.res_layer.reset_parameters()

    def forward(self, h, data):

        out = self.res_lin0(h)

        for convi in self.res_layer:
            out = out + F.relu(convi(x=out, data=data))

        return out


class CrystalGraphConvNeuralNet(BaseCrystalModel):
    """
    CrystalGraph.
    """

    def __init__(self, *args, nfeat_edge=3, nc_node_interaction=16,
                 nc_node_hidden=8,
                 **kwargs):
        super(CrystalGraphConvNeuralNet, self).__init__(*args,
                                                        nfeat_edge=nfeat_edge,
                                                        nc_node_interaction=nc_node_interaction,
                                                        nc_node_hidden=nc_node_hidden, **kwargs)
        self.nfeat_state = None  # not used for this network.

    def get_interactions_layer(self):
        self.layer_interaction = GeoResPack(CGConv,
                                            self.nc_node_hidden,
                                            self.nc_node_interaction,
                                            nc_edge_hidden=self.nc_edge_hidden,
                                            batch_norm=True,
                                            n_res=self.num_interactions,
                                            **self.interaction_kwargs)


CGCNN = CrystalGraphConvNeuralNet
