# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/timegpt.ipynb.

# %% auto 0
__all__ = ['main_logger', 'httpx_logger']

# %% ../nbs/timegpt.ipynb 3
import inspect
import json
import logging
import os
import requests
import warnings
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
from tenacity import (
    retry,
    stop_after_attempt,
    wait_fixed,
    stop_after_delay,
    RetryCallState,
    retry_if_exception,
    retry_if_not_exception_type,
)
from utilsforecast.processing import (
    backtest_splits,
    drop_index_if_pandas,
    join,
    maybe_compute_sort_indices,
    take_rows,
    vertical_concat,
)

from .client import ApiError, Nixtla, SingleSeriesForecast

logging.basicConfig(level=logging.INFO)
main_logger = logging.getLogger(__name__)
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.ERROR)

# %% ../nbs/timegpt.ipynb 5
date_features_by_freq = {
    # Daily frequencies
    "B": ["year", "month", "day", "weekday"],
    "C": ["year", "month", "day", "weekday"],
    "D": ["year", "month", "day", "weekday"],
    # Weekly
    "W": ["year", "week", "weekday"],
    # Monthly
    "M": ["year", "month"],
    "SM": ["year", "month", "day"],
    "BM": ["year", "month"],
    "CBM": ["year", "month"],
    "MS": ["year", "month"],
    "SMS": ["year", "month", "day"],
    "BMS": ["year", "month"],
    "CBMS": ["year", "month"],
    # Quarterly
    "Q": ["year", "quarter"],
    "BQ": ["year", "quarter"],
    "QS": ["year", "quarter"],
    "BQS": ["year", "quarter"],
    # Yearly
    "A": ["year"],
    "Y": ["year"],
    "BA": ["year"],
    "BY": ["year"],
    "AS": ["year"],
    "YS": ["year"],
    "BAS": ["year"],
    "BYS": ["year"],
    # Hourly
    "BH": ["year", "month", "day", "hour", "weekday"],
    "H": ["year", "month", "day", "hour"],
    # Minutely
    "T": ["year", "month", "day", "hour", "minute"],
    "min": ["year", "month", "day", "hour", "minute"],
    # Secondly
    "S": ["year", "month", "day", "hour", "minute", "second"],
    # Milliseconds
    "L": ["year", "month", "day", "hour", "minute", "second", "millisecond"],
    "ms": ["year", "month", "day", "hour", "minute", "second", "millisecond"],
    # Microseconds
    "U": ["year", "month", "day", "hour", "minute", "second", "microsecond"],
    "us": ["year", "month", "day", "hour", "minute", "second", "microsecond"],
    # Nanoseconds
    "N": [],
}

