"""VGG for CIFAR10/100

Reference:
Simonyan, Karen, et al. Very Deep Convolutional Networks for Large-Scale Image Recognition (ICLR '15)
"""

import torch.nn as nn

cfg = {
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


class VGG(nn.Module):
    def __init__(self, dataset, features):
        super().__init__()
        self.features = features

        if dataset == "cifar10":
            num_classes = 10
        elif dataset == "cifar100":
            num_classes = 100
        else:
            raise ValueError("Incorrect Dataset Input.")

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output


def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)


def vgg11(Dataset):
    return VGG(Dataset, make_layers(cfg["A"], batch_norm=True))


def vgg13(Dataset):
    return VGG(Dataset, make_layers(cfg["B"], batch_norm=True))


def vgg16(Dataset):
    return VGG(Dataset, make_layers(cfg["D"], batch_norm=True))


def vgg19(Dataset):
    return VGG(Dataset, make_layers(cfg["E"], batch_norm=True))
