from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np

from physrisk.api.v1.common import VulnerabilityCurve, VulnerabilityCurves
from physrisk.kernel.assets import Asset, RealEstateAsset
from physrisk.kernel.hazard_model import HazardDataRequest, HazardDataResponse, HazardParameterDataResponse
from physrisk.kernel.impact_distrib import ImpactDistrib, ImpactType
from physrisk.kernel.vulnerability_matrix_provider import VulnMatrixProvider
from physrisk.kernel.vulnerability_model import VulnerabilityModel

from ..kernel.hazards import ChronicHeat, CoastalInundation, PluvialInundation, RiverineInundation, Wind
from ..kernel.vulnerability_model import (
    DeterministicVulnerabilityModel,
    VulnerabilityModelBase,
    applies_to_events,
    checked_beta_distrib,
    get_vulnerability_curves_from_resource,
)


class RealEstateInundationModel(VulnerabilityModel):
    _default_impact_bin_edges = np.array([0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    _default_resource = "EU JRC global flood depth-damage functions"

    def __init__(
        self,
        *,
        hazard_type: type,
        indicator_id: str,
        resource: str = _default_resource,
        impact_bin_edges=_default_impact_bin_edges,
    ):
        """
        Inundation vulnerability model for real estates assets. Applies to both riverine and coastal inundation.

        Args:
            event_type: Event type.
            model: optional identifier for hazard event model, passed to HazardModel.
            resource: embedded resource identifier used to infer vulnerability matrix.
            impact_bin_edges: specifies the impact (fractional damage/disruption bins).
        """

        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)

        # for this model, key for looking up curves is (location, asset_type), e.g. ('Asian', 'Building/Industrial')
        self.vulnerability_curves = dict(((c.location, c.asset_type), c) for c in curve_set.items)
        self.vuln_curves_by_type = defaultdict(list)
        self.proxy_curves: Dict[Tuple[str, str], VulnerabilityCurve] = {}
        for item in curve_set.items:
            self.vuln_curves_by_type[item.asset_type].append(item)

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        impact_type = (
            ImpactType.damage
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )
        super().__init__(
            indicator_id=indicator_id,
            hazard_type=hazard_type,
            impact_type=impact_type,
            impact_bin_edges=impact_bin_edges,
        )

    def get_impact_curve(self, intensity_bin_centres: np.ndarray, asset: Asset):
        # we interpolate the mean and standard deviation and use this to construct distributions
        assert isinstance(asset, RealEstateAsset)

        key = (asset.location, asset.type)
        curve = self.vulnerability_curves[key]

        std_curve = curve
        if len(curve.impact_std) == 0:
            if key not in self.proxy_curves:
                self.proxy_curves[key] = self.closest_curve_of_type(curve, asset)
            std_curve = self.proxy_curves[key]

        impact_means = np.interp(intensity_bin_centres, curve.intensity, curve.impact_mean)
        impact_stddevs = np.interp(intensity_bin_centres, std_curve.intensity, std_curve.impact_std)

        return VulnMatrixProvider(
            intensity_bin_centres,
            impact_cdfs=[checked_beta_distrib(m, s) for m, s in zip(impact_means, impact_stddevs)],
        )

    def closest_curve_of_type(self, curve: VulnerabilityCurve, asset: RealEstateAsset):
        # we return the standard deviations of the damage curve most similar to the asset location
        candidate_set = list(cand for cand in self.vuln_curves_by_type[asset.type] if (len(cand.impact_std) > 0))
        sum_square_diff = (self.sum_square_diff(curve, cand) for cand in candidate_set)
        lowest = np.argmin(np.array(list(sum_square_diff)))
        return candidate_set[lowest]

    def sum_square_diff(self, curve1: VulnerabilityCurve, curve2: VulnerabilityCurve):
        return np.sum((curve1.impact_mean - np.interp(curve1.intensity, curve2.intensity, curve2.impact_mean)) ** 2)


@applies_to_events([CoastalInundation])
class RealEstateCoastalInundationModel(RealEstateInundationModel):
    def __init__(
        self,
        *,
        indicator_id: str = "flood_depth",
        resource: str = RealEstateInundationModel._default_resource,
        impact_bin_edges=RealEstateInundationModel._default_impact_bin_edges,
    ):
        # by default include subsidence and 95% sea-level rise
        super().__init__(
            hazard_type=CoastalInundation,
            indicator_id=indicator_id,
            resource=resource,
            impact_bin_edges=impact_bin_edges,
        )


class RealEstatePluvialInundationModel(RealEstateInundationModel):
    def __init__(
        self,
        *,
        indicator_id: str = "flood_depth",
        resource: str = RealEstateInundationModel._default_resource,
        impact_bin_edges=RealEstateInundationModel._default_impact_bin_edges,
    ):
        # by default include subsidence and 95% sea-level rise
        super().__init__(
            hazard_type=PluvialInundation,
            indicator_id=indicator_id,
            resource=resource,
            impact_bin_edges=impact_bin_edges,
        )


@applies_to_events([RiverineInundation])
class RealEstateRiverineInundationModel(RealEstateInundationModel):
    def __init__(
        self,
        *,
        indicator_id: str = "flood_depth",
        resource: str = RealEstateInundationModel._default_resource,
        impact_bin_edges=RealEstateInundationModel._default_impact_bin_edges,
    ):
        super().__init__(
            hazard_type=RiverineInundation,
            indicator_id=indicator_id,
            resource=resource,
            impact_bin_edges=impact_bin_edges,
        )


class GenericTropicalCycloneModel(DeterministicVulnerabilityModel):
    def __init__(self):
        """A very simple generic tropical cyclone vulnerability model."""
        v_half = 74.7  # m/s
        intensities = np.arange(0, 100, 10)
        impacts = self.wind_damage(intensities, v_half)
        super().__init__(
            hazard_type=Wind,
            damage_curve_intensities=intensities,
            damage_curve_impacts=impacts,
            indicator_id="max_speed",
            impact_type=ImpactType.damage,
        )

    def wind_damage(self, v: np.ndarray, v_half: float):
        """Calculates damage based on functional form of
        Emanuel K. Global warming effects on US hurricane damage. Weather, Climate, and Society. 2011 Oct 1;3(4):261-8.
        Using a threshold speed of 25.7 m/s.
        A review of the origin of parameters is available in
        Eberenz S, Lüthi S, Bresch DN. Regional tropical cyclone impact functions for
        globally consistent risk assessments.
        Natural Hazards and Earth System Sciences. 2021 Jan 29;21(1):393-415.
        which also provides suggested region-specific variations.
        Args:
            v (np.ndarray[float]): Wind speeds at which to calculate the fractional damage.
            v_half (float): The 'v_half' function parameter.

        Returns:
            np.ndarray[float]: Fractional damage.
        """
        v_thresh = 25.7  # m/s
        vn = np.where(v > v_thresh, v - v_thresh, 0) / (v_half - v_thresh)
        return vn**3 / (1 + vn**3)


class CoolingModel(VulnerabilityModelBase):
    _default_transfer_coeff = 200  # W/K
    _default_cooling_cop = 3  # W/K

    # 200 W/K is a nominal total-asset heat transfer coefficient. It is approximately the
    # heat loss of a fairly recently built residential property.
    # For 2000 degree days of heating required in a year, the corresponding heating requirement
    # would be 200 * 2000 * 24 / 1000 = 9600 kWh
    # https://www.thegreenage.co.uk/how-much-energy-does-my-home-use/ has a gentle introduction to
    # degree days for home cooling/heating.

    def __init__(self, threshold_temp_c: float = 23):
        """Simple degree-days-based model for calculating cooling requirements as annual kWh of
        electricity equivalent. The main limitation of the approach is that solar radiation and
        humidity are not taken into account. Limitations of similar approaches and ways to address
        are default with, for example in:

        Berardi U, Jafarpur P. Assessing the impact of climate change on building heating
        and cooling energy demand in Canada.
        Renewable and Sustainable Energy Reviews. 2020 Apr 1;121:109681.23.

        Cellura M, Guarino F, Longo S, Tumminia G. Climate change and the building sector:
        Modelling and energy implications to an office building in southern Europe.
        Energy for Sustainable Development. 2018 Aug 1;45:46-65.
        """
        self.indicator_id = "mean_degree_days/above/index"
        self.hazard_type = ChronicHeat
        self.threshold_temp_c = threshold_temp_c

    def get_data_requests(self, asset: Asset, *, scenario: str, year: int):
        return HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=year,
            indicator_id=self.indicator_id,
        )

    def get_impact(self, asset: Asset, data_responses: List[HazardDataResponse]) -> ImpactDistrib:
        (data,) = data_responses
        assert isinstance(data, HazardParameterDataResponse)
        # we interpolate the specific threshold from the different values
        deg_days = float(np.interp(self.threshold_temp_c, data.param_defns, data.parameters))  # [0]
        heat_transfer = deg_days * self._default_transfer_coeff * 24 / 1000  # kWh of heat removed from asset
        annual_electricity = heat_transfer / self._default_cooling_cop  # kWh of electricity required for heat removal
        # this is non-probabilistic model: probability of 1 of electricity use
        return ImpactDistrib(ChronicHeat, [annual_electricity, annual_electricity], [1])
