import importlib.resources
import json
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Protocol, Sequence, Tuple, Union

import numpy as np
from scipy import stats

import physrisk.data.static.vulnerability
from physrisk.kernel.impact_distrib import ImpactDistrib, ImpactType

from ..api.v1.common import VulnerabilityCurve, VulnerabilityCurves
from .assets import Asset
from .curve import ExceedanceCurve
from .hazard_event_distrib import HazardEventDistrib
from .hazard_model import HazardDataRequest, HazardDataResponse, HazardEventDataResponse
from .vulnerability_distrib import VulnerabilityDistrib
from .vulnerability_matrix_provider import VulnMatrixProvider

PLUGINS = dict()  # type:ignore


def repeat(num_times):
    def decorator_repeat(func): ...  # Create and return a wrapper function

    return decorator_repeat


def applies_to_events(event_types):
    def decorator_events(func):
        PLUGINS[func.__name__] = func
        return func

    return decorator_events


def applies_to_assets(asset_types):
    def decorator_events(func):
        PLUGINS[func.__name__] = func
        return func

    return decorator_events


def get_vulnerability_curves_from_resource(id: str) -> VulnerabilityCurves:
    with importlib.resources.open_text(physrisk.data.static.vulnerability, id + ".json") as f:
        curve_set = VulnerabilityCurves(**json.load(f))
        return curve_set


def delta_cdf(y):
    return lambda x: np.where(x < y, 0, 1)


def checked_beta_distrib(mean, std, scaling_factor=1.0):
    if std == 0 or mean == 0 or mean == scaling_factor:
        return delta_cdf(mean)
    return beta_distrib(mean, std, scaling_factor)


def beta_distrib(mean, std, scaling_factor):
    cv = std / mean
    a = ((scaling_factor - mean) / (cv * cv) - mean) / scaling_factor
    b = a * (scaling_factor - mean) / mean
    return lambda x, a=a, b=b: stats.beta.cdf(x / scaling_factor, a, b)


class DataRequester(Protocol):
    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]: ...


class EventBased(Protocol):
    def impact_samples(self, asset: Asset, data_responses: Iterable[HazardDataResponse]) -> np.ndarray:
        # event-based models generate impact samples based on events received by the hazard model
        # the events may be in the form of an array of severities in the form of return periods.
        ...


class VulnerabilityModelBase(ABC, DataRequester):
    def __init__(self, indicator_id: str, hazard_type: type, impact_type: ImpactType):
        self.indicator_id = indicator_id
        self.hazard_type = hazard_type
        self.impact_type = impact_type
        self._event_types: List[type] = []
        self._asset_types: List[type] = []

    @abstractmethod
    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        """Provide the one or more hazard event data requests required in order to calculate
        the VulnerabilityDistrib and HazardEventDistrib for the asset."""
        ...

    @abstractmethod
    def get_impact(self, asset: Asset, event_data: List[HazardDataResponse]) -> ImpactDistrib: ...


class VulnerabilityModelAcuteBase(VulnerabilityModelBase):
    """Models generate the VulnerabilityDistrib and HazardEventDistrib of an
    Asset.
    """

    def __init__(self, indicator_id: str, hazard_type: type, impact_type: ImpactType):
        super().__init__(indicator_id=indicator_id, hazard_type=hazard_type, impact_type=impact_type)

    @abstractmethod
    def get_distributions(
        self, asset: Asset, event_data_responses: Iterable[HazardDataResponse]
    ) -> Tuple[VulnerabilityDistrib, HazardEventDistrib]:
        """Return distributions for asset: VulnerabilityDistrib and HazardEventDistrib.
        The hazard event data is used to do this.

        Args:
            asset: the asset.
            event_data_responses: the responses to the requests made by get_data_requests, in the same order.
        """
        ...

    def get_impact(self, asset: Asset, data_responses: Iterable[HazardDataResponse]):
        impact, _, _ = self.get_impact_details(asset, data_responses)
        return impact

    def get_impact_details(
        self, asset: Asset, data_responses: Iterable[HazardDataResponse]
    ) -> Tuple[ImpactDistrib, VulnerabilityDistrib, HazardEventDistrib]:
        """Return impact distribution along with vulnerability and hazard event distributions used to infer this.

        Args:
            asset: the asset.
            event_data_responses: the responses to the requests made by get_data_requests, in the same order.
        """
        vulnerability_dist, event_dist = self.get_distributions(asset, data_responses)
        impact_prob = vulnerability_dist.prob_matrix.T @ event_dist.prob
        return (
            ImpactDistrib(
                vulnerability_dist.event_type, vulnerability_dist.impact_bins, impact_prob, impact_type=self.impact_type
            ),
            vulnerability_dist,
            event_dist,
        )

    def _check_event_type(self):
        if self.hazard_type not in self._event_types:
            raise NotImplementedError(f"model does not support events of type {self.hazard_type.__name__}")


