import torch
from torch import nn
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


class TransformerLM(nn.Module):
    '''Adapted from PyTorch-examples
    '''
    def __init__(self, vocab_size, nb_heads, dim_res, dim_ff, nb_layers, dropout):
        super().__init__()
        self.src_mask = None
        self.encoder = nn.Embedding(vocab_size, dim_res)
        self.pos_encoder = PositionalEncoding(dim_res, dropout)

        layer_norm = nn.LayerNorm(dim_res)
        encoder_layers = TransformerEncoderLayer(dim_res, nb_heads, dim_ff, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nb_layers, layer_norm)
        self.dim_res = dim_res

        self.init_weights()
        self.in_len = 1
        self.batch_first = True

    def init_hidden(self, batch_size):
        w = next(self.parameters())
        return torch.zeros((batch_size, 1), dtype=w.dtype, device=w.device)

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, hidden=None):
        src = src.t()
        device = src.device
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            mask = generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.dim_res)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)

        return output.permute(1, 0, 2), hidden
