# AUTOGENERATED! DO NOT EDIT! File to edit: ../12_DM1.ipynb.

# %% auto 0
__all__ = ['BasicConvNet', 'PreNormResidual', 'FeedForward', 'MLPMixer', 'BasicUNet', 'NoiseConditionedUNet',
           'NoiseAndClassConditionedUNet']

# %% ../12_DM1.ipynb 4
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from .data_utils import get_mnist_dl

# %% ../12_DM1.ipynb 18
#|code-fold: true
class BasicConvNet(nn.Module):
    """A stack of conv layers with padding to keep the output the same size as 
    the input. Hidden channel numbers fixed at: [16, 32, 64, 64, 16].
    Args: in_channels, out_channels,kernel_size=5."""
    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        padding = kernel_size // 2 # So we keep output size the same
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(64, 16, kernel_size,  padding=padding),
            nn.ReLU(),
            nn.Conv2d(16, out_channels, kernel_size, padding=padding),
        )

    def forward(self, x):
        return self.net(x)

# %% ../12_DM1.ipynb 19
#|code-fold: true
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    inner_dim = int(dim * expansion_factor)
    return nn.Sequential(
        dense(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(inner_dim, dim),
        nn.Dropout(dropout)
    )

def MLPMixer(*, image_size, channels, patch_size, dim, depth, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
    """A minimal MLP Mixer stolen from lucidrain's implementation."""
    # Get image width and height (same if image_size isn't a tuple):
    pair = lambda x: x if isinstance(x, tuple) else (x, x)
    image_h, image_w = pair(image_size)
    # Check they divide neatly by patch_size
    assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_h // patch_size) * (image_w // patch_size)
    # Prep the two layers
    chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
    # Return the model (a stack of [FeedForward(chan_first), FeedForward(chan_last)] pairs
    # with layer norm on the inputs and a skip connection thanks to PreNormResidual)
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear((patch_size ** 2) * channels, dim),
        *[nn.Sequential(
            PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
            PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
        ) for _ in range(depth)],
        Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h = int(image_h/patch_size),  w = int(image_w/patch_size), p1 = patch_size, p2 = patch_size),
        nn.Conv2d(dim//(patch_size**2), channels, kernel_size=1) # Back to right number of channels
    )

# %% ../12_DM1.ipynb 20
#|code-fold: true
class BasicUNet(nn.Module):
    """A minimal UNet implementation."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        self.act = nn.SiLU()
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            h.append(x)
            if i < 2: x = self.downscale(x)
        for i, l in enumerate(self.up_layers):
            if i > 0: x = self.upscale(x)
            x += h.pop()
            x = self.act(l(x))
        return x

# %% ../12_DM1.ipynb 39
#|code-fold: true
class NoiseConditionedUNet(nn.Module):
    """Wraps a BasicUNet but adds an extra input channel for the conditioning."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = BasicUNet(in_channels+1, out_channels)

    def forward(self, x, noise_amount):
        # Shape of x
        bs, ch, w, h = x.shape
        
        # Get noise_amount as a single channel 'image' the same shape as x
        if not torch.is_tensor(noise_amount):
            noise_amount = x.new_full((x.size(0),), noise_amount)
        noise_amount = noise_amount.view(-1, 1, 1, 1).expand(bs, 1, w, h) # If x.shape is [8,3,28,28] noise_amount is [8,1,28, 28]
        
        # Concatenate this onto x to get the final net input:
        net_input = torch.cat((x, noise_amount), 1)
        
        # Now pass through the net to get the prediction as before
        return self.net(net_input)

# %% ../12_DM1.ipynb 50
#|code-fold: show
class NoiseAndClassConditionedUNet(nn.Module):
    """Wraps a BasicUNet but adds an extra input channel for the conditioning and several input 
    channels for class conditioning. An nn.Embedding layer maps num_classes to class_emb_channels."""
    def __init__(self, in_channels, out_channels, num_classes=10, class_emb_channels=4):
        super().__init__()
        self.class_emb = nn.Embedding(num_classes, class_emb_channels) # Map num_classes discrete classes to class_emb_channels numbers
        self.net = BasicUNet(in_channels+1+class_emb_channels, out_channels) # Note input channels = in_channels+1+class_emb_channels

    def forward(self, x, noise_amount, class_labels):
        # Shape of x:
        bs, ch, w, h = x.shape
        
        # Get noise_amount the same shape as x
        if not torch.is_tensor(noise_amount):
            noise_amount = x.new_full((x.size(0),), noise_amount)
        noise_amount = noise_amount.view(-1, 1, 1, 1).expand(bs, 1, w, h)
        
        # And the class cond
        class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) # Reshape
        
        # Net input is now x, noise amound and class cond concatenated together
        net_input = torch.cat((x, noise_amount, class_cond), 1)
        return self.net(net_input)
