# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/data_model.ipynb (unless otherwise specified).

__all__ = ['DataFreq', 'DataFormat', 'DataConfig', 'BacktestConfig', 'FeaturesConfig', 'ForecastConfig', 'ModelConfig',
           'LocalConfig', 'ClusterConfig', 'DistributedModelName', 'DistributedModelConfig', 'DistributedConfig',
           'FlowConfig', 'DateFeatures', 'Transforms']

# Cell
from enum import Enum
from typing import Dict, List, Optional, Union

import window_ops.ewm
import window_ops.expanding
import window_ops.rolling
from pydantic import BaseModel, root_validator
from typing_extensions import Literal

from .core import date_features_dtypes

# Internal Cell
_available_tfms = {}
for module_name in ('rolling', 'expanding', 'ewm'):
    module = getattr(window_ops, module_name)
    for tfm in module.__all__:
        _available_tfms[tfm] = getattr(module, tfm)

# Cell
DateFeatures = Literal[tuple(date_features_dtypes.keys())]

Transforms = Literal[tuple(_available_tfms.keys())]


class DataFreq(str, Enum):
    """Pandas frequencies."""

    B = 'B'
    C = 'C'
    D = 'D'
    W = 'W'
    M = 'M'
    SM = 'SM'
    BM = 'BM'
    CBM = 'CBM'
    MS = 'MS'
    SMS = 'SMS'
    BMS = 'BMS'
    CBMS = 'CBMS'
    Q = 'Q'
    BQ = 'BQ'
    QS = 'QS'
    BQS = 'BQS'
    A = 'A'
    Y = 'Y'
    BA = 'BA'
    BY = 'BY'
    AS = 'AS'
    YS = 'YS'
    BAS = 'BAS'
    BYS = 'BYS'
    BH = 'BH'
    H = 'H'
    T = 'T'
    S = 'S'
    L = 'L'
    U = 'U'
    N = 'N'
    W_MON = 'W-MON'
    W_TUE = 'W-TUE'
    W_WED = 'W-WED'
    W_THU = 'W-THU'
    W_FRI = 'W-FRI'
    W_SAT = 'W-SAT'
    Q_JAN = 'Q-JAN'
    Q_FEB = 'Q-FEB'
    Q_MAR = 'Q-MAR'
    Q_APR = 'Q-APR'
    Q_MAY = 'Q-MAY'
    Q_JUN = 'Q-JUN'
    Q_JUL = 'Q-JUL'
    Q_AUG = 'Q-AUG'
    Q_SEP = 'Q-SEP'
    Q_OCT = 'Q-OCT'
    Q_NOV = 'Q-NOV'
    A_JAN = 'A-JAN'
    A_FEB = 'A-FEB'
    A_MAR = 'A-MAR'
    A_APR = 'A-APR'
    A_MAY = 'A-MAY'
    A_JUN = 'A-JUN'
    A_JUL = 'A-JUL'
    A_AUG = 'A-AUG'
    A_SEP = 'A-SEP'
    A_OCT = 'A-OCT'
    A_NOV = 'A-NOV'


class DataFormat(str, Enum):
    """Allowed data formats."""

    csv = 'csv'
    parquet = 'parquet'


class DataConfig(BaseModel):
    """Data configuration."""

    prefix: str
    input: str
    output: str
    format: DataFormat
    dynamic: Optional[List[str]]


class BacktestConfig(BaseModel):
    """Backtest configuration."""

    n_windows: int
    window_size: int


class FeaturesConfig(BaseModel):
    """Features configuration."""

    freq: DataFreq
    lags: Optional[List[int]]
    lag_transforms: Optional[Dict[int, List[Union[Transforms, Dict[Transforms, Dict]]]]]
    date_features: Optional[List[DateFeatures]]
    static_features: Optional[List[str]]
    keep_last_n: Optional[int]
    num_threads: Optional[int]


class ForecastConfig(BaseModel):
    """Forecast configuration."""

    horizon: int


class ModelConfig(BaseModel):
    """Model configuration.

    name must include the modules i.e. sklearn.ensemble.RandomForestRegressor."""

    name: str
    params: Optional[Dict]


class LocalConfig(BaseModel):
    """Configuration for local pipeline."""

    model: ModelConfig


class ClusterConfig(BaseModel):
    """Cluter configuration.

    class_name must include the modules i.e. dask.distributed.LocalCluster"""

    class_name: str
    class_kwargs: Dict


class DistributedModelName(str, Enum):
    """Available models for distributed training."""

    XGBForecast = 'XGBForecast'
    LGBMForecast = 'LGBMForecast'


class DistributedModelConfig(BaseModel):
    """Configuration for distributed models."""

    name: DistributedModelName
    params: Optional[Dict]


class DistributedConfig(BaseModel):
    """Configuration for distributed training."""

    model: DistributedModelConfig
    cluster: ClusterConfig


class FlowConfig(BaseModel):
    """Flow configuration."""

    data: DataConfig
    features: FeaturesConfig
    backtest: Optional[BacktestConfig]
    forecast: Optional[ForecastConfig]
    local: Optional[LocalConfig]
    distributed: Optional[DistributedConfig]

    @root_validator
    def check_local_or_distributed(cls, values):
        local = values.get('local')
        distributed = values.get('distributed')
        if local and distributed:
            raise ValueError('Must specify either local or distributed, not both.')
        if not local and not distributed:
            raise ValueError('Must specify either local or distributed.')
        return values
