import torch
import torch.nn as nn
import torch.nn.functional as F
from . import mobilenetv2, resnet
from .utils import IntermediateLayerGetter, _SimpleSegmentationModel


class DeepLabV3Impl(_SimpleSegmentationModel):
  """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

    Arguments:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
  pass


class DeepLabHeadV3Plus(nn.Module):

  def __init__(self,
               in_channels,
               low_level_channels,
               num_classes,
               aspp_dilate=[12, 24, 36]):
    super(DeepLabHeadV3Plus, self).__init__()
    self.project = nn.Sequential(
        nn.Conv2d(low_level_channels, 48, 1, bias=False),
        nn.BatchNorm2d(48),
        nn.ReLU(inplace=True),
    )

    self.aspp = ASPP(in_channels, aspp_dilate)

    self.classifier = nn.Sequential(
        nn.Conv2d(304, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256),
        nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, 1))
    self._init_weight()

  def forward(self, feature):
    low_level_feature = self.project(feature['low_level'])
    output_feature = self.aspp(feature['out'])
    output_feature = F.interpolate(output_feature,
                                   size=low_level_feature.shape[2:],
                                   mode='bilinear',
                                   align_corners=False)
    return self.classifier(torch.cat([low_level_feature, output_feature],
                                     dim=1))

  def _init_weight(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class DeepLabHead(nn.Module):

  def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
    super(DeepLabHead, self).__init__()

    self.classifier = nn.Sequential(
        ASPP(in_channels, aspp_dilate),
        nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256),
        nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, 1))
    self._init_weight()

  def forward(self, feature):
    return self.classifier(feature['out'])

  def _init_weight(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class AtrousSeparableConvolution(nn.Module):
  """ Atrous Separable Convolution
    """

  def __init__(self,
               in_channels,
               out_channels,
               kernel_size,
               stride=1,
               padding=0,
               dilation=1,
               bias=True):
    super(AtrousSeparableConvolution, self).__init__()
    self.body = nn.Sequential(
        # Separable Conv
        nn.Conv2d(in_channels,
                  in_channels,
                  kernel_size=kernel_size,
                  stride=stride,
                  padding=padding,
                  dilation=dilation,
                  bias=bias,
                  groups=in_channels),
        # PointWise Conv
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size=1,
                  stride=1,
                  padding=0,
                  bias=bias),
    )

    self._init_weight()

  def forward(self, x):
    return self.body(x)

  def _init_weight(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class ASPPConv(nn.Sequential):

  def __init__(self, in_channels, out_channels, dilation):
    modules = [
        nn.Conv2d(in_channels,
                  out_channels,
                  3,
                  padding=dilation,
                  dilation=dilation,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    ]
    super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):

  def __init__(self, in_channels, out_channels):
    super(ASPPPooling,
          self).__init__(nn.AdaptiveAvgPool2d(1),
                         nn.Conv2d(in_channels, out_channels, 1, bias=False),
                         nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))

  def forward(self, x):
    size = x.shape[-2:]
    x = super(ASPPPooling, self).forward(x)
    return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):

  def __init__(self, in_channels, atrous_rates):
    super(ASPP, self).__init__()
    out_channels = 256
    modules = []
    modules.append(
        nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
                      nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)))

    rate1, rate2, rate3 = tuple(atrous_rates)
    modules.append(ASPPConv(in_channels, out_channels, rate1))
    modules.append(ASPPConv(in_channels, out_channels, rate2))
    modules.append(ASPPConv(in_channels, out_channels, rate3))
    modules.append(ASPPPooling(in_channels, out_channels))

    self.convs = nn.ModuleList(modules)

    self.project = nn.Sequential(
        nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Dropout(0.1),
    )

  def forward(self, x):
    res = []
    for conv in self.convs:
      res.append(conv(x))
    res = torch.cat(res, dim=1)
    return self.project(res)


def convert_to_separable_conv(module):
  new_module = module
  if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1:
    new_module = AtrousSeparableConvolution(module.in_channels,
                                            module.out_channels,
                                            module.kernel_size, module.stride,
                                            module.padding, module.dilation,
                                            module.bias)
  for name, child in module.named_children():
    new_module.add_module(name, convert_to_separable_conv(child))
  return new_module


def _segm_resnet(name, backbone_name, num_classes, output_stride,
                 pretrained_backbone):

  if output_stride == 8:
    replace_stride_with_dilation = [False, True, True]
    aspp_dilate = [12, 24, 36]
  else:
    replace_stride_with_dilation = [False, False, True]
    aspp_dilate = [6, 12, 18]

  backbone = resnet.__dict__[backbone_name](
      pretrained=pretrained_backbone,
      replace_stride_with_dilation=replace_stride_with_dilation)

  inplanes = 2048
  low_level_planes = 256

  if name == 'deeplabv3plus':
    return_layers = {'layer4': 'out', 'layer1': 'low_level'}
    classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes,
                                   aspp_dilate)
  elif name == 'deeplabv3':
    return_layers = {'layer4': 'out'}
    classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
  backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

  model = DeepLabV3Impl(backbone, classifier)
  return model


def _segm_mobilenet(name, backbone_name, num_classes, output_stride,
                    pretrained_backbone):
  if output_stride == 8:
    aspp_dilate = [12, 24, 36]
  else:
    aspp_dilate = [6, 12, 18]

  backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone,
                                      output_stride=output_stride)

  # rename layers
  backbone.low_level_features = backbone.features[0:4]
  backbone.high_level_features = backbone.features[4:-1]
  backbone.features = None
  backbone.classifier = None

  inplanes = 320
  low_level_planes = 24

  if name == 'deeplabv3plus':
    return_layers = {
        'high_level_features': 'out',
        'low_level_features': 'low_level'
    }
    classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes,
                                   aspp_dilate)
  elif name == 'deeplabv3':
    return_layers = {'high_level_features': 'out'}
    classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
  backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

  model = DeepLabV3Impl(backbone, classifier)
  return model


def _load_model(arch_type, backbone, num_classes, output_stride,
                pretrained_backbone):

  if backbone == 'mobilenetv2':
    model = _segm_mobilenet(arch_type,
                            backbone,
                            num_classes,
                            output_stride=output_stride,
                            pretrained_backbone=pretrained_backbone)
  elif backbone.startswith('resnet'):
    model = _segm_resnet(arch_type,
                         backbone,
                         num_classes,
                         output_stride=output_stride,
                         pretrained_backbone=pretrained_backbone)
  else:
    raise NotImplementedError
  return model


# Deeplab v3


def deeplabv3_resnet50(num_classes=21,
                       output_stride=8,
                       pretrained_backbone=True):
  """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3',
                     'resnet50',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)


