# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/arima.ipynb (unless otherwise specified).

__all__ = ['auto_arima_f', 'predict_arima']

# Cell
import math
import os
import warnings
from collections import namedtuple
from functools import partial

import numpy as np
import pandas as pd
import statsmodels.api as sm
from numba import njit
from scipy.optimize import minimize

from .utils import AirPassengers as ap

# Internal Cell
OptimResult = namedtuple('OptimResult', 'success status x fun')

# Internal Cell
@njit
def partrans(p, raw, new):
    if p > 100:
        raise ValueError('can only transform 100 pars in arima0')

    new[:p] = np.tanh(raw[:p])
    work = new[:p].copy()

    for j in range(1, p):
        a = new[j]
        for k in range(j):
            work[k] -= a * new[j - k - 1]
        new[:j] = work[:j]

# Internal Cell
@njit
def arima_gradtrans(x, arma):
    eps = 1e-3
    mp, mq, msp = arma[:3]
    n = len(x)
    y = np.identity(n)
    w1 = np.empty(100)
    w2 = np.empty(100)
    w3 = np.empty(100)
    if mp > 0:
        for i in range(mp):
            w1[i] = x[i]
        partrans(mp, w1, w2)
        for i in range(mp):
            w1[i] += eps
            partrans(mp, w1, w3)
            for j in range(mp):
                y[i, j] = (w3[j] - w2[j]) / eps
            w1[i] -= eps
    if msp > 0:
        v = mp + mq
        for i in range(msp):
            w1[i] = x[i + v]
        partrans(msp, w1, w2)
        for j in range(msp):
            w1[i] += eps
            partrans(msp, w1, w3)
            y[i + v, j + v] = (w3[j] - w2[j]) / eps
            w1[i] -= eps
    return y

# Internal Cell
@njit
def arima_undopars(x, arma):
    mp, mq, msp = arma[:3]
    n = len(x)

    res = x.copy()
    if mp > 0:
        partrans(mp, x, res)
    v = mp + mq
    if msp > 0:
        partrans(msp, x[v:], res[v:])
    return res

# Internal Cell
@njit
def tsconv(a, b):
    na = len(a)
    nb = len(b)

    nab = na + nb - 1
    ab = np.zeros(nab)

    for i in range(na):
        for j in range(nb):
            ab[i + j] += a[i] * b[j]

    return ab

# Internal Cell
@njit
def inclu2(np_, xnext, xrow, ynext, d, rbar, thetab):
    for i in range(np_):
        xrow[i] = xnext[i]

    ithisr = 0
    for i in range(np_):
        if xrow[i] != 0.:
            xi = xrow[i]
            di = d[i]
            dpi = di + xi * xi
            d[i] = dpi
            cbar = di / dpi if dpi != 0. else math.inf
            sbar = xi / dpi  if dpi != 0. else math.inf
            for k in range(i + 1, np_):
                xk = xrow[k]
                rbthis = rbar[ithisr]
                xrow[k] = xk - xi * rbthis
                rbar[ithisr] = cbar * rbthis + sbar * xk
                ithisr += 1
            xk = ynext
            ynext = xk - xi * thetab[i]
            thetab[i] = cbar * thetab[i] + sbar * xk
            if di == 0.:
                return
        else:
            ithisr = ithisr + np_ - i - 1

# Internal Cell
@njit
def invpartrans(p, phi, new):
    if p > 100:
        raise ValueError('can only transform 100 pars in arima0')

    new = phi[:p].copy()
    work = new.copy()
    for k in range(p-1):
        j = p - k - 1
        a = new[j]
        for k in range(j):
            work[k] = (new[k] + a * new[j - k - 1]) / (1 - a * a)
        for k in range(j):
            new[k] = work[k]
    for j in range(p):
        new[j] = math.atanh(new[j])

# Internal Cell
@njit
def arima_undopars(x, arma):
    mp, mq, msp = arma[:3]
    n = len(x)

    res = x.copy()
    if mp > 0:
        partrans(mp, x, res)
    v = mp + mq
    if msp > 0:
        partrans(msp, x[v:], res[v:])
    return res

# Internal Cell
@njit
def ARIMA_invtrans(x, arma):
    mp, mq, msp = arma[:3]
    n = len(x)
    y = x.copy()

    if mp > 0:
        invpartrans(mp, x, y)
    v = mp + mq
    if msp > 0:
        invpartrans(msp, x[v:], y[v:])
    return y

# Internal Cell
@njit
def getQ0(phi, theta):
    p = len(phi)
    q = len(theta)
    r = max(p, q + 1)

    np_ = r * (r + 1) // 2
    nrbar = np_ * (np_ - 1) // 2

    V = np.zeros(np_)
    ind = 0
    for j in range(r):
        vj = 0.
        if j == 0:
            vj = 1.
        elif j - 1 < q:
            vj = theta[j - 1]

        for i in range(j, r):
            vi = 0.
            if i == 0:
                vi = 1.0
            elif i - 1 < q:
                vi = theta[i - 1]
            V[ind] = vi * vj
            ind += 1

    res = np.zeros((r, r))
    res = res.flatten()

    if r == 1:
        if p == 0:
            res[0] = 1.
        else:
            res[0] = 1. / (1. - phi[0] * phi[0])

        res = res.reshape((r, r))
        return res

    if p > 0:
        rbar = np.zeros(nrbar)
        thetab = np.zeros(np_)
        xnext = np.zeros(np_)
        xrow = np.zeros(np_)

        ind = 0
        ind1 = -1
        npr = np_ - r
        npr1 = npr + 1
        indj = npr
        ind2 = npr - 1

        for j in range(r):
            phij = phi[j] if j < p else 0.
            xnext[indj] = 0.
            indj += 1
            indi = npr1 + j
            for i in range(j, r):
                ynext = V[ind]
                ind += 1
                phii = phi[i] if i < p else 0.
                if j != r - 1:
                    xnext[indj] = -phii
                    if i != r - 1:
                        xnext[indi] -= phij
                        ind1 += 1
                        xnext[ind1] = -1.
                xnext[npr] = -phii * phij
                ind2 += 1
                if ind2 >= np_:
                    ind2 = 0
                xnext[ind2] += 1.
                inclu2(np_, xnext, xrow, ynext, res, rbar, thetab)
                xnext[ind2] = 0.
                if i != r - 1:
                    xnext[indi] = 0.
                    indi += 1
                    xnext[ind1] = 0.

        ithisr = nrbar - 1
        im = np_ - 1
        for i in range(np_):
            bi = thetab[im]
            jm = np_ - 1
            for j in range(i):
                bi -= rbar[ithisr] * res[jm]
                ithisr -= 1
                jm -= 1
            res[im] = bi
            im -= 1

        # Now reorder p
        ind = npr
        for i in range(r):
            xnext[i] = res[ind]
            ind += 1
        ind = np_ - 1
        ind1 = npr - 1
        for i in range(npr):
            res[ind] = res[ind1]
            ind -= 1
            ind1 -= 1
        for i in range(r):
            res[i] = xnext[i]
    else:
        indn = np_
        ind = np_
        for i in range(r):
            for j in range(i + 1):
                ind -= 1
                res[ind] = V[ind]
                if j != 0:
                    indn -= 1
                    res[ind] += res[ind]

    # Unpack to a full matrix
    ind = np_
    for i in range(r - 1, 0, -1):
        for j in range(r - 1, i - 1, -1):
            ind -= 1
            res[r * i + j] = res[ind]

    for i in range(r - 1):
        for j in range(i + 1, r):
            res[i + r * j] = res[j + r * i]

    res = res.reshape((r, r))
    return res

