# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/blocks.ipynb (unless otherwise specified).

__all__ = ['BaseRate', 'HomeAdvantage', 'TeamStrength', 'KeyBlock', 'ConstantBlock']

# Cell
import abc

import numpy as np

import mezzala.parameters

# Internal Cell


class ModelBlockABC(abc.ABC):
    """
    Base class for model blocks
    """
    PRIORITY = 0

    def param_keys(self, adapter, data):
        return []

    def constraints(self, adapter, data):
        return []

    def home_terms(self, adapter, data):
        return []

    def away_terms(self, adapter, data):
        return []

# Cell


class BaseRate(ModelBlockABC):
    """
    Estimate average goalscoring rate as a separate parameter.

    This can be useful, since it results in both team offence and
    team defence parameters being centered around 1.0
    """

    def __init__(self):
        pass

    def __repr__(self):
        return 'BaseRate()'

    def param_keys(self, adapter, data):
        return [mezzala.parameters.AVG_KEY]

    def home_terms(self, adapter, row):
        return [mezzala.parameters.AVG_KEY]

    def away_terms(self, adapter, row):
        return [mezzala.parameters.AVG_KEY]

# Cell


class HomeAdvantage(ModelBlockABC):
    """
    Estimate home advantage.

    Assumes constant home advantage is present in every match in the
    dataset
    """

    def __init__(self):
        # TODO: allow HFA on/off depending on the data?
        pass

    def __repr__(self):
        return 'HomeAdvantage()'

    def param_keys(self, adapter, data):
        return [mezzala.parameters.HFA_KEY]

    def home_terms(self, adapter, row):
        return [mezzala.parameters.HFA_KEY]

# Cell


class TeamStrength(ModelBlockABC):
    """
    Estimate team offence and team defence parameters.
    """


    # This is a gross hack so that we know that the
    # team strength parameters come first, and thus can
    # do the constraints (which are positionally indexed)
    PRIORITY = 1

    def __init__(self):
        pass

    def __repr__(self):
        return 'TeamStrength()'

    def _teams(self, adapter, data):
        return set(adapter.home_team(r) for r in data) | set(adapter.away_team(r) for r in data)

    def offence_key(self, label):
        return mezzala.parameters.OffenceParameterKey(label)

    def defence_key(self, label):
        return mezzala.parameters.DefenceParameterKey(label)

    def param_keys(self, adapter, data):
        teams = self._teams(adapter, data)

        offence = [self.offence_key(t) for t in teams]
        defence = [self.defence_key(t) for t in teams]

        return offence + defence

    def constraints(self, adapter, data):
        n_teams = len(self._teams(adapter, data))
        return [
            # Force team offence parameters to average to 1
            {'fun': lambda x: 1 - np.mean(np.exp(x[0:n_teams])),
             'type': 'eq'},
        ]

    def home_terms(self, adapter, row):
        return [
            self.offence_key(adapter.home_team(row)),
            self.defence_key(adapter.away_team(row))
        ]

    def away_terms(self, adapter, row):
        return [
            self.offence_key(adapter.away_team(row)),
            self.defence_key(adapter.home_team(row))
        ]

# Cell


class KeyBlock(ModelBlockABC):
    """
    Generic model block for adding arbitrary model terms from the data
    to both home and away team
    """
    def __init__(self, key):
        self.key = key

    def __repr__(self):
        return 'KeyBlock()'

    def param_keys(self, adapter, data):
        return list(set(self.key(r) for r in data))

    def home_terms(self, adapter, row):
        return [self.key(row)]

    def away_terms(self, adapter, row):
        return [self.key(row)]

# Cell


class ConstantBlock(ModelBlockABC):
    """
    A model block for adding specific model terms to the parameter keys.

    Can be useful in conjunction with `LumpedAdapter` to ensure that certain parameters
    are in the model (even if they aren't estimated)
    """
    def __init__(self, *args):
        self.terms = args

    def __repr__(self):
        return 'ConstantBlock()'

    def param_keys(self, adapter, data):
        return list(self.terms)