# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/102b_models.InceptionTimePlus.ipynb (unless otherwise specified).

__all__ = ['InceptionModulePlus', 'InceptionBlockPlus', 'InceptionTimePlus', 'InCoordTime', 'XCoordTime',
           'InceptionTimePlus17x17', 'InceptionTimePlus32x32', 'InceptionTimePlus47x47', 'InceptionTimePlus62x62',
           'MultiInceptionTimePlus']

# Cell
from ..imports import *
from ..utils import *
from .layers import *
from .utils import *
torch.set_num_threads(cpus)

# Cell
# This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com modified from:

# Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J., ... & Petitjean, F. (2019).
# InceptionTime: Finding AlexNet for Time Series Classification. arXiv preprint arXiv:1909.04939.
# Official InceptionTime tensorflow implementation: https://github.com/hfawaz/InceptionTime


class InceptionModulePlus(Module):
    def __init__(self, ni, nf, ks=40, bottleneck=True, padding='same', coord=False, separable=False, dilation=1, stride=1, conv_dropout=0., sa=False, se=None,
                 norm='Batch', zero_norm=False, bn_1st=True, act=nn.ReLU, act_kwargs={}):
        if isinstance(ks, Integral): ks = [ks // (2**i) for i in range(3)]
        ks = [ksi if ksi % 2 != 0 else ksi - 1 for ksi in ks]  # ensure odd ks for padding='same'
        bottleneck = False if ni == nf else bottleneck
        self.bottleneck = Conv(ni, nf, 1, coord=coord, bias=False) if bottleneck else noop #
        self.convs = nn.ModuleList()
        for i in range(len(ks)): self.convs.append(Conv(nf if bottleneck else ni, nf, ks[i], padding=padding, coord=coord, separable=separable,
                                                         dilation=dilation**i, stride=stride, bias=False))
        self.mp_conv = nn.Sequential(*[nn.MaxPool1d(3, stride=1, padding=1), Conv(ni, nf, 1, coord=coord, bias=False)])
        self.concat = Concat()
        self.norm = Norm(nf * 4, norm=norm, zero_norm=zero_norm)
        self.conv_dropout = nn.Dropout(conv_dropout) if conv_dropout else noop
        self.sa = SimpleSelfAttention(nf * 4) if sa else noop
        self.act = act(**act_kwargs) if act else noop
        self.se = nn.Sequential(SqueezeExciteBlock(nf * 4, reduction=se), BN1d(nf * 4)) if se else noop

        self._init_cnn(self)

    def _init_cnn(self, m):
        if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)
        if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)
        for l in m.children(): self._init_cnn(l)

    def forward(self, x):
        input_tensor = x
        x = self.bottleneck(x)
        x = self.concat([l(x) for l in self.convs] + [self.mp_conv(input_tensor)])
        x = self.norm(x)
        x = self.conv_dropout(x)
        x = self.sa(x)
        x = self.act(x)
        x = self.se(x)
        return x


