import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mhkit.wave.resource import significant_wave_height as _sig_wave_height
from mhkit.wave.resource import peak_period as _peak_period
from mhkit.river.graphics import _xy_plot


def plot_spectrum(S, ax=None):
    """
    Plots wave amplitude spectrum versus omega

    Parameters
    ------------
    S: pandas DataFrame
        Spectral density [m^2/Hz] indexed frequency [Hz]
    ax : matplotlib axes object
        Axes for plotting.  If None, then a new figure is created.
    Returns
    ---------
    ax : matplotlib pyplot axes

    """
    assert isinstance(S, pd.DataFrame), 'S must be of type pd.DataFrame'

    f = S.index

    ax = _xy_plot(f*2*np.pi, S/(2*np.pi), fmt='-', xlabel='omega [rad/s]',
             ylabel='Spectral density [m$^2$s/rad]', ax=ax)


    return ax

def plot_elevation_timeseries(eta, ax=None):
    """
    Plot wave surface elevation time-series

    Parameters
    ------------
    eta: pandas DataFrame
        Wave surface elevation [m] indexed by time [datetime or s]
    ax : matplotlib axes object
        Axes for plotting.  If None, then a new figure is created.

    Returns
    ---------
    ax : matplotlib pyplot axes

    """

    assert isinstance(eta, pd.DataFrame), 'eta must be of type pd.DataFrame'

    ax = _xy_plot(eta.index, eta, fmt='-', xlabel='Time',
                  ylabel='$\eta$ [m]', ax=ax)

    return ax

def plot_matrix(M, xlabel='Te', ylabel='Hm0', zlabel=None, show_values=True,
                ax=None):
    """
    Plots values in the matrix as a scatter diagram
    
    Parameters
    ------------
    M: pandas DataFrame
        Matrix with numeric labels for x and y axis, and numeric entries.
        An example would be the average capture length matrix generated by
        mhkit.device.wave, or something similar.
    xlabel: string (optional)
        Title of the x-axis
    ylabel: string (optional)
        Title of the y-axis
    zlabel: string (optional)
        Colorbar label
    show_values : bool (optional)
        Show values on the scatter diagram
    ax : matplotlib axes object
        Axes for plotting.  If None, then a new figure is created.

    Returns
    ---------
    ax : matplotlib pyplot axes

    """
    
    assert isinstance(M, pd.DataFrame), 'M must be of type pd.DataFrame'

    if ax is None:
        plt.figure()
        ax = plt.gca()

    im = ax.imshow(M, origin='lower', aspect='auto')

    # Add colorbar
    cbar = plt.colorbar(im)
    if zlabel:
        cbar.set_label(zlabel, rotation=270, labelpad=15)

    # Set x and y label
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    # Show values in the plot
    if show_values:
        for i, col in enumerate(M.columns):
            for j, index in enumerate(M.index):
                if not np.isnan(M.loc[index,col]):
                    ax.text(i, j, format(M.loc[index,col], '.2f'), ha="center", va="center")

    # Reset x and y ticks
    ax.set_xticks(np.arange(len(M.columns)))
    ax.set_yticks(np.arange(len(M.index)))
    ax.set_xticklabels(M.columns)
    ax.set_yticklabels(M.index)

    return ax


