import logging
from ..static import LOGGER_PRINT_LEVEL
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(LOGGER_PRINT_LEVEL)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class Attention(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.linear_out = nn.Linear(dim*2, dim)

    def forward(self, output, context):
        batch_size = output.size(0)
        hidden_size = output.size(2)
        input_size = context.size(1)

        # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len)
        attn = torch.bmm(output, context.transpose(1, 2))
        mask = torch.eq(attn, 0).data.byte()
        attn.data.masked_fill_(mask, -float('inf'))
        attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)

        # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim)
        mix = torch.bmm(attn, context)

        # concat -> (batch, out_len, 2*dim)
        combined = torch.cat((mix, output), dim=2)

        # output -> (batch, out_len, dim)
        output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size)

        if not output.is_contiguous():
            output = output.contiguous()

        return output, attn