# -*- coding: utf-8 -*-
"""
seasonal

@author: Colin
"""
import random
import numpy as np
import pandas as pd
from autots.tools.lunar import moon_phase
from autots.tools.window_functions import sliding_window_view


def seasonal_int(include_one: bool = False, small=False, very_small=False):
    """Generate a random integer of typical seasonalities.

    Args:
        include_one (bool): whether to include 1 in output options
        small (bool): if True, keep below 364
        very_small (bool): if True keep below 30
    """
    prob_dict = {
        -1: 0.1,  # random int
        1: 0.05,  # previous day
        2: 0.1,
        4: 0.05,  # quarters
        7: 0.15,  # week
        10: 0.01,
        12: 0.1,  # months
        24: 0.1,  # months or hours
        28: 0.1,  # days in month to weekday
        60: 0.05,
        96: 0.04,  # quarter in days
        168: 0.01,
        364: 0.1,  # year to weekday
        1440: 0.01,
        420: 0.01,
        52: 0.01,
        84: 0.01,
    }
    lag = random.choices(
        list(prob_dict.keys()),
        list(prob_dict.values()),
        k=1,
    )[0]
    if lag == -1:
        lag = random.randint(2, 100)
    if not include_one and lag == 1:
        lag = seasonal_int(include_one=include_one, small=small, very_small=very_small)
    if small:
        lag = lag if lag < 364 else 364
    if very_small:
        while lag > 30:
            lag = seasonal_int(include_one=include_one, very_small=very_small)
    return int(lag)


date_part_methods = [
    "recurring",
    "simple",
    "expanded",
    "simple_2",
    "simple_3",
    'lunar_phase',
    "simple_binarized",
    "simple_binarized_poly",
    "expanded_binarized",
    'common_fourier',
    'common_fourier_rw',
]