def plot_chakrabarti(H, lambda_w, D, ax=None):
    """
    Plots, in the style of Chakrabart (2005), relative importance of viscous,
    inertia, and diffraction phemonena

    Chakrabarti, Subrata. Handbook of Offshore Engineering (2-volume set).
    Elsevier, 2005.

    Parameters
    ------------
    H: float or numpy array or pandas Series
        Wave height [m]
    lambda_w: float or numpy array or pandas Series
        Wave length [m]
    D: float or numpy array or pandas Series
        Characteristic length [m]
    ax : matplotlib axes object (optional)
        Axes for plotting.  If None, then a new figure is created.


    Returns
    ---------
    ax : matplotlib pyplot axes

    Examples
    --------
    **Using floats**

    >>> plt.figure()
    >>> D = 5
    >>> H = 8
    >>> lambda_w = 200
    >>> wave.graphics.plot_chakrabarti(H, lambda_w, D)

    **Using numpy array**

    >>> plt.figure()
    >>> D = np.linspace(5,15,5)
    >>> H = 8*np.ones_like(D)
    >>> lambda_w = 200*np.ones_like(D)
    >>> wave.graphics.plot_chakrabarti(H, lambda_w, D)

    **Using pandas DataFrame**

    >>> plt.figure()
    >>> D = np.linspace(5,15,5)
    >>> H = 8*np.ones_like(D)
    >>> lambda_w = 200*np.ones_like(D)
    >>> df = pd.DataFrame([H.flatten(),lambda_w.flatten(),D.flatten()], \
                              index=['H','lambda_w','D']).transpose()
    >>> wave.graphics.plot_chakrabarti(df.H, df.lambda_w, df.D)
    """
    assert isinstance(H, (np.ndarray, float, int, np.int64,pd.Series)), \
           'H must be a real numeric type'
    assert isinstance(lambda_w, (np.ndarray, float, int, np.int64,pd.Series)), \
           'lambda_w must be a real numeric type'
    assert isinstance(D, (np.ndarray, float, int, np.int64,pd.Series)), \
           'D must be a real numeric type'

    if any([(isinstance(H, np.ndarray) or isinstance(H, pd.Series)),        \
            (isinstance(lambda_w, np.ndarray) or isinstance(H, pd.Series)), \
            (isinstance(D, np.ndarray) or isinstance(H, pd.Series))\
           ]):
        errMsg = 'D, H, and lambda_w must be same shape'
        n_H = H.squeeze().shape
        n_lambda_w = lambda_w.squeeze().shape
        n_D = D.squeeze().shape
        assert n_H == n_lambda_w and n_H == n_D, errMsg

        if isinstance(H, np.ndarray):
            mvals = pd.DataFrame(H.reshape(len(H),1), columns=['H'])
            mvals['lambda_w'] = lambda_w
            mvals['D'] = D
        elif isinstance(H, pd.Series):
            mvals = pd.DataFrame(H)
            mvals['lambda_w'] = lambda_w
            mvals['D'] = D

    else:
        H = np.array([H])
        lambda_w = np.array([lambda_w])
        D = np.array([D])
        mvals = pd.DataFrame(H.reshape(len(H),1), columns=['H'])
        mvals['lambda_w'] = lambda_w
        mvals['D'] = D

    if ax is None:
        plt.figure()
        ax = plt.gca()

    ax.set_xscale('log')
    ax.set_yscale('log')

    for index, row in mvals.iterrows():
        H = row.H
        D = row.D
        lambda_w = row.lambda_w
        
        KC = H / D
        Diffraction = np.pi*D / lambda_w
        label = f'$H$ = {H:g}, $\lambda_w$ = {lambda_w:g}, $D$ = {D:g}'
        ax.plot(Diffraction, KC, 'o', label=label)
   
    if np.any(KC>=10 or KC<=.02) or np.any(Diffraction>=50) or \
        np.any(lambda_w >= 1000) :
        ax.autoscale(enable=True, axis='both', tight=True)  
    else:
        ax.set_xlim((0.01, 10))
        ax.set_ylim((0.01, 50))

    graphScale = list(ax.get_xlim())
    if graphScale[0] >= .01:
        graphScale[0] =.01

    # deep water breaking limit (H/lambda_w = 0.14)
    x = np.logspace(1,np.log10(graphScale[0]), 2)
    y_breaking = 0.14 * np.pi / x
    ax.plot(x, y_breaking, 'k-')
    graphScale = list(ax.get_xlim())
    
    ax.text(1, 7, 
            'wave\nbreaking\n$H/\lambda_w > 0.14$', 
            ha='center', va='center', fontstyle='italic', 
            fontsize='small',clip_on='True')

    # upper bound of low drag region
    ldv = 20
    y_small_drag = 20*np.ones_like(graphScale)
    graphScale[1] = 0.14 * np.pi / ldv
    ax.plot(graphScale, y_small_drag,'k--')
    ax.text(0.0125, 30, 
            'drag', 
            ha='center', va='top', fontstyle='italic',
            fontsize='small',clip_on='True')
            
    # upper bound of small drag region
    sdv = 1.5
    y_small_drag = sdv*np.ones_like(graphScale)
    graphScale[1] = 0.14 * np.pi / sdv
    ax.plot(graphScale, y_small_drag,'k--')
    ax.text(0.02, 7, 
            'inertia \n& drag', 
            ha='center', va='center', fontstyle='italic', 
            fontsize='small',clip_on='True')

    # upper bound of negligible drag region
    ndv = 0.25
    graphScale[1] = 0.14 * np.pi / ndv
    y_small_drag = ndv*np.ones_like(graphScale)
    ax.plot(graphScale, y_small_drag,'k--')
    ax.text(8e-2, 0.7, 
            'large\ninertia', 
            ha='center', va='center', fontstyle='italic', 
            fontsize='small',clip_on='True')


    ax.text(8e-2, 6e-2, 
            'all\ninertia', 
            ha='center', va='center', fontstyle='italic', 
            fontsize='small', clip_on='True')

    # left bound of diffraction region
    drv = 0.5
    graphScale = list(ax.get_ylim())
    graphScale[1] = 0.14 * np.pi / drv
    x_diff_reg = drv*np.ones_like(graphScale)
    ax.plot(x_diff_reg, graphScale, 'k--')
    ax.text(2, 6e-2, 
            'diffraction', 
            ha='center', va='center', fontstyle='italic',
            fontsize='small',clip_on='True')


    if index > 0:
        ax.legend(fontsize='xx-small', ncol=2)

    ax.set_xlabel('Diffraction parameter, $\\frac{\\pi D}{\\lambda_w}$')
    ax.set_ylabel('KC parameter, $\\frac{H}{D}$')

    plt.tight_layout()