# Internal Cell
@njit
def arima_transpar(params_in, arma, trans):
    #TODO check trans=True results
    mp, mq, msp, msq, ns = arma[:5]
    p = mp + ns * msp
    q = mq + ns * msq

    phi = np.zeros(p)
    theta = np.zeros(q)
    params = params_in.copy()

    if trans:
        n = mp + mq + msp + msq
        if mp > 0:
            partrans(mp, params_in, params)
        v = mp + mq
        if msp > 0:
            partrans(msp, params_in[v:], params[v:])
    if ns > 0:
        phi[:mp] = params[:mp]
        phi[mp:p] = 0.
        theta[:mq] = params[mp:mp+mq]
        theta[mq:q] = 0.
        for j in range(msp):
            phi[(j + 1) * ns - 1] += params[j + mp + mq]
            for i in range(mp):
                phi[(j + 1) * ns + i] -= params[i] * params[j + mp + mq]

        for j in range(msq):
            theta[(j + 1) * ns - 1] += params[j + mp + mq + msp]
            for i in range(mq):
                theta[(j + 1) * ns + i] += params[i + mp] * params[j + mp + mq + msp]
    else:
        phi[:mp] = params[:mp]
        theta[:mq] = theta[mp:mp + mq]

    return phi, theta

# Internal Cell
@njit
def arima_css(y, arma, phi, theta, ncond):
    n = len(y)
    p = len(phi)
    q = len(theta)
    nu = 0
    ssq = 0.0

    w = y.copy()

    for i in range(arma[5]):
        for l in range(n - 1, 0, -1):
            w[l] -= w[l - 1]

    ns = arma[4]
    for i in range(arma[6]):
        for l in range(n - 1, ns - 1, -1):
            w[l] -= w[l - ns]

    resid = np.empty(n)
    resid[:ncond] = 0.
    for l in range(ncond, n):
        tmp = w[l]
        for j in range(p):
            if l - j - 1 < 0:
                continue
            tmp -= phi[j] * w[l - j - 1]

        for j in range(min(l - ncond, q)):
            if l - j - 1 < 0:
                continue
            tmp -= theta[j] * resid[l - j - 1]

        resid[l] = tmp

        if not np.isnan(tmp):
            nu += 1
            ssq += tmp * tmp

    res = ssq / nu

    return res, resid

# Internal Cell
@njit
def _make_arima(phi, theta, delta, kappa = 1e6, tol = np.finfo(float).eps):
    # check nas phi
    # check nas theta
    p = len(phi)
    q = len(theta)
    r = max(p, q + 1)
    d = len(delta)

    rd = r + d
    Z = np.concatenate((np.array([1.]), np.zeros(r - 1), delta))
    T = np.zeros((rd, rd))

    if p > 0:
        T[:p, 0] = phi
    if r > 1:
        for i in range(1, r):
            T[i - 1, i] = 1

    if d > 0:
        T[r] = Z
        if d > 1:
            for ind in range(1, d):
                T[r + ind, r + ind - 1] = 1

    if q < r - 1:
        theta = np.concatenate((theta, np.zeros(r - 1 - q)))

    R = np.concatenate((np.array([1.]), theta, np.zeros(d)))
    V = R * R.reshape(-1, 1)
    h = 0.
    a = np.zeros(rd)
    Pn = np.zeros((rd, rd))
    P = np.zeros((rd, rd))

    if r > 1:
        Pn[:r, :r] = getQ0(phi, theta)
    else:
        Pn[0, 0] = 1 / (1 - phi[0] ** 2) if p > 0 else 1.

    if d > 0:
        for i in range(d):
            Pn[r + i, r + i] = kappa

    return phi, theta, delta, Z, a, P, T, V, h, Pn

def make_arima(phi, theta, delta, kappa = 1e6, tol = np.finfo(np.float64).eps):
    keys = ['phi', 'theta', 'delta', 'Z', 'a', 'P', 'T', 'V', 'h', 'Pn']
    res = _make_arima(phi, theta, delta, kappa, tol)
    return dict(zip(keys, res))

# Internal Cell
@njit
def arima_like(y, phi, theta, delta, a, P, Pn, up, use_resid):
    n = len(y)
    rd = len(a)
    p = len(phi)
    q = len(theta)
    d = len(delta)
    r = rd - d

    sumlog = 0.
    ssq = 0.
    nu = 0

    P = P.ravel()
    Pnew = Pn.ravel()
    anew = np.empty(rd)
    M = np.empty(rd)
    if d > 0:
        mm = np.empty(rd * rd)

    if use_resid:
        rsResid = np.empty(n)

    for l in range(n):
        for i in range(r):
            tmp = a[i + 1] if i < r - 1 else 0.
            if i < p:
                tmp += phi[i] * a[0]
            anew[i] = tmp
        if d > 0:
            for i in range(r + 1, rd):
                anew[i] = a[i - 1]
            tmp = a[0]
            for i in range(d):
                tmp += delta[i] * a[r + i]
            anew[r] = tmp
        if l > up:
            if d == 0:
                for i in range(r):
                    vi = 0.
                    if i == 0:
                        vi = 1.
                    elif i - 1 < q:
                        vi = theta[i - 1]
                    for j in range(r):
                        tmp = 0.
                        if j == 0:
                            tmp = vi
                        elif j - 1 < q:
                            tmp = vi * theta[j - 1]
                        if i < p and j < p:
                            tmp += phi[i] * phi[j] * P[0]
                        if i < r - 1 and j < r -1:
                            tmp += P[i + 1 + r * (j + 1)]
                        if i < p and j < r - 1:
                            tmp += phi[i] * P[j + 1]
                        if j < p and i < r -1:
                            tmp += phi[j] * P[i + 1]
                        Pnew[i + r * j] = tmp
            else:
                # mm = TP
                for i in range(r):
                    for j in range(rd):
                        tmp = 0.
                        if i < p:
                            tmp += phi[i] * P[rd * j]
                        if i < r - 1:
                            tmp += P[i + 1 + rd * j]
                        mm[i + rd * j] = tmp
                for j in range(rd):
                    tmp = P[rd * j]
                    for k in range(d):
                        tmp += delta[k] * P[r + k + rd * j]
                    mm[r + rd * j] = tmp
                for i in range(1, d):
                    for j in range(rd):
                        mm[r + i + rd * j] = P[r + i - 1 + rd * j]

                # Pnew = mmT'
                for i in range(r):
                    for j in range(rd):
                        tmp = 0.
                        if i < p:
                            tmp += phi[i] * mm[j]
                        if i < r - 1:
                            tmp += mm[rd * (i + 1) + j]
                        Pnew[j + rd * i] = tmp
                for j in range(rd):
                    tmp = mm[j]
                    for k in range(d):
                        tmp += delta[k] * mm[rd * (r + k) + j]
                    Pnew[rd * r + j] = tmp
                for i in range(1, d):
                    for j in range(rd):
                        Pnew[rd * (r + i) + j] = mm[rd * (r + i - 1) + j]
                for i in range(q + 1):
                    vi = 1. if i == 0 else theta[i - 1]
                    for j in range(q + 1):
                        Pnew[i + rd * j] += vi * (1. if j == 0 else theta[j - 1])

        if not math.isnan(y[l]):
            resid = y[l] - anew[0]
            for i in range(d):
                resid -= delta[i] * anew[r + i]
            for i in range(rd):
                tmp = Pnew[i]
                for j in range(d):
                    tmp += Pnew[i + (r + j) * rd] * delta[j]
                M[i] = tmp
            gain = M[0]
            for j in range(d):
                gain += delta[j] * M[r + j]
            if gain < 1e4:
                nu += 1
                ssq += resid * resid / gain if gain != 0. else math.inf
                sumlog += math.log(gain)
            if use_resid:
                 rsResid[l] = resid / math.sqrt(gain) if gain != 0. else math.inf
            for i in range(rd):
                a[i] = anew[i] + M[i] * resid / gain if gain != 0. else math.inf
            for i in range(rd):
                for j in range(rd):
                    P[i + j * rd] = Pnew[i + j * rd] - M[i] * M[j] / gain if gain != 0. else math.inf
        else:
            a[:] = anew[:]
            P[:] = Pnew[:]
            if use_resid:
                rsResid[l] = np.nan
    if not use_resid:
        rsResid = None
    return ssq, sumlog, nu, rsResid

