import copy
import threading
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple

from optuna.distributions import BaseDistribution
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


SearchSpaceSetT = Set[Tuple[str, BaseDistribution]]
SearchSpaceListT = List[Tuple[str, BaseDistribution]]

#  In-memory cache
cached_extra_study_property_cache_lock = threading.Lock()
cached_extra_study_property_cache: Dict[int, "_CachedExtraStudyProperty"] = {}


def get_cached_extra_study_property(
    study_id: int, trials: List[FrozenTrial]
) -> Tuple[SearchSpaceListT, SearchSpaceListT, List[Tuple[str, bool]], bool]:
    with cached_extra_study_property_cache_lock:
        cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None)
        if cached_extra_study_property is None:
            cached_extra_study_property = _CachedExtraStudyProperty()
        cached_extra_study_property.update(trials)
        cached_extra_study_property_cache[study_id] = cached_extra_study_property
        return (
            cached_extra_study_property.intersection,
            cached_extra_study_property.union,
            cached_extra_study_property.union_user_attrs,
            cached_extra_study_property.has_intermediate_values,
        )


class _CachedExtraStudyProperty:
    def __init__(self) -> None:
        self._cursor: int = -1
        # TODO: intersection_search_space and union_search_space look more clear since now we have
        # union_user_attrs.
        self._intersection: Optional[SearchSpaceSetT] = None
        self._union: SearchSpaceSetT = set()
        self._union_user_attrs: Dict[str, bool] = {}  # attr_name: is_sortable (= is_number)
        self.has_intermediate_values: bool = False

    @property
    def intersection(self) -> SearchSpaceListT:
        if self._intersection is None:
            return []
        intersection = list(self._intersection)
        intersection.sort(key=lambda x: x[0])
        return intersection

    @property
    def union(self) -> SearchSpaceListT:
        union = list(self._union)
        union.sort(key=lambda x: x[0])
        return union

    @property
    def union_user_attrs(self) -> List[Tuple[str, bool]]:
        union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()]
        sorted(union, key=lambda x: x[0])
        return union

    def update(self, trials: List[FrozenTrial]) -> None:
        next_cursor = self._cursor
        for trial in reversed(trials):
            if self._cursor > trial.number:
                break

            if not trial.state.is_finished():
                next_cursor = trial.number

            self._update_user_attrs(trial)
            if trial.state != TrialState.FAIL:
                self._update_intermediate_values(trial)
                self._update_search_space(trial)

        self._cursor = next_cursor

    def _update_user_attrs(self, trial: FrozenTrial) -> None:
        # TODO(c-bata): Support numpy-specific number types.
        current_user_attrs = {k: isinstance(v, (int, float)) for k, v in trial.user_attrs.items()}
        for attr_name, current_is_sortable in current_user_attrs.items():
            is_sortable = self._union_user_attrs.get(attr_name)
            if is_sortable is None:
                self._union_user_attrs[attr_name] = current_is_sortable
            elif is_sortable and not current_is_sortable:
                self._union_user_attrs[attr_name] = False

    def _update_intermediate_values(self, trial: FrozenTrial) -> None:
        if not self.has_intermediate_values and len(trial.intermediate_values) > 0:
            self.has_intermediate_values = True

    def _update_search_space(self, trial: FrozenTrial) -> None:
        current = set([(n, d) for n, d in trial.distributions.items()])
        self._union = self._union.union(current)

        if self._intersection is None:
            self._intersection = copy.copy(current)
        else:
            self._intersection = self._intersection.intersection(current)
