# -*- coding: utf-8 -*-
"""
A simple abstract base class with method stubs to enable users to extend
QuakeMigrate with custom onset functions that remain compatible with the core
of the package.

Also contains a light class to encapsulate the data generated by the onset
function, to be used for migration or phase picking.

:copyright:
    2020 - 2021, QuakeMigrate developers.
:license:
    GNU General Public License, Version 3
    (https://www.gnu.org/licenses/gpl-3.0.html)

"""

from abc import ABC, abstractmethod

import numpy as np

import quakemigrate.util as util


class Onset(ABC):
    """
    QuakeMigrate default onset function class.

    Attributes
    ----------
    sampling_rate : int
        Desired sampling rate for input data; sampling rate at which the onset
        functions will be computed.
    pre_pad : float, optional
        Option to override the default pre-pad duration of data to read before
        computing 4-D coalescence in detect() and locate().
    post_pad : float
        Option to override the default post-pad duration of data to read before
        computing 4-D coalescence in detect() and locate().

    Methods
    -------
    calculate_onsets()
        Generate onset functions that represent seismic phase arrivals
    pad(timespan)
        Create appropriate padding to include the taper.

    """

    def __init__(self, **kwargs):
        """Instantiate the Onset object."""

        self.sampling_rate = kwargs.get("sampling_rate")
        if self.sampling_rate is None:
            raise ValueError("Must specify 'sampling_rate' for any Onset.")

        self._pre_pad = 0
        self._post_pad = 0

    def __str__(self):
        """Return short summary string of the Onset object."""

        return "Base Onset object - add a __str__ method to your Onset class"

    def pad(self, timespan):
        """
        Determine the number of samples needed to pre- and post-pad the
        timespan.

        Parameters
        ----------
        timespan : float
            The time window to pad.

        Returns
        -------
        pre_pad : float
            Option to override the default pre-pad duration of data to read
            before computing 4-D coalescence in detect() and locate().
        post_pad : float
            Option to override the default post-pad duration of data to read
            before computing 4-D coalescence in detect() and locate().

        """

        # Add additional padding for any tapering applied to data
        timespan += (self.pre_pad + self.post_pad)
        pre_pad = util.trim2sample(self.pre_pad + np.ceil(timespan*0.06),
                                   self.sampling_rate)
        post_pad = util.trim2sample(self.post_pad + np.ceil(timespan*0.06),
                                    self.sampling_rate)

        return pre_pad, post_pad

    def gaussian_halfwidth(self, phase):
        """Method stub for Gaussian half-width estimate."""

        msg = ("In order to use the GaussianPicker module with a custom Onset,"
               " you need to provide a 'gaussian_halfwidth' method.")

        raise AttributeError(msg)

    @abstractmethod
    def calculate_onsets(self):
        """Method stub for calculation of onset functions."""
        pass

    @property
    @abstractmethod
    def pre_pad(self):
        """Get property stub for pre_pad."""
        return self._pre_pad

    @pre_pad.setter
    @abstractmethod
    def pre_pad(self, value):
        """Set property stub for pre_pad."""
        self._pre_pad = value

    @property
    @abstractmethod
    def post_pad(self):
        """Get property stub for pre_pad."""
        return self._post_pad

    @post_pad.setter
    @abstractmethod
    def post_pad(self, value):
        """Set property stub for pre_pad."""
        self._post_pad = value


class OnsetData:
    """
    The OnsetData class encapsulates the onset functions calculated by
    transforming seismic data using the chosen onset detection algorithm
    (characteristic function).

    This includes a dictionary describing which onset functions are available
    for each station and phase, and the intermediary filtered or otherwise
    pre-processed waveform data used to calculate the onset function.

    Parameters
    ----------
    onsets : dict of dicts
        Keys "station", each of which contains keys for each phase, e.g. "P"
        and "S". {"station": {"P": `p_onset`, "S": `s_onset`}}. Onset
        functions are calculated by transforming the raw seismic data using
        some characteristic function designed to highlight phase arrivals.
    phases : list of str
        Phases for which onsets have been calculated. (e.g. ["P", "S"])
    channel_maps : dict of str
        Data component maps - keys are phases. (e.g. {"P": "Z"})
    filtered_waveforms : `obspy.Stream` object
        Filtered and/or resampled and otherwise processed seismic data
        generated during onset function generation. Only contains waveforms
        that have passed the quality control criteria, at a unified sampling
        rate - see `sampling_rate`.
    availability : dict
        Dictionary with keys "station_phase", containing 1's or 0's
        corresponding to whether an onset function is available for that
        station and phase - determined by data availability and quality checks.
    starttime : `obspy.UTCDateTime` object
        Start time of onset functions.
    endtime : `obspy.UTCDateTime` object
        End time of onset functions.
    sampling_rate : int
        Sampling rate of filtered waveforms and onset functions.

    """

    def __init__(self, onsets, phases, channel_maps, filtered_waveforms,
                 availability, starttime, endtime, sampling_rate):
        """Instantiate the OnsetData object."""

        self.onsets = onsets
        self.phases = phases
        self.channel_maps = channel_maps
        self.filtered_waveforms = filtered_waveforms
        self.availability = availability

        self.starttime = starttime
        self.endtime = endtime
        self.sampling_rate = sampling_rate
