# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/076_models.MultiRocketPlus.ipynb.

# %% auto 0
__all__ = ['MultiRocket', 'Flatten', 'MultiRocketFeaturesPlus', 'MultiRocketBackbonePlus', 'MultiRocketPlus']

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 3
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from collections import OrderedDict
import itertools
from .layers import rocket_nd_head

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 4
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 5
def _LPVV(o_pos, dim=2):
    "Longest stretch of positive values (-1, 1)" 
    shape = list(o_pos.shape)
    shape[dim] = 1
    o_pos = torch.cat([torch.zeros(shape, device=o_pos.device), o_pos], dim)
    o_arange_shape = [1] * o_pos.ndim
    o_arange_shape[dim] = -1
    o_arange = torch.arange(o_pos.shape[dim], device=o_pos.device).reshape(o_arange_shape)
    o_pos = torch.where(o_pos == 1, 0, o_arange)
    o_pos = o_pos.cummax(dim).values
    return ((o_arange - o_pos).max(dim).values / (o_pos.shape[dim] - 1)) * 2 - 1

def _MPV(o, dim=2):
    "Mean of Positive Values (any positive value)"
    o = torch.where(o > 0, o, torch.nan)
    o = torch.nanmean(o, dim)
    return torch.nan_to_num(o)

def _RSPV(o, dim=2):
    "Relative Sum of Positive Values (-1, 1)"
    o_sum = torch.clamp_min(torch.abs(o).sum(dim), 1e-8)
    o_pos_sum = torch.nansum(F.relu(o), dim)
    return (o_pos_sum / o_sum) * 2 - 1

def _MIPV(o, o_pos, dim=2):
    "Mean of Indices of Positive Values (-1, 1)"
    seq_len = o.shape[dim]
    o_arange_shape = [1] * o_pos.ndim
    o_arange_shape[dim] = -1
    o_arange = torch.arange(o_pos.shape[dim], device=o.device).reshape(o_arange_shape)
    o = torch.where(o_pos, o_arange, torch.nan)
    o = torch.nanmean(o, dim)
    return (torch.nan_to_num(o) / seq_len) * 2 - 1

