# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05-price-moe.ipynb (unless otherwise specified).

__all__ = ['construct_dispatchable_lims_df', 'construct_pred_mask_df', 'AxTransformer', 'set_ticks', 'set_date_ticks',
           'construct_df_pred', 'construct_pred_ts', 'calc_error_metrics', 'get_model_pred_ts', 'weighted_mean_s']

# Cell
import json
import pandas as pd
import numpy as np

import pickle
import scipy
from sklearn import linear_model
from sklearn.metrics import r2_score
from collections.abc import Iterable

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from ipypb import track
from IPython.display import JSON

from moepy import lowess, eda
from .surface import PicklableFunction

# Cell
def construct_dispatchable_lims_df(s_dispatchable, rolling_w=3, daily_quantiles=[0.001, 0.999]):
    """Identifies the rolling limits to be used in masking"""
    df_dispatchable_lims = (s_dispatchable
                            .resample('1d')
                            .quantile(daily_quantiles)
                            .unstack()
                            .rolling(rolling_w*7)
                            .mean()
                            .bfill()
                            .ffill()
                            .iloc[:-1, :]
                           )

    df_dispatchable_lims.index = pd.to_datetime(df_dispatchable_lims.index.strftime('%Y-%m-%d'))

    return df_dispatchable_lims

def construct_pred_mask_df(df_pred, df_dispatchable_lims):
    """Constructs a DataFrame mask for the prediction"""
    df_pred = df_pred[df_dispatchable_lims.index]
    df_pred_mask = pd.DataFrame(dict(zip(df_pred.columns, [df_pred.index]*df_pred.shape[1])), index=df_pred.index)
    df_pred_mask = (df_pred_mask > df_dispatchable_lims.iloc[:, 0].values) & (df_pred_mask < df_dispatchable_lims.iloc[:, 1].values)

    df_pred.columns = pd.to_datetime(df_pred.columns)
    df_pred_mask.columns = pd.to_datetime(df_pred_mask.columns)

    return df_pred_mask

# Cell
class AxTransformer:
    """Helper class for cleaning axis tick locations and labels"""
    def __init__(self, datetime_vals=False):
        self.datetime_vals = datetime_vals
        self.lr = linear_model.LinearRegression()

        return

    def process_tick_vals(self, tick_vals):
        if not isinstance(tick_vals, Iterable) or isinstance(tick_vals, str):
            tick_vals = [tick_vals]

        if self.datetime_vals == True:
            tick_vals = pd.to_datetime(tick_vals).astype(int).values

        tick_vals = np.array(tick_vals)

        return tick_vals

    def fit(self, ax, axis='x'):
        axis = getattr(ax, f'get_{axis}axis')()

        tick_locs = axis.get_ticklocs()
        tick_vals = self.process_tick_vals([label._text for label in axis.get_ticklabels()])

        self.lr.fit(tick_vals.reshape(-1, 1), tick_locs)

        return

    def transform(self, tick_vals):
        tick_vals = self.process_tick_vals(tick_vals)
        tick_locs = self.lr.predict(np.array(tick_vals).reshape(-1, 1))

        return tick_locs

def set_ticks(ax, tick_locs, tick_labels=None, axis='y'):
    """Sets ticks at standard numerical locations"""
    if tick_labels is None:
        tick_labels = tick_locs
    ax_transformer = AxTransformer()
    ax_transformer.fit(ax, axis=axis)

    getattr(ax, f'set_{axis}ticks')(ax_transformer.transform(tick_locs))
    getattr(ax, f'set_{axis}ticklabels')(tick_labels)

    ax.tick_params(axis=axis, which='both', bottom=True, top=False, labelbottom=True)

    return ax

def set_date_ticks(ax, start_date, end_date, axis='y', date_format='%Y-%m-%d', **date_range_kwargs):
    """Sets ticks at datetime locations"""
    dt_rng = pd.date_range(start_date, end_date, **date_range_kwargs)

    ax_transformer = AxTransformer(datetime_vals=True)
    ax_transformer.fit(ax, axis=axis)

    getattr(ax, f'set_{axis}ticks')(ax_transformer.transform(dt_rng))
    getattr(ax, f'set_{axis}ticklabels')(dt_rng.strftime(date_format))

    ax.tick_params(axis=axis, which='both', bottom=True, top=False, labelbottom=True)

    return ax

# Cell
def construct_df_pred(model_fp, x_pred=np.linspace(-2, 61, 631), dt_pred=pd.date_range('2009-01-01', '2020-12-31', freq='1D')):
    """Constructs the prediction surface for the specified pre-fitted model"""
    smooth_dates = pickle.load(open(model_fp, 'rb'))
    df_pred = smooth_dates.predict(x_pred=x_pred, dt_pred=dt_pred)
    df_pred.index = np.round(df_pred.index, 1)

    return df_pred

# Cell
def construct_pred_ts(s, df_pred):
    """Uses the time-adaptive LOWESS surface to generate time-series prediction"""
    s_pred_ts = pd.Series(index=s.index, dtype='float64')

    for dt_idx, val in track(s.iteritems(), total=s.size):
        s_pred_ts.loc[dt_idx] = df_pred.loc[round(val, 1), dt_idx.strftime('%Y-%m-%d')]

    return s_pred_ts

# Cell
def calc_error_metrics(s_err, max_err_quantile=1):
    """Calculates several error metrics using the passed error series"""
    if s_err.isnull().sum() > 0:
        s_err = s_err.dropna()

    max_err_cutoff = s_err.abs().quantile(max_err_quantile)
    s_err = s_err[s_err.abs()<=max_err_cutoff]

    metrics = {
        'median_abs_err': s_err.abs().median(),
        'mean_abs_err': s_err.abs().mean(),
        'root_mean_square_error': np.sqrt((s_err**2).mean())
    }

    return metrics

# Cell
def get_model_pred_ts(s, model_fp, s_demand=None, x_pred=np.linspace(-2, 61, 631), dt_pred=pd.date_range('2009-01-01', '2020-12-31', freq='1D')):
    """Constructs the time-series prediction for the specified pre-fitted model"""
    df_pred = construct_df_pred(model_fp, x_pred=x_pred, dt_pred=dt_pred)
    s_cleaned = s.dropna().loc[df_pred.columns.min():df_pred.columns.max()+pd.Timedelta(hours=23, minutes=30)]
    s_pred_ts = construct_pred_ts(s_cleaned, df_pred)

    if s_demand is None:
        return s_pred_ts
    else:
        s_cleaned = s_demand.dropna().loc[df_pred.columns.min():df_pred.columns.max()+pd.Timedelta(hours=23, minutes=30)]
        s_pred_ts_demand = construct_pred_ts(s_cleaned, df_pred)
        return s_pred_ts, s_pred_ts_demand

# Cell
def weighted_mean_s(s, s_weight=None, dt_rng=pd.date_range('2009-12-01', '2021-01-01', freq='W'), end_dt_delta_days=7):
    """Calculates the weighted average of a series"""
    capture_prices = dict()

    for start_dt in dt_rng:
        end_dt = start_dt + pd.Timedelta(days=end_dt_delta_days)

        if s_weight is not None:
            weights = s_weight[start_dt:end_dt]
        else:
            weights=None

        capture_prices[start_dt] = np.average(s[start_dt:end_dt], weights=weights)

    s_capture_prices = pd.Series(capture_prices)
    s_capture_prices.index = pd.to_datetime(s_capture_prices.index)

    return s_capture_prices