# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/ces.ipynb.

# %% auto 0
__all__ = ['ces_target_fn']

# %% ../nbs/ces.ipynb 1
import math
import os

import numpy as np
from numba import njit
from .ets import nelder_mead
from statsmodels.tsa.seasonal import seasonal_decompose

# %% ../nbs/ces.ipynb 4
# Global variables
NONE = 0
SIMPLE = 1
PARTIAL = 2
FULL = 3
TOL = 1.0e-10
HUGEN = 1.0e10
NA = -99999.0
smalno = np.finfo(float).eps
NOGIL = os.environ.get("NUMBA_RELEASE_GIL", "False").lower() in ["true"]
CACHE = os.environ.get("NUMBA_CACHE", "False").lower() in ["true"]

# %% ../nbs/ces.ipynb 6
def initstate(y, m, seasontype):
    n = len(y)
    components = 2 + (seasontype == "P") + 2 * (seasontype == "F")
    lags = 1 if seasontype == "N" else m
    states = np.zeros((lags, components), dtype=np.float32)
    if seasontype == "N":
        idx = min(max(10, m), n)
        mean_ = np.mean(y[:idx])
        states[0, 0] = mean_
        states[0, 1] = mean_ / 1.1
    elif seasontype == "S":
        states[:lags, 0] = y[:lags]
        states[:lags, 1] = y[:lags] / 1.1
    elif seasontype == "P":
        states[:lags, 0] = np.mean(y[:lags])
        states[:lags, 1] = states[:lags, 0] / 1.1
        states[:lags, 2] = seasonal_decompose(y, period=lags).seasonal[:lags]
    elif seasontype == "F":
        states[:lags, 0] = np.mean(y[:lags])
        states[:lags, 1] = states[:lags, 0] / 1.1
        states[:lags, 2] = seasonal_decompose(y, period=lags).seasonal[:lags]
        states[:lags, 3] = states[:lags, 2] / 1.1
    else:
        raise Exception(f"Unkwon seasontype: {seasontype}")

    return states

# %% ../nbs/ces.ipynb 8
@njit(nogil=NOGIL, cache=CACHE)
def cescalc(
    y: np.ndarray,
    states: np.ndarray,  # states
    m: int,
    season: int,
    alpha_0: float,
    alpha_1: float,
    beta_0: float,
    beta_1: float,
    e: np.ndarray,
    amse: np.ndarray,
    nmse: int,
    backfit: int,
) -> float:
    denom = np.zeros(nmse)
    m = 1 if season == NONE else m
    f = np.zeros(max(nmse, m))
    lik = 0.0
    lik2 = 0.0
    amse[:nmse] = 0.0
    n = len(y)
    for i in range(m, n + m):
        # one step forecast
        cesfcst(states, i, m, season, f, nmse, alpha_0, alpha_1, beta_0, beta_1)
        if math.fabs(f[0] - NA) < TOL:
            lik = NA
            return lik
        e[i - m] = y[i - m] - f[0]
        for j in range(nmse):
            if (i + j) < n:
                denom[j] += 1.0
                tmp = y[i + j] - f[j]
                amse[j] = (amse[j] * (denom[j] - 1.0) + (tmp * tmp)) / denom[j]
        # update state
        cesupdate(states, i, m, season, alpha_0, alpha_1, beta_0, beta_1, y[i - m])
        lik = lik + e[i - m] * e[i - m]
        lik2 += math.log(math.fabs(f[0]))
    new_states = cesfcst(
        states, n + m, m, season, f, m, alpha_0, alpha_1, beta_0, beta_1
    )
    states[-m:] = new_states[-m:]
    lik = n * math.log(lik)
    if not backfit:
        return lik
    y[:] = y[::-1]
    states[:] = states[::-1]
    e[:] = e[::-1]
    lik = 0.0
    lik2 = 0.0
    for i in range(m, n + m):
        # one step forecast
        cesfcst(states, i, m, season, f, nmse, alpha_0, alpha_1, beta_0, beta_1)
        if math.fabs(f[0] - NA) < TOL:
            lik = NA
            return lik
        e[i - m] = y[i - m] - f[0]
        for j in range(nmse):
            if (i + j) < n:
                denom[j] += 1.0
                tmp = y[i + j] - f[j]
                amse[j] = (amse[j] * (denom[j] - 1.0) + (tmp * tmp)) / denom[j]
        # update state
        cesupdate(states, i, m, season, alpha_0, alpha_1, beta_0, beta_1, y[i - m])
        lik = lik + e[i - m] * e[i - m]
        lik2 += math.log(math.fabs(f[0]))
    new_states = cesfcst(
        states, n + m, m, season, f, m, alpha_0, alpha_1, beta_0, beta_1
    )
    states[-m:] = new_states[-m:]
    # fit again
    lik = 0.0
    lik2 = 0.0
    y[:] = y[::-1]
    states[:] = states[::-1]
    e[:] = e[::-1]
    for i in range(m, n + m):
        # one step forecast
        cesfcst(states, i, m, season, f, nmse, alpha_0, alpha_1, beta_0, beta_1)
        if math.fabs(f[0] - NA) < TOL:
            lik = NA
            return lik
        e[i - m] = y[i - m] - f[0]
        for j in range(nmse):
            if (i + j) < n:
                denom[j] += 1.0
                tmp = y[i + j] - f[j]
                amse[j] = (amse[j] * (denom[j] - 1.0) + (tmp * tmp)) / denom[j]
        # update state
        cesupdate(states, i, m, season, alpha_0, alpha_1, beta_0, beta_1, y[i - m])
        lik = lik + e[i - m] * e[i - m]
        lik2 += math.log(math.fabs(f[0]))
    new_states = cesfcst(
        states, n + m, m, season, f, m, alpha_0, alpha_1, beta_0, beta_1
    )
    states[-m:] = new_states[-m:]
    lik = n * math.log(lik)
    return lik

