import os
from abc import abstractmethod
from typing import Collection, Tuple, List, Union

import mongomock
import pymongo.collection
from fastapi import HTTPException

from optimade.filterparser import LarkParser
from optimade.filtertransformers.mongo import MongoTransformer
from optimade.models import EntryResource

from .config import CONFIG
from .mappers import BaseResourceMapper
from .query_params import EntryListingQueryParams, SingleEntryQueryParams

try:
    CI_FORCE_MONGO = bool(int(os.environ.get("OPTIMADE_CI_FORCE_MONGO", 0)))
except (TypeError, ValueError):  # pragma: no cover
    CI_FORCE_MONGO = False


if CONFIG.use_real_mongo or CI_FORCE_MONGO:
    from pymongo import MongoClient

    client = MongoClient(CONFIG.mongo_uri)
    print("Using: Real MongoDB (pymongo)")
else:
    from mongomock import MongoClient

    client = MongoClient()
    print("Using: Mock MongoDB (mongomock)")


class EntryCollection(Collection):  # pylint: disable=inherit-non-class
    def __init__(
        self,
        collection,
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
    ):
        self.collection = collection
        self.parser = LarkParser()
        self.resource_cls = resource_cls
        self.resource_mapper = resource_mapper

    def __len__(self):
        return self.collection.count()

    def __iter__(self):
        return self.collection.find()

    def __contains__(self, entry):
        return self.collection.count(entry) > 0

    def get_attribute_fields(self) -> set:
        schema = self.resource_cls.schema()
        attributes = schema["properties"]["attributes"]
        if "allOf" in attributes:
            allOf = attributes.pop("allOf")
            for dict_ in allOf:
                attributes.update(dict_)
        if "$ref" in attributes:
            path = attributes["$ref"].split("/")[1:]
            attributes = schema.copy()
            while path:
                next_key = path.pop(0)
                attributes = attributes[next_key]
        return set(attributes["properties"].keys())

    @abstractmethod
    def find(
        self, params: EntryListingQueryParams
    ) -> Tuple[List[EntryResource], int, bool, set]:
        """
        Fetches results and indicates if more data is available.

        Also gives the total number of data available in the absence of page_limit.

        Args:
            params (EntryListingQueryParams): entry listing URL query params

        Returns:
            Tuple[List[Entry], int, bool, set]: (results, data_returned, more_data_available, fields)

        """

    def count(self, **kwargs):
        return self.collection.count(**kwargs)


class MongoCollection(EntryCollection):
    def __init__(
        self,
        collection: Union[
            pymongo.collection.Collection, mongomock.collection.Collection
        ],
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
    ):
        super().__init__(collection, resource_cls, resource_mapper)
        self.transformer = MongoTransformer()

        self.provider = CONFIG.provider["prefix"]
        self.provider_fields = CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
        self.parser = LarkParser(
            version=(0, 10, 1), variant="default"
        )  # The MongoTransformer only supports v0.10.1 as the latest grammar

        # check aliases do not clash with mongo operators
        self._mapper_aliases = self.resource_mapper.all_aliases()
        if any(
            alias[0].startswith("$") or alias[1].startswith("$")
            for alias in self._mapper_aliases
        ):
            raise RuntimeError(
                f"Cannot define an alias starting with a '$': {self._mapper_aliases}"
            )

    def __len__(self):
        return self.collection.estimated_document_count()

    def __contains__(self, entry):
        return self.collection.count_documents(entry.dict()) > 0

    def count(self, **kwargs):
        for k in list(kwargs.keys()):
            if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"):
                del kwargs[k]
        if "filter" not in kwargs:  # "filter" is needed for count_documents()
            kwargs["filter"] = {}
        return self.collection.count_documents(**kwargs)

    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[List[EntryResource], int, bool, set]:
        criteria = self._parse_params(params)

        all_fields = criteria.pop("fields")
        if getattr(params, "response_fields", False):
            fields = set(params.response_fields.split(","))
            fields |= self.resource_mapper.get_required_fields()
        else:
            fields = all_fields.copy()

        results = []
        for doc in self.collection.find(**criteria):
            results.append(self.resource_cls(**self.resource_mapper.map_back(doc)))

        nresults_now = len(results)
        if isinstance(params, EntryListingQueryParams):
            criteria_nolimit = criteria.copy()
            criteria_nolimit.pop("limit", None)
            data_returned = self.count(**criteria_nolimit)
            more_data_available = nresults_now < data_returned
        else:
            # SingleEntryQueryParams, e.g., /structures/{entry_id}
            data_returned = nresults_now
            more_data_available = False
            if nresults_now > 1:
                raise HTTPException(
                    status_code=404,
                    detail=f"Instead of a single entry, {nresults_now} entries were found",
                )
            results = results[0] if results else None

        return results, data_returned, more_data_available, all_fields - fields

    def _alias_filter(self, _filter: dict) -> dict:
        """ Check whether any fields in the filter have aliases so
        that they can be renamed for the Mongo query.

        """
        # if there are no defined aliases, just skip
        if not self._mapper_aliases:
            return _filter

        if isinstance(_filter, dict):
            unaliased_filter = {}
            for key, value in _filter.items():
                unaliased_filter[
                    self.resource_mapper.alias_for(key)
                ] = self._alias_filter(value)
            return unaliased_filter

        elif isinstance(_filter, list):
            return [self._alias_filter(subdict) for subdict in _filter]

        # if we already have a string, or another value, then there
        # are no more aliases to parse
        else:
            return _filter

    def _parse_params(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> dict:
        cursor_kwargs = {}

        if getattr(params, "filter", False):
            tree = self.parser.parse(params.filter)
            mongo_filter = self.transformer.transform(tree)
            cursor_kwargs["filter"] = self._alias_filter(mongo_filter)
        else:
            cursor_kwargs["filter"] = {}

        if (
            getattr(params, "response_format", False)
            and params.response_format != "json"
        ):
            raise HTTPException(
                status_code=400, detail="Only 'json' response_format supported"
            )

        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise HTTPException(
                    status_code=403,  # Forbidden
                    detail=f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        # All OPTiMaDe fields
        fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()
        fields |= self.get_attribute_fields()
        # All provider-specific fields
        fields |= {self.provider + _ for _ in self.provider_fields}
        cursor_kwargs["fields"] = fields
        cursor_kwargs["projection"] = [
            self.resource_mapper.alias_for(f) for f in fields
        ]

        if getattr(params, "sort", False):
            sort_spec = []
            for elt in params.sort.split(","):
                field = elt
                sort_dir = 1
                if elt.startswith("-"):
                    field = field[1:]
                    sort_dir = -1
                sort_spec.append((field, sort_dir))
            cursor_kwargs["sort"] = sort_spec

        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs
