from collections import defaultdict
from typing import Iterable, List, Tuple, Union, cast

import numpy as np
from scipy.stats import multivariate_normal, norm

from physrisk.api.v1.common import VulnerabilityCurve, VulnerabilityCurves
from physrisk.kernel.assets import Asset, ThermalPowerGeneratingAsset, TurbineKind
from physrisk.kernel.impact_distrib import ImpactDistrib, ImpactType
from physrisk.kernel.vulnerability_model import DeterministicVulnerabilityModel, VulnerabilityModelBase

from ..kernel.curve import ExceedanceCurve
from ..kernel.hazard_event_distrib import HazardEventDistrib
from ..kernel.hazard_model import (
    HazardDataRequest,
    HazardDataResponse,
    HazardEventDataResponse,
    HazardParameterDataResponse,
)
from ..kernel.hazards import (
    AirTemperature,
    ChronicHeat,
    CoastalInundation,
    Drought,
    RiverineInundation,
    WaterRisk,
    WaterTemperature,
)
from ..kernel.vulnerability_distrib import VulnerabilityDistrib
from ..kernel.vulnerability_model import applies_to_assets, applies_to_events, get_vulnerability_curves_from_resource


class ThermalPowerGenerationInundationModel(DeterministicVulnerabilityModel):
    # Number of disrupted days per year
    _default_resource = "WRI thermal power plant physical climate vulnerability factors"

    # delimitation of the area for the hazard data expressed in metres (within [0,1000]).
    _default_buffer = 1000

    def __init__(
        self, *, hazard_type: type, indicator_id: str, resource: str = _default_resource, buffer: int = _default_buffer
    ):
        """
        Inundation vulnerability model for thermal power generation.
        Applies to both riverine and coastal inundation.

        Args:
                hazard_type (type): _description_
                indicator_id (str): ID of the hazard indicator to which this applies.
                resource (str): embedded resource identifier used to infer vulnerability table.
                buffer (int): delimitation of the area for the hazard data expressed in metres (within [0,1000]).
        """

        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)

        # for this model, key for looking up curves is asset_type, e.g. 'Steam/Recirculating'
        self.vulnerability_curves = dict(
            (c.asset_type, c) for c in curve_set.items if c.event_type == hazard_type.__base__.__name__  # type:ignore
        )
        self.vuln_curves_by_type = defaultdict(list)
        for key in self.vulnerability_curves:
            self.vuln_curves_by_type[TurbineKind[key.split("/")[0]]].append(self.vulnerability_curves[key])

        impact_type = (
            ImpactType.disruption
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        super().__init__(
            indicator_id=indicator_id,
            hazard_type=hazard_type,
            impact_type=impact_type,
            damage_curve_intensities=[],
            damage_curve_impacts=[],
            buffer=buffer,
        )

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        """Provide the list of hazard event data requests required in order to calculate
        the VulnerabilityDistrib and HazardEventDistrib for the asset."""
        request_scenario = HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=year,
            indicator_id=self.indicator_id,
            buffer=self.buffer,
        )
        request_baseline = HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=2030,
            indicator_id=self.indicator_id,
            buffer=self.buffer,
        )
        return request_scenario, request_baseline

    def get_distributions(
        self, asset: Asset, event_data_responses: Iterable[HazardDataResponse]
    ) -> Tuple[VulnerabilityDistrib, HazardEventDistrib]:
        assert isinstance(asset, ThermalPowerGeneratingAsset)

        (response_scenario, response_baseline) = event_data_responses
        assert isinstance(response_scenario, HazardEventDataResponse)
        assert isinstance(response_baseline, HazardEventDataResponse)

        baseline_curve = ExceedanceCurve(1.0 / response_baseline.return_periods, response_baseline.intensities)
        protection_depth = (
            0.0
            if len(response_baseline.intensities) == 0
            else baseline_curve.get_value(1.0 / asset.get_inundation_protection_return_period())
        )

        intensity_curve = ExceedanceCurve(1.0 / response_scenario.return_periods, response_scenario.intensities)
        if 0 < len(intensity_curve.values):
            if intensity_curve.values[0] < protection_depth:
                if protection_depth < intensity_curve.values[-1]:
                    intensity_curve = intensity_curve.add_value_point(protection_depth)

        intensities, probs = intensity_curve.get_probability_bins()
        if 0 < len(intensity_curve.values):
            probs = np.insert(probs, 0, intensity_curve.probs[0])

        curves: List[VulnerabilityCurve] = []
        if asset.turbine is None:
            curves = [self.vulnerability_curves[key] for key in self.vulnerability_curves]
        elif asset.cooling is not None:
            key = "/".join([asset.turbine.name, asset.cooling.name])
            if key in self.vulnerability_curves:
                curves = [self.vulnerability_curves[key]]
        elif asset.turbine in self.vuln_curves_by_type:
            curves = self.vuln_curves_by_type[asset.turbine]

        if 0 < len(curves):
            impacts = [
                (
                    np.max([np.interp(intensity, curve.intensity, curve.impact_mean) for curve in curves]) / 365.0
                    if protection_depth < intensity
                    else 0.0
                )
                for intensity in intensities
            ]
        else:
            impacts = [0.0 for _ in intensities]

        vul = VulnerabilityDistrib(self.hazard_type, intensities, impacts, np.eye(len(probs), len(probs)))
        event = HazardEventDistrib(self.hazard_type, intensities, probs)
        return vul, event