# %% ../nbs/ces.ipynb 9
@njit(nogil=NOGIL, cache=CACHE)
def cesfcst(states, i, m, season, f, h, alpha_0, alpha_1, beta_0, beta_1):
    # obs:
    # forecast are obtained in a recursive manner
    # this is not standard, for example in ets
    # forecasts
    new_states = np.zeros((m + h, states.shape[1]), dtype=np.float32)
    new_states[:m] = states[(i - m) : i]
    for i_h in range(m, m + h):
        if season in [NONE, PARTIAL, FULL]:
            f[i_h - m] = new_states[i_h - 1, 0]
        else:
            f[i_h - m] = new_states[i_h - m, 0]
        if season > SIMPLE:
            f[i_h - m] += new_states[i_h - m, 2]
        cesupdate(
            new_states, i_h, m, season, alpha_0, alpha_1, beta_0, beta_1, f[i_h - m]
        )
    return new_states

# %% ../nbs/ces.ipynb 10
@njit(nogil=NOGIL, cache=CACHE)
def cesupdate(
    states, i, m, season, alpha_0, alpha_1, beta_0, beta_1, y  # kind of season
):
    # season
    if season in [NONE, PARTIAL, FULL]:
        e = y - states[i - 1, 0]
    else:
        e = y - states[i - m, 0]
    if season > SIMPLE:
        e -= states[i - m, 2]

    if season in [NONE, PARTIAL, FULL]:
        states[i, 0] = (
            states[i - 1, 0]
            - (1.0 - alpha_1) * states[i - 1, 1]
            + (alpha_0 - alpha_1) * e
        )
        states[i, 1] = (
            states[i - 1, 0]
            + (1.0 - alpha_0) * states[i - 1, 1]
            + (alpha_0 + alpha_1) * e
        )
    else:
        states[i, 0] = (
            states[i - m, 0]
            - (1.0 - alpha_1) * states[i - m, 1]
            + (alpha_0 - alpha_1) * e
        )
        states[i, 1] = (
            states[i - m, 0]
            + (1.0 - alpha_0) * states[i - m, 1]
            + (alpha_0 + alpha_1) * e
        )

    if season == PARTIAL:
        states[i, 2] = states[i - m, 2] + beta_0 * e
    if season == FULL:
        states[i, 2] = (
            states[i - m, 2] - (1 - beta_1) * states[i - m, 3] + (beta_0 - beta_1) * e
        )
        states[i, 3] = (
            states[i - m, 2] + (1 - beta_0) * states[i - m, 3] + (beta_0 + beta_1) * e
        )

