#!/usr/bin/env python3


import torch
import torch.nn as nn

from collections import OrderedDict


def get_op(op_name: str, C_in: int, stride: int = 1, affine: bool = True, p: float = None,
           C_out: int = None):
    C_out = C_out if C_out is not None else C_in
    if op_name == 'none':
        return Zero(stride)
    elif op_name == 'skip_connect' and stride == 1:
        return nn.Identity()
    else:
        seq = nn.Sequential()
        if 'pool' not in op_name and 'sep_conv' not in op_name and 'dil_conv' not in op_name:
            seq.add_module('relu', nn.ReLU())

        if op_name == 'conv':
            seq.add_module('conv', nn.Conv2d(C_in, C_out, 1, stride, 0, bias=False))
        elif op_name == 'avg_pool_3x3':
            seq.add_module('pool', nn.AvgPool2d(3, stride, 1, count_include_pad=False))
        elif op_name == 'max_pool_3x3':
            seq.add_module('pool', nn.MaxPool2d(3, stride, 1))
        elif op_name in ['skip_connect', 'factorized_reduce']:
            seq.add_module('reduce', FactorizedReduce(C_in, C_out))
        elif op_name == 'sep_conv_3x3':
            seq.add_module('dil_conv1', DilConv(C_in, C_out, 3, stride, 1, dilation=1, affine=affine))
            seq.add_module('dil_conv2', DilConv(C_in, C_out, 3, 1, 1, dilation=1, affine=affine))
        elif op_name == 'sep_conv_5x5':
            seq.add_module('dil_conv1', DilConv(C_in, C_out, 5, stride, 2, dilation=1, affine=affine))
            seq.add_module('dil_conv2', DilConv(C_in, C_out, 5, 1, 2, dilation=1, affine=affine))
        elif op_name == 'sep_conv_7x7':
            seq.add_module('dil_conv1', DilConv(C_in, C_out, 7, stride, 3, dilation=1, affine=affine))
            seq.add_module('dil_conv2', DilConv(C_in, C_out, 7, 1, 3, dilation=1, affine=affine))
        elif op_name == 'dil_conv_3x3':
            seq = DilConv(C_in, C_out, 3, stride, 2, dilation=2, affine=affine)
        elif op_name == 'dil_conv_5x5':
            seq = DilConv(C_in, C_out, 5, stride, 4, dilation=2, affine=affine)
        elif op_name == 'conv_7x1_1x7':
            seq.add_module('conv1', nn.Conv2d(C_in, C_out, (7, 1), stride, 3, bias=False))
            seq.add_module('conv2', nn.Conv2d(C_in, C_out, (1, 7), stride, 3, bias=False))

        if 'pool' not in op_name and 'sep_conv' not in op_name and 'dil_conv' not in op_name:
            seq.add_module('bn', nn.BatchNorm2d(C_out, affine=affine))
        if p is not None:
            seq.add_module('dropout', nn.Dropout(p=p))
        return seq


def DilConv(C_in: int, C_out: int, kernel_size: int, stride: int, padding: int, dilation: int, affine: bool = True):
    """ (Dilated) depthwise separable conv
    ReLU - (Dilated) depthwise separable - Pointwise - BN
    If dilation == 2, 3x3 conv => 5x5 receptive field
                      5x5 conv => 9x9 receptive field
    """
    return nn.Sequential(OrderedDict([
        ('relu', nn.ReLU()),
        ('conv1', nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False)),
        ('conv2', nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False)),
        ('bn', nn.BatchNorm2d(C_out, affine=affine))
    ]))


class Zero(nn.Module):
    def __init__(self, stride: int = 1):
        super().__init__()
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul(0.) if self.stride == 1 else x[:, :, ::self.stride, ::self.stride].mul(0.)


class FactorizedReduce(nn.Module):
    """ Reduce feature map size by factorized pointwise(stride=2). """

    def __init__(self, C_in: int, C_out: int):
        super().__init__()
        C_out_1 = C_out // 2
        self.conv1 = nn.Conv2d(C_in, C_out_1, 1, stride=2, padding=0, bias=False)
        self.conv2 = nn.Conv2d(C_in, C_out - C_out_1, 1, stride=2, padding=0, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