# %% ../nbs/timegpt.ipynb 6
class _TimeGPTModel:
    def __init__(
        self,
        client: Nixtla,
        h: int,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        freq: str = None,
        level: Optional[List[Union[int, float]]] = None,
        finetune_steps: int = 0,
        finetune_loss: str = "default",
        clean_ex_first: bool = True,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        max_retries: int = 6,
        retry_interval: int = 10,
        max_wait_time: int = 6 * 60,
    ):
        self.client = client
        self.h = h
        self.id_col = id_col
        self.time_col = time_col
        self.target_col = target_col
        self.base_freq = freq
        self.level = level
        self.finetune_steps = finetune_steps
        self.finetune_loss = finetune_loss
        self.clean_ex_first = clean_ex_first
        self.date_features = date_features
        self.date_features_to_one_hot = date_features_to_one_hot
        self.model = model
        self.max_retries = max_retries
        self.retry_interval = retry_interval
        self.max_wait_time = max_wait_time
        # variables defined by each flow
        self.weights_x: pd.DataFrame = None
        self.freq: str = self.base_freq
        self.drop_uid: bool = False
        self.x_cols: List[str]
        self.input_size: int
        self.model_horizon: int

    def _retry_strategy(self):
        def after_retry(retry_state: RetryCallState):
            """Called after each retry attempt."""
            main_logger.info(f"Attempt {retry_state.attempt_number} failed...")

        # we want to retry when:
        # there is no ApiError
        # there is an ApiError with string body
        def is_api_error_with_text_body(exception):
            if isinstance(exception, ApiError):
                if isinstance(exception.body, str):
                    return True
            return False

        return retry(
            stop=(
                stop_after_attempt(self.max_retries)
                | stop_after_delay(self.max_wait_time)
            ),
            wait=wait_fixed(self.retry_interval),
            reraise=True,
            after=after_retry,
            retry=retry_if_exception(is_api_error_with_text_body)
            | retry_if_not_exception_type(ApiError),
        )

    def _call_api(self, method, kwargs):
        response = self._retry_strategy()(method)(**kwargs)
        if "data" in response:
            response = response["data"]
        return response

    def transform_inputs(self, df: pd.DataFrame, X_df: pd.DataFrame):
        df = df.copy()
        main_logger.info("Validating inputs...")
        if self.base_freq is None and hasattr(df.index, "freq"):
            inferred_freq = df.index.freq
            if inferred_freq is not None:
                inferred_freq = inferred_freq.rule_code
                main_logger.info(f"Inferred freq: {inferred_freq}")
            self.freq = inferred_freq
            time_col = df.index.name
            if time_col is None:
                time_col = "ds"
                df.index.name = time_col
            df = df.reset_index()
        else:
            self.freq = self.base_freq
        renamer = {
            self.id_col: "unique_id",
            self.time_col: "ds",
            self.target_col: "y",
        }
        df = df.rename(columns=renamer)
        if df.dtypes.ds != "object":
            df["ds"] = df["ds"].astype(str)
        if "unique_id" not in df.columns:
            # Insert unique_id column
            df = df.assign(unique_id="ts_0")
            self.drop_uid = True
        if X_df is not None:
            X_df = X_df.copy()
            X_df = X_df.rename(columns=renamer)
            if "unique_id" not in X_df.columns:
                X_df = X_df.assign(unique_id="ts_0")
            if X_df.dtypes.ds != "object":
                X_df["ds"] = X_df["ds"].astype(str)
        return df, X_df

    def transform_outputs(self, fcst_df: pd.DataFrame):
        renamer = {
            "unique_id": self.id_col,
            "ds": self.time_col,
            "y": self.target_col,
        }
        if self.drop_uid:
            fcst_df = fcst_df.drop(columns="unique_id")
        fcst_df = fcst_df.rename(columns=renamer)
        return fcst_df

    def infer_freq(self, df: pd.DataFrame):
        # special freqs that need to be checked
        # for example to ensure 'W'-> 'W-MON'
        special_freqs = ["W", "M", "Q", "Y", "A"]
        if self.freq is None or self.freq in special_freqs:
            unique_id = df.iloc[0]["unique_id"]
            df_id = df.query("unique_id == @unique_id")
            inferred_freq = pd.infer_freq(df_id["ds"].sort_values())
            if inferred_freq is None:
                raise Exception(
                    "Could not infer frequency of ds column. This could be due to "
                    "inconsistent intervals. Please check your data for missing, "
                    "duplicated or irregular timestamps"
                )
            if self.freq is not None:
                # check we have the same base frequency
                # except when we have yearly frequency (A, and Y means the same)
                if (self.freq != inferred_freq[0] and self.freq != "Y") or (
                    self.freq == "Y" and inferred_freq[0] != "A"
                ):
                    raise Exception(
                        f"Failed to infer special date, inferred freq {inferred_freq}"
                    )
            main_logger.info(f"Inferred freq: {inferred_freq}")
            self.freq = inferred_freq

    def resample_dataframe(self, df: pd.DataFrame):
        df = df.copy()
        df["ds"] = pd.to_datetime(df["ds"])
        resampled_df = (
            df.set_index("ds").groupby("unique_id").resample(self.freq).bfill()
        )
        resampled_df = resampled_df.drop(columns="unique_id").reset_index()
        resampled_df["ds"] = resampled_df["ds"].astype(str)
        return resampled_df

    def make_future_dataframe(self, df: pd.DataFrame, reconvert: bool = True):
        last_dates = df.groupby("unique_id")["ds"].max()

        def _future_date_range(last_date):
            future_dates = pd.date_range(last_date, freq=self.freq, periods=self.h + 1)
            future_dates = future_dates[-self.h :]
            return future_dates

        future_df = last_dates.apply(_future_date_range).reset_index()
        future_df = future_df.explode("ds").reset_index(drop=True)
        if reconvert and df.dtypes["ds"] == "object":
            # avoid date 000
            future_df["ds"] = future_df["ds"].astype(str)
        return future_df

    def compute_date_feature(self, dates, feature):
        if callable(feature):
            feat_name = feature.__name__
            feat_vals = feature(dates)
        else:
            feat_name = feature
            if feature in ("week", "weekofyear"):
                dates = dates.isocalendar()
            feat_vals = getattr(dates, feature)
        if not isinstance(feat_vals, pd.DataFrame):
            vals = np.asarray(feat_vals)
            feat_vals = pd.DataFrame({feat_name: vals})
        feat_vals["ds"] = dates
        return feat_vals

    def add_date_features(
        self,
        df: pd.DataFrame,
        X_df: Optional[pd.DataFrame],
    ):
        # df contains exogenous variables
        # X_df are the future values of the exogenous variables
        # construct dates
        train_dates = df["ds"].unique().tolist()
        # if we dont have future exogenos variables
        # we need to compute the future dates
        if (self.h is not None) and X_df is None:
            X_df = self.make_future_dataframe(df=df)
            future_dates = X_df["ds"].unique().tolist()
        elif X_df is not None:
            future_dates = X_df["ds"].unique().tolist()
        else:
            future_dates = []
        dates = pd.DatetimeIndex(np.unique(train_dates + future_dates).tolist())
        date_features_df = pd.DataFrame({"ds": dates})
        for feature in self.date_features:
            feat_df = self.compute_date_feature(dates, feature)
            date_features_df = date_features_df.merge(feat_df, on=["ds"], how="left")
        if df.dtypes["ds"] == "object":
            date_features_df["ds"] = date_features_df["ds"].astype(str)
        if self.date_features_to_one_hot is not None:
            date_features_df = pd.get_dummies(
                date_features_df,
                columns=self.date_features_to_one_hot,
                dtype=int,
            )
        # remove duplicated columns if any
        date_features_df = date_features_df.drop(
            columns=[
                col
                for col in date_features_df.columns
                if col in df.columns and col not in ["unique_id", "ds"]
            ]
        )
        # add date features to df
        df = df.merge(date_features_df, on="ds", how="left")
        # add date features to X_df
        if X_df is not None:
            X_df = X_df.merge(date_features_df, on="ds", how="left")
        return df, X_df

    def preprocess_X_df(self, X_df: pd.DataFrame):
        if X_df.isna().any().any():
            raise Exception("Some of your exogenous variables contain NA, please check")
        X_df = X_df.sort_values(["unique_id", "ds"]).reset_index(drop=True)
        X_df = self.resample_dataframe(X_df)
        return X_df

    def preprocess_dataframes(
        self,
        df: pd.DataFrame,
        X_df: Optional[pd.DataFrame],
    ):
        self.infer_freq(df=df)
        """Returns Y_df and X_df dataframes in the structure expected by the endpoints."""
        # add date features logic
        if isinstance(self.date_features, bool):
            if self.date_features:
                self.date_features = date_features_by_freq.get(self.freq)
                if self.date_features is None:
                    warnings.warn(
                        f"Non default date features for {self.freq} "
                        "please pass a list of date features"
                    )
            else:
                self.date_features = None

        if self.date_features is not None:
            if isinstance(self.date_features_to_one_hot, bool):
                if self.date_features_to_one_hot:
                    self.date_features_to_one_hot = [
                        feat for feat in self.date_features if not callable(feat)
                    ]
                    self.date_features_to_one_hot = (
                        None
                        if not self.date_features_to_one_hot
                        else self.date_features_to_one_hot
                    )
                else:
                    self.date_features_to_one_hot = None
            df, X_df = self.add_date_features(df=df, X_df=X_df)
        y_cols = ["unique_id", "ds", "y"]
        Y_df = df[y_cols]
        if Y_df["y"].isna().any():
            raise Exception("Your target variable contains NA, please check")
        # Azul: efficient this code
        # and think about returning dates that are not in the training set
        Y_df = self.resample_dataframe(Y_df)
        x_cols = []
        if X_df is not None:
            x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list()
            if not all(col in df.columns for col in x_cols):
                raise Exception(
                    "You must include the exogenous variables in the `df` object, "
                    f'exogenous variables {",".join(x_cols)}'
                )
            if (self.h is not None) and (
                len(X_df) != df["unique_id"].nunique() * self.h
            ):
                raise Exception(
                    f"You have to pass the {self.h} future values of your "
                    "exogenous variables for each time series"
                )
            X_df_history = df[["unique_id", "ds"] + x_cols]
            X_df = pd.concat([X_df_history, X_df])
            X_df = self.preprocess_X_df(X_df)
        elif (X_df is None) and (self.h is None) and (len(y_cols) < df.shape[1]):
            # case for just insample,
            # we dont need h
            X_df = df.drop(columns="y")
            x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list()
            X_df = self.preprocess_X_df(X_df)
        self.x_cols = x_cols
        return Y_df, X_df

    def dataframes_to_dict(self, Y_df: pd.DataFrame, X_df: pd.DataFrame):
        to_dict_args = {"orient": "split"}
        if "index" in inspect.signature(pd.DataFrame.to_dict).parameters:
            to_dict_args["index"] = False
        y = Y_df.to_dict(**to_dict_args)
        x = X_df.to_dict(**to_dict_args) if X_df is not None else None
        # A: I'm aware that sel.x_cols exists, but
        # I want to be sure that we are logging the correct
        # x cols :kiss:
        if x:
            x_cols = [col for col in x["columns"] if col not in ["unique_id", "ds"]]
            main_logger.info(
                f'Using the following exogenous variables: {", ".join(x_cols)}'
            )
        return y, x

    def set_model_params(self):
        model_params = self._call_api(
            self.client.timegpt_model_params,
            {"request": SingleSeriesForecast(freq=self.freq, model=self.model)},
        )
        model_params = model_params["detail"]
        self.input_size, self.model_horizon = (
            model_params["input_size"],
            model_params["horizon"],
        )

    def validate_input_size(self, Y_df: pd.DataFrame):
        min_history = Y_df.groupby("unique_id").size().min()
        if min_history < self.input_size + self.model_horizon:
            raise Exception(
                "Your time series data is too short "
                "Please be sure that your unique time series contain "
                f"at least {self.input_size + self.model_horizon} observations"
            )

    def forecast(
        self,
        df: pd.DataFrame,
        X_df: Optional[pd.DataFrame] = None,
        add_history: bool = False,
    ):
        df, X_df = self.transform_inputs(df=df, X_df=X_df)
        main_logger.info("Preprocessing dataframes...")
        Y_df, X_df = self.preprocess_dataframes(df=df, X_df=X_df)
        self.set_model_params()
        if self.h > self.model_horizon:
            main_logger.warning(
                'The specified horizon "h" exceeds the model horizon. '
                "This may lead to less accurate forecasts. "
                "Please consider using a smaller horizon."
            )
        # restrict input if
        # - we dont want to finetune
        # - we dont have exogenous regegressors
        # - and we dont want to produce pred intervals
        # - no add history
        restrict_input = (
            self.finetune_steps == 0
            and X_df is None
            and self.level is not None
            and not add_history
        )
        if restrict_input:
            # add sufficient info to compute
            # conformal interval
            main_logger.info("Restricting input...")
            new_input_size = 3 * self.input_size + max(self.model_horizon, self.h)
            Y_df = Y_df.groupby("unique_id").tail(new_input_size)
            if X_df is not None:
                X_df = X_df.groupby("unique_id").tail(
                    new_input_size + self.h
                )  # history plus exogenous
        if self.finetune_steps > 0 or self.level is not None:
            self.validate_input_size(Y_df=Y_df)
        y, x = self.dataframes_to_dict(Y_df, X_df)
        main_logger.info("Calling Forecast Endpoint...")
        payload = dict(
            y=y,
            x=x,
            fh=self.h,
            freq=self.freq,
            level=self.level,
            finetune_steps=self.finetune_steps,
            finetune_loss=self.finetune_loss,
            clean_ex_first=self.clean_ex_first,
            model=self.model,
        )
        response_timegpt = self._call_api(
            self.client.timegpt_multi_series,
            payload,
        )
        if "weights_x" in response_timegpt:
            self.weights_x = pd.DataFrame(
                {
                    "features": self.x_cols,
                    "weights": response_timegpt["weights_x"],
                }
            )
        fcst_df = pd.DataFrame(**response_timegpt["forecast"])
        if add_history:
            main_logger.info("Calling Historical Forecast Endpoint...")
            self.validate_input_size(Y_df=Y_df)
            response_timegpt = self._call_api(
                self.client.timegpt_multi_series_historic,
                dict(
                    y=y,
                    x=x,
                    freq=self.freq,
                    level=self.level,
                    clean_ex_first=self.clean_ex_first,
                    model=self.model,
                ),
            )
            fitted_df = pd.DataFrame(**response_timegpt["forecast"])
            fitted_df = fitted_df.drop(columns="y")
            fcst_df = pd.concat([fitted_df, fcst_df]).sort_values(["unique_id", "ds"])
        fcst_df = self.transform_outputs(fcst_df)
        return fcst_df

    def detect_anomalies(self, df: pd.DataFrame):
        # Azul
        # Remember the input X_df is the FUTURE ex vars
        # there is a misleading notation here
        # because X_df inputs in the following methods
        # returns X_df outputs that means something different
        # ie X_df = [X_df_history, X_df]
        # exogenous variables are passed after df
        df, _ = self.transform_inputs(df=df, X_df=None)
        main_logger.info("Preprocessing dataframes...")
        Y_df, X_df = self.preprocess_dataframes(df=df, X_df=None)
        main_logger.info("Calling Anomaly Detector Endpoint...")
        y, x = self.dataframes_to_dict(Y_df, X_df)
        response_timegpt = self._call_api(
            self.client.timegpt_multi_series_anomalies,
            dict(
                y=y,
                x=x,
                freq=self.freq,
                level=[self.level]
                if (isinstance(self.level, int) or isinstance(self.level, float))
                else [self.level[0]],
                clean_ex_first=self.clean_ex_first,
                model=self.model,
            ),
        )
        if "weights_x" in response_timegpt:
            self.weights_x = pd.DataFrame(
                {
                    "features": self.x_cols,
                    "weights": response_timegpt["weights_x"],
                }
            )
        anomalies_df = pd.DataFrame(**response_timegpt["forecast"])
        anomalies_df = anomalies_df.drop(columns="y")
        anomalies_df = self.transform_outputs(anomalies_df)
        return anomalies_df

    def cross_validation(
        self,
        df: pd.DataFrame,
        n_windows: int = 1,
        step_size: Optional[int] = None,
    ):
        # A: see `transform_inputs`
        # the code always will return X_df=None
        # if X_df=None
        df, _ = self.transform_inputs(df=df, X_df=None)
        self.infer_freq(df)
        df["ds"] = pd.to_datetime(df["ds"])
        # mlforecast cv code
        results = []
        sort_idxs = maybe_compute_sort_indices(df, "unique_id", "ds")
        if sort_idxs is not None:
            df = take_rows(df, sort_idxs)
        splits = backtest_splits(
            df,
            n_windows=n_windows,
            h=self.h,
            id_col="unique_id",
            time_col="ds",
            freq=pd.tseries.frequencies.to_offset(self.freq),
            step_size=self.h if step_size is None else step_size,
        )
        for i_window, (cutoffs, train, valid) in enumerate(splits):
            if len(valid.columns) > 3:
                # if we have uid, ds, y + exogenous vars
                train_future = valid.drop(columns="y")
            else:
                train_future = None
            y_pred = self.forecast(
                df=train,
                X_df=train_future,
            )
            y_pred, _ = self.transform_inputs(df=y_pred, X_df=None)
            y_pred = join(y_pred, cutoffs, on="unique_id", how="left")
            y_pred["ds"] = pd.to_datetime(y_pred["ds"])
            result = join(
                valid[["unique_id", "ds", "y"]],
                y_pred,
                on=["unique_id", "ds"],
            )
            if result.shape[0] < valid.shape[0]:
                raise ValueError(
                    "Cross validation result produced less results than expected. "
                    "Please verify that the frequency parameter (freq) matches your series' "
                    "and that there aren't any missing periods."
                )
            results.append(result)
        out = vertical_concat(results)
        out = drop_index_if_pandas(out)
        first_out_cols = ["unique_id", "ds", "cutoff", "y"]
        remaining_cols = [c for c in out.columns if c not in first_out_cols]
        fcst_cv_df = out[first_out_cols + remaining_cols]
        fcst_cv_df["ds"] = fcst_cv_df["ds"].astype(str)
        fcst_cv_df = self.transform_outputs(fcst_cv_df)
        return fcst_cv_df