# %% ../nbs/ces.ipynb 11
@njit(nogil=NOGIL, cache=CACHE)
def cesforecast(states, n, m, season, f, h, alpha_0, alpha_1, beta_0, beta_1):
    # compute forecasts
    m = 1 if season == NONE else m
    new_states = cesfcst(
        states=states,
        i=m + n,
        m=m,
        season=season,
        f=f,
        h=h,
        alpha_0=alpha_0,
        alpha_1=alpha_1,
        beta_0=beta_0,
        beta_1=beta_1,
    )
    return new_states

# %% ../nbs/ces.ipynb 20
@njit(nogil=NOGIL, cache=CACHE)
def initparamces(
    alpha_0: float, alpha_1: float, beta_0: float, beta_1: float, seasontype: str
):
    if np.isnan(alpha_0):
        alpha_0 = 1.3
        optimize_alpha_0 = 1
    else:
        optimize_alpha_0 = 0
    if np.isnan(alpha_1):
        alpha_1 = 1.0
        optimize_alpha_1 = 1
    else:
        optimize_alpha_1 = 0
    if seasontype == "P":
        if np.isnan(beta_0):
            beta_0 = 0.1
            optimize_beta_0 = 1
        else:
            optimize_beta_0 = 0
        beta_1 = np.nan  # no optimize
        optimize_beta_1 = 0
    elif seasontype == "F":
        if np.isnan(beta_0):
            beta_0 = 1.3
            optimize_beta_0 = 1
        else:
            optimize_beta_0 = 0
        if np.isnan(beta_1):
            beta_1 = 1.0
            optimize_beta_1 = 1
        else:
            optimize_beta_1 = 0
    else:
        # no optimize
        optimize_beta_0 = 0
        optimize_beta_1 = 0
        beta_0 = np.nan
        beta_1 = np.nan
    return {
        "alpha_0": alpha_0,
        "optimize_alpha_0": optimize_alpha_0,
        "alpha_1": alpha_1,
        "optimize_alpha_1": optimize_alpha_1,
        "beta_0": beta_0,
        "optimize_beta_0": optimize_beta_0,
        "beta_1": beta_1,
        "optimize_beta_1": optimize_beta_1,
    }

# %% ../nbs/ces.ipynb 22
@njit(nogil=NOGIL, cache=CACHE)
def switch_ces(x: str):
    return {"N": 0, "S": 1, "P": 2, "F": 3}[x]

# %% ../nbs/ces.ipynb 24
@njit(nogil=NOGIL, cache=CACHE)
def pegelsresid_ces(
    y: np.ndarray,
    m: int,
    init_states: np.ndarray,
    n_components: int,
    seasontype: str,
    alpha_0: float,
    alpha_1: float,
    beta_0: float,
    beta_1: float,
    nmse: int,
):
    states = np.zeros((len(y) + 2 * m, n_components), dtype=np.float32)
    states[:m] = init_states
    e = np.full_like(y, fill_value=np.nan)
    amse = np.full(nmse, fill_value=np.nan)
    lik = cescalc(
        y=y,
        states=states,
        m=m,
        season=switch_ces(seasontype),
        alpha_0=alpha_0,
        alpha_1=alpha_1,
        beta_0=beta_0,
        beta_1=beta_1,
        e=e,
        amse=amse,
        nmse=nmse,
        backfit=1,
    )
    if not np.isnan(lik):
        if np.abs(lik + 99999) < 1e-7:
            lik = np.nan
    return amse, e, states, lik