def deeplabv3_resnet101(num_classes=21,
                        output_stride=8,
                        pretrained_backbone=True):
  """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3',
                     'resnet101',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)


def deeplabv3_mobilenet(num_classes=21,
                        output_stride=8,
                        pretrained_backbone=True,
):
  """Constructs a DeepLabV3 model with a MobileNetv2 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3',
                     'mobilenetv2',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)


# Deeplab v3+


def deeplabv3plus_resnet50(num_classes=21,
                           output_stride=8,
                           pretrained_backbone=True):
  """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3plus',
                     'resnet50',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)


def deeplabv3plus_resnet101(num_classes=21,
                            output_stride=8,
                            pretrained_backbone=True):
  """Constructs a DeepLabV3+ model with a ResNet-101 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3plus',
                     'resnet101',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)


def deeplabv3plus_mobilenet(num_classes=21,
                            output_stride=8,
                            pretrained_backbone=True):
  """Constructs a DeepLabV3+ model with a MobileNetv2 backbone.

    Args:
        num_classes (int): number of classes.
        output_stride (int): output stride for deeplab.
        pretrained_backbone (bool): If True, use the pretrained backbone.
    """
  return _load_model('deeplabv3plus',
                     'mobilenetv2',
                     num_classes,
                     output_stride=output_stride,
                     pretrained_backbone=pretrained_backbone)