class VulnerabilityModel(VulnerabilityModelAcuteBase):
    """A vulnerability model that requires only specification of distributions of impacts for given intensities,
    by implementing get_impact_curve."""

    def __init__(
        self,
        *,
        indicator_id: str = "",
        hazard_type: type,
        impact_type: ImpactType,
        impact_bin_edges,
        buffer: Optional[int] = None,
    ):
        super().__init__(indicator_id, hazard_type, impact_type)
        self.impact_bin_edges = impact_bin_edges
        self.buffer = buffer

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        return HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=year,
            indicator_id=self.indicator_id,
            buffer=self.buffer,
        )

    def get_distributions(
        self, asset: Asset, event_data_responses: Iterable[HazardDataResponse]
    ) -> Tuple[VulnerabilityDistrib, HazardEventDistrib]:
        (event_data,) = event_data_responses
        assert isinstance(event_data, HazardEventDataResponse)

        intensity_curve = ExceedanceCurve(1.0 / event_data.return_periods, event_data.intensities)
        intensity_bin_edges, probs = intensity_curve.get_probability_bins()

        intensity_bin_centres = (intensity_bin_edges[1:] + intensity_bin_edges[:-1]) / 2
        vul = VulnerabilityDistrib(
            self.hazard_type,
            intensity_bin_edges,
            self.impact_bin_edges,
            # np.eye(8, 11)
            self.get_impact_curve(intensity_bin_centres, asset).to_prob_matrix(self.impact_bin_edges),
        )

        event = HazardEventDistrib(self.hazard_type, intensity_bin_edges, probs)
        return vul, event

    @abstractmethod
    def get_impact_curve(self, intensity_bin_centres: np.ndarray, asset: Asset) -> VulnMatrixProvider:
        """Defines a VulnMatrixProvider. The VulnMatrixProvider returns probabilities of specified impact bins
        for the intensity bin centres."""
        ...


class CurveBasedVulnerabilityModel(VulnerabilityModel):
    def get_impact_curve(self, intensity_bin_centres: np.ndarray, asset: Asset) -> VulnMatrixProvider:
        curve: VulnerabilityCurve = self.get_vulnerability_curve(asset)
        impact_means = np.interp(intensity_bin_centres, curve.intensity, curve.impact_mean)
        impact_stddevs = np.interp(intensity_bin_centres, curve.intensity, curve.impact_std)
        return VulnMatrixProvider(
            intensity_bin_centres,
            impact_cdfs=[checked_beta_distrib(m, s) for m, s in zip(impact_means, impact_stddevs)],
        )

    @abstractmethod
    def get_vulnerability_curve(self, asset: Asset) -> VulnerabilityCurve: ...


class DeterministicVulnerabilityModel(VulnerabilityModelAcuteBase):
    def __init__(
        self,
        *,
        hazard_type: type,
        damage_curve_intensities: Sequence[float],
        damage_curve_impacts: Sequence[float],
        indicator_id: str,
        impact_type: ImpactType,
        buffer: Optional[int] = None,
    ):
        """A vulnerability model that requires only specification of a damage/disruption curve.
        This simple model contains no uncertainty around damage/disruption. The damage curve is passed via the
        constructor. The edges of the (hazard) intensity bins are determined by the granularity of
        the hazard data itself. The impact bin edges are inferred from the intensity bin edges, by
        looking up the impact corresponding to the hazard indicator intensity from the damage curve.

            Args:
                event_type (type): _description_
                damage_curve_intensities (Sequence[float]): Intensities
                (i.e. hazard indicator values) of the damage/disruption (aka impact) curve.
                damage_curve_impacts (Sequence[float]): Fractional damage to asset/disruption
                to operation resulting from a hazard of the corresponding intensity.
                indicator_id (str): ID of the hazard indicator to which this applies. Defaults to "".
                buffer (Optional[int]): Delimitation of the area for the hazard data in metres (within [0,1000]).
        """
        super().__init__(indicator_id=indicator_id, hazard_type=hazard_type, impact_type=impact_type)
        self.damage_curve_intensities = damage_curve_intensities
        self.damage_curve_impacts = damage_curve_impacts
        self.buffer = buffer

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        return HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=year,
            indicator_id=self.indicator_id,
            buffer=self.buffer,
        )

    def get_distributions(
        self, asset: Asset, event_data_responses: Iterable[HazardDataResponse]
    ) -> Tuple[VulnerabilityDistrib, HazardEventDistrib]:
        (event_data,) = event_data_responses
        assert isinstance(event_data, HazardEventDataResponse)

        intensity_curve = ExceedanceCurve(1.0 / event_data.return_periods, event_data.intensities)
        intensity_bin_edges, probs = intensity_curve.get_probability_bins()

        # look up the impact bin edges
        impact_bins_edges = np.interp(intensity_bin_edges, self.damage_curve_intensities, self.damage_curve_impacts)

        # the vulnerability distribution probabilities are an identity matrix:
        # we assume that if the intensity falls within a certain bin then the impacts *will* fall within the
        # bin where the edges are obtained by applying the damage curve to the intensity bin edges.
        vul = VulnerabilityDistrib(
            type(self.hazard_type), intensity_bin_edges, impact_bins_edges, np.eye(len(impact_bins_edges) - 1)
        )
        event = HazardEventDistrib(type(self.hazard_type), intensity_bin_edges, probs)
        return vul, event