# %% ../nbs/timegpt.ipynb 7
def validate_model_parameter(func):
    def wrapper(self, *args, **kwargs):
        if "model" in kwargs and kwargs["model"] not in self.supported_models:
            raise ValueError(
                f'unsupported model: {kwargs["model"]} '
                f'supported models: {", ".join(self.supported_models)}'
            )
        return func(self, *args, **kwargs)

    return wrapper

# %% ../nbs/timegpt.ipynb 8
def remove_unused_categories(df: pd.DataFrame, col: str):
    """Check if col exists in df and if it is a category column.
    In that case, it removes the unused levels."""
    if df is not None and col in df:
        if df[col].dtype == "category":
            df = df.copy()
            df[col] = df[col].cat.remove_unused_categories()
    return df

# %% ../nbs/timegpt.ipynb 9
def partition_by_uid(func):
    def wrapper(self, num_partitions, **kwargs):
        if num_partitions is None or num_partitions == 1:
            return func(self, **kwargs, num_partitions=1)
        df = kwargs.pop("df")
        X_df = kwargs.pop("X_df", None)
        id_col = kwargs["id_col"]
        uids = df["unique_id"].unique()
        results_df = []
        for uids_split in np.array_split(uids, num_partitions):
            df_uids = df.query("unique_id in @uids_split")
            if X_df is not None:
                X_df_uids = X_df.query("unique_id in @uids_split")
            else:
                X_df_uids = None
            df_uids = remove_unused_categories(df_uids, col=id_col)
            X_df_uids = remove_unused_categories(X_df_uids, col=id_col)
            kwargs_uids = {"df": df_uids, **kwargs}
            if X_df_uids is not None:
                kwargs_uids["X_df"] = X_df_uids
            results_uids = func(self, **kwargs_uids, num_partitions=1)
            results_df.append(results_uids)
        results_df = pd.concat(results_df).reset_index(drop=True)
        return results_df

    return wrapper