# %% ../nbs/ces.ipynb 25
@njit(nogil=NOGIL, cache=CACHE)
def ces_target_fn(
    optimal_param,
    init_alpha_0,
    init_alpha_1,
    init_beta_0,
    init_beta_1,
    opt_alpha_0,
    opt_alpha_1,
    opt_beta_0,
    opt_beta_1,
    y,
    m,
    init_states,
    n_components,
    seasontype,
    nmse,
):
    states = np.zeros((len(y) + 2 * m, n_components), dtype=np.float32)
    states[:m] = init_states
    j = 0
    if opt_alpha_0:
        alpha_0 = optimal_param[j]
        j += 1
    else:
        alpha_0 = init_alpha_0

    if opt_alpha_1:
        alpha_1 = optimal_param[j]
        j += 1
    else:
        alpha_1 = init_alpha_1

    if opt_beta_0:
        beta_0 = optimal_param[j]
        j += 1
    else:
        beta_0 = init_beta_0

    if opt_beta_1:
        beta_1 = optimal_param[j]
        j += 1
    else:
        beta_1 = init_beta_1

    e = np.full_like(y, fill_value=np.nan)
    amse = np.full(nmse, fill_value=np.nan)
    lik = cescalc(
        y=y,
        states=states,
        m=m,
        season=switch_ces(seasontype),
        alpha_0=alpha_0,
        alpha_1=alpha_1,
        beta_0=beta_0,
        beta_1=beta_1,
        e=e,
        amse=amse,
        nmse=nmse,
        backfit=1,
    )
    if lik < -1e10:
        lik = -1e10
    if math.isnan(lik):
        lik = -np.inf
    if math.fabs(lik + 99999) < 1e-7:
        lik = -np.inf
    return lik

# %% ../nbs/ces.ipynb 26
def optimize_ces_target_fn(
    init_par, optimize_params, y, m, init_states, n_components, seasontype, nmse
):
    x0 = [init_par[key] for key, val in optimize_params.items() if val]
    x0 = np.array(x0, dtype=np.float32)
    if not len(x0):
        return

    init_alpha_0 = init_par["alpha_0"]
    init_alpha_1 = init_par["alpha_1"]
    init_beta_0 = init_par["beta_0"]
    init_beta_1 = init_par["beta_1"]

    opt_alpha_0 = optimize_params["alpha_0"]
    opt_alpha_1 = optimize_params["alpha_1"]
    opt_beta_0 = optimize_params["beta_0"]
    opt_beta_1 = optimize_params["beta_1"]

    res = nelder_mead(
        ces_target_fn,
        x0,
        args=(
            init_alpha_0,
            init_alpha_1,
            init_beta_0,
            init_beta_1,
            opt_alpha_0,
            opt_alpha_1,
            opt_beta_0,
            opt_beta_1,
            y,
            m,
            init_states,
            n_components,
            seasontype,
            nmse,
        ),
        tol_std=1e-4,
        lower=np.array([0.01, 0.01, 0.01, 0.01]),
        upper=np.array([1.8, 1.9, 1.5, 1.5]),
        max_iter=1_000,
        adaptive=True,
    )
    return res

