# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/distributed.fugue.ipynb.

# %% auto 0
__all__ = ['FugueBackend']

# %% ../nbs/distributed.fugue.ipynb 5
from typing import Any, Dict

import numpy as np
import pandas as pd
try:
    from fugue import transform
except ModuleNotFoundError as e:
    msg = (
        f'{e}. To use fugue you have to install it.'
        'Please run `pip install fugue`. '
    )
    raise ModuleNotFoundError(msg) from e
from ..core import StatsForecast
from .core import ParallelBackend
from triad import Schema

# %% ../nbs/distributed.fugue.ipynb 6
class FugueBackend(ParallelBackend):
    def __init__(
            self, 
            engine: Any = None, # Fugue engine
            conf: Any = None, # Engine configuration
            **transform_kwargs: Any # Additional kwargs to pass to `transform`'s fugue
        ):
        self._engine = engine
        self._conf = conf
        self._transform_kwargs = dict(transform_kwargs)

    def __getstate__(self) -> Dict[str, Any]:
        return {}

    def forecast(
            self, 
            df, # DataFrame with columns `unique_id`, `ds`, `y`, and exogenous variables 
            models, # List of instantiated models (`statsforecast.models`) 
            freq, # Frequency of the data
            **kwargs: Any,
        ) -> Any:
        schema = "*-y+" + str(self._get_output_schema(models))
        return transform(
            df,
            self._forecast_series,
            params=dict(models=models, freq=freq, kwargs=kwargs),
            schema=schema,
            partition={"by": "unique_id"},
            engine=self._engine,
            engine_conf=self._conf,
            **self._transform_kwargs,
        )

    def cross_validation(
            self, 
            df, # DataFrame with columns `unique_id`, `ds`, `y`, and exogenous variables 
            models, # List of instantiated models (`statsforecast.models`) 
            freq, # Frequency of the data
            **kwargs: Any, 
        ) -> Any:
        schema = "*-y+" + str(self._get_output_schema(models, mode="cv"))
        return transform(
            df,
            self._cv,
            params=dict(models=models, freq=freq, kwargs=kwargs),
            schema=schema,
            partition={"by": "unique_id"},
            engine=self._engine,
            engine_conf=self._conf,
            **self._transform_kwargs,
        )

    def _forecast_series(self, df: pd.DataFrame, models, freq, kwargs) -> pd.DataFrame:
        tdf = df.set_index("unique_id")
        model = StatsForecast(df=tdf, models=models, freq=freq, n_jobs=1)
        return model.forecast(**kwargs).reset_index()

    def _cv(self, df: pd.DataFrame, models, freq, kwargs) -> pd.DataFrame:
        tdf = df.set_index("unique_id")
        model = StatsForecast(df=tdf, models=models, freq=freq, n_jobs=1)
        return model.cross_validation(**kwargs).reset_index()

    def _get_output_schema(self, models, mode="forecast") -> Schema:
        cols = [(repr(model), np.float32) for model in models]
        if mode == "cv":
            cols = [("cutoff", "datetime"), ("y", np.float32)] + cols
        return Schema(cols)
