import abc
import dataclasses
from typing import Dict, List

from .enums import AggregatorType


class Aggregator(abc.ABC):
    """Base class for aggregator methods.

    The instances contain parameters to compose the associated SQL query.
    Backend packages should extend these according to support, and implement
    the methods needed for conversion to string.
    """

    def __repr__(self) -> str:
        return "{0}({1})".format(self, ", ".join(
            "{0}={1}".format(f.name, repr(getattr(self, f.name)))
            for f in dataclasses.fields(self)
        ) if dataclasses.is_dataclass(self) else "")

    def __str__(self) -> str:
        return type(self).__name__

    def get_params(self) -> Dict[str, str]:
        """Returns a dict with the serialized params of the instance.

        Base class just returns an empty dict. More complex subclasses must
        override this method."""
        return {}

    @classmethod
    def new(cls, kwargs: dict):
        if dataclasses.is_dataclass(cls):
            field_names = set(f.name for f in dataclasses.fields(cls))
            return cls(**{k: v for k, v in kwargs.items() if k in field_names})
        else:
            return cls()

    @staticmethod
    def from_enum(enum: AggregatorType) -> "Aggregator":
        agg_classes = {
            AggregatorType.SUM: Sum,
            AggregatorType.COUNT: Count,
            AggregatorType.AVERAGE: Average,
            AggregatorType.MAX: Max,
            AggregatorType.MIN: Min,
            AggregatorType.BASICGROUPEDMEDIAN: BasicGroupedMedian,
            AggregatorType.WEIGHTEDSUM: WeightedSum,
            AggregatorType.WEIGHTEDAVERAGE: WeightedAverage,
            AggregatorType.REPLICATEWEIGHTMOE: ReplicateWeightMoe,
            AggregatorType.CALCULATEDMOE: CalculatedMoe,
            AggregatorType.WEIGHTEDAVERAGEMOE: WeightedAverageMoe,
        }
        return agg_classes[enum]


class Sum(Aggregator):
    pass


class Count(Aggregator):
    pass


class Average(Aggregator):
    pass


class Max(Aggregator):
    pass


class Min(Aggregator):
    pass


@dataclasses.dataclass
class BasicGroupedMedian(Aggregator):
    group_aggregator: str
    group_dimension: str

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)


@dataclasses.dataclass
class WeightedSum(Aggregator):
    """Weighted Sum is calculated against the measure's value column.

    `sum(column * weight_column)`

    First roll-up is sum(column * weight_column) as weighted_sum_first
    Second roll-up is sum(weighted_sum_first) as weighted_sum_final
    """
    weight_column: str

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)


@dataclasses.dataclass
class WeightedAverage(Aggregator):
    """Weighted Average is calculated against the measure's value column.

    `sum(column * weight_column) / sum(weight_column)`
    """
    weight_column: str

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)


@dataclasses.dataclass
class ReplicateWeightMoe(Aggregator):
    """Where the measure column is the primary value, and a list of secondary
    columns is provided to the MO aggregator:

    The general equation for Margin of Error is

    `cv * pow(df * (pow(sum(column) - sum(secondary_columns[0]), 2) + pow(sum(column) - sum(secondary_columns_[1]), 2) + ...), 0.5)`

    where cv = critical value, for 90% confidence interval it's 1.645
    where df = design factor / #samples
    """
    critical_value: float
    design_factor: float
    secondary_columns: List[str]

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)


@dataclasses.dataclass
class CalculatedMoe(Aggregator):
    """Where the moe is already calculated for each row, and this just
    aggregates them correctly.

    `sqrt(sum(power(moe / cv, 2))) * cv`

    where cv = critical value; for 90% confidence interval it's 1.645
    """
    critical_value: float

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)


@dataclasses.dataclass
class WeightedAverageMoe(Aggregator):
    """
    Where the measure column is the primary value,
    and a list of secondary weight columns is provided to the MO aggregator:

    The general equation for Margin of Error is

    `cv * pow(df * (pow(( sum(column * primary_weight)/sum(primary_weight) ) - ( sum(column * secondary_weight_columns[0])/sum(secondary_weight_columns[0]) ), 2) + pow(( sum(column * primary_weight)/sum(primary_weight) ) - ( sum(column * secondary_weight_columns[1]/sum(secondary_weight_columns[1]) ), 2) + ...), 0.5)`

    where cv = critical value, for 90% confidence interval it's 1.645
    where df = design factor / #samples
    """
    critical_value: float
    design_factor: float
    primary_weight: str
    secondary_weight_columns: List[str]

    def get_params(self) -> Dict[str, str]:
        return dataclasses.asdict(self)