# Internal Cell
@njit
def diff1d(x, lag, differences):
    y = x.copy()
    for _ in range(differences):
        x = y.copy()
        for i in range(lag):
            y[i] = np.nan
        for i in range(lag, x.size):
            y[i] = x[i] - x[i - lag]
    return y

@njit
def diff2d(x, lag, differences):
    y = np.empty_like(x)
    for j in range(x.shape[1]):
        y[:, j] = diff1d(x[:, j], lag, differences)
    return y


def diff(x, lag, differences):
    if x.ndim == 1:
        y = diff1d(x, lag, differences)
        nan_mask = np.isnan(y)
    elif x.ndim == 2:
        y = diff2d(x, lag, differences)
        nan_mask = np.isnan(y).all(1)
    else:
        raise ValueError(x.ndim)
    return y[~nan_mask]

# Internal Cell
def arima(x: np.ndarray,
          order=(0, 0, 0),
          seasonal={'order': (0, 0, 0), 'period': 1},
          xreg=None,
          include_mean=True,
          transform_pars=True,
          fixed=None,
          init=None,
          method='CSS',
          SSinit='Gardner1980',
          optim_method='BFGS',
          kappa = 1e6,
          tol=1e-8,
          optim_control = {'maxiter': 100}):
    SSG = SSinit == 'Gardner1980'

    def upARIMA(mod, phi, theta):
        p = len(phi)
        q = len(theta)
        mod['phi'] = phi
        mod['theta'] = theta
        r = max(p, q + 1)
        if p > 0:
            mod['T'][:p, 0] = phi
        if r > 1:
            if SSG:
                mod['Pn'][:r, :r] = getQ0(phi, theta)
            else:
                mod['Pn'][:r, :r] = getQ0bis(phi, theta, tol=0)
        else:
            mod['Pn'][0, 0] = 1 / (1 - phi**2) if p > 0 else 1
        mod['a'][:] = 0  # a es vector?
        return mod

    def arimaSS(y, mod):
        # arima_like(y, phi, theta, delta, a, P, Pn, up, use_resid)
        return arima_like(
            y,
            mod['phi'],
            mod['theta'],
            mod['delta'],
            mod['a'],
            mod['P'],
            mod['Pn'],
            0,
            True,
        )

    def armafn(p, x, trans):
        x = x.copy()
        par = coef.copy()
        par[mask] = p
        trarma = arima_transpar(par, arma, trans)
        Z = upARIMA(mod, trarma[0], trarma[1])
        if Z is None:
            return np.finfo(np.float64).max
        if ncxreg > 0:
            x -= np.dot(xreg, par[narma + np.arange(ncxreg)])
        res = arima_like(x,
                         Z['phi'],
                         Z['theta'],
                         Z['delta'],
                         Z['a'],
                         Z['P'],
                         Z['Pn'],
                         0,
                         False,
                        )
        if res[2] == 0.:
            return math.inf

        s2 = res[0] / res[2]
        if s2 <= 0:
            return math.nan
        return 0.5 * (math.log(s2) + res[1] / res[2])

    def arCheck(ar):
        p = np.argmax(np.append(1, -ar) != 0)
        if not p:
            return True
        coefs = np.append(1, -ar[:p])
        roots = np.polynomial.polynomial.polyroots(coefs)
        return all(np.abs(roots) > 1)

    def maInvert(ma):
        q = len(ma)
        q0 = np.argmax(np.append(1, ma) != 0)
        if not q0:
            return ma
        coefs = np.append(1, ma[:q0])
        roots = np.polynomial.polynomial.polyroots(coefs)
        ind = np.abs(roots) < 1
        if any(ind):
            return ma
        if q0 == 1:
            return np.append(1 / ma[0], np.repeat(0, q - q0))
        roots[ind] = 1 / roots[ind]
        x = 1
        for r in roots:
            x = np.append(x, 0) - np.append(0, x) / r
        return x.real[1:], np.repeat(0, q - q0)

    if x.ndim > 1:
        raise ValueError('Only implemented for univariate time series')

    if x.dtype not in (np.float32, np.float64):
        x = x.astype(np.float64)
    n = len(x)

    if len(order) != 3 or any(o < 0 or not isinstance(o, int) for o in order):
        raise ValueError(f'order must be 3 non-negative integers, got {order}')
    if 'order' not in seasonal:
        raise ValueError('order must be a key in seasonal')
    if len(seasonal['order']) != 3 or any(o < 0 or not isinstance(o, int) for o in seasonal['order']):
        raise ValueError('order must be 3 non-negative integers')

    if seasonal['period'] is None or seasonal['period'] == 0:
        warnings.warn('Setting seasonal period to 1')
        seasonal['period'] = 1

    #fixed
    #mask
    arma = (*order[::2],
            *seasonal['order'][::2],
            seasonal['period'],
            order[1],
            seasonal['order'][1])
    narma = sum(arma[:4])

    # xtsp = init x, end x and frequency
    # tsp(x) = None
    Delta = np.array([1.])
    for i in range(order[1]):
        Delta = tsconv(Delta, np.array([1., -1.]))

    for i in range(seasonal['order'][1]):
        Delta = tsconv(Delta, np.array([1] + [0]*(seasonal['period'] - 1) + [-1]))
    Delta = - Delta[1:]
    nd = order[1] + seasonal['order'][1]
    n_used = (~np.isnan(x)).sum() - len(Delta)

    if xreg is None:
        ncxreg = 0
    else:
        if xreg.shape[0] != n:
            raise Exception('lengths of `x` and `xreg` do not match')

        ncxreg = xreg.shape[1]

    nmxreg = [f'ex_{i+1}' for i in range(ncxreg)]
    if include_mean and (nd == 0):
        intercept = np.ones(n, dtype=np.float64).reshape(-1, 1)
        if xreg is None:
            xreg = intercept
        else:
            xreg = np.concatenate([intercept, xreg])
        ncxreg += 1
        nmxreg = ['intercept'] + nmxreg

    # check nas for method CSS-ML
    if method == 'CSS-ML':
        anyna = np.isnan(x).any()
        if ncxreg:
            anyna |= np.isnan(xreg).any()
        if anyna:
            method = 'ML'
    if method.startswith('CSS'):
        ncond = order[1] + seasonal['order'][1] * seasonal['period']
        ncond1 = order[0] + seasonal['order'][0] * seasonal['period']
        ncond = ncond + ncond1
    else:
        ncond = 0

    if fixed is None:
        fixed = np.full(narma + ncxreg, np.nan)
    else:
        if len(fixed) != narma + ncxreg:
            raise Exception('wrong length for `fixed`')
    mask = np.isnan(fixed)

    no_optim = not mask.any()

    if no_optim:
        transform_pars = False

    if transform_pars:
        ind = arma[0] + arma[1] + np.arange(arma[2])
        # check masks and more
        if any(~mask[np.arange(arma[0])]) or any(~mask[ind]):
            warnings.warn('some AR parameters were fixed: setting transform_pars = False')
            transform_pars = False

    init0 = np.zeros(narma)
    parscale = np.ones(narma)

    # xreg processing
    if ncxreg:
        cn = nmxreg
        orig_xreg = (ncxreg == 1) | (~mask[narma + np.arange(ncxreg)]).any()
        if not orig_xreg:
            _, _, vt = np.linalg.svd(xreg[(~np.isnan(xreg)).all(1)])
            xreg = xreg * vt
        dx = x
        dxreg = xreg
        if order[1] > 0:
            dx = diff(dx, 1, order[1])
            dxreg = diff(dxreg, 1, order[1])
        if seasonal['period'] > 1 and seasonal['order'][1] > 0:
            dx = diff(dx, seasonal['period'], seasonal['order'][1])
            dxreg = diff(dxreg, seasonal['period'], seasonal['order'][1])
        if len(dx) > dxreg.shape[1]:
            model = sm.OLS(dx, dxreg)
            result = model.fit()
            fit = {'coefs': result.params, 'stderrs': result.bse}
        else:
            raise RuntimeError
        isna = np.isnan(x) | np.isnan(xreg).any(1)
        n_used = (~isna).sum() - len(Delta)
        init0 = np.append(init0, fit['coefs'])
        ses = fit['stderrs']
        parscale = np.append(parscale, 10 * ses)

    if n_used <= 0:
        raise ValueError('Too few non-missing observations')

    if init is not None:
        if len(init) != len(init0):
            raise ValueError(f'init should have length {len(init0)}')
        nan_mask = np.isnan(init)
        if nan_mask.any():
            init[nan_mask] = init0[nan_mask]
        if method == 'ML':
            # check stationarity
            if arma[0] > 0:
                if not arCheck(init[:arma[0]]):
                    raise ValueError('non-stationary AR part')
                if arma[2] > 0:
                    if not arCheck(init[arma[:2]].sum() + np.arange(arma[2])):
                        raise ValueError('non-stationary seasonal AR part')
                if transform_pars:
                    init = ARIMA_invtrans(init, arma)
    else:
        init = init0

    def arma_css_op(p):
        phi, theta = arima_transpar(p, arma, False)
        res, resid = arima_css(x, arma, phi, theta, ncond)

        return 0.5 * np.log(res)



    coef = np.array(fixed)
    # parscale definition, think about it, scipy doesnt use it
    if method == 'CSS':
        if no_optim:
            res = OptimResult(True, 0, np.array([]), arma_css_op(np.array([])))
        else:
            res = minimize(arma_css_op, init0, method=optim_method, tol=tol, options=optim_control)

        if res.status > 0:
            warnings.warn(
                f'possible convergence problem: minimize gave code {res.status}]'
            )

        coef[mask] = res.x
        phi, theta = arima_transpar(coef, arma, False)
        mod = make_arima(phi, theta, Delta, kappa)
        if ncxreg > 0:
            x -= np.dot(xreg, coef[narma + np.arange(ncxreg)])
        val = arima_css(x, arma, phi, theta, ncond)
        sigma2 = val[0]
        var = None if no_optim else res.hess_inv / n_used
    else:
        if method == 'CSS-ML':
            if no_optim:
                res = OptimResult(True, 0, np.array([]), arma_css_op(np.array([])))
            else:
                res = minimize(arma_css_op, init[mask], method=optim_method, tol=tol, options=optim_control)
            # if not res.success:
                # warnings.warn(res.message)
            #if res.success:
            init[mask] = res.x
            if arma[0] > 0:
                if not arCheck(init[:arma[0]]):
                    raise ValueError('non-stationary AR part from CSS')
            if arma[2] > 0:
                if not arCheck(init[np.sum(arma[:2])] + np.arange(arma[2])):
                    raise ValueError('non-stationary seasonal AR part from CSS')
            ncond = 0
            if transform_pars:
                init = ARIMA_invtrans(init, arma)
                if arma[1] > 0:
                    ind = arma[0] + np.arange(arma[1])
                    init[ind] = maInvert(init[ind])
                if arma[3] > 0:
                    ind = np.sum(arma[:3]) + np.arange(arma[3])
                    init[ind] = maInvert(init[ind])
        trarma = arima_transpar(init, arma, transform_pars)
        mod = make_arima(trarma[0], trarma[1], Delta, kappa, SSinit)
        if no_optim:
            res = OptimResult(True, 0, np.array([]), armafn(np.array([]), x, transform_pars))
        else:
            res = minimize(armafn, init[mask], args=(x, transform_pars,),
                           method=optim_method, tol=tol, options=optim_control)
        # if not res.success:
            # warnings.warn(res.message)
        coef[mask] = res.x
        if transform_pars:
            if arma[1] > 0:
                ind = arma[0] + np.arange(arma[1])
                if mask[ind].all():
                    coef[ind] = maInvert(coef[ind])
            if arma[3] > 0:
                ind = np.sum(arma[:3]) + np.arange(arma[3])
                if mask[ind].all():
                    coef[ind] = maInvert(coef[ind])
            if any(coef[mask] != res.x):
                oldcode = res.status
                res = minimize(arma_css_op, coef[mask], method=optim_method,
                               tol=tol, options=optim_control)
                res.status = oldcode
                coef[mask] = res.x
            A = arima_gradtrans(coef, arma)
            A = A[mask][mask]
            sol = np.matmul(res.hess_inv, A) / n_used
            var = np.dot(sol, sol)
            coef = arima_undopars(coef, arma)
        else:
            var = None if no_optim else res.hess_inv / n_used
        trarma = arima_transpar(coef, arma, False)
        mod = make_arima(trarma[0], trarma[1], Delta, kappa, SSinit)
        if ncxreg > 0:
            val = arimaSS(np.dot(x.reshape(-1, 1) - xreg, coef[narma + np.arange(ncxreg)]), mod)
        else:
            val = arimaSS(x, mod)
        val = (val[0], val[3])
        sigma2 = val[0] / n_used

    value = 2 * n_used * res.fun + n_used + n_used * np.log(2 * np.pi)
    aic = value + 2 * sum(mask) + 2 if method != 'CSS' else np.nan

    nm = []
    if arma[0] > 0: nm.extend([f'ar{i+1}' for i in range(arma[0])])
    if arma[1] > 0: nm.extend([f'ma{i+1}' for i in range(arma[1])])
    if arma[2] > 0: nm.extend([f'sar{i+1}' for i in range(arma[2])])
    if arma[3] > 0: nm.extend([f'sma{i+1}' for i in range(arma[3])])
    if ncxreg > 0:
        nm += cn
        if not orig_xreg:
            ind = narma + np.arange(ncxreg)
            coef[ind] = np.dot(vt, coef[ind])
            A = np.identity(narma + ncxreg)
            A[ind, ind] = vt
            A = A[mask, mask]
            var = np.dot(np.dot(A, var), A.T)
    coef = dict(zip(nm, coef))
    # if no_optim:
    #     var = pd.DataFrame(var, columns=nm[mask], index=nm[mask])
    resid = val[1]

    ans = {
        'coef': coef,
        'sigma2': sigma2,
        'var_coef': var,
        'mask': mask,
        'loglik': -0.5 * value,
        'aic': aic,
        'arma': arma,
        'residuals': resid,
        #'series': series,
        'code': res.status,
        'n_cond': ncond,
        'nobs': n_used,
        'model': mod
    }
    return ans