# %% ../nbs/timegpt.ipynb 10
class _TimeGPT:
    """
    A class used to interact with the TimeGPT API.
    """

    def __init__(
        self,
        token: Optional[str] = None,
        environment: Optional[str] = None,
        max_retries: int = 6,
        retry_interval: int = 10,
        max_wait_time: int = 6 * 60,
    ):
        """
        Constructs all the necessary attributes for the TimeGPT object.

        Parameters
        ----------
        token : str, (default=None)
            The authorization token interacts with the TimeGPT API.
            If not provided, it will be inferred by the TIMEGPT_TOKEN environment variable.
        environment : str, (default=None)
            Custom environment. Pass only if provided.
        max_retries : int, (default=6)
            The maximum number of attempts to make when calling the API before giving up.
            It defines how many times the client will retry the API call if it fails.
            Default value is 6, indicating the client will attempt the API call up to 6 times in total
        retry_interval : int, (default=10)
            The interval in seconds between consecutive retry attempts.
            This is the waiting period before the client tries to call the API again after a failed attempt.
            Default value is 10 seconds, meaning the client waits for 10 seconds between retries.
        max_wait_time : int, (default=360)
            The maximum total time in seconds that the client will spend on all retry attempts before giving up.
            This sets an upper limit on the cumulative waiting time for all retry attempts.
            If this time is exceeded, the client will stop retrying and raise an exception.
            Default value is 360 seconds, meaning the client will cease retrying if the total time
            spent on retries exceeds 360 seconds.
            The client throws a ReadTimeout error after 60 seconds of inactivity. If you want to
            catch these errors, use max_wait_time >> 60.
        """
        if token is None:
            token = os.environ.get("TIMEGPT_TOKEN")
        if token is None:
            raise Exception(
                "The token must be set either by passing `token` "
                "or by setting the TIMEGPT_TOKEN environment variable."
            )
        if environment is None:
            environment = "https://dashboard.nixtla.io/api"
        self.client = Nixtla(base_url=environment, token=token)
        self.max_retries = max_retries
        self.retry_interval = retry_interval
        self.max_wait_time = max_wait_time
        self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]
        # custom attr
        self.weights_x: pd.DataFrame = None

    def validate_token(self, log: bool = True) -> bool:
        """Returns True if your token is valid."""
        validation = self.client.validate_token()
        valid = False
        if "message" in validation:
            if validation["message"] == "success":
                valid = True
        elif "detail" in validation:
            if "Forecasting! :)" in validation["detail"]:
                valid = True
        if "support" in validation and log:
            main_logger.info(f'Happy Forecasting! :), {validation["support"]}')
        return valid

    @validate_model_parameter
    @partition_by_uid
    def _forecast(
        self,
        df: pd.DataFrame,
        h: int,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        X_df: Optional[pd.DataFrame] = None,
        level: Optional[List[Union[int, float]]] = None,
        finetune_steps: int = 0,
        finetune_loss: str = "default",
        clean_ex_first: bool = True,
        validate_token: bool = False,
        add_history: bool = False,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: int = 1,
    ):
        if validate_token and not self.validate_token(log=False):
            raise Exception("Token not valid, please email ops@nixtla.io")
        timegpt_model = _TimeGPTModel(
            client=self.client,
            h=h,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            freq=freq,
            level=level,
            finetune_steps=finetune_steps,
            finetune_loss=finetune_loss,
            clean_ex_first=clean_ex_first,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            max_retries=self.max_retries,
            retry_interval=self.retry_interval,
            max_wait_time=self.max_wait_time,
        )
        fcst_df = timegpt_model.forecast(df=df, X_df=X_df, add_history=add_history)
        self.weights_x = timegpt_model.weights_x
        return fcst_df

    @validate_model_parameter
    @partition_by_uid
    def _detect_anomalies(
        self,
        df: pd.DataFrame,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        level: Union[int, float] = 99,
        clean_ex_first: bool = True,
        validate_token: bool = False,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: int = 1,
    ):
        if validate_token and not self.validate_token(log=False):
            raise Exception("Token not valid, please email ops@nixtla.io")
        timegpt_model = _TimeGPTModel(
            client=self.client,
            h=None,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            freq=freq,
            level=level,
            clean_ex_first=clean_ex_first,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            max_retries=self.max_retries,
            retry_interval=self.retry_interval,
            max_wait_time=self.max_wait_time,
        )
        anomalies_df = timegpt_model.detect_anomalies(df=df)
        self.weights_x = timegpt_model.weights_x
        return anomalies_df

    @validate_model_parameter
    @partition_by_uid
    def _cross_validation(
        self,
        df: pd.DataFrame,
        h: int,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        level: Optional[List[Union[int, float]]] = None,
        validate_token: bool = False,
        n_windows: int = 1,
        step_size: Optional[int] = None,
        finetune_steps: int = 0,
        finetune_loss: str = "default",
        clean_ex_first: bool = True,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: int = 1,
    ):
        if validate_token and not self.validate_token(log=False):
            raise Exception("Token not valid, please email ops@nixtla.io")
        timegpt_model = _TimeGPTModel(
            client=self.client,
            h=h,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            freq=freq,
            level=level,
            finetune_steps=finetune_steps,
            finetune_loss=finetune_loss,
            clean_ex_first=clean_ex_first,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            max_retries=self.max_retries,
            retry_interval=self.retry_interval,
            max_wait_time=self.max_wait_time,
        )
        cv_df = timegpt_model.cross_validation(
            df=df, n_windows=n_windows, step_size=step_size
        )
        self.weights_x = timegpt_model.weights_x
        return cv_df

    def plot(
        self,
        df: pd.DataFrame,
        forecasts_df: Optional[pd.DataFrame] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        unique_ids: Union[Optional[List[str]], np.ndarray] = None,
        plot_random: bool = True,
        models: Optional[List[str]] = None,
        level: Optional[List[float]] = None,
        max_insample_length: Optional[int] = None,
        plot_anomalies: bool = False,
        engine: str = "matplotlib",
        resampler_kwargs: Optional[Dict] = None,
    ):
        """Plot forecasts and insample values.

        Parameters
        ----------
        df : pandas.DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        forecasts_df : pandas.DataFrame, optional (default=None)
            DataFrame with columns [`unique_id`, `ds`] and models.
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        unique_ids : List[str], optional (default=None)
            Time Series to plot.
            If None, time series are selected randomly.
        plot_random : bool (default=True)
            Select time series to plot randomly.
        models : List[str], optional (default=None)
            List of models to plot.
        level : List[float], optional (default=None)
            List of prediction intervals to plot if paseed.
        max_insample_length : int, optional (default=None)
            Max number of train/insample observations to be plotted.
        plot_anomalies : bool (default=False)
            Plot anomalies for each prediction interval.
        engine : str (default='plotly')
            Library used to plot. 'plotly', 'plotly-resampler' or 'matplotlib'.
        resampler_kwargs : dict
            Kwargs to be passed to plotly-resampler constructor.
            For further custumization ("show_dash") call the method,
            store the plotting object and add the extra arguments to
            its `show_dash` method.
        """
        try:
            from utilsforecast.plotting import plot_series
        except ModuleNotFoundError:
            raise Exception(
                "You have to install additional dependencies to use this method, "
                'please install them using `pip install "nixtlats[plotting]"`'
            )
        df = df.copy()
        if id_col not in df:
            df[id_col] = "ts_0"
        df[time_col] = pd.to_datetime(df[time_col])
        if forecasts_df is not None:
            forecasts_df = forecasts_df.copy()
            if id_col not in forecasts_df:
                forecasts_df[id_col] = "ts_0"
            forecasts_df[time_col] = pd.to_datetime(forecasts_df[time_col])
            if "anomaly" in forecasts_df:
                # special case to plot outputs
                # from detect_anomalies
                forecasts_df = forecasts_df.drop(columns="anomaly")
                cols = forecasts_df.columns
                cols = cols[cols.str.contains("TimeGPT-lo-")]
                level = cols.str.replace("TimeGPT-lo-", "")[0]
                level = float(level) if "." in level else int(level)
                level = [level]
                plot_anomalies = True
                models = ["TimeGPT"]
                forecasts_df = df.merge(forecasts_df, how="left")
                df = df.groupby("unique_id").head(1)
                # prevent double plotting
                df.loc[:, target_col] = np.nan
        return plot_series(
            df=df,
            forecasts_df=forecasts_df,
            ids=unique_ids,
            plot_random=plot_random,
            models=models,
            level=level,
            max_insample_length=max_insample_length,
            plot_anomalies=plot_anomalies,
            engine=engine,
            resampler_kwargs=resampler_kwargs,
            palette="tab20b",
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )

