from collections import namedtuple
from typing import Iterator, NamedTuple, Iterable, Type

from hcube.api.backend import CubeBackend
from hcube.api.models.aggregation import AggregationOp, Count
from hcube.api.models.cube import Cube
from hcube.api.models.filters import (
    Filter,
    ListFilter,
    IsNullFilter,
    ComparisonFilter,
    ComparisonType,
    NegativeListFilter,
    EqualityFilter,
)
from hcube.api.models.ordering import OrderDirection, OrderSpec
from hcube.api.models.query import CubeQuery
from hcube.api.models.transforms import ExplicitMappingTransform
from hcube.settings import GlobalSettings


class NaiveCubeBackend(CubeBackend):
    """
    Naive pure-python list based Cube backend.

    It is meant as an example and documentation and not for any serious work.
    """

    def __init__(self):
        self.values = []
        self._aggreg_state = {}

    def store_records(self, cube: Type[Cube], records: Iterable[NamedTuple]):
        self.values = list(cube.cleanup_records(records))

    def get_records(self, query: CubeQuery) -> Iterator[NamedTuple]:
        aggreg_base = {agg.name: self._base_value(agg.op) for agg in query.aggregations}
        result = {}
        # if no groups or aggregations were defined, we are dealing with all dimensions
        if query.groups or query.aggregations:
            fields = query.groups
        else:
            fields = list(query.cube._dimensions.values()) + list(query.cube._metrics.values())
        for record in self.values:
            key = tuple(getattr(record, group.name) for group in fields)
            if not all(self._row_passes_filter(record, fltr) for fltr in query.filters):
                continue
            if key not in result:
                result[key] = {**aggreg_base}
            for agg in query.aggregations:
                if isinstance(agg, Count) and agg.distinct:
                    value = getattr(record, agg.distinct.name)
                    result[key][agg.name] = self._count_distinct(agg.name, value)
                else:
                    value = getattr(record, agg.metric.name) if agg.metric else None
                    result[key][agg.name] = self._aggregate(agg.op, result[key][agg.name], value)
            for transform in query.transforms:
                base_value = getattr(record, transform.dimension.name)
                if isinstance(transform, ExplicitMappingTransform):
                    result[key][transform.name] = transform.mapping.get(base_value, base_value)
                else:
                    raise ValueError(
                        f"Transformation {transform.__class__} is not supported by this backend"
                    )
        typ = namedtuple(
            "AggRecord",
            [grp.name for grp in fields]
            + [agg.name for agg in query.aggregations]
            + [tr.name for tr in query.transforms],
        )
        ret = [typ(*key, **aggs) for key, aggs in result.items()]
        for sorter in reversed(query.orderings):
            # we assume sorting is stable, so sorting from the least important dimension will be
            # preserved in the result
            ret.sort(
                key=lambda x: self._sort_key(x, sorter),
                reverse=sorter.direction == OrderDirection.DESC,
            )
        # ensure there is at least one record if aggregations were applied
        if not ret and query.aggregations:
            ret = [typ(**{agg.name: self._empty_value(agg.op) for agg in query.aggregations})]
        # limit
        if query.limit:
            ret = ret[: query.limit]
        return ret

    def get_count(self, query: CubeQuery) -> int:
        return len(list(self.get_records(query)))

    def delete_records(self, query: CubeQuery) -> None:
        new_records = []
        for record in self.values:
            if not all(self._row_passes_filter(record, fltr) for fltr in query.filters):
                new_records.append(record)
        self.values = new_records

    @classmethod
    def _sort_key(cls, record, sorter: OrderSpec):
        value = getattr(record, sorter.dimension.name)
        if value is None:
            # sort nulls last - use 1 as first value
            return 1, sorter.dimension.default
        return 0, value

    @classmethod
    def _row_passes_filter(cls, row, fltr: Filter) -> bool:
        value = getattr(row, fltr.dimension.name)
        if isinstance(fltr, ListFilter):
            return value in fltr.values
        if isinstance(fltr, NegativeListFilter):
            return value not in fltr.values
        if isinstance(fltr, IsNullFilter):
            return (value is None) == fltr.is_null
        if isinstance(fltr, ComparisonFilter):
            if fltr.comparison == ComparisonType.GT:
                return value > fltr.value
            if fltr.comparison == ComparisonType.GTE:
                return value >= fltr.value
            if fltr.comparison == ComparisonType.LT:
                return value < fltr.value
            if fltr.comparison == ComparisonType.LTE:
                return value <= fltr.value
        if isinstance(fltr, EqualityFilter):
            return value == fltr.dimension.to_python(fltr.value)
        raise ValueError(f"unsupported filter {fltr}")

    @classmethod
    def _aggregate(cls, op: AggregationOp, base, value):
        if op == AggregationOp.COUNT:
            return base + 1
        elif op == AggregationOp.SUM:
            return base + value
        elif op == AggregationOp.MIN:
            return min(base, value) if base is not None else value
        elif op == AggregationOp.MAX:
            return max(base, value) if base is not None else value

    def _count_distinct(self, name, value):
        if name not in self._aggreg_state:
            self._aggreg_state[name] = set()
        self._aggreg_state[name].add(value)
        return len(self._aggreg_state[name])

    @classmethod
    def _base_value(cls, op: AggregationOp):
        if op in (AggregationOp.COUNT, AggregationOp.SUM):
            return 0
        return None

    @classmethod
    def _empty_value(cls, op: AggregationOp):
        if op in (AggregationOp.COUNT,):
            return 0
        return 0 if GlobalSettings.aggregates_zero_for_empty_data else None