# Internal Cell
@njit
def kalman_forecast(n, Z, a, P, T, V, h):
    p = len(a)

    a = a.copy()
    anew = np.empty(p)
    Pnew = np.empty((p, p))
    mm = np.empty((p, p))
    forecasts = np.empty(n)
    se = np.empty(n)
    P = P.copy()

    for l in range(n):
        fc = 0.
        anew = T @ a

        a[:] = anew[:]
        forecasts[l] = anew @ Z

        for i in range(p):
            for j in range(p):
                tmp = 0.
                for k in range(p):
                    tmp += T[i, k] * P[k, j]
                mm[i, j] = tmp

        for i in range(p):
            for j in range(p):
                tmp = V[i, j]
                for k in range(p):
                    tmp += mm[i, k] * T[j, k]
                Pnew[i, j] = tmp

        tmp = h
        for i in range(p):
            for j in range(p):
                P[i, j] = Pnew[i, j]
                tmp += Z[i] * Z[j] * P[i, j]
        se[l] = tmp

    return forecasts, se

# Internal Cell
def checkarima(obj):
    if obj['var_coef'] is None: return False
    return any(np.isnan(np.sqrt(np.diag(obj['var_coef']))))

# Internal Cell
def myarima(
    x,
    order=(0, 0, 0),
    seasonal={'order': (0, 0, 0), 'period': 1},
    constant=True,
    ic='aic',
    trace=False,
    approximation=False,
    offset=0,
    xreg=None,
    method=None,
    **kwargs
):
    missing = np.isnan(x)
    missing_idxs = np.where(~missing)[0]
    firstnonmiss = missing_idxs.min()
    lastnonmiss = missing_idxs.max()
    n = np.sum(~missing[firstnonmiss:lastnonmiss])
    m = seasonal['period']
    seas_order = seasonal['order']
    use_season = np.sum(seas_order) > 0 and m > 0
    diffs = order[1] + seas_order[1]
    if method is None:
        if approximation:
            method = 'CSS'
        else:
            method = 'CSS-ML'
    try:
        if diffs == 1 and constant:
            xreg = np.arange(1, x.size + 1, dtype=np.float64).reshape(-1, 1)  # drift
            if use_season:
                fit = arima(x, order, seasonal, xreg, method=method)
            else:
                fit = arima(x, order, xreg=xreg, method=method)
        else:
            if use_season:
                fit = arima(
                    x, order, seasonal, include_mean=constant, method=method, xreg=xreg
                )
            else:
                fit = arima(x, order, include_mean=constant, method=method, xreg=xreg)
        nxreg = 0 if xreg is None else xreg.shape[1]
        nstar = n - order[1] - seas_order[1] * m
        if diffs == 1 and constant:
            fit['xreg'] = xreg
        npar = fit['mask'].sum() + 1
        if method == 'CSS':
            fit['aic'] = offset + nstar * math.log(fit['sigma2']) + 2 * npar
        if not math.isnan(fit['aic']):
            fit['bic'] = fit['aic'] + npar * (math.log(nstar) - 2)
            fit['aicc'] = fit['aic'] + 2 * npar * (npar + 1) / (nstar - npar - 1)
            fit['ic'] = fit[ic]
        else:
            fit['ic'] = fit['aic'] = fit['bic'] = fit['aicc'] = math.inf
        fit['sigma2'] = np.sum(fit['residuals']**2) / (nstar - npar + 1)
        minroot = 2
        if order[0] + seas_order[0] > 0:
            testvec = fit['model']['phi']
            k = abs(testvec) > 1e-8
            if k.sum() > 0:
                last_nonzero = np.max(np.where(k)[0])
            else:
                last_nonzero = 0
            if last_nonzero > 0:
                testvec = testvec[:last_nonzero]
                proots = np.polynomial.polynomial.polyroots(np.append(1, -testvec))
                if proots.size > 0:
                    minroot = min(minroot, *abs(proots))
        if order[2] + seas_order[2] > 0 and fit['ic'] < math.inf:
            testvec = fit['model']['theta']
            k = abs(testvec) > 1e-8
            if np.sum(k) > 0:
                last_nonzero = np.max(np.where(k)[0])
            else:
                last_nonzero = 0
            if last_nonzero > 0:
                testvec = testvec[:last_nonzero]
                proots = np.polynomial.polynomial.polyroots(np.append(1, -testvec))
                if proots.size > 0:
                    minroot = min(minroot, *abs(proots))
        if minroot < 1 + 0.1 or checkarima(fit):
            fit['ic'] = math.inf
        if trace:
            print(fit)
        fit['xreg'] = xreg
        return fit
    except ValueError as e:
        raise e
        return {'ic': math.inf}

