# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/053_models.ROCKET.ipynb.

# %% auto 0
__all__ = ['RocketClassifier', 'load_rocket', 'RocketRegressor']

# %% ../../nbs/053_models.ROCKET.ipynb 3
import sklearn
from sklearn.linear_model import RidgeClassifierCV, RidgeCV
from sklearn.metrics import make_scorer
from sklearn.preprocessing import StandardScaler

from ..data.external import *
from ..imports import *
from .layers import *

# %% ../../nbs/053_models.ROCKET.ipynb 4
class RocketClassifier(sklearn.pipeline.Pipeline):
    """Time series classification using ROCKET features and a linear classifier"""
    
    def __init__(self, num_kernels=10_000, normalize_input=True, random_state=None, 
                 alphas=np.logspace(-3, 3, 7), normalize_features=True, memory=None, verbose=False, scoring=None, class_weight=None, **kwargs):
        """
        RocketClassifier is recommended for up to 10k time series. 
        For a larger dataset, you can use ROCKET (in Pytorch).
        scoring = None --> defaults to accuracy.
        
        Rocket args:            
            num_kernels     : int, number of random convolutional kernels (default 10,000)
            normalize_input : boolean, whether or not to normalise the input time series per instance (default True)
            random_state    : Optional random seed (default None)

        """
        try: 
            import sktime
            from sktime.transformations.panel.rocket import Rocket
        except ImportError:
            raise("You need to install sktime to be able to use RocketClassifier")
            
        self.steps = [('rocket', Rocket(num_kernels=num_kernels, normalise=normalize_input, random_state=random_state))]
        if normalize_features:
            self.steps += [('scalar', StandardScaler(with_mean=False))]
        self.steps += [('ridgeclassifiercv', RidgeClassifierCV(alphas=alphas, scoring=scoring, class_weight=class_weight, **kwargs))]
        store_attr()
        self._validate_steps()

    def __repr__(self):  
        return f'Pipeline(steps={self.steps.copy()})'

    def save(self, fname='Rocket', path='./models'):
        path = Path(path)
        filename = path/fname
        with open(f'{filename}.pkl', 'wb') as output:
            pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)

# %% ../../nbs/053_models.ROCKET.ipynb 5
def load_rocket(fname='Rocket', path='./models'):
    path = Path(path)
    filename = path/fname
    with open(f'{filename}.pkl', 'rb') as input:
        output = pickle.load(input)
    return output

# %% ../../nbs/053_models.ROCKET.ipynb 6
class RocketRegressor(sklearn.pipeline.Pipeline):
    """Time series regression using ROCKET features and a linear regressor"""
    
    def __init__(self, num_kernels=10_000, normalize_input=True, random_state=None, 
                 alphas=np.logspace(-3, 3, 7), normalize_features=True, memory=None, verbose=False, scoring=None, **kwargs):
        """
        RocketRegressor is recommended for up to 10k time series. 
        For a larger dataset, you can use ROCKET (in Pytorch).
        scoring = None --> defaults to r2.
        
        Args:            
            num_kernels     : int, number of random convolutional kernels (default 10,000)
            normalize_input : boolean, whether or not to normalise the input time series per instance (default True)
            random_state    : Optional random seed (default None)
        """
        try: 
            import sktime
            from sktime.transformations.panel.rocket import Rocket
        except ImportError:
            raise("You need to install sktime to be able to use RocketRegressor")
            
        self.steps = [('rocket', Rocket(num_kernels=num_kernels, normalise=normalize_input, random_state=random_state))]
        if normalize_features:
            self.steps += [('scalar', StandardScaler(with_mean=False))]
        self.steps += [('ridgecv', RidgeCV(alphas=alphas, scoring=scoring, **kwargs))]
        store_attr()
        self._validate_steps()


    def __repr__(self):  
        return f'Pipeline(steps={self.steps.copy()})'

    def save(self, fname='Rocket', path='./models'):
        path = Path(path)
        filename = path/fname
        with open(f'{filename}.pkl', 'wb') as output:
            pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