def date_part(
    DTindex,
    method: str = 'simple',
    set_index: bool = True,
    polynomial_degree: int = None,
):
    """Create date part columns from pd.DatetimeIndex.

    Args:
        DTindex (pd.DatetimeIndex): datetime index to provide dates
        method (str): expanded, recurring, or simple
            simple - just day, year, month, weekday
            expanded - all available futures
            recurring - all features that should commonly repeat without aging
            simple_2
            simple_3
            simple_binarized
            expanded_binarized
            common_fourier
        set_index (bool): if True, return DTindex as index of df
        polynomial_degree (int): add this degree of sklearn polynomial features if not None

    Returns:
        pd.Dataframe with DTindex
    """
    if "_poly" in method:
        method = method.replace("_poly", "")
        polynomial_degree = 2
    if method == 'recurring':
        date_part_df = pd.DataFrame(
            {
                'month': DTindex.month,
                'day': DTindex.day,
                'weekday': DTindex.weekday,
                'weekend': (DTindex.weekday > 4).astype(int),
                'hour': DTindex.hour,
                'quarter': DTindex.quarter,
                'midyear': (
                    (DTindex.dayofyear > 74) & (DTindex.dayofyear < 258)
                ).astype(
                    int
                ),  # 2 season
            }
        )
    elif method in ["simple_2", "simple_2_poly"]:
        date_part_df = pd.DataFrame(
            {
                'month': DTindex.month,
                'day': DTindex.day,
                'weekday': DTindex.weekday,
                'weekend': (DTindex.weekday > 4).astype(int),
                'epoch': pd.to_numeric(
                    DTindex, errors='coerce', downcast='integer'
                ).values
                / 100000000000,
            }
        )
    elif method in ["simple_3", "lunar_phase"]:
        # trying to *prevent* it from learning holidays for this one
        date_part_df = pd.DataFrame(
            {
                'month': pd.Categorical(
                    DTindex.month, categories=list(range(1, 13)), ordered=True
                ),
                'weekday': pd.Categorical(
                    DTindex.weekday, categories=list(range(7)), ordered=True
                ),
                'weekend': (DTindex.weekday > 4).astype(int),
                'quarter': DTindex.quarter,
                'epoch': DTindex.to_julian_date(),
            }
        )
        # date_part_df['weekday'] = date_part_df['month'].astype(pd.CategoricalDtype(categories=list(range(6))))
        date_part_df = pd.get_dummies(date_part_df, columns=['month', 'weekday'])
        if method == "lunar_phase":
            date_part_df['phase'] = moon_phase(DTindex)
    elif "simple_binarized" in method:
        date_part_df = pd.DataFrame(
            {
                'month': pd.Categorical(
                    DTindex.month, categories=list(range(1, 13)), ordered=True
                ),
                'weekday': pd.Categorical(
                    DTindex.weekday, categories=list(range(7)), ordered=True
                ),
                'day': DTindex.day,
                'weekend': (DTindex.weekday > 4).astype(int),
                'epoch': DTindex.to_julian_date(),
            }
        )
        date_part_df = pd.get_dummies(date_part_df, columns=['month', 'weekday'])
    elif method in "expanded_binarized":
        date_part_df = pd.DataFrame(
            {
                'month': pd.Categorical(
                    DTindex.month, categories=list(range(1, 13)), ordered=True
                ),
                'weekday': pd.Categorical(
                    DTindex.weekday, categories=list(range(7)), ordered=True
                ),
                'day': pd.Categorical(
                    DTindex.day, categories=list(range(1, 32)), ordered=True
                ),
                'weekdayofmonth': pd.Categorical(
                    (DTindex.day - 1) // 7 + 1,
                    categories=list(range(1, 6)),
                    ordered=True,
                ),
                'weekend': (DTindex.weekday > 4).astype(int),
                'quarter': DTindex.quarter,
                'epoch': DTindex.to_julian_date(),
            }
        )
        date_part_df = pd.get_dummies(
            date_part_df, columns=['month', 'weekday', 'day', 'weekdayofmonth']
        )
    elif method in ["common_fourier", "common_fourier_rw"]:
        seasonal_list = []
        DTmin = DTindex.min()
        DTmax = DTindex.max()
        # less than one year of data is always going to be an issue
        seasonal_ratio = (DTmax.year - DTmin.year + 1) / len(DTindex)
        # hourly
        if seasonal_ratio < 0.001:  # 0.00011 to 0.00023
            t = DTindex - pd.Timestamp("2030-01-01")
            t = (t.days * 24) + (t.components['minutes'] / 60)
            # add hourly, weekly, yearly
            seasonal_list.append(fourier_series(t, p=8766, n=10))
            seasonal_list.append(fourier_series(t, p=24, n=3))
            seasonal_list.append(fourier_series(t, p=168, n=5))
            # interactions
            seasonal_list.append(
                fourier_series(t, p=168, n=5) * fourier_series(t, p=24, n=5)
            )
            seasonal_list.append(
                fourier_series(t, p=168, n=3) * fourier_series(t, p=8766, n=3)
            )
        # daily (+ business day)
        elif seasonal_ratio < 0.012:  # 0.0027 to 0.0055
            t = (DTindex - pd.Timestamp("2030-01-01")).days
            # add yearly and weekly seasonality
            seasonal_list.append(fourier_series(t, p=365.25, n=10))
            seasonal_list.append(fourier_series(t, p=7, n=3))
            # interaction
            seasonal_list.append(
                fourier_series(t, p=7, n=5) * fourier_series(t, p=7, n=5)
            )
        # weekly
        elif seasonal_ratio < 0.05:  # 0.019 to 0.038
            t = (DTindex - pd.Timestamp("2030-01-01")).days
            seasonal_list.append(fourier_series(t, p=365.25, n=10))
            seasonal_list.append(fourier_series(t, p=28, n=3))
        # monthly
        elif seasonal_ratio < 0.5:  # 0.083 to 0.154
            t = (DTindex - pd.Timestamp("2030-01-01")).days
            seasonal_list.append(fourier_series(t, p=365.25, n=3))
            seasonal_list.append(fourier_series(t, p=1461, n=10))
        # yearly
        else:
            t = (DTindex - pd.Timestamp("2030-01-01")).days
            seasonal_list.append(fourier_series(t, p=1461, n=10))
        date_part_df = pd.DataFrame(np.concatenate(seasonal_list, axis=1)).rename(
            columns=lambda x: "seasonalitycommonfourier_" + str(x)
        )
        if method == "common_fourier_rw":
            date_part_df['epoch'] = (DTindex.to_julian_date() ** 0.65).astype(int)
    else:
        # method == "simple"
        date_part_df = pd.DataFrame(
            {
                'year': DTindex.year,
                'month': DTindex.month,
                'day': DTindex.day,
                'weekday': DTindex.weekday,
            }
        )
        if method == 'expanded':
            try:
                weekyear = DTindex.isocalendar().week.to_numpy()
            except Exception:
                weekyear = DTindex.week
            date_part_df2 = pd.DataFrame(
                {
                    'hour': DTindex.hour,
                    'week': weekyear,
                    'quarter': DTindex.quarter,
                    'dayofyear': DTindex.dayofyear,
                    'midyear': (
                        (DTindex.dayofyear > 74) & (DTindex.dayofyear < 258)
                    ).astype(
                        int
                    ),  # 2 season
                    'weekend': (DTindex.weekday > 4).astype(int),
                    'weekdayofmonth': (DTindex.day - 1) // 7 + 1,
                    'month_end': (DTindex.is_month_end).astype(int),
                    'month_start': (DTindex.is_month_start).astype(int),
                    "quarter_end": (DTindex.is_quarter_end).astype(int),
                    'year_end': (DTindex.is_year_end).astype(int),
                    'daysinmonth': DTindex.daysinmonth,
                    'epoch': pd.to_numeric(
                        DTindex, errors='coerce', downcast='integer'
                    ).values
                    - 946684800000000000,
                    'us_election_year': (DTindex.year % 4 == 0).astype(int),
                }
            )
            date_part_df = pd.concat([date_part_df, date_part_df2], axis=1)
    if polynomial_degree is not None:
        from sklearn.preprocessing import PolynomialFeatures

        date_part_df = pd.DataFrame(
            PolynomialFeatures(polynomial_degree, include_bias=False).fit_transform(
                date_part_df
            )
        )
        date_part_df.columns = ['dp' + str(x) for x in date_part_df.columns]
    if set_index:
        date_part_df.index = DTindex
    return date_part_df