# Internal Cell
def search_arima(
    x,
    d=0,
    D=0,
    max_p=5,
    max_q=5,
    max_P=2,
    max_Q=2,
    max_order=5,
    stationary=False,
    ic='aic',
    trace=False,
    approximation=False,
    xreg=None,
    offset=None,
    allow_drift=True,
    allow_mean=True,
    parallel=False,
    num_cores=2,
    period=1,
    **kwargs
):
    m = period
    allow_drift = allow_drift and (d + D) == 1
    allow_mean = allow_mean and (d + D) == 0
    max_K = allow_drift or allow_mean

    if not parallel:
        best_ic = np.inf
        for i in range(max_p):
            for j in range(max_q):
                for I in range(max_P):
                    for J in range(max_Q):
                        if i + j + I + J > max_order:
                            continue
                        fit = myarima(
                            x,
                            order=(i, d, j),
                            seasonal={'order': (I, D, J), 'period': m},
                        )
                        if fit['ic'] < best_ic:
                            best_ic = fit['ic']
                            best_fit = fit
    else:
        raise NotImplementedError('parallel=True')
    return best_fit

# Internal Cell
def Arima(
    x,
    order=(0, 0, 0),
    seasonal={'order': (0, 0, 0), 'period': 1},
    xreg=None,
    include_mean=True,
    include_drift=False,
    include_constant=None,
    blambda=None,
    biasadj=False,
    method='CSS',
    model=None,
):
    origx = x
    seas_order = seasonal['order']
    if blambda is not None:
        x = boxcox(x, blambda)
        if not hasattr(blambda, 'biasadj'):
            setattr(blambda, 'biasadj', biasadj)
    if xreg is not None:
        if xreg.dtype not in (np.float32, np.float64):
            raise ValueError('xreg should be a float array')
    if len(x) <= order[1]:
        raise ValueError('Not enough data to fit the model')
    if len(x) <= order[1] + seas_order[1] * seasonal['period']:
        raise ValueError('Not enough data to fit the model')
    if include_constant is not None:
        if include_constant:
            include_mean = True
            if order[1] + seas_order[1] == 1:
                include_drift = True
        else:
            include_mean = include_drift = False
    if order[1] + seas_order[1] > 1 and include_drift:
        warnings.warn("No drift term fitted as the order of difference is 2 or more.")
        include_drift = False
    if model is not None:
        ...  # arima2
    else:
        if include_drift:
            xreg = np.arange(1, x.size + 1, dtype=np.float64).reshape(-1, 1)  # drift
        if xreg is None:
            tmp = arima(x, order, seasonal, include_mean=include_mean, method=method)
        else:
            tmp = arima(
                x, order, seasonal, xreg, include_mean, method=method,
            )
    npar = np.sum(tmp['mask']) + 1
    missing = np.isnan(tmp['residuals'])
    nonmiss_idxs = np.where(~missing)[0]
    firstnonmiss = np.min(nonmiss_idxs)
    lastnonmiss = np.max(nonmiss_idxs)
    n = np.sum(~missing[firstnonmiss:lastnonmiss])
    nstar = n - tmp['arma'][5] - tmp['arma'][6] * tmp['arma'][4]
    tmp['aicc'] = tmp['aic'] + 2*npar*(nstar / (nstar - npar - 1) - 1)
    tmp['bic'] = tmp['aic'] + npar*(math.log(nstar) - 2)
    tmp['xreg'] = xreg
    tmp['lambda'] = blambda
    tmp['x'] = origx
    if model is None:
        tmp['sigma2'] = np.sum(tmp['residuals']**2) / (nstar - npar + 1)
    return tmp