# %% ../nbs/ces.ipynb 27
def cesmodel(
    y: np.ndarray,
    m: int,
    seasontype: str,
    alpha_0: float,
    alpha_1: float,
    beta_0: float,
    beta_1: float,
    nmse: int,
):
    if seasontype == "N":
        m = 1
    # initial parameters
    par = initparamces(alpha_0, alpha_1, beta_1, beta_0, seasontype)
    optimize_params = {
        key.replace("optimize_", ""): val for key, val in par.items() if "optim" in key
    }
    par = {key: val for key, val in par.items() if "optim" not in key}
    # initial states
    init_state = initstate(y, m, seasontype)
    n_components = init_state.shape[1]
    # parameter optimization
    fred = optimize_ces_target_fn(
        init_par=par,
        optimize_params=optimize_params,
        y=y,
        m=m,
        init_states=init_state,
        n_components=n_components,
        seasontype=seasontype,
        nmse=nmse,
    )
    if fred is not None:
        fit_par = fred.x
    j = 0
    if optimize_params["alpha_0"]:
        par["alpha_0"] = fit_par[j]
        j += 1
    if optimize_params["alpha_1"]:
        par["alpha_1"] = fit_par[j]
        j += 1
    if optimize_params["beta_0"]:
        par["beta_0"] = fit_par[j]
        j += 1
    if optimize_params["beta_1"]:
        par["beta_1"] = fit_par[j]
        j += 1

    amse, e, states, lik = pegelsresid_ces(
        y=y,
        m=m,
        init_states=init_state,
        n_components=n_components,
        seasontype=seasontype,
        nmse=nmse,
        **par
    )
    np_ = n_components + 1
    ny = len(y)
    aic = lik + 2 * np_
    bic = lik + np.log(ny) * np_
    if ny - np_ - 1 != 0.0:
        aicc = aic + 2 * np_ * (np_ + 1) / (ny - np_ - 1)
    else:
        aicc = np.inf

    mse = amse[0]
    amse = np.mean(amse)

    return dict(
        loglik=-0.5 * lik,
        aic=aic,
        bic=bic,
        aicc=aicc,
        mse=mse,
        amse=amse,
        fit=fred,
        residuals=e,
        m=m,
        states=states,
        par=par,
        n=len(y),
        seasontype=seasontype,
    )

# %% ../nbs/ces.ipynb 29
def pegelsfcast_C(h, obj, npaths=None, level=None, bootstrap=None):
    forecast = np.full(h, fill_value=np.nan)
    m = obj["m"]
    n = obj["n"]
    states = obj["states"]
    cesforecast(
        states=states,
        n=n,
        m=m,
        season=switch_ces(obj["seasontype"]),
        h=h,
        f=forecast,
        **obj["par"]
    )
    return forecast

# %% ../nbs/ces.ipynb 30
def forecast_ces(obj, h):
    fcst = pegelsfcast_C(h, obj)
    out = {"mean": fcst}
    return out

# %% ../nbs/ces.ipynb 32
def auto_ces(
    y,
    m,
    model="Z",
    alpha_0=None,
    alpha_1=None,
    beta_0=None,
    beta_1=None,
    opt_crit="lik",
    nmse=3,
    ic="aicc",
):
    # converting params to floats
    # to improve numba compilation
    if alpha_0 is None:
        alpha_0 = np.nan
    if alpha_1 is None:
        alpha_1 = np.nan
    if beta_0 is None:
        beta_0 = np.nan
    if beta_1 is None:
        beta_1 = np.nan
    if nmse < 1 or nmse > 30:
        raise ValueError("nmse out of range")
    # refit model not implement yet
    if model not in ["Z", "N", "S", "P", "F"]:
        raise ValueError("Invalid model type")

    seasontype = model
    if m < 1 or len(y) <= m or m == 1:
        seasontype = "N"
    n = len(y)
    npars = 2
    if seasontype == "P":
        npars += 1
    if seasontype in ["F", "Z"]:
        npars += 2
    # ses for non-optimized tiny datasets
    if n <= npars:
        # we need HoltWintersZZ function
        raise NotImplementedError("tiny datasets")
    if seasontype == "Z":
        seasontype = ["N", "S", "P", "F"]
    best_ic = np.inf
    for stype in seasontype:
        fit = cesmodel(
            y=y,
            m=m,
            seasontype=stype,
            alpha_0=alpha_0,
            alpha_1=alpha_1,
            beta_0=beta_0,
            beta_1=beta_1,
            nmse=nmse,
        )
        fit_ic = fit[ic]
        if not np.isnan(fit_ic):
            if fit_ic < best_ic:
                model = fit
                best_ic = fit_ic
    if np.isinf(best_ic):
        raise Exception("no model able to be fitted")
    return model
