import math
import torch
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]


class PositionalEncoding2D(nn.Module):
    """
    2D positional encoding for pytorch-geometric Data objects
    """
    def __init__(self, d_model, max_height=10000, max_width=10000):
        super(PositionalEncoding2D, self).__init__()
        assert d_model % 4 == 0
        self.max_height = max_height
        self.max_width = max_width
        self.x_pe = PositionalEncoding(d_model // 2, max_width)
        self.y_pe = PositionalEncoding(d_model // 2, max_height)

    def forward(self, data):
        pos_x = data.pos[:, 0] % self.max_width
        pos_y = data.pos[:, 1] % self.max_height
        x_encoding = self.x_pe.pe[pos_x, :].squeeze(1)
        y_encoding = self.y_pe.pe[pos_y, :].squeeze(1)
        data.x += torch.cat((x_encoding, y_encoding), dim=1)
        return data