# Internal Cell
def is_constant(x):
    return np.all(x[0] == x)

# Internal Cell
def mstl(x, period, blambda=None, s_window=7 + 4 * np.arange(1, 7)):
    origx = x
    n = len(x)
    msts = period
    iterate = 1
    if x.ndim == 2:
        x = x[:, 0]
    if np.isnan(x).any():
        ...  # na.interp
    if blambda is not None:
        ...  # boxcox
    tt = np.arange(n)
    if msts > 1:
        fit = sm.tsa.STL(x, period=msts, seasonal=s_window[0]).fit()
        seas = fit.seasonal
        deseas = x - seas
        trend = fit.trend
    else:
        try:
            from supersmoother import SuperSmoother
        except ImportError as e:
            print('supersmoother is required for mstl with period=1')
            raise e
        msts = None
        deseas = x
        t = 1 + np.arange(n)
        trend = SuperSmoother().fit(t, x).predict(t)
    deseas[np.isnan(origx)] = np.nan
    remainder = deseas - trend
    output = {'data': origx, 'trend': trend}
    if msts is not None:
        output['seasonal'] = seas
    output['remainder'] = remainder
    return pd.DataFrame(output)

# Internal Cell
def seas_heuristic(x, period):
    nperiods = period > 1
    season = math.nan
    stlfit = mstl(x, period)
    remainder = stlfit['remainder']
    seasonal = stlfit.get('seasonal', None)
    vare = np.var(remainder, ddof=1)
    if seasonal is not None:
        season = max(0, min(1, 1 - vare / np.var(remainder + seasonal, ddof=1)))
    return season

# Internal Cell
def nsdiffs(x, test='seas', alpha=0.05, period=1, max_D=1, **kwargs):
    D = 0
    if alpha < 0.01:
        warnings.warn(
            "Specified alpha value is less than the minimum, setting alpha=0.01"
        )
        alpha = 0.01
    elif alpha > 0.1:
        warnings.warn(
            "Specified alpha value is larger than the maximum, setting alpha=0.1"
        )
        alpha = 0.1
    if test == 'ocsb':
        warnings.warn(
            "Significance levels other than 5% are not currently supported by test='ocsb', defaulting to alpha = 0.05."
        )
        alpha = 0.05
    if test in ('hegy', 'ch'):
        raise NotImplementedError
    if is_constant(x):
        return D
    if period == 1:
        raise ValueError('Non seasonal data')
    elif period < 1:
        warnings.warn(
            "I can't handle data with frequency less than 1. Seasonality will be ignored."
        )
        return 0
    if period >= len(x):
        return 0

    def run_tests(x, test, alpha):
        try:
            diff = seas_heuristic(x, period) > 0.64
            if diff not in (0, 1):
                raise ValueError(f'Found {diff} in seasonal test.')
        except Exception as e:
            warnings.warn(
                f"The chosen seasonal unit root test encountered an error when testing for the {D} difference.\n"
                f"From {test}(): {e}\n"
                f"{D} seasonal differences will be used. Consider using a different unit root test."
            )
            diff = 0
        return diff

    dodiff = run_tests(x, test, alpha)
    if dodiff and not isinstance(period, int):
        warnings.warn(
            "The time series frequency has been rounded to support seasonal differencing."
        )
        period = round(period)
    while dodiff and D < max_D:
        D += 1
        x = diff(x, period, 1)
        if is_constant(x):
            return D
        if len(x) >= 2*period and D < max_D:
            dodiff = run_tests(x, test, alpha)
        else:
            dodiff = False
    return D

# Internal Cell
def ndiffs(x, alpha=0.05, test='kpss', kind='level', max_d=2):
    x = x[~np.isnan(x)]
    d = 0
    if alpha < 0.01:
        warnings.warn(
            "Specified alpha value is less than the minimum, setting alpha=0.01"
        )
        alpha = 0.01
    elif alpha > 0.1:
        warnings.warn(
            "Specified alpha value is larger than the maximum, setting alpha=0.1"
        )
        alpha = 0.1
    if is_constant(x):
        return d

    def run_tests(x, test, alpha):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                nlags = math.floor(3 * math.sqrt(len(x)) / 13)
                diff = sm.tsa.kpss(x, 'c', nlags=nlags)[1] < alpha
        except Exception as e:
            warnings.warn(
                f"The chosen unit root test encountered an error when testing for the {d} difference.\n"
                f"From {test}(): {e}\n"
                f"{d} differences will be used. Consider using a different unit root test."
            )
            diff = False
        return diff

    dodiff = run_tests(x, test, alpha)
    if math.isnan(dodiff):
        return d
    while dodiff and d < max_d:
        d += 1
        x = diff(x, 1, 1)[1:]
        if is_constant(x):
            return d
        dodiff = run_tests(x, test, alpha)
        if math.isnan(dodiff):
            return d - 1
    return d

# Internal Cell
def newmodel(p, d, q, P, D, Q, constant, results):
    curr = np.array([p, d, q, P, D, Q, constant])
    in_results = (curr == results[:, :7]).all(1).any()
    return not in_results

