# https://pytorch.org/vision/main/_modules/torchvision/models/efficientnet.html
# lightly modified from standard pytorch implementation

import copy
import math
from functools import partial
from typing import Callable, Optional, List, Sequence

import torch
from torch import nn, Tensor

from torchvision._internally_replaced_utils import load_state_dict_from_url
from torchvision.ops.misc import ConvNormActivation

from torchvision.models.efficientnet import MBConvConfig, MBConv


class EfficientNet(nn.Module):  # could make lightning, but I think it's clearer to do that one level up
    def __init__(
        self,
        inverted_residual_setting: List[MBConvConfig],
        dropout: float,
        include_top: bool,
        input_channels: int = 3,
        stochastic_depth_prob: float = 0.2,
        num_classes: int = 1000,
        block: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        """
        EfficientNet main class

        Args:
            inverted_residual_setting (List[MBConvConfig]): Network structure
            dropout (float): The dropout probability. Only used for head, and I usually only have headless effnet, so rarely used
            stochastic_depth_prob (float): The stochastic depth probability
            num_classes (int): Number of classes
            block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
        """
        super().__init__()
        # _log_api_usage_once(self)

        if not inverted_residual_setting:
            raise ValueError("The inverted_residual_setting should not be empty")
        elif not (
            isinstance(inverted_residual_setting, Sequence)
            and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])
        ):
            raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")

        if block is None:
            block = MBConv

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        layers: List[nn.Module] = []

        # building first layer
        firstconv_output_channels = inverted_residual_setting[0].input_channels
        layers.append(
            ConvNormActivation(
                # added input_channels as an
                input_channels, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
            )
        )

        # building inverted residual blocks
        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
        stage_block_id = 0
        for cnf in inverted_residual_setting:
            stage: List[nn.Module] = []
            for _ in range(cnf.num_layers):
                # copy to avoid modifications. shallow copy is enough
                block_cnf = copy.copy(cnf)

                # overwrite info if not the first conv in the stage
                if stage:
                    block_cnf.input_channels = block_cnf.out_channels
                    block_cnf.stride = 1

                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks

                stage.append(block(block_cnf, sd_prob, norm_layer))
                stage_block_id += 1

            layers.append(nn.Sequential(*stage))

        # building last several layers
        lastconv_input_channels = inverted_residual_setting[-1].out_channels
        lastconv_output_channels = 4 * lastconv_input_channels
        layers.append(
            ConvNormActivation(
                lastconv_input_channels,
                lastconv_output_channels,
                kernel_size=1,
                norm_layer=norm_layer,
                activation_layer=nn.SiLU,
            )
        )

        self.features = nn.Sequential(*layers)
        # pytorch version includes the pooling outside of include_top
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        if include_top:  # should never be true for Zoobot
            self.head = nn.Sequential(
                # TODO replace with PermaDropout via self.training
                nn.Dropout(p=dropout, inplace=True),
                nn.Linear(lastconv_output_channels, num_classes),
            )
        else:
            self.head = nn.Identity()  # does nothing

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / math.sqrt(m.out_features)
                nn.init.uniform_(m.weight, -init_range, init_range)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.head(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _efficientnet(
    arch: str,
    width_mult: float,
    depth_mult: float,
    dropout: float,
    use_imagenet_weights: bool,
    include_top: bool,
    input_channels: int,
    stochastic_depth_prob: float,
    progress: bool
) -> EfficientNet:
    bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
    # ratio, kernel, stride, input channels, output channels, layers, width/depth kwargs
    inverted_residual_setting = [
        bneck_conf(1, 3, 1, 32, 16, 1),
        bneck_conf(6, 3, 2, 16, 24, 2),
        bneck_conf(6, 5, 2, 24, 40, 2),
        bneck_conf(6, 3, 2, 40, 80, 3),
        bneck_conf(6, 5, 1, 80, 112, 3),
        bneck_conf(6, 5, 2, 112, 192, 4),
        bneck_conf(6, 3, 1, 192, 320, 1),
    ]
    model = EfficientNet(
        inverted_residual_setting, dropout, include_top, input_channels, stochastic_depth_prob)
    if use_imagenet_weights:
        assert include_top  # otherwise not sure if weights will load as I've changed code
        if model_urls.get(arch, None) is None:
            raise ValueError(f"No checkpoint is available for model type {arch}")
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model


def efficientnet_b0(
    input_channels,
    stochastic_depth_prob: float = 0.2,
    use_imagenet_weights: bool = False,
    include_top: bool = True,
    progress: bool = True) -> EfficientNet:
    """
    Constructs a EfficientNet B0 architecture from
    `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_.

    Args:
        use_imagenet_weights (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    # added include_top and input_channels, renamed pretrained to use_imagenet_weights
    return _efficientnet(
        arch="efficientnet_b0",
        width_mult=1.0,
        depth_mult=1.0,
        dropout=0.2,
        stochastic_depth_prob=stochastic_depth_prob,
        use_imagenet_weights=use_imagenet_weights,
        include_top=include_top,
        input_channels=input_channels,
        progress=progress)


def efficientnet_b2(
    input_channels,
    stochastic_depth_prob: float = 0.2,
    use_imagenet_weights: bool = False,
    include_top: bool = True,
    progress: bool = True) -> EfficientNet:
    """
    See efficientnet_b0, identical other than multipliers and dropout
    """
    # added include_top and input_channels, renamed pretrained to use_imagenet_weights
    assert not use_imagenet_weights
    return _efficientnet(
        arch="efficientnet_b2",
        width_mult=1.1,
        depth_mult=1.2,
        dropout=0.3,
        stochastic_depth_prob=stochastic_depth_prob,  # I added as an arg, for extra hparam to search (optionally - default=0.2)
        use_imagenet_weights=use_imagenet_weights,  # will likely fail
        include_top=include_top,
        input_channels=input_channels,
        progress=progress)


def efficientnet_b4(
    input_channels,
    stochastic_depth_prob: float = 0.2,
    use_imagenet_weights: bool = False,
    include_top: bool = True,
    progress: bool = True) -> EfficientNet:
    """
    See efficientnet_b0, identical other than multipliers and dropout
    """
    # added include_top and input_channels, renamed pretrained to use_imagenet_weights
    assert not use_imagenet_weights
    return _efficientnet(
        arch="efficientnet_b4",
        width_mult=1.4,
        depth_mult=1.8,
        dropout=0.4,
        stochastic_depth_prob=stochastic_depth_prob,
        use_imagenet_weights=use_imagenet_weights,  # will likely fail
        include_top=include_top,
        input_channels=input_channels,
        progress=progress)


# TODO efficientnet_v2_*, perhaps?

model_urls = {
    # Weights ported from https://github.com/rwightman/pytorch-image-models/
    "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
    "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
    "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
    "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
    "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
    # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
    "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
    "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
    "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
}
