"""
Class to generate the car noise signal.
"""
from __future__ import annotations

import numpy as np
from scipy.signal import butter, lfilter


def _butter_filter(
    order: int, cutoff_hz: list | np.ndarray, btype: str, sample_rate: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Function that creates a butter filter
    Args:
        order (int): Order of the filter
        cutoff_hz (list or np.ndarray): Cutoff frequency of the filter
        btype (str): Type of the filter
        sample_rate (int): Sampling frequency
    Returns:
        numer (np.ndarray): Numerator of the filter
        denom (np.ndarray): Denominator of the filter
    """
    if isinstance(cutoff_hz, list):
        cutoff_hz = np.array(cutoff_hz)
    numer, denom = butter(
        order, cutoff_hz / (sample_rate / 2.0), btype=btype, output="ba"
    )
    return numer, denom


class CarNoiseSignalGenerator:
    """
    A class to generate car noise signal.

    The constructor takes the sample_rate and duration of
    the generated signals.

    The method `generate_car_noise` takes parameters for the
    noise and generates the signal. These parameters
    can be generated by the CarNoiseParameters class.


    Example:
        >>> car_noise_parameters = CarNoiseParameters(random_flag=True)
        >>> parameters = car_noise_parameters.gen_parameters(speed_kph=100)

        >>> car_noise = CarNoiseGenerator(sample_rate=44100, duration_secs=5,
                random_flag=True)
        >>> car_noise_signal = car_noise.generate_car_noise(parameters, 3, 0.5)
    """

    REFERENCE_CONSTANT_DB = 30
    FINAl_MULTIPLIER = 1 / 2000

    def __init__(
        self,
        sample_rate: int,
        duration_secs: float,
        random_flag: bool = True,
    ):
        """Constructor takes the sample_rate and duration of the generated signals.

        Args:
            sample_rate (int): Sample rate of the generated signal
            duration_secs (float): Duration of the generated signal
            random_flag (bool, optional): Flag to indicate whether add some randomness
            to the output signals. Defaults to True.
        """
        self.sample_rate = sample_rate
        self.duration_secs = duration_secs
        self.random_flag = random_flag

        self.sampleduration_secs = 1 / self.sample_rate
        self.duration_samples = round(self.duration_secs) * self.sample_rate
        self.timevector_ms = np.arange(self.duration_samples) * self.sampleduration_secs

    def generate_car_noise(
        self,
        noise_parameters: dict,
        number_noise_sources: int,
        commonness_factor: float,
    ) -> np.ndarray:
        """
        Method that generates the car noise signal.
        It organizes the parameters and calls the methods to generate the independent
        parts.

        Args:
            noise_parameters (dict): Dictionary with the parameters for the noise
                as generated by the CarNoiseParameters class.
            number_noise_sources (int): Number of noise sources. First source is the
                engine noise. Following sources are other noise sources.
            commonness_factor (float): Commonness factor
        Returns:
            np.ndarray: Car noise signal
        """

        # .. reference level = no speed dependence  plus a small randomization
        referencelevel_db = noise_parameters["reference_level_db"]

        filter_numer_denom: dict = {}
        for filter_type in [
            "primary_filter",
            "secondary_filter",
            "bump",
            "dip_low",
            "dip_high",
        ]:
            if filter_type in noise_parameters:
                filter_numer_denom[filter_type] = {}
                numer, denom = _butter_filter(
                    **noise_parameters[filter_type], sample_rate=self.sample_rate
                )
                filter_numer_denom[filter_type]["numer"] = numer
                filter_numer_denom[filter_type]["denom"] = denom
            else:
                raise ValueError(f"Parameter for filter {filter_type} are mandatory.")

        # initialise waveform matrix
        # global_noise = Main array storing the different sources
        # extra_noise_for_coherence = Array to add some correlation between sources
        car_noise = np.zeros((number_noise_sources + 1, self.duration_samples))
        extra_noise_for_coherence = np.zeros((1, self.duration_samples))

        # Generate Engine Noise
        # stored in index 0 in source_noise
        engine_noise = self.generate_engine_noise(
            speed=noise_parameters["speed"],
            rpm=noise_parameters["rpm"],
            engine_num_harmonics=noise_parameters["engine_num_harmonics"],
            reference_level_db=referencelevel_db,
            primary_filter=filter_numer_denom["primary_filter"],
            secondary_filter=filter_numer_denom["secondary_filter"],
        )
        car_noise[0, :] = engine_noise

        # Generate Noise for N sources
        for n in range(1, number_noise_sources + 2):
            source_n_noise = self.generate_source_noise(
                reference_level_db=referencelevel_db,
                primary_filter=filter_numer_denom["primary_filter"],
                secondary_filter=filter_numer_denom["secondary_filter"],
                bump_filter=filter_numer_denom["bump"],
                dip_low_filter=filter_numer_denom["dip_low"],
                dip_high_filter=filter_numer_denom["dip_high"],
            )

            if n <= number_noise_sources:
                car_noise[n, :] = source_n_noise
            else:
                extra_noise_for_coherence[0, :] = source_n_noise

        # Add some correlation between sources
        for n in range(1, number_noise_sources + 1):
            car_noise[n, :] = self.apply_commonness(
                car_noise[n, :], extra_noise_for_coherence[0, :], commonness_factor
            )

        return car_noise * self.FINAl_MULTIPLIER

    def generate_source_noise(
        self,
        reference_level_db: float,
        primary_filter: dict[str, np.ndarray],
        secondary_filter: dict[str, np.ndarray],
        bump_filter: dict[str, np.ndarray],
        dip_low_filter: dict[str, np.ndarray],
        dip_high_filter: dict[str, np.ndarray],
    ) -> np.ndarray:
        """
        Method that generates the noise of a single source.

        Args:
            reference_level_db (float): Reference level in dB
            primary_filter (Dict[str, np.ndarray]): Primary filter
            secondary_filter (Dict[str, np.ndarray]): Secondary filter
            bump_filter (Dict[str, np.ndarray]): Bump filter
            dip_low_filter (Dict[str, np.ndarray]): Low dip filter
            dip_high_filter (Dict[str, np.ndarray]): High dip filter

        Returns:
            np.ndarray: Noise of a single source

        """
        lowpasslevel_db = (
            reference_level_db + 35
        )  # puts it slightly quieter than the tones
        lowpassnoise_gaussianstd = 10 ** (lowpasslevel_db / 20)

        source_noise = (
            np.random.normal(size=self.duration_samples) * lowpassnoise_gaussianstd
        )

        source_noise = lfilter(
            primary_filter["numer"], primary_filter["denom"], source_noise
        )

        # Add bump to the noise
        # Get Bump noise parameters
        lowpassnoise_gaussianstd, bump_gaussianstd = self.get_bump_params(
            reference_level_db
        )
        bump_waveform = np.random.normal(size=self.duration_samples) * bump_gaussianstd
        bump_waveform = lfilter(
            bump_filter["numer"], bump_filter["denom"], bump_waveform
        )

        source_noise = source_noise + bump_waveform

        # Add dip to the noise
        dip_lowerpart = lfilter(
            dip_low_filter["numer"], dip_low_filter["denom"], source_noise
        )
        dip_upperpart = lfilter(
            dip_high_filter["numer"], dip_high_filter["denom"], source_noise
        )
        source_noise = dip_lowerpart + dip_upperpart

        # final secondary filter
        source_noise = lfilter(
            secondary_filter["numer"], secondary_filter["denom"], source_noise
        )

        return source_noise

    def generate_engine_noise(
        self,
        speed: float,
        rpm: float,
        reference_level_db: float,
        engine_num_harmonics: int,
        primary_filter: dict[str, np.ndarray],
        secondary_filter: dict[str, np.ndarray],
    ) -> np.ndarray:
        """
        Method that generates the noise of the engine.
        Args:
            speed (float): Speed of the car in km/h
            rpm (float): RPM of the engine
            reference_level_db (float): Reference level in dB
            engine_num_harmonics (int): Number of harmonics of the engine
            primary_filter (Dict[str, np.ndarray]): Primary filter
            secondary_filter (Dict[str, np.ndarray]): Secondary filter

        Returns:
            np.ndarray: Noise of the engine

        """
        # Creating butter filters for engine noise

        (
            harmonic_complex_freqs_hz,
            harmonic_complex_power_db,
        ) = self.get_engine_params(speed, rpm, reference_level_db, engine_num_harmonics)

        engine_noise = np.zeros(self.duration_samples)

        for c in range(engine_num_harmonics):
            component = {}
            # set amplitude of tone
            component["amplitude"] = 10 ** (harmonic_complex_power_db[c] / 20)
            # set random phase
            component["phase_rads"] = np.random.rand() * 2 * np.pi
            # set angular frequency
            component["angularfreq_rads"] = 2 * np.pi * harmonic_complex_freqs_hz[c]
            # make :-)
            component_waveform = component["amplitude"] * np.sin(
                component["angularfreq_rads"] * self.timevector_ms
                + component["phase_rads"]
            )
            # store
            engine_noise += component_waveform

        # Applies Primary and Secondary filters
        engine_noise = lfilter(
            primary_filter["numer"],
            primary_filter["denom"],
            engine_noise,
        )
        engine_noise = lfilter(
            secondary_filter["numer"],
            secondary_filter["denom"],
            engine_noise,
        )

        return engine_noise

    def get_bump_params(self, reference_level_db: float) -> tuple[float, float]:
        """
        Method that gets the parameters of the bump noise
        Args:
            reference_level_db (float): Reference level in dB

        Returns:
            lowpass_noise_gaussian_std (float): Standard deviation of the low pass noise
            bump_gaussian_std (float): Standard deviation of the bump noise
        """
        lowpasslevel_db = (
            reference_level_db + 35
        )  # puts it slightly quieter than the tones
        lowpass_noise_gaussian_std = 10 ** (lowpasslevel_db / 20)

        bump_level_db = lowpasslevel_db + np.random.choice(
            np.arange(6 if not self.random_flag else 0, 6 + 1, 1)
        )
        bump_gaussian_std = 10 ** (bump_level_db / 20)

        return lowpass_noise_gaussian_std, bump_gaussian_std

    def get_engine_params(
        self,
        speed: float,
        rpm: float,
        reference_level_db: float,
        engine_num_harmonics: int,
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Method that gets the parameters of the engine noise

        Args:
            speed (float): Speed of the car in km/h
            rpm (float): RPM of the engine
            reference_level_db (float): Reference level in dB
            engine_num_harmonics (int): Number of harmonics of the engine

        Returns:
            harmoniccomplex_freqs_hz (np.ndarray): Frequency of the harmonic complex
            harmoniccomplex_ntones (int): Number of tones of the harmonic complex
            harmoniccomplex_power_db (np.ndarray): Power of the harmonic complex
        """

        harmoniccomplex_f0 = rpm / 60

        harmoniccomplex_freqs_hz = harmoniccomplex_f0 * np.arange(
            1, engine_num_harmonics + 1
        )

        tone_speeddependence_level_dbperkph = 2 * 0.067
        toneconstant_db = -3

        harmoniccomplex_power_db = np.zeros(engine_num_harmonics)

        # Generating power levels for each component
        for n in range(engine_num_harmonics):
            harmoniccomplex_power_db[n] = (
                reference_level_db
                + toneconstant_db
                + tone_speeddependence_level_dbperkph * speed
                + np.random.choice(
                    np.arange(
                        0 if not self.random_flag else -2,
                        0 if not self.random_flag else 2.1,
                        0.1,
                    )
                )
            )
            if self.random_flag == 0:
                if n + 1 in [5, 7, 13, 14, 15, 19, 21, 23]:
                    harmoniccomplex_power_db[n] = -99
            else:
                if np.random.uniform(0, 1) < 0.333:
                    harmoniccomplex_power_db[n] = -99

        return (
            harmoniccomplex_freqs_hz,
            harmoniccomplex_power_db,
        )

    @staticmethod
    def apply_commonness(
        target_signal: np.ndarray,
        coherence_signal: np.ndarray,
        commonness_factor: float,
    ) -> np.ndarray:
        """Function to apply correlation between the target signal
        using the coherence signal.

        Args:
            target_signal (np.ndarray): Target signal
            coherence_signal (np.ndarray): Coherence signal
            commonness_factor (float): Commonness factor

        Returns:
            target_signal (np.ndarray): Target signal with the coherence signal
        """
        target_signal = (
            1 - commonness_factor
        ) * target_signal + commonness_factor * coherence_signal

        # correct amplitude for the sum of two gaussians
        correctionfactor = 1 / np.sqrt(
            (1 - commonness_factor) ** 2 + commonness_factor**2
        )
        return target_signal * correctionfactor