def _PPV(o_pos, dim=2):
    "Proportion of Positive Values (-1, 1)"
    return (o_pos).float().mean(dim) * 2 - 1

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 6
class MultiRocketFeaturesPlus(nn.Module):
    fitting = False

    def __init__(self, c_in, seq_len, num_features=10_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=9, max_num_kernels=84, diff=False):
        super(MultiRocketFeaturesPlus, self).__init__()
        
        self.c_in, self.seq_len = c_in, seq_len
        self.kernel_size, self.max_num_channels = kernel_size, max_num_channels

        # Kernels
        indices, pos_values = self.get_indices(kernel_size, max_num_kernels)
        self.num_kernels = len(indices)
        kernels = (-torch.ones(self.num_kernels, 1, self.kernel_size)).scatter_(2, indices, pos_values)
        self.indices = indices
        self.kernels = nn.Parameter(kernels.repeat(c_in, 1, 1), requires_grad=False)
        num_features = num_features // 4
        self.num_features = num_features // self.num_kernels * self.num_kernels
        self.max_dilations_per_kernel = max_dilations_per_kernel

        # Dilations
        self.set_dilations(seq_len)

        # Channel combinations (multivariate)
        if c_in > 1:
            self.set_channel_combinations(c_in, max_num_channels)

        # Bias
        for i in range(self.num_dilations):
            self.register_buffer(f'biases_{i}', torch.empty(
                (self.num_kernels, self.num_features_per_dilation[i])))
        self.register_buffer('prefit', torch.BoolTensor([False]))

    def forward(self, x):
        
        _features = []
        for i, (dilation, padding) in enumerate(zip(self.dilations, self.padding)):
            _padding1 = i % 2

            # Convolution
            C = F.conv1d(x, self.kernels, padding=padding,
                         dilation=dilation, groups=self.c_in)
            if self.c_in > 1:  # multivariate
                C = C.reshape(x.shape[0], self.c_in, self.num_kernels, -1)
                channel_combination = getattr(
                    self, f'channel_combinations_{i}')
                C = torch.mul(C, channel_combination)
                C = C.sum(1)

            # Bias
            if not self.prefit or self.fitting:
                num_features_this_dilation = self.num_features_per_dilation[i]
                bias_this_dilation = self.get_bias(
                    C, num_features_this_dilation)
                setattr(self, f'biases_{i}', bias_this_dilation)
                if self.fitting:
                    if i < self.num_dilations - 1:
                        continue
                    else:
                        self.prefit = torch.BoolTensor([True])
                        return
                elif i == self.num_dilations - 1:
                    self.prefit = torch.BoolTensor([True])
            else:
                bias_this_dilation = getattr(self, f'biases_{i}')

            # Features
            _features.append(self.apply_pooling_ops(
                C[:, _padding1::2], bias_this_dilation[_padding1::2]))
            _features.append(self.apply_pooling_ops(
                C[:, 1-_padding1::2, padding:-padding], bias_this_dilation[1-_padding1::2]))

        return torch.cat(_features, dim=1)

    def fit(self, X, chunksize=None):
        num_samples = X.shape[0]
        if chunksize is None:
            chunksize = min(num_samples, self.num_dilations * self.num_kernels)
        else: 
            chunksize = min(num_samples, chunksize)
        idxs = np.random.choice(num_samples, chunksize, False)
        self.fitting = True
        if isinstance(X, np.ndarray): 
            self(torch.from_numpy(X[idxs]).to(self.kernels.device))
        else:
            self(X[idxs].to(self.kernels.device))
        self.fitting = False

    def apply_pooling_ops(self, C, bias):
        C = C.unsqueeze(-1)
        bias = bias.view(1, bias.shape[0], 1, bias.shape[1])
        pos_vals = (C > bias)
        ppv = _PPV(pos_vals).flatten(1)
        mpv = _MPV(C - bias).flatten(1)
        # rspv = _RSPV(C - bias).flatten(1)
        mipv = _MIPV(C, pos_vals).flatten(1)
        lspv = _LPVV(pos_vals).flatten(1)
        return torch.cat((ppv, mpv, mipv, lspv), dim=1)
        return torch.cat((ppv, rspv, mipv, lspv), dim=1)

    def set_dilations(self, input_length):
        num_features_per_kernel = self.num_features // self.num_kernels
        true_max_dilations_per_kernel = min(
            num_features_per_kernel, self.max_dilations_per_kernel)
        multiplier = num_features_per_kernel / true_max_dilations_per_kernel
        max_exponent = np.log2((input_length - 1) / (self.kernel_size - 1))
        dilations, num_features_per_dilation = \
            np.unique(np.logspace(0, max_exponent, true_max_dilations_per_kernel, base=2).astype(
                np.int32), return_counts=True)
        num_features_per_dilation = (
            num_features_per_dilation * multiplier).astype(np.int32)
        remainder = num_features_per_kernel - num_features_per_dilation.sum()
        i = 0
        while remainder > 0:
            num_features_per_dilation[i] += 1
            remainder -= 1
            i = (i + 1) % len(num_features_per_dilation)
        self.num_features_per_dilation = num_features_per_dilation
        self.num_dilations = len(dilations)
        self.dilations = dilations
        self.padding = []
        for i, dilation in enumerate(dilations):
            self.padding.append((((self.kernel_size - 1) * dilation) // 2))

    def set_channel_combinations(self, num_channels, max_num_channels):
        num_combinations = self.num_kernels * self.num_dilations
        if max_num_channels:
            max_num_channels = min(num_channels, max_num_channels)
        else:
            max_num_channels = num_channels
        max_exponent_channels = np.log2(max_num_channels + 1)
        num_channels_per_combination = (
            2 ** np.random.uniform(0, max_exponent_channels, num_combinations)).astype(np.int32)
        self.num_channels_per_combination = num_channels_per_combination
        channel_combinations = torch.zeros(
            (1, num_channels, num_combinations, 1))
        for i in range(num_combinations):
            channel_combinations[:, np.random.choice(
                num_channels, num_channels_per_combination[i], False), i] = 1
        channel_combinations = torch.split(
            channel_combinations, self.num_kernels, 2)  # split by dilation
        for i, channel_combination in enumerate(channel_combinations):
            self.register_buffer(
                f'channel_combinations_{i}', channel_combination)  # per dilation

    def get_quantiles(self, n):
        return torch.tensor([(_ * ((np.sqrt(5) + 1) / 2)) % 1 for _ in range(1, n + 1)]).float()

    def get_bias(self, C, num_features_this_dilation):
        isp = torch.randint(C.shape[0], (self.num_kernels,))
        samples = C[isp].diagonal().T
        biases = torch.quantile(samples, self.get_quantiles(
            num_features_this_dilation).to(C.device), dim=1).T
        return biases

    def get_indices(self, kernel_size, max_num_kernels):
        num_pos_values = math.ceil(kernel_size / 3)
        num_neg_values = kernel_size - num_pos_values
        pos_values = num_neg_values / num_pos_values
        if kernel_size > 9:
            random_kernels = [np.sort(np.random.choice(kernel_size, num_pos_values, False)).reshape(
                1, -1) for _ in range(max_num_kernels)]
            indices = torch.from_numpy(
                np.concatenate(random_kernels, 0)).unsqueeze(1)
        else:
            indices = torch.LongTensor(list(itertools.combinations(
                np.arange(kernel_size), num_pos_values))).unsqueeze(1)
            if max_num_kernels and len(indices) > max_num_kernels:
                indices = indices[np.sort(np.random.choice(
                    len(indices), max_num_kernels, False))]
        return indices, pos_values

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 7
class MultiRocketBackbonePlus(nn.Module):
    def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True):
        super(MultiRocketBackbonePlus, self).__init__()
        
        num_features_per_branch = num_features // (1 + use_diff)        
        self.branch_x = MultiRocketFeaturesPlus(c_in, seq_len, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,
                                                kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)
        if use_diff:
            self.branch_x_diff = MultiRocketFeaturesPlus(c_in, seq_len - 1, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,
                                                         kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)
        if use_diff:
            self.num_features = (self.branch_x.num_features + self.branch_x_diff.num_features) * 4 # 4 types of features
        else:
            self.num_features = self.branch_x.num_features * 4
        self.use_diff = use_diff
        
    def forward(self, x):
        if self.use_diff:
            x_features = self.branch_x(x)
            x_diff_features = self.branch_x(torch.diff(x))
            output = torch.cat([x_features, x_diff_features], dim=-1)
            return output
        else:
            output = self.branch_x(x)
            return output

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 8
class MultiRocketPlus(nn.Sequential):

    def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,
                 use_bn=True, fc_dropout=0, custom_head=None, zero_init=True, use_diff=True):

        # Backbone
        backbone = MultiRocketBackbonePlus(c_in, seq_len, num_features=num_features, max_dilations_per_kernel=max_dilations_per_kernel,
                                          kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels, use_diff=use_diff)
        num_features = backbone.num_features

        # Head
        self.head_nf = num_features
        if custom_head is not None: 
            if isinstance(custom_head, nn.Module): head = custom_head
            else: head = custom_head(self.head_nf, c_out, 1)
        elif d is not None:
            head = rocket_nd_head(num_features, c_out, seq_len=None, d=d, use_bn=use_bn, fc_dropout=fc_dropout, zero_init=zero_init)
        else:
            layers = [Flatten()]
            if use_bn:
                layers += [nn.BatchNorm1d(num_features)]
            if fc_dropout:
                layers += [nn.Dropout(fc_dropout)]
            linear = nn.Linear(num_features, c_out)
            if zero_init:
                nn.init.constant_(linear.weight.data, 0)
                nn.init.constant_(linear.bias.data, 0)
            layers += [linear]
            head = nn.Sequential(*layers)

        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))

MultiRocket = MultiRocketPlus