def plot_environmental_contour(x1, x2, x1_contour, x2_contour, **kwargs):
    '''
    Plots an overlay of the x1 and x2 variables to the calculate
    environmental contours.

    Parameters
    ----------
    x1: numpy array  
        x-axis data
    x2: numpy array 
        x-axis data
    x1_contour: numpy array 
        Calculated x1 contour values
    x2_contour: numpy array 
        Calculated x2 contour values 
    **kwargs : optional         
        x_label: string (optional)
            x-axis label. Default None. 
        y_label: string (optional)
            y-axis label. Default None.
        data_label: string (optional)
            Legend label for x1, x2 data (e.g. 'Buoy 46022'). 
            Default None.
        contour_label: string or list of strings (optional)
            Legend label for x1_contour, x2_contour countor data 
            (e.g. '100-year contour'). Default None.
        ax : matplotlib axes object (optional)
            Axes for plotting.  If None, then a new figure is created.
            Default None.

    Returns
    -------
    ax : matplotlib pyplot axes
    '''
    
    assert isinstance(x1, np.ndarray), 'x1 must be of type np.ndarray'
    assert isinstance(x2, np.ndarray), 'x2 must be of type np.ndarray'
    assert isinstance(x1_contour, np.ndarray), ('x1_contour must be of '
                                                'type np.ndarray')
    assert isinstance(x2_contour, np.ndarray), ('x2_contour must be of '
                                               'type np.ndarray')
    x_label = kwargs.get("x_label", None)
    y_label = kwargs.get("y_label", None)
    data_label=kwargs.get("data_label", None)
    contour_label=kwargs.get("contour_label", None)
    ax=kwargs.get("ax", None)
    assert isinstance(data_label, str), 'data_label must be of type str'
    assert isinstance(contour_label, (str,list)), ('contour_label be of '
                                                  'type str')
    assert x2_contour.ndim == x1_contour.ndim,  ('contour must be of' 
            f'equal dimesion got {x2_contour.ndim} and {x1_contour.ndim}')                                                  
    
        
    if x2_contour.ndim == 1:
        x2_contour  = x2_contour.reshape(-1,1) 
        x1_contour = x1_contour.reshape(-1,1) 
    
    N_contours = x2_contour.shape[1]
    
    if contour_label != None:
        if isinstance(contour_label, str):
            contour_label = [contour_label] 
        N_c_labels = len(contour_label)
        assert  N_c_labels == N_contours, ('If specified, the '
             'number of contour lables must be equal to number the '
            f'number of contour years. Got {N_c_labels} and {N_contours}')   
    else:
        contour_label = [None] * N_contours
    
    for i in range(N_contours):       
        ax = _xy_plot(x1_contour[:,i], x2_contour[:,i],'-', 
                      label=contour_label[i], ax=ax)
            
    ax = plt.plot(x1, x2, 'bo', alpha=0.1, 
                  label=data_label) 
    plt.legend(loc='lower right')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.tight_layout()
    return ax