@applies_to_events([CoastalInundation])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationCoastalInundationModel(ThermalPowerGenerationInundationModel):
    def __init__(
        self,
        *,
        indicator_id: str = "flood_depth",
        resource: str = ThermalPowerGenerationInundationModel._default_resource,
    ):
        # by default include subsidence and 95% sea-level rise
        super().__init__(hazard_type=CoastalInundation, indicator_id=indicator_id, resource=resource)


@applies_to_events([RiverineInundation])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationRiverineInundationModel(ThermalPowerGenerationInundationModel):
    def __init__(
        self,
        *,
        indicator_id: str = "flood_depth",
        resource: str = ThermalPowerGenerationInundationModel._default_resource,
    ):
        # by default request HazardModel to use "MIROC-ESM-CHEM" GCM
        super().__init__(hazard_type=RiverineInundation, indicator_id=indicator_id, resource=resource)


@applies_to_events([Drought])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationDroughtModel(VulnerabilityModelBase):
    # Number of disrupted days per year
    _default_resource = "WRI thermal power plant physical climate vulnerability factors"
    _impact_based_on_a_single_point = False

    def __init__(
        self,
        *,
        resource: str = _default_resource,
        impact_based_on_a_single_point: bool = _impact_based_on_a_single_point,
    ):
        """
        Drought vulnerability model for thermal power generation.

        Args:
                resource (str): embedded resource identifier used to infer vulnerability table.
                impact_based_on_a_single_point (str): calculation based on a single point instead of a curve.
        """

        hazard_type = Drought
        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)

        # for this model, key for looking up curves is asset_type, e.g. 'Steam/Recirculating'
        self.vulnerability_curves = dict(
            (c.asset_type, c) for c in curve_set.items if c.event_type == hazard_type.__name__
        )

        self.vuln_curves_by_type = defaultdict(list)
        for key in self.vulnerability_curves:
            self.vuln_curves_by_type[TurbineKind[key.split("/")[0]]].append(self.vulnerability_curves[key])

        impact_type = (
            ImpactType.disruption
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        super().__init__(
            indicator_id="months/spei3m/below/-2" if impact_based_on_a_single_point else "months/spei12m/below/index",
            hazard_type=hazard_type,
            impact_type=impact_type,
        )

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        return HazardDataRequest(
            self.hazard_type,
            asset.longitude,
            asset.latitude,
            scenario=scenario,
            year=year,
            indicator_id=self.indicator_id,
        )

    def get_impact(self, asset: Asset, data_responses: List[HazardDataResponse]) -> ImpactDistrib:
        assert isinstance(asset, ThermalPowerGeneratingAsset)

        # The unit being number of months per year, we divide by 12 to express the result as a year fraction.
        intensities = np.array(cast(HazardParameterDataResponse, data_responses[0]).parameters / 12.0)
        if len(intensities) == 1:
            thresholds = np.array([-2.0])  # hard-coded
        else:
            thresholds = np.array(cast(HazardParameterDataResponse, data_responses[0]).param_defns)
            intensities[:-1] -= intensities[1:]

        curves: List[VulnerabilityCurve] = []
        if asset.turbine is None:
            curves = [self.vulnerability_curves[key] for key in self.vulnerability_curves]
        elif asset.cooling is not None:
            key = "/".join([asset.turbine.name, asset.cooling.name])
            if key in self.vulnerability_curves:
                curves = [self.vulnerability_curves[key]]
        elif asset.turbine in self.vuln_curves_by_type:
            curves = self.vuln_curves_by_type[asset.turbine]

        if 0 < len(curves):
            if len(intensities) == 1:
                impact = 0.0
                denominator = norm.cdf(thresholds[0])
                for curve in curves:
                    probabilities = np.array([norm.cdf(intensity) / denominator for intensity in curve.intensity])
                    probabilities[:-1] -= probabilities[1:]
                    impact = max(
                        impact,
                        sum([probability * impact for probability, impact in zip(probabilities, curve.impact_mean)]),
                    )
                impacts = [impact]
            else:
                impacts = [
                    np.max([np.interp(threshold, curve.intensity[::-1], curve.impact_mean[::-1]) for curve in curves])
                    for threshold in thresholds
                ]
        else:
            impacts = [0.0 for _ in thresholds]

        # The point injected at the beginning of impacts/intensities
        # allows to successfully call to_exceedance() in the get_impact API:
        impact_distrib = ImpactDistrib(
            self.hazard_type,
            [0.0] + impacts,
            np.concatenate((np.array([1.0 - sum(intensities)]), intensities)),
            self.impact_type,
        )
        return impact_distrib


@applies_to_events([AirTemperature])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationAirTemperatureModel(VulnerabilityModelBase):
    # Number of disrupted days per year
    _default_resource = "WRI thermal power plant physical climate vulnerability factors"
    _default_temperatures = [25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0]

    def __init__(self, *, resource: str = _default_resource, temperatures: List[float] = _default_temperatures):
        """
        Air temperature vulnerability model for thermal power generation.

        Args:
                resource (str): embedded resource identifier used to infer vulnerability table.
                temperatures (list[Float]): thresholds of the "days with average temperature above".
        """
        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)

        # for this model, key for looking up curves is asset_type, e.g. 'Steam/Recirculating'
        self.vulnerability_curves = dict((c.asset_type, c) for c in curve_set.items if c.event_type == "AirTemperature")
        self.vuln_curves_by_type = defaultdict(list)
        for key in self.vulnerability_curves:
            self.vuln_curves_by_type[TurbineKind[key.split("/")[0]]].append(self.vulnerability_curves[key])

        impact_type = (
            ImpactType.disruption
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )

        self.temperatures = temperatures

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        super().__init__(indicator_id="days_tas/above/{temp_c}c", hazard_type=AirTemperature, impact_type=impact_type)

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        data_request = []
        for temperature in self.temperatures:
            data_request.append(
                HazardDataRequest(
                    ChronicHeat,
                    asset.longitude,
                    asset.latitude,
                    scenario=scenario,
                    year=year,
                    indicator_id=self.indicator_id.format(temp_c=str(int(temperature))),
                )
            )
        for temperature in self.temperatures:
            data_request.append(
                HazardDataRequest(
                    ChronicHeat,
                    asset.longitude,
                    asset.latitude,
                    scenario="historical",
                    year=2005,
                    indicator_id=self.indicator_id.format(temp_c=str(int(temperature))),
                )
            )
        return data_request

    def get_impact(self, asset: Asset, data_responses: List[HazardDataResponse]) -> ImpactDistrib:
        assert isinstance(asset, ThermalPowerGeneratingAsset)

        assert 2 * len(self.temperatures) == len(data_responses)

        # The unit being number of days per year, we divide by 365 to express the result as a year fraction.
        baseline = [
            1.0 - cast(HazardParameterDataResponse, data_response).parameter / 365.0
            for data_response in data_responses[len(self.temperatures) :]
        ]

        # Threshold when it no longer makes technical or economical sense to keep power plant running.
        shutdown_air_temperature = 50

        # Temperature at which the power plant generates electricity with the designed maximum efficiency.
        design_air_temperature = np.interp(0.9, baseline, self.temperatures)

        intensities = np.array(
            [
                cast(HazardParameterDataResponse, data_response).parameter / 365.0
                for data_response in data_responses[: len(self.temperatures)]
            ]
        )
        intensities[:-1] -= intensities[1:]

        curves: List[VulnerabilityCurve] = []
        if asset.turbine is None:
            curves = [self.vulnerability_curves[key] for key in self.vulnerability_curves]
        elif asset.cooling is not None:
            key = "/".join([asset.turbine.name, asset.cooling.name])
            if key in self.vulnerability_curves:
                curves = [self.vulnerability_curves[key]]
        elif asset.turbine in self.vuln_curves_by_type:
            curves = self.vuln_curves_by_type[asset.turbine]

        if 0 < len(curves):
            impacts = [
                (
                    1.0
                    if shutdown_air_temperature < temperature
                    else (
                        0.0
                        if temperature < design_air_temperature
                        else np.max(
                            [
                                np.interp(temperature - design_air_temperature, curve.intensity, curve.impact_mean)
                                for curve in curves
                            ]
                        )
                    )
                )
                for temperature in self.temperatures
            ]
        else:
            impacts = [0.0 for _ in self.temperatures]

        # The point injected at the beginning of impacts/intensities
        # allows to successfully call to_exceedance() in the get_impact API:
        impact_distrib = ImpactDistrib(
            self.hazard_type,
            [0.0] + impacts,
            np.concatenate((np.array([1.0 - sum(intensities)]), intensities)),
            self.impact_type,
        )
        return impact_distrib


