import numpy as np
import pywt
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import time


def cwt_for_batch(
    data, scales=np.arange(1, 33), mother_wavelet:str='mexh', use_abs:bool=False,
    downsampling_ratio:float=None, remove_last_row_column:bool=False):
    """Deploy Continuous Wavelet Transform (CWT) on a batch of data, then optionally downsample and clip it.  
    The dataset could be an `N x C x L` tensor, where
    
    - `N` is the batch size (number of data points). 
    - `C` is the number of channels (features).
    - `L` is the number of time steps (sequence length).
    - `H` is the number of elements in the `scales` argument.
    
    In general, the data can be of shape `(..., L)`, then the output will a tensor of shape `(..., H, L)`.   
    The function first performs CWT on the data, using the last dimension as time. It then reshapes it to have the shape described above.  
    Optionally, the last two dimensions `(H, L)` corresponding to our images, will be zoomed in (downsampled) to a ratio of `downsampling_ratio`, if provided.  
    If desired (for eliminating baseline drift and motion artifacts), the last row and column (last time step, last scale) are dropped.
    
    The coefficients are returned as an `(... x H' x L')` tensor, where `H' = r*H - 1` if clipping is requested, otherwise `r*H` (same with `L'`).  
    `r` is the downsampling ratio, if provided.
        
    ### Args:
    
        `data` (NumPy array): Data of shape (...,timesteps), preferrable (batchsize, channels, timesteps).
        `scales` (Numpy array, optional): Array of scales. Defaults to `np.arange(1, 33)`. Input to `pywt.cwt` function.
        `mother_wavelet` (str, optional): Mother wavelet. Defaults to `'mexh'` for Mexican Hat Wavelet. Input to `pywt.cwt` function.
        `use_abs` (bool, optional): Whether to use the absolute value of the data for CWT, or data itself. Defaults to False.
        `downsampling_ratio` (float, optional): Downsampling ratio (smaller than 1). Defaults to None.
        `remove_last_row_column` (bool, optional): Whether to remove the last row and column of the CWT coefficients. Defaults to False.

    ### Returns:
    
        CWT coefficients: Dataset with shape `(..., H', L')`, or `(N, C, H', L')` in the preferred case.
    """
    invec = np.abs(data) if use_abs else data                             # N C L
    outvec = cwt_for_tensor(invec, mother_wavelet, scales)                # H N C L
    outvec = np.expand_dims(outvec, len(outvec.shape)-1)                  # H N C 1 L
    outvec = outvec.swapaxes(0, len(outvec.shape)-2)                      # 1 N C H L
    outvec = outvec.squeeze(0)                                            # N C H L
    if downsampling_ratio:
        zoomlst = [1]*(len(outvec.shape)-2) + [downsampling_ratio]*2
        outvec = zoom(outvec, zoomlst, order=0)                           # N C rH rL
    if remove_last_row_column:
        outvec = outvec[...,:-1,:-1]                                      # N C rH-1 rL-1
    return outvec
    



def cwt_for_tensor(matrix, mother_wavelet='mexh', scales=np.arange(1, 33)):
    """Calculate CWT of data matrix
    
    L: number of samples (sequence length)
    C: number of data channels (features)
    
    ### Args:
        `matrix`: C x L matrix of data 
        `mother_wavelet` (str, optional): Mother wavelet. Defaults to 'mexh' for Mexican Hat Wavelet.
        `scales` (Numpy array, optional): Array of scales. Defaults to np.arange(1, 33).

    ### Returns:
        coef: H x C x L matrix of CWT coefficients, where H is the height of the image generated, and L is the width.
    """
    coef, _ = pywt.cwt(matrix, scales=scales, wavelet=mother_wavelet)
    return coef



def show_wavelet(coef):
    """Plots the image generated by CWT for a single-dimensional data.

    ### Args:
        `coef` (Numpy array): H x C x L Matrix of coefficients returned directly by the CWT function.
        H is the height (size of the scales array passed to the CWT function)
        L is the width (sequence length)
        C is the number of channels (features)
    """
    # print(np.shape(coef))
    # plt.rcParams.update({'font.size': 36})
    ndim = len(np.shape(coef))
    if ndim == 2:
        plt.matshow(coef)
        plt.ylabel('Scale')
        plt.xlabel('Samples')
        plt.show()
    else:
        C = np.shape(coef)[1]
        for c in range(C):
            plt.subplot(C, 1, c+1)
            plt.imshow(coef[:, c, :], aspect='auto')
            # plt.matshow(coef[:, :, c])
            plt.title('Channel {}'.format(c+1))
            plt.ylabel('Scales')
        plt.xlabel('Samples')
        plt.subplots_adjust(hspace=0.5)
        plt.show()



def test_cwt_for_tensor():
    print("---------------------------------------")
    print("Testing cwt_for_tensor()...")
    L = 512
    C = 2
    H = 32
    t = np.arange(L)
    x1 = np.sin(2 * np.pi * 32 * t / L)
    x2 = np.sin(2 * np.pi * 64 * t / L)
    x = np.array([x1, x2])
    print("L: ", L)
    print("C: ", C)
    print("H: ", H)
    print("Shape of t: ", t.shape)
    print("Shape of x: ", x.shape)
    coeff = cwt_for_tensor(x)
    print("Shape of coeff: ", coeff.shape)
    show_wavelet(coeff)



def test_cwt_for_batch():
    print("---------------------------------------")
    print("Testing cwt_for_batch()...")
    N = 10
    L = 512
    C = 2
    H = 32
    dataset = np.random.rand(N, C, L)
    print("Shape of dataset: ", dataset.shape)
    new_dataset = cwt_for_batch(dataset)
    print("Shape of new_dataset : ", new_dataset.shape)
    print("Trying with preprocessing ...")
    new_dataset = cwt_for_batch(dataset, downsampling_ratio=0.25, remove_last_row_column=True)
    print("Shape of new_dataset : ", new_dataset.shape)
    

def test_cwt_for_batch_with_timing(dur_sec=60, fs_Hz=1000, seqlen_sec=1.0, nchan=4, num_scales=32): 
    # --> 280 sec in both cases with these defaults
    # It is the cwt itself that is computationally expensive, not downsampling or clipping.
    print("---------------------------------------")
    print("Testing cwt_for_batch() with timing ...")
    # See how long it takes to get cwt of an entire trial of data with 4 channels
    N = int(dur_sec*fs_Hz)
    C = int(nchan)
    L = int(seqlen_sec*fs_Hz)
    scales = np.arange(1,num_scales+1).astype(np.float32)
    data = np.random.rand(N, C, L).astype(np.float32)
    print("Shape of input array: ", data.shape)
    print("Trying without any postprocessing ...")
    t1 = time.time()
    coefs = cwt_for_batch(data, scales=scales)
    t2 = time.time()
    print("Elapsed time: ", t2-t1)
    print("Trying with postprocessing ...")
    t1 = time.time()
    coefs = cwt_for_batch(data, scales=scales, downsampling_ratio=0.25, remove_last_row_column=True)
    t2 = time.time()
    print("Elapsed time: ", t2-t1)

    
    



if __name__ == '__main__':
    # test_cwt_for_tensor()
    # test_cwt_for_batch()
    # test_cwt_for_batch_with_timing()
    pass