# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/models.ipynb (unless otherwise specified).

__all__ = ['ScorelinePrediction', 'Outcomes', 'OutcomePrediction', 'scoreline_to_outcome', 'scorelines_to_outcomes',
           'DixonColes']

# Cell
import collections
import dataclasses
import enum
import functools
import itertools
import typing
import warnings

import numpy as np
import scipy.stats
import scipy.optimize

import mezzala.blocks
import mezzala.weights
import mezzala.parameters

# Cell


@dataclasses.dataclass(frozen=True)
class ScorelinePrediction:
    home_goals: int
    away_goals: int
    probability: float


# Cell


class Outcomes(enum.Enum):
    HOME_WIN = 'Home win'
    DRAW = 'Draw'
    AWAY_WIN = 'Away win'

    def __repr__(self):
        return f"Outcomes('{self.value}')"


@dataclasses.dataclass(frozen=True)
class OutcomePrediction:
    outcome: Outcomes
    probability: float

# Cell


def scoreline_to_outcome(home_goals, away_goals):
    if home_goals > away_goals:
        return Outcomes.HOME_WIN
    if home_goals == away_goals:
        return Outcomes.DRAW
    if home_goals < away_goals:
        return Outcomes.AWAY_WIN


def scorelines_to_outcomes(scorelines):
    return {
        outcome: OutcomePrediction(
            outcome,
            sum(s.probability for s in scorelines if scoreline_to_outcome(s.home_goals, s.away_goals) == outcome)
        )
        for outcome in Outcomes
    }

# Cell

_DEFAULT_BLOCKS = [
    mezzala.blocks.BaseRate(),
    mezzala.blocks.HomeAdvantage(),
    mezzala.blocks.TeamStrength(),
]