@applies_to_events([WaterTemperature])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationWaterTemperatureModel(VulnerabilityModelBase):
    # Number of disrupted days per year
    _default_resource = "WRI thermal power plant physical climate vulnerability factors"
    _default_correlation = 0.5

    def __init__(self, *, resource: str = _default_resource, correlation: float = _default_correlation):
        """
        Water temperature vulnerability model for thermal power generation.

        Args:
                resource (str): embedded resource identifier used to infer vulnerability table.
                correlation (float): correlation specifying the Gaussian copula which joins
                                     the marginal distributions of water temperature and WBGT.
        """
        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)
        self.gaussian_copula = multivariate_normal(
            mean=np.array([0.0, 0.0]), cov=np.array([[1.0, correlation], [correlation, 1.0]])
        )

        # for this model, key for looking up curves is asset_type, e.g. 'Steam/Recirculating'
        self.vulnerability_curves = dict(
            (c.asset_type, c) for c in curve_set.items if c.event_type == "WaterTemperature"
        )
        self.vuln_curves_by_type = defaultdict(list)
        for key in self.vulnerability_curves:
            self.vuln_curves_by_type[TurbineKind[key.split("/")[0]]].append(self.vulnerability_curves[key])

        self.regulatory_discharge_curves = dict(
            (c.asset_type, c) for c in curve_set.items if c.event_type == "RegulatoryDischargeWaterLimit"
        )

        impact_type = (
            ImpactType.disruption
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        super().__init__(indicator_id="weeks_water_temp_above", hazard_type=WaterTemperature, impact_type=impact_type)

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        data_request = []
        data_request.append(
            HazardDataRequest(
                ChronicHeat,
                asset.longitude,
                asset.latitude,
                scenario=scenario,
                year=year,
                indicator_id=self.indicator_id,
            ),
        )
        data_request.append(
            HazardDataRequest(
                ChronicHeat,
                asset.longitude,
                asset.latitude,
                scenario="historical",
                year=1991,
                indicator_id=self.indicator_id,
            ),
        )
        data_request.append(
            HazardDataRequest(
                ChronicHeat,
                asset.longitude,
                asset.latitude,
                scenario=scenario,
                year=year,
                indicator_id="days_wbgt_above",
            ),
        )
        data_request.append(
            HazardDataRequest(
                ChronicHeat,
                asset.longitude,
                asset.latitude,
                scenario="historical",
                year=2005,
                indicator_id="days_wbgt_above",
            ),
        )
        return data_request

    def get_impact(self, asset: Asset, data_responses: List[HazardDataResponse]) -> ImpactDistrib:
        assert isinstance(asset, ThermalPowerGeneratingAsset)
        assert len(data_responses) == 4

        # Water temperature below which the power plant does not experience generation losses.
        design_intake_water_temperature = cast(
            float,
            np.interp(
                0.9,
                # The unit being number of weeks per year, we divide by 52 to express the result as a year fraction.
                1.0 - cast(HazardParameterDataResponse, data_responses[1]).parameters / 52.0,
                cast(HazardParameterDataResponse, data_responses[1]).param_defns,
            ),
        )

        # Linear relationship between the outlet (discharge)
        # water temperature and the intake water temperature.
        design_intake_water_temperature_for_recirculating_steam_unit = min(
            design_intake_water_temperature, (35.0 - 9.7951) / 1.0191
        )

        # WBGT below which the recirculating steam unit does not
        # experience any water-temperature-related generation losses.
        design_wbgt_threshold = cast(
            float,
            np.interp(
                0.99,
                # The unit being number of days per year, we divide by 365 to express the result as a year fraction.
                1.0 - cast(HazardParameterDataResponse, data_responses[3]).parameters / 365.0,
                cast(HazardParameterDataResponse, data_responses[3]).param_defns,
            ),
        )

        impact_scale_for_recirculating_steam_unit = cast(
            float,
            np.interp(
                design_wbgt_threshold,
                cast(HazardParameterDataResponse, data_responses[2]).param_defns,
                cast(HazardParameterDataResponse, data_responses[2]).parameters,
            )
            / 365.0,
        )

        intake_water_temperatures = cast(HazardParameterDataResponse, data_responses[0]).param_defns
        intake_water_temperature_intensities = cast(HazardParameterDataResponse, data_responses[0]).parameters / 52.0

        if impact_scale_for_recirculating_steam_unit == 0.0:
            intake_water_temperature_intensities_for_recirculating_steam_unit = intake_water_temperature_intensities
        else:
            gaussian_threshold: float = norm.ppf(impact_scale_for_recirculating_steam_unit)
            intake_water_temperature_intensities_for_recirculating_steam_unit = np.array(
                [
                    (
                        intake_water_temperature_intensity
                        if intake_water_temperature_intensity == 0.0 or intake_water_temperature_intensity == 1.0
                        else self.gaussian_copula.cdf(
                            np.array([norm.ppf(intake_water_temperature_intensity), gaussian_threshold])
                        )
                        / impact_scale_for_recirculating_steam_unit
                    )
                    for intake_water_temperature_intensity in intake_water_temperature_intensities
                ]
            )

        intake_water_temperature_intensities[:-1] -= intake_water_temperature_intensities[1:]
        intake_water_temperature_intensities = np.concatenate(
            (np.array([1.0 - sum(intake_water_temperature_intensities)]), intake_water_temperature_intensities)
        )

        intake_water_temperature_intensities_for_recirculating_steam_unit[
            :-1
        ] -= intake_water_temperature_intensities_for_recirculating_steam_unit[1:]
        intake_water_temperature_intensities_for_recirculating_steam_unit = np.concatenate(
            (
                np.array([1.0 - sum(intake_water_temperature_intensities_for_recirculating_steam_unit)]),
                intake_water_temperature_intensities_for_recirculating_steam_unit,
            )
        )

        curves: List[VulnerabilityCurve] = []
        if asset.turbine is None:
            curves = [self.vulnerability_curves[key] for key in self.vulnerability_curves]
        elif asset.cooling is not None:
            key = "/".join([asset.turbine.name, asset.cooling.name])
            if key in self.vulnerability_curves:
                curves = [self.vulnerability_curves[key]]
        elif asset.turbine in self.vuln_curves_by_type:
            curves = self.vuln_curves_by_type[asset.turbine]

        impact_distrib_by_curve: List[ImpactDistrib] = []
        for curve in curves:
            scale = 1.0
            threshold = design_intake_water_temperature
            intensities = intake_water_temperature_intensities
            if curve.asset_type == "Steam/Recirculating":
                scale = impact_scale_for_recirculating_steam_unit
                threshold = design_intake_water_temperature_for_recirculating_steam_unit
                intensities = intake_water_temperature_intensities_for_recirculating_steam_unit
            impacts = [
                (
                    0.0
                    if intake_water_temperature < threshold
                    else scale
                    * cast(
                        float,
                        np.interp(
                            intake_water_temperature - threshold,
                            curve.intensity,
                            curve.impact_mean,
                        ),
                    )
                )
                for intake_water_temperature in intake_water_temperatures
            ]
            if curve.asset_type in self.regulatory_discharge_curves:
                regulatory_discharge_curve = self.regulatory_discharge_curves[curve.asset_type]
                impacts = [
                    max(
                        impact,
                        cast(
                            float,
                            np.interp(
                                intake_water_temperature,
                                regulatory_discharge_curve.intensity,
                                regulatory_discharge_curve.impact_mean,
                            ),
                        ),
                    )
                    for impact, intake_water_temperature in zip(impacts, intake_water_temperatures)
                ]

            # The point injected at the beginning of impacts/intensities
            # allows to successfully call to_exceedance() in the get_impact API:
            impact_distrib_by_curve.append(
                ImpactDistrib(
                    self.hazard_type,
                    [0.0] + impacts,
                    intensities,
                    self.impact_type,
                )
            )

        if 0 < len(impact_distrib_by_curve):
            impact_distrib = sorted(impact_distrib_by_curve, key=lambda x: x.mean_impact())[-1]
        else:
            impact_distrib = ImpactDistrib(
                self.hazard_type,
                [0.0] + [0.0 for _ in intake_water_temperatures],
                intake_water_temperature_intensities,
                self.impact_type,
            )

        return impact_distrib


@applies_to_events([WaterRisk])
@applies_to_assets([ThermalPowerGeneratingAsset])
class ThermalPowerGenerationWaterStressModel(VulnerabilityModelBase):
    # Number of disrupted days per year
    _default_resource = "WRI thermal power plant physical climate vulnerability factors"

    def __init__(self, *, resource: str = _default_resource):
        """
        Water stress vulnerability model for thermal power generation.

        Args:
                resource (str): embedded resource identifier used to infer vulnerability table.
        """
        curve_set: VulnerabilityCurves = get_vulnerability_curves_from_resource(resource)

        # for this model, key for looking up curves is asset_type, e.g. 'Steam/Recirculating'
        self.vulnerability_curves = dict((c.asset_type, c) for c in curve_set.items if c.event_type == "WaterStress")
        self.vuln_curves_by_type = defaultdict(list)
        for key in self.vulnerability_curves:
            self.vuln_curves_by_type[TurbineKind[key.split("/")[0]]].append(self.vulnerability_curves[key])

        impact_type = (
            ImpactType.disruption
            if len(self.vulnerability_curves) == 0
            else [ImpactType[self.vulnerability_curves[key].impact_type.lower()] for key in self.vulnerability_curves][
                0
            ]
        )

        # global circulation parameter 'model' is a hint; can be overriden by hazard model
        super().__init__(indicator_id="water_stress", hazard_type=WaterRisk, impact_type=impact_type)

    def get_data_requests(
        self, asset: Asset, *, scenario: str, year: int
    ) -> Union[HazardDataRequest, Iterable[HazardDataRequest]]:
        data_request = []
        data_request.append(
            HazardDataRequest(
                WaterRisk,
                asset.longitude,
                asset.latitude,
                scenario=scenario,
                year=year,
                indicator_id="water_stress",
            ),
        )
        data_request.append(
            HazardDataRequest(
                WaterRisk,
                asset.longitude,
                asset.latitude,
                scenario=scenario,
                year=year,
                indicator_id="water_supply",
            ),
        )
        data_request.append(
            HazardDataRequest(
                WaterRisk,
                asset.longitude,
                asset.latitude,
                scenario="historical",
                year=1999,
                indicator_id="water_supply",
            ),
        )
        return data_request

    def get_impact(self, asset: Asset, data_responses: List[HazardDataResponse]) -> ImpactDistrib:
        assert isinstance(asset, ThermalPowerGeneratingAsset)
        assert len(data_responses) == 3

        # We (naively) assume that water stress follows a shifted uniform distribution: water_stress - 0.5 + U(0,1):
        probability_water_stress_above_40pct = max(
            0.0, min(1.0, 0.1 + cast(HazardParameterDataResponse, data_responses[0]).parameter)
        )

        baseline_water_supply = cast(HazardParameterDataResponse, data_responses[2]).parameter
        supply_reduction_rate = (
            0.0
            if baseline_water_supply == 0.0
            else (cast(HazardParameterDataResponse, data_responses[1]).parameter / baseline_water_supply - 1.0)
        )

        curves: List[VulnerabilityCurve] = []
        if asset.turbine is None:
            curves = [self.vulnerability_curves[key] for key in self.vulnerability_curves]
        elif asset.cooling is not None:
            key = "/".join([asset.turbine.name, asset.cooling.name])
            if key in self.vulnerability_curves:
                curves = [self.vulnerability_curves[key]]
        elif asset.turbine in self.vuln_curves_by_type:
            curves = self.vuln_curves_by_type[asset.turbine]

        impact = (
            np.max([np.interp(-supply_reduction_rate, curve.intensity, curve.impact_mean) for curve in curves])
            if 0 < len(curves)
            else 0.0
        )

        # The point injected at the beginning of impacts/intensities
        # allows to successfully call to_exceedance() in the get_impact API:
        impact_distrib = ImpactDistrib(
            self.hazard_type,
            np.array([0.0, impact]),
            np.array([1.0 - probability_water_stress_above_40pct, probability_water_stress_above_40pct]),
            self.impact_type,
        )
        return impact_distrib
