# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
import math

import torch
from torch import nn

from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
  GaussianDiffuser,
)
from diffusion_models.models.base_diffusion_model import BaseDiffusionModel


class Block(nn.Module):
  def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
    super().__init__()
    self.time_mlp = nn.Linear(time_emb_dim, out_ch)
    if up:
      self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
      self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
    else:
      self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
      self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
    self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
    self.bnorm1 = nn.BatchNorm2d(out_ch)
    self.bnorm2 = nn.BatchNorm2d(out_ch)
    self.relu = nn.ReLU()

  def forward(
    self,
    x,
    t,
  ):
    # First Conv
    h = self.bnorm1(self.relu(self.conv1(x)))
    # Time embedding
    time_emb = self.relu(self.time_mlp(t))
    # Extend last 2 dimensions
    time_emb = time_emb[(...,) + (None,) * 2]
    # Add time channel
    h = h + time_emb
    # Second Conv
    h = self.bnorm2(self.relu(self.conv2(h)))
    # Down or Upsample
    return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.dim = dim

  def forward(self, time):
    device = time.device
    half_dim = self.dim // 2
    embeddings = math.log(10000) / (half_dim - 1)
    embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
    embeddings = time[:, None] * embeddings[None, :]
    embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
    # TODO: Double check the ordering here
    return embeddings


class SimpleUnet(BaseDiffusionModel):
  def __init__(
    self,
    diffuser: GaussianDiffuser,
    image_channels: int,
  ):
    """A Simplified variant of the Unet architecture used in DDPM.

    Args:
      diffuser: A gaussian diffuser.
      image_channels: The number of image channels.
    """
    super().__init__(diffuser=diffuser)
    image_channels = image_channels
    down_channels = (64, 128, 256, 512, 1024)
    up_channels = (1024, 512, 256, 128, 64)
    out_dim = image_channels
    time_emb_dim = 32

    # Time embedding
    self.time_mlp = nn.Sequential(
      SinusoidalPositionEmbeddings(time_emb_dim),
      nn.Linear(time_emb_dim, time_emb_dim),
      nn.ReLU(),
    )

    # Initial projection
    self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

    # Downsample
    self.downs = nn.ModuleList(
      [
        Block(down_channels[i], down_channels[i + 1], time_emb_dim)
        for i in range(len(down_channels) - 1)
      ]
    )
    # Upsample
    self.ups = nn.ModuleList(
      [
        Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True)
        for i in range(len(up_channels) - 1)
      ]
    )

    # Edit: Corrected a bug found by Jakub C (see YouTube comment)
    self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

  def forward(self, x, timestep):
    # if self.training:
    #   x = self.augmentations(x)

    # Embedd time
    t = self.time_mlp(timestep)
    # Initial conv
    x = self.conv0(x)
    # Unet
    residual_inputs = []
    for down in self.downs:
      x = down(x, t)
      residual_inputs.append(x)
    for up in self.ups:
      residual_x = residual_inputs.pop()
      # Add residual x as additional channels
      x = torch.cat((x, residual_x), dim=1)
      x = up(x, t)
    return self.output(x)