# Cell
def auto_arima_f(
    x,
    d=None,
    D=None,
    max_p=5,
    max_q=5,
    max_P=2,
    max_Q=2,
    max_order=5,
    max_d=2,
    max_D=1,
    start_p=2,
    start_q=2,
    start_P=1,
    start_Q=1,
    stationary=False,
    seasonal=True,
    ic='aicc',
    stepwise=True,
    nmodels=94,
    trace=False,
    approximation=None,
    method=None,
    truncate=None,
    xreg=None,
    test='kpss',
    test_kwargs=None,
    seasonal_test='seas',
    seasonal_test_kwargs=None,
    allowdrift=True,
    allowmean=True,
    blambda=None,
    biasadj=False,
    parallel=False,
    num_cores=2,
    period=1,
):
    if approximation is None:
        approximation = len(x) > 150 or period > 12
    if stepwise and parallel:
        warnings.warn("Parallel computer is only implemented when stepwise=FALSE, the model will be fit in serial.")
        parallel = False
    if trace and parallel:
        warnings.warn("Tracing model searching in parallel is not supported.")
        trace = False
    if x.ndim > 1:
        raise ValueError("auto_arima can only handle univariate time series")
    if test_kwargs is None:
        test_kwargs = {}
    if seasonal_test_kwargs is None:
        seasonal_test_kwargs = {}
    x = x.copy()
    origx = x
    missing = np.isnan(x)
    nonmissing_idxs = np.where(~missing)[0]
    firstnonmiss = nonmissing_idxs.min()
    lastnonmiss = nonmissing_idxs.max()
    series_len = np.sum(~missing[firstnonmiss:lastnonmiss])
    x = x[firstnonmiss:]
    if xreg is not None:
        if xreg.dtype not in (np.float32, np.float64):
            raise ValueError('xreg should be a float array')
        xreg = xreg[firstnonmiss:]
    if is_constant(x):
        if np.isnan(x).all():
            raise ValueError('all data are missing')
        if allowmean:
            fit = Arima(x, order=(0, 0, 0), fixed=np.mean(x))
        else:
            fit = Arima(x, order=(0, 0, 0), include_mean=False)
        fit['x'] = origx
        fit['constant'] = True
        return fit
    m = period if seasonal else 1
    if m < 1:
        m = 1
    else:
        m = round(m)
    max_p = min(max_p, series_len // 3)
    max_q = min(max_q, series_len // 3)
    max_P = min(max_P, math.floor(series_len / 3 / m))
    max_Q = min(max_Q, math.floor(series_len / 3 / m))
    if series_len <= 3:
        ic = 'aic'
    if blambda is not None:
        x = boxcox(x, blambda)
        setattr(blambda, 'biasadj', biasadj)
    if xreg is not None:
        constant_columns = np.array([is_constant(col) for col in xregg.T])
        if constant_columns.all():
            xregg = None
        else:
            if constant_columns.any():
                xregg = xregg[:, ~constant_columns]
            X = np.hstack([np.arange(1, xregg.shape[0]+1), xregg])
            X = X[~np.isnan(X).any(1)]
            _, sv, _ = np.linalg.svd(X)
            if sv.min() / sv.sum() < np.finfo(np.float64).eps:
                raise ValueError('xreg is rank deficient')
            j = (~np.isnan(x)) & (~np.isnan(np.nansum(xregg, 1)))
            xx[j] = sm.OLS(x, xregg).fit().resid
    else:
        xx = x
        xregg = None
    if stationary:
        d = D = 0
    if m == 1:
        D = max_P = max_Q = 0
    elif D is None and len(xx) <= 2 * m:
        D = 0
    elif D is None:
        D = nsdiffs(xx, period=m, test=seasonal_test, max_D=max_D, **seasonal_test_kwargs)
        if D > 0 and xregg is not None:
            diffxreg = diff(xregg, m, D)
            if any(is_constant(col) for col in xregg.T):
                D -= 1
        if D > 0:
            dx = diff(xx, m, D)
            if np.isnan(dx).all():
                D -= 1
    if D > 0:
        dx = diff(xx, m, D)
    else:
        dx = xx
    if xregg is not None:
        if D > 0:
            diffxreg = diff(xregg, m, D)
        else:
            diffxreg = xregg
    if d is None:
        d = ndiffs(dx, test=test, max_d=max_d, **test_kwargs)
        if d > 0 and xregg is not None:
            diffxreg = diff(diffxreg, 1, d)
            if any(is_constant(col) for col in diffxreg.T):
                d -= 1
        if d > 0:
            diffdx = diff(dx, 1, d)
            if np.isnan(diffdx).all():
                d -= 1
    if D >= 2:
        warnings.warn("Having more than one seasonal differences is not recommended. Please consider using only one seasonal difference.")
    elif D + d > 2:
        warnings.warn("Having 3 or more differencing operations is not recommended. Please consider reducing the total number of differences.")
    if d > 0:
        dx = diff(dx, 1, d)
    if len(dx) == 0:
        raise ValueError('not enough data to proceed')
    elif is_constant(dx):
        if xreg is None:
            if D > 0 and d == 0:
                fit = Arima(
                    x,
                    order=(0, d, 0),
                    seasonal={'order': (0, D, 0), 'period': m},
                    include_constant=True,
                    fixed=np.mean(dx/m),
                    method=method,
                )
            elif D > 0 and d > 0:
                fit = Arima(
                    x,
                    order=(0, d, 0),
                    seasonal={'order': (0, D, 0), 'period': m},
                    method=method,
                )
            elif d == 2:
                fit = Arima(x, order=(0, d, 0), method=method)
            elif d < 2:
                fit = Arima(
                    x,
                    order=(0, d, 0),
                    include_constant=True,
                    fixed=np.mean(dx),
                    method=method,
                )
            else:
                raise ValueError("Data follow a simple polynomial and are not suitable for ARIMA modelling.")
        else:
            if D > 0:
                fit = Arima(
                    x,
                    order=(0, d, 0),
                    seasonal={'order': (0, D, 0), 'period': m},
                    xreg=xreg,
                    method=method
                )
            else:
                fit = Arima(x, order=(0, d, 0), xreg=xreg, method=method)
        fit['x'] = origx
        return fit
    if m > 1:
        if max_p > 0:
            max_p = min(max_p, m - 1)
        if max_q > 0:
            max_q = min(max_q, m - 1)
    if approximation:
        if truncate is not None:
            if len(x) > truncate:
                x = x[-truncate:]
        try:
            if D == 0:
                fit = arima(x, order=(0, d, 0), xreg=xreg)
            else:
                fit = arima(
                    x,
                    order=(0, d, 0),
                    seasonal={'order': (0, D, 0), 'period': m},
                    xreg=xreg
                )
            offset = -2*fit['loglik'] - series_len*math.log(fit['sigma2'])
        except:
            offset = 0
    else:
        offset = 0
    allowdrift = allowdrift and (d + D) == 1
    allowmean = allowmean and (d + D) == 0
    constant = allowdrift or allowmean
    if approximation and trace:
        print('Fitting models using approximations to speed things up')
    if not stepwise:
        bestfit = search_arima(
            x,
            d,
            D,
            max_p,
            max_q,
            max_P,
            max_Q,
            max_order,
            stationary,
            ic,
            trace,
            approximation,
            method=method,
            xreg=xreg,
            offset=offset,
            allowdrift=allowdrift,
            allowmean=allowmean,
            parallel=parallel,
            num_cores=num_cores,
            period=m,
        )
        bestfit['lambda'] = blambda
        bestfit['x'] = origx
        if trace:
            print(f"Best model: arma={fit['arma']}")
        return bestfit
    if len(x) < 10:
        start_p = min(start_p, 1)
        start_q = min(start_q, 1)
        start_P = 0
        start_Q = 0
    p = start_p = min(start_p, max_p)
    q = start_q = min(start_q, max_q)
    P = start_P = min(start_P, max_P)
    Q = start_Q = min(start_Q, max_Q)
    results = np.full((nmodels, 8), np.nan)
    p_myarima = partial(
        myarima,
        x=x,
        constant=constant,
        ic=ic,
        trace=trace,
        approximation=approximation,
        offset=offset,
        xreg=xreg,
        method=method,
    )
    bestfit = p_myarima(
        order=(p, d, q),
        seasonal={'order': (P, D, Q), 'period': m},
    )
    results[0] = (p, d, q, P, D, Q, constant, bestfit['ic'])
    fit = p_myarima(
        order=(0, d, 0),
        seasonal={'order': (0, D, 0), 'period': m},
    )
    results[1] = (0, d, 0, 0, D, 0, constant, fit['ic'])
    if fit['ic'] < bestfit['ic']:
        bestfit = fit
        p = q = P = Q = 0
    k = 1
    if max_p > 0 or max_P > 0:
        p_ = int(max_p > 0)
        P_ = int(m > 1 and max_P > 0)
        fit = p_myarima(
            order=(p_, d, 0),
            seasonal={'order': (P_, D, 0), 'period': m},
        )
        results[k + 1] = (p_, d, 0, P_, D, 0, constant, fit['ic'])
        if fit['ic'] < bestfit['ic']:
            bestfit = fit
            p = p_
            P = P_
            q = Q = 0
        k += 1
    if max_q > 0 or max_Q > 0:
        q_ = int(max_q > 0)
        Q_ = int(m > 1 and max_Q > 0)
        fit = p_myarima(
            order=(0, d, q_),
            seasonal={'order': (0, D, Q_), 'period': m},
        )
        results[k + 1] = (0, d, q_, 0, D, Q_, constant, fit['ic'])
        if fit['ic'] < bestfit['ic']:
            bestfit = fit
            p = P = 0
            Q = Q_
            q = q_
        k += 1
    if constant:
        fit = p_myarima(
            order=(0, d, 0),
            seasonal={'order': (0, D, 0), 'period': m},
            constant=False,
        )
        results[k + 1] = (0, d, 0, 0, D, 0, 0, fit['ic'])
        if fit['ic'] < bestfit['ic']:
            bestfit = fit
            p = q = P = Q = 0
        k += 1

    def try_params(p, d, q, P, D, Q, constant, k, bestfit):
        k += 1
        improved = False
        if k >= results.shape[0]:
            return k, bestfit, improved
        fit = p_myarima(
            order=(p, d, q),
            seasonal={'order': (P, D, Q), 'period': m},
        )
        results[k] = (p, d, q, P, D, Q, constant, fit['ic'])
        if fit['ic'] < bestfit['ic']:
            bestfit = fit
            improved = True
        return k, bestfit, improved

    startk = 0
    while startk < k and k < nmodels:
        startk = k
        if P > 0 and newmodel(p, d, q, P - 1, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P - 1, D, Q, constant, k, bestfit)
            if improved:
                P -= 1
                continue
        if Q > 0 and newmodel(p, d, q, P, D, Q - 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P, D, Q - 1, constant, k, bestfit)
            if improved:
                Q -= 1
                continue
        if P < max_P and newmodel(p, d, q, P + 1, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P + 1, D, Q, constant, k, bestfit)
            if improved:
                P += 1
                continue
        if Q < max_Q and newmodel(p, d, q, P, D, Q + 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P, D, Q + 1, constant, k, bestfit)
            if improved:
                Q += 1
                continue
        if Q > 0 and P > 0 and newmodel(p, d, q, P - 1, D, Q - 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P - 1, D, Q - 1, constant, k, bestfit)
            if improved:
                P -= 1
                Q -= 1
                continue
        if Q < max_Q and P > 0 and newmodel(p, d, q, P - 1, D, Q + 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P - 1, D, Q + 1, constant, k, bestfit)
            if improved:
                P -= 1
                Q += 1
                continue
        if Q > 0 and P < max_P and newmodel(p, d, q, P + 1, D, Q - 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P + 1, D, Q - 1, constant, k, bestfit)
            if improved:
                P += 1
                Q -= 1
                continue
        if Q < max_Q and P < max_P and newmodel(p, d, q, P + 1, D, Q + 1, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P + 1, D, Q + 1, constant, k, bestfit)
            if improved:
                P += 1
                Q += 1
                continue
        if p > 0 and newmodel(p - 1, d, q, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p - 1, d, q, P, D, Q, constant, k, bestfit)
            if improved:
                p -= 1
                continue
        if q > 0 and newmodel(p, d, q - 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q - 1, P, D, Q, constant, k, bestfit)
            if improved:
                q -= 1
                continue
        if p < max_p and newmodel(p + 1, d, q, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p + 1, d, q, P, D, Q, constant, k, bestfit)
            if improved:
                p += 1
                continue
        if q < max_q and newmodel(p, d, q + 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q + 1, P, D, Q, constant, k, bestfit)
            if improved:
                q += 1
                continue
        if q > 0 and p > 0 and newmodel(p - 1, d, q - 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p - 1, d, q - 1, P, D, Q, constant, k, bestfit)
            if improved:
                p -= 1
                q -= 1
                continue
        if q < max_q and p > 0 and newmodel(p - 1, d, q + 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p - 1, d, q + 1, P, D, Q, constant, k, bestfit)
            if improved:
                p -= 1
                q += 1
                continue
        if q > 0 and p < max_p and newmodel(p + 1, d, q - 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p + 1, d, q - 1, P, D, Q, constant, k, bestfit)
            if improved:
                p += 1
                q -= 1
                continue
        if q < max_q and p < max_p and newmodel(p + 1, d, q + 1, P, D, Q, constant, results[:k]):
            k, bestfit, improved = try_params(p + 1, d, q + 1, P, D, Q, constant, k, bestfit)
            if improved:
                p += 1
                q += 1
                continue
        if (allowdrift or allowmean) and newmodel(p, d, q, P, D, Q, not constant, results[:k]):
            k, bestfit, improved = try_params(p, d, q, P, D, Q, not constant, k, bestfit)
            if improved:
                constant = not constant
                continue
    if k >= nmodels:
        warnings.warn(
            f"Stepwise search was stopped early due to reaching the model number limit: nmodels={nmodels}"
        )
    if approximation or bestfit['arma'] is not None:
        if trace:
            print("Now re-fitting the best model(s) without approximations...\n")
        icorder = np.argsort(results[:, 7])
        nmodels = np.sum(~np.isnan(results[:, 7]))
        for i in range(nmodels):
            k = icorder[i]
            p, q, P, Q, constant = map(int, results[k, [0, 2, 3, 5, 6]])
            fit = myarima(
                x,
                (p, d, q),
                {'order': (P, D, Q), 'period': m},
                constant=results[k, 6],
                ic=ic,
                trace=trace,
                approximation=False,
                method=method,
                xreg=xreg,
            )
            if fit['ic'] < math.inf:
                bestfit = fit
                break
    if math.isinf(bestfit['ic']) and method != 'CSS':
        raise ValueError('No suitable ARIMA model found')
    return bestfit

# Cell
def predict_arima(model, n_ahead, newxreg = None, se_fit=True):

    myNCOL = lambda x: x.shape[1] if x is not None else 0
    rsd = model['residuals']
    #xreg = model['xreg']
    #ncxreg = myNCOL(xreg)
    ncxreg = 0

    #if myNCOL(newxreg) != ncxreg:
    #    raise Exception("'xreg' and 'newxreg' have different numbers of columns")

    n = len(rsd)
    arma = model['arma']
    ncoefs, coefs = list(model['coef'].keys()), list(model['coef'].values())
    narma = sum(arma[:4])
    if len(coefs) > narma:
        # check intercept
        # i think xreg is unused
        if ncoefs[narma] == "intercept":
            intercept = np.ones(n_ahead, dtype=np.float64).reshape(-1, 1)
            if newxreg is None:
                newxreg = intercept
            else:
                newxreg = np.concatenate([intercept, newxreg])
        ncxreg += 1
        if narma == 0:
            xm = np.matmul(newxreg, coefs)
        else:
            xm = np.matmul(newxreg, coefs[narma:])

        xm = xm.flatten()
    else:
        xm = 0

    # just warnings
    #if (arma[2L] > 0L) {
    #    ma <- coefs[arma[1L] + 1L:arma[2L]]
    #    if (any(Mod(polyroot(c(1, ma))) < 1))
    #        warning("MA part of model is not invertible")
    #}

    #if (arma[4L] > 0L) {
    #    ma <- coefs[sum(arma[1L:3L]) + 1L:arma[4L]]
    #    if (any(Mod(polyroot(c(1, ma))) < 1))
    #        warning("seasonal MA part of model is not invertible")
    #}

    pred, se = kalman_forecast(n_ahead, *(model['model'][var] for var in ['Z', 'a', 'P', 'T', 'V', 'h']))
    pred += xm
    if se_fit:
        se = np.sqrt(se * model['sigma2'])
        return pred, se

    return pred