@delegates(InceptionModulePlus.__init__)
class InceptionBlockPlus(Module):
    def __init__(self, ni, nf, residual=True, depth=6, coord=False, norm='Batch', zero_norm=False, act=nn.ReLU, act_kwargs={}, sa=False, se=None,
                 keep_prob=1., **kwargs):
        self.residual, self.depth = residual, depth
        self.inception, self.shortcut, self.act = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for d in range(depth):
            self.inception.append(InceptionModulePlus(ni if d == 0 else nf * 4, nf, coord=coord, norm=norm,
                                                      zero_norm=zero_norm if d % 3 == 2 else False,
                                                      act=act if d % 3 != 2 else None, act_kwargs=act_kwargs,
                                                      sa=sa if d % 3 == 2 else False,
                                                      se=se if d % 3 != 2 else None,
                                                      **kwargs))
            if self.residual and d % 3 == 2:
                n_in, n_out = ni if d == 2 else nf * 4, nf * 4
                self.shortcut.append(Norm(n_in, norm=norm) if n_in == n_out else ConvBlock(n_in, n_out, 1, coord=coord, bias=False, norm=norm, act=None))
                self.act.append(act(**act_kwargs))
        self.add = Add()
        self.keep_prob = keep_prob

    def forward(self, x):
        res = x
        for i in range(self.depth):
            if self.training and self.keep_prob[i//3] < 1. and self.keep_prob[i//3] < random.random() and self.residual and i % 3 == 2:
                res = x = self.act[i//3](self.shortcut[i//3](res))
            else:
                x = self.inception[i](x)
                if self.residual and i % 3 == 2: res = x = self.act[i//3](self.add(x, self.shortcut[i//3](res)))
        return x


@delegates(InceptionModulePlus.__init__)
class InceptionTimePlus(Module):
    def __init__(self, c_in, c_out, seq_len=None, nf=32, nb_filters=None, concat_pool=False, fc_dropout=0., depth=6, stoch_depth=1., y_range=None,
                 flatten=False, custom_head=None, **kwargs):

        nf = ifnone(nf, nb_filters) # for compatibility
        self.fc_dropout, self.c_out, self.y_range = fc_dropout, c_out, y_range
        self.c_out = c_out

        if stoch_depth is not 0: keep_prob = np.linspace(1, stoch_depth, depth // 3)
        else: keep_prob = np.array([1] * depth // 3)
        self.inceptionblock = InceptionBlockPlus(c_in, nf, depth=depth, keep_prob=keep_prob, **kwargs)

        self.head_nf = nf * 4
        self.flatten = None
        if flatten:  self.head_nf *= seq_len
        self.flatten = Flatten() if flatten else None
        if custom_head: self.head = custom_head(self.head_nf, c_out)
        else: self.head = self.create_head(self.head_nf, c_out, concat_pool=concat_pool, fc_dropout=fc_dropout, y_range=y_range)

    def create_head(self, nf, c_out, concat_pool=False, fc_dropout=0., y_range=None, **kwargs):
        if concat_pool: nf = nf * 2
        layers = [GACP1d(1) if concat_pool else GAP1d(1)]
        if fc_dropout: layers += [nn.Dropout(fc_dropout)]
        layers += [nn.Linear(nf, c_out)]
        if y_range: layers += [SigmoidRange(*y_range)]
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.inceptionblock(x)
        if self.flatten is not None: x = self.flatten(x)
        x = self.head(x)
        return x


class InCoordTime(InceptionTimePlus):
    def __init__(self, *args, coord=True, zero_norm=True, **kwargs):
        super().__init__(*args, coord=coord, zero_norm=zero_norm, **kwargs)


class XCoordTime(InceptionTimePlus):
    def __init__(self, *args, coord=True, separable=True, zero_norm=True, **kwargs):
        super().__init__(*args, coord=coord, separable=separable, zero_norm=zero_norm, **kwargs)

InceptionTimePlus17x17 = partial(InceptionTimePlus, nf=17, depth=3)
setattr(InceptionTimePlus17x17, '__name__', 'InceptionTimePlus17x17')
InceptionTimePlus32x32 = InceptionTimePlus
InceptionTimePlus47x47 = partial(InceptionTimePlus, nf=47, depth=9)
setattr(InceptionTimePlus47x47, '__name__', 'InceptionTimePlus47x47')
InceptionTimePlus62x62 = partial(InceptionTimePlus, nf=62, depth=9)
setattr(InceptionTimePlus62x62, '__name__', 'InceptionTimePlus62x62')

# Cell
@delegates(InceptionTimePlus.__init__)
class MultiInceptionTimePlus(Module):
    _arch = InceptionTimePlus
    def __init__(self, feat_mask, c_out, seq_len=None, **kwargs):
        r"""
        MultiInceptionTimePlus is a class that allows you to create a model with multiple branches of InceptionTimePlus.

        Args:
            - feat_mask: list with number of features that will be passed to each body.
        """
        self.feat_mask = [feat_mask] if isinstance(feat_mask, int) else feat_mask
        self.c_out = c_out

        # Body
        self.branches = nn.ModuleList()
        self.head_nf = 0
        for feat in self.feat_mask:
            m = create_model(self._arch, c_in=feat, c_out=c_out, seq_len=seq_len, **kwargs)
            self.head_nf += m.head_nf
            m.head = Noop
            self.branches.append(m)

        # Head
        self.head = self._arch.create_head(self, self.head_nf, c_out, **kwargs)


    def forward(self, x):
        x = torch.split(x, self.feat_mask, dim=1)
        for i, branch in enumerate(self.branches):
            out = branch(x[i]) if i == 0 else torch.cat([out, branch(x[i])], dim=1)
        return self.head(out)