class DixonColes:
    """
    Dixon-Coles models in Python
    """

    def __init__(self, adapter, blocks=_DEFAULT_BLOCKS, weight=mezzala.weights.UniformWeight(), params=None):
        # NOTE: Should params be stored internally as separate lists of keys and values?
        # Then `params` (the dict) can be a property?
        self.params = params
        self.adapter = adapter
        self.weight = weight
        self._blocks = blocks

    def __repr__(self):
        return f'DixonColes(adapter={repr(self.adapter)}, blocks={repr([b for b in self.blocks])}), weight={repr(self.weight)}'

    @property
    def blocks(self):
        # Make sure blocks are always in the correct order
        return sorted(self._blocks, key=lambda x: -x.PRIORITY)

    def home_goals(self, row):
        """ Returns home goals scored """
        return self.adapter.home_goals(row)

    def away_goals(self, row):
        """ Returns away goals scored """
        return self.adapter.away_goals(row)

    def parse_params(self, data):
        """ Returns a tuple of (parameter_names, [constraints]) """
        base_params = [mezzala.parameters.RHO_KEY]
        block_params = list(itertools.chain(*[b.param_keys(self.adapter, data) for b in self.blocks]))
        return (
            block_params + base_params,
            list(itertools.chain(*[b.constraints(self.adapter, data) for b in self.blocks]))
        )

    def _home_terms(self, row):
        return dict(itertools.chain(*[b.home_terms(self.adapter, row) for b in self.blocks]))

    def _away_terms(self, row):
        return dict(itertools.chain(*[b.away_terms(self.adapter, row) for b in self.blocks]))

    # Core methods

    @staticmethod
    def _assign_params(param_keys, param_values):
        return dict(zip(param_keys, param_values))

    def _create_feature_matrices(self, param_keys, data):
        """ Create X (feature) matrices for home and away poisson rates """
        home_X = np.empty([len(data), len(param_keys)])
        away_X = np.empty([len(data), len(param_keys)])
        for row_i, row in enumerate(data):
            home_rate_terms = self._home_terms(row)
            away_rate_terms = self._away_terms(row)
            for param_i, param_key in enumerate(param_keys):
                home_X[row_i, param_i] = home_rate_terms.get(param_key, 0)
                away_X[row_i, param_i] = away_rate_terms.get(param_key, 0)
        return home_X, away_X

    @staticmethod
    def _tau(home_goals, away_goals, home_rate, away_rate, rho):

        tau = np.ones(len(home_goals))
        tau = np.where((home_goals == 0) & (away_goals == 0), 1 - home_rate*away_rate*rho, tau)
        tau = np.where((home_goals == 0) & (away_goals == 1), 1 + home_rate*rho, tau)
        tau = np.where((home_goals == 1) & (away_goals == 0), 1 + away_rate*rho, tau)
        tau = np.where((home_goals == 1) & (away_goals == 1), 1 - rho, tau)

        return tau

    def _log_like(self, home_goals, away_goals, home_rate, away_rate, rho):
        return (
            scipy.stats.poisson.logpmf(home_goals, home_rate) +
            scipy.stats.poisson.logpmf(away_goals, away_rate) +
            np.log(self._tau(home_goals, away_goals, home_rate, away_rate, rho))
        )

    def objective_fn(self, data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs):
        rho = xs[rho_ix]

        # Parameters are estimated in log-space, but `scipy.stats.poisson`
        # expects real number inputs, so we have to use `np.exp`
        home_rate = np.exp(np.dot(home_X, xs))
        away_rate = np.exp(np.dot(away_X, xs))

        log_like = self._log_like(home_goals, away_goals, home_rate, away_rate, rho)
        pseudo_log_like = log_like * weights
        return -np.sum(pseudo_log_like)

    def fit(self, data, **kwargs):
        param_keys, constraints = self.parse_params(data)

        init_params = (
            # Attempt to initialise parameters from any already-existing parameters
            # This substantially speeds up fitting during (e.g.) backtesting
            np.asarray([self.params.get(p, 0) for p in param_keys])
            # If the model has no parameters, just initialise with 0s
            if self.params
            else np.zeros(len(param_keys))
        )

        # Precalculate the things we can (for speed)

        # Create X (feature) matrices for home and away poisson rates
        home_X, away_X = self._create_feature_matrices(param_keys, data)

        # Get home goals, away goals, and weights from the data
        home_goals, away_goals = np.empty(len(data)), np.empty(len(data))
        weights = np.empty(len(data))
        for i, row in enumerate(data):
            home_goals[i] = self.home_goals(row)
            away_goals[i] = self.away_goals(row)
            weights[i] = self.weight(row)

        # Get the index of the Rho correlation parameter
        rho_ix = param_keys.index(mezzala.parameters.RHO_KEY)

        # Optimise!
        with warnings.catch_warnings():
            # This is a hack
            # Because we haven't properly constrained `rho`, it's possible for 0 or even negative
            # values of `tau` (and therefore invalid probabilities)
            # Ignoring the warnings has little practical impact, since the model
            # will still find the objective function's minimum point regardless
            warnings.simplefilter('ignore')

            estimate = scipy.optimize.minimize(
                lambda xs: self.objective_fn(data, home_goals, away_goals, weights, home_X, away_X, rho_ix, xs),
                x0=init_params,
                constraints=constraints,
                **kwargs
            )

        # Parse the estimates into parameter map
        self.params = self._assign_params(param_keys, estimate.x)

        return self

    def predict_one(self, row, up_to=26):
        scorelines = list(itertools.product(range(up_to), repeat=2))

        home_goals = np.asarray([h for h, a in scorelines])
        away_goals = np.asarray([a for h, a in scorelines])

        param_keys = self.params.keys()
        param_values = np.asarray([v for v in self.params.values()])

        home_X, away_X = self._create_feature_matrices(param_keys, [row])

        home_rate = np.exp(np.dot(home_X, param_values))
        away_rate = np.exp(np.dot(away_X, param_values))
        rho = self.params[mezzala.parameters.RHO_KEY]

        probs = np.exp(self._log_like(home_goals, away_goals, home_rate, away_rate, rho))

        return [ScorelinePrediction(*vals) for vals in zip(home_goals, away_goals, probs)]

    def predict(self, data, up_to=26):
        scorelines = [self.predict_one(row, up_to=up_to) for row in data]
        return scorelines