def fourier_series(t, p=365.25, n=10):
    # 2 pi n / p
    x = 2 * np.pi * np.arange(1, n + 1) / p
    # 2 pi n / p * t
    x = x * np.asarray(t)[:, None]
    x = np.concatenate((np.cos(x), np.sin(x)), axis=1)
    return x


def create_seasonality_feature(DTindex, t, seasonality, history_days=None):
    # for consistency, all must have a range index, not date index
    # fourier orders
    if isinstance(seasonality, (int, float)):
        if history_days is None:
            history_days = (DTindex.max() - DTindex.min()).days
        return pd.DataFrame(
            fourier_series(np.asarray(t), seasonality / history_days, n=10)
        ).rename(columns=lambda x: f"seasonality{seasonality}_" + str(x))
    # dateparts
    elif seasonality == "dayofweek":
        return pd.get_dummies(
            pd.Categorical(DTindex.weekday, categories=list(range(7)), ordered=True)
        ).rename(columns=lambda x: f"{seasonality}_" + str(x))
    elif seasonality == "month":
        return pd.get_dummies(
            pd.Categorical(DTindex.month, categories=list(range(1, 13)), ordered=True)
        ).rename(columns=lambda x: f"{seasonality}_" + str(x))
    elif seasonality == "weekend":
        return pd.DataFrame((DTindex.weekday > 4).astype(int), columns=["weekend"])
    elif seasonality == "weekdayofmonth":
        return pd.get_dummies(
            pd.Categorical(
                (DTindex.day - 1) // 7 + 1,
                categories=list(range(1, 6)),
                ordered=True,
            )
        ).rename(columns=lambda x: f"{seasonality}_" + str(x))
    elif seasonality == "hour":
        return pd.get_dummies(
            pd.Categorical(DTindex.hour, categories=list(range(1, 25)), ordered=True)
        ).rename(columns=lambda x: f"{seasonality}_" + str(x))
    elif seasonality == "daysinmonth":
        return pd.DataFrame({'daysinmonth': DTindex.daysinmonth})
    elif seasonality == "quarter":
        return pd.get_dummies(
            pd.Categorical(DTindex.quarter, categories=list(range(1, 5)), ordered=True)
        ).rename(columns=lambda x: f"{seasonality}_" + str(x))
    elif seasonality in date_part_methods:
        return date_part(DTindex, method=seasonality, set_index=False)
    else:
        return ValueError(f"Seasonality `{seasonality}` not recognized")


def seasonal_window_match(
    DTindex, k, window_size, forecast_length, datepart_method, distance_metric
):
    array = date_part(DTindex, method=datepart_method).to_numpy()

    # when k is larger, can be more aggressive on allowing a longer portion into view
    min_k = 5
    if k > min_k:
        n_tail = min(window_size, forecast_length)
    else:
        n_tail = forecast_length
    # finding sliding windows to compare
    temp = sliding_window_view(array[:-n_tail, :], window_size, axis=0)
    # compare windows by metrics
    if distance_metric == "mae":
        scores = np.mean(np.abs(temp - array[-window_size:, :].T), axis=2)
    elif distance_metric == "mqae":
        q = 0.85
        ae = np.abs(temp - array[-window_size:, :].T)
        if ae.shape[2] <= 1:
            vals = ae
        else:
            qi = int(ae.shape[2] * q)
            qi = qi if qi > 1 else 1
            vals = np.partition(ae, qi, axis=2)[..., :qi]
        scores = np.mean(vals, axis=2)
    elif distance_metric == "mse":
        scores = np.mean((temp - array[-window_size:, :].T) ** 2, axis=2)
    else:
        raise ValueError(f"distance_metric: {distance_metric} not recognized")

    # select smallest windows
    min_idx = np.argpartition(scores.mean(axis=1), k - 1, axis=0)[:k]
    # take the period starting AFTER the window
    test = (
        np.broadcast_to(
            np.arange(0, forecast_length)[..., None],
            (forecast_length, min_idx.shape[0]),
        )
        + min_idx
        + window_size
    )
    # for data over the end, fill last value
    if k > min_k:
        test = np.where(test >= len(DTindex), -1, test)
    return test, scores