# %% ../nbs/timegpt.ipynb 11
class TimeGPT(_TimeGPT):
    def _instantiate_distributed_timegpt(self):
        from nixtlats.distributed.timegpt import _DistributedTimeGPT

        dist_timegpt = _DistributedTimeGPT(
            token=self.client._client_wrapper._token,
            environment=self.client._client_wrapper._base_url,
            max_retries=self.max_retries,
            retry_interval=self.retry_interval,
            max_wait_time=self.max_wait_time,
        )
        return dist_timegpt

    def forecast(
        self,
        df: pd.DataFrame,
        h: int,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        X_df: Optional[pd.DataFrame] = None,
        level: Optional[List[Union[int, float]]] = None,
        finetune_steps: int = 0,
        finetune_loss: str = "default",
        clean_ex_first: bool = True,
        validate_token: bool = False,
        add_history: bool = False,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: Optional[int] = None,
    ):
        """Forecast your time series using TimeGPT.

        Parameters
        ----------
        df : pandas.DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        h : int
            Forecast horizon.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        X_df : pandas.DataFrame, optional (default=None)
            DataFrame with [`unique_id`, `ds`] columns and `df`'s future exogenous.
        level : List[float], optional (default=None)
            Confidence levels between 0 and 100 for prediction intervals.
        finetune_steps : int (default=0)
            Number of steps used to finetune TimeGPT in the
            new data.
        finetune_loss : str (default='default')
            Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts
            using TimeGPT.
        validate_token : bool (default=False)
            If True, validates token before
            sending requests.
        add_history : bool (default=False)
            Return fitted values of the model.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates.
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=True)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt=1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`.
            We recommend using `timegpt-1-long-horizon` for forecasting
            if you want to predict more than one seasonal
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.

        Returns
        -------
        fcsts_df : pandas.DataFrame
            DataFrame with TimeGPT forecasts for point predictions and probabilistic
            predictions (if level is not None).
        """
        if isinstance(df, pd.DataFrame):
            return self._forecast(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                X_df=X_df,
                level=level,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                add_history=add_history,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )
        else:
            dist_timegpt = self._instantiate_distributed_timegpt()
            return dist_timegpt.forecast(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                X_df=X_df,
                level=level,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                add_history=add_history,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )

    def detect_anomalies(
        self,
        df: pd.DataFrame,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        level: Union[int, float] = 99,
        clean_ex_first: bool = True,
        validate_token: bool = False,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: Optional[int] = None,
    ):
        """Detect anomalies in your time series using TimeGPT.

        Parameters
        ----------
        df : pandas.DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        level : float (default=99)
            Confidence level between 0 and 100 for detecting the anomalies.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts
            using TimeGPT.
        validate_token : bool (default=False)
            If True, validates token before
            sending requests.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates.
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=True)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt=1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`.
            We recommend using `timegpt-1-long-horizon` for forecasting
            if you want to predict more than one seasonal
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.

        Returns
        -------
        anomalies_df : pandas.DataFrame
            DataFrame with anomalies flagged with 1 detected by TimeGPT.
        """
        if isinstance(df, pd.DataFrame):
            return self._detect_anomalies(
                df=df,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )
        else:
            dist_timegpt = self._instantiate_distributed_timegpt()
            return dist_timegpt.detect_anomalies(
                df=df,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )

    def cross_validation(
        self,
        df: pd.DataFrame,
        h: int,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        level: Optional[List[Union[int, float]]] = None,
        validate_token: bool = False,
        n_windows: int = 1,
        step_size: Optional[int] = None,
        finetune_steps: int = 0,
        finetune_loss: str = "default",
        clean_ex_first: bool = True,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = True,
        model: str = "timegpt-1",
        num_partitions: Optional[int] = None,
    ):
        """Perform cross validation in your time series using TimeGPT.

        Parameters
        ----------
        df : pandas.DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        h : int
            Forecast horizon.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        level : float (default=99)
            Confidence level between 0 and 100 for detecting the anomalies.
        validate_token : bool (default=False)
            If True, validates token before
            sending requests.
        n_windows : int (defaul=1)
            Number of windows to evaluate.
        step_size : int, optional (default=None)
            Step size between each cross validation window. If None it will be equal to `h`.
        finetune_steps : int (default=0)
            Number of steps used to finetune TimeGPT in the
            new data.
        finetune_loss : str (default='default')
            Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts
            using TimeGPT.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates.
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=True)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt=1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`.
            We recommend using `timegpt-1-long-horizon` for forecasting
            if you want to predict more than one seasonal
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.

        Returns
        -------
        cv_df : pandas.DataFrame
            DataFrame with cross validation forecasts.
        """
        if isinstance(df, pd.DataFrame):
            return self._cross_validation(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                n_windows=n_windows,
                step_size=step_size,
                num_partitions=num_partitions,
            )
        else:
            dist_timegpt = self._instantiate_distributed_timegpt()
            return dist_timegpt.cross_validation(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                validate_token=validate_token,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
                n_windows=n_windows,
                step_size=step_size,
            )
