# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/hierarchical.ipynb (unless otherwise specified).

__all__ = ['Labour', 'TourismLarge', 'TourismSmall', 'Traffic', 'Wiki2', 'HierarchicalInfo', 'HierarchicalData']

# Cell
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd

from .utils import download_file, Info

# Cell
@dataclass
class Labour:
    freq: str = 'MS'
    horizon: int = 8
    seasonality: int = 12
    test_size: int = 125
    tags_names: Tuple[str] = (
        'Country',
        'Country/Region',
        'Country/Gender/Region',
        'Country/Employment/Gender/Region',
    )

# Cell
@dataclass
class TourismLarge:
    freq: str = 'MS'
    horizon: int = 12
    seasonality: int = 12
    test_size: int = 57
    tags_names: Tuple[str] = (
        'Country',
        'Country/State',
        'Country/State/Zone',
        'Country/State/Zone/Region',
        'Country/Purpose',
        'Country/State/Purpose',
        'Country/State/Zone/Purpose',
        'Country/State/Zone/Region/Purpose',
    )

# Cell
@dataclass
class TourismSmall:
    freq: str = 'Q'
    horizon: int = 8
    seasonality: int = 4
    test_size: int = 9
    tags_names: Tuple[str] = (
        'Country',
        'Country/Purpose',
        'Country/Purpose/State',
        'Country/Purpose/State/CityNonCity',
    )

# Cell
@dataclass
class Traffic:
    freq: str = 'D'
    horizon: int = 14
    seasonality: int = 7
    test_size: int = 91
    tags_names: Tuple[str] = (
        'Level1',
        'Level2',
        'Level3',
        'Level4',
    )

# Cell
@dataclass
class Wiki2:
    freq: str = 'D'
    horizon: int = 14
    seasonality: int = 7
    test_size: int = 91
    tags_names: Tuple[str] = (
        'Views',
        'Views/Country',
        'Views/Country/Access',
        'Views/Country/Access/Agent',
        'Views/Country/Access/Agent/Topic'
    )

# Cell
HierarchicalInfo = Info(
    (
        Labour, TourismLarge,
        TourismSmall,
        Traffic, Wiki2
    )
)

# Cell
class HierarchicalData:

    source_url: str = 'https://nixtla-public.s3.amazonaws.com/hierarchical-data/datasets.zip'

    @staticmethod
    def load(directory: str,
             group: str,
             cache: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Downloads hierarchical forecasting benchmark datasets.

            Parameters
            ----------
            directory: str
                Directory where data will be downloaded.
            group: str
                Group name.
            cache: bool
                If `True` saves and loads

            Returns
            -------
            Y_df: pd.DataFrame
                Target time series with columns ['unique_id', 'ds', 'y'].
                Containes the base time series.
            S: pd.DataFrame
                Summing matrix of size (hierarchies, bottom).
        """
        if group not in HierarchicalInfo.groups:
            raise Exception(f'group not found {group}')

        path = f'{directory}/hierarchical/'
        file_cache = Path(f'{path}/{group}.p')

        if file_cache.is_file() and cache:
            Y_df, S, tags = pd.read_pickle(file_cache)

            return Y_df, S, tags

        HierarchicalData.download(directory)
        path = Path(f'{path}/{group}')
        S = pd.read_csv(path / 'agg_mat.csv', index_col=0)
        Y_df = pd.read_csv(path / 'data.csv', index_col=0).T
        Y_df = Y_df.stack()
        Y_df.name = 'y'
        Y_df.index = Y_df.index.set_names(['unique_id', 'ds'])
        Y_df = Y_df.reset_index()
        if group == 'Labour':
            #for labour we avoid covid periods
            Y_df = Y_df.query('ds < "2020-01-01"').reset_index(drop=True)

        if not all(Y_df['unique_id'].unique() == S.index):
            raise Exception('mismatch order between `Y_df` and `S`')

        def get_levels_from_S(S):
            cut_idxs, = np.where(S.sum(axis=1).cumsum() % S.shape[1] == 0.)
            levels = [S.iloc[(cut_idxs[i] + 1):(cut_idxs[i+1] + 1)].index.values for i in range(cut_idxs.size-1)]
            levels = [S.iloc[[0]].index.values] + levels
            assert sum([len(lv) for lv in levels]) == S.shape[0]
            return levels

        cls_group = HierarchicalInfo[group]
        tags = dict(zip(cls_group.tags_names, get_levels_from_S(S)))

        if cache:
            pd.to_pickle((Y_df, S, tags), file_cache)

        return Y_df, S, tags

    @staticmethod
    def download(directory: str) -> None:
        """
        Download Hierarchical Datasets.

            Parameters
            ----------
            directory: str
                Directory path to download dataset.
        """
        path = f'{directory}/hierarchical/'
        if not Path(path).exists():
             download_file(path, HierarchicalData.source_url, decompress=True)