import numpy as np
from sklearn.linear_model import Lasso
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve
"""
get_zdFF.py calculates standardized dF/F signal based on calcium-idependent
and calcium-dependent signals commonly recorded using fiber photometry calcium imaging

Ocober 2019 Ekaterina Martianova ekaterina.martianova.1@ulaval.ca

Reference:
  (1) Martianova, E., Aronson, S., Proulx, C.D. Multi-Fiber Photometry
      to Record Neural Activity in Freely Moving Animal. J. Vis. Exp.
      (152), e60278, doi:10.3791/60278 (2019)
      https://www.jove.com/video/60278/multi-fiber-photometry-to-record-neural-activity-freely-moving

airPLS.py Copyright 2014 Renato Lombardo - renato.lombardo@unipa.it
Baseline correction using adaptive iteratively reweighted penalized least squares

This program is a translation in python of the R source code of airPLS version 2.0
by Yizeng Liang and Zhang Zhimin - https://code.google.com/p/airpls

Reference:
Z.-M. Zhang, S. Chen, and Y.-Z. Liang, Baseline correction using adaptive iteratively
reweighted penalized least squares. Analyst 135 (5), 1138-1146 (2010).

LICENCE
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
"""


def add_zdFF(df, method='airPLS', **kwargs):
    """
    High-level function for calculating zdFF using various methods and adding zdFF to the df
    :param df:
    :param method:
    :param kwargs:
    :return:
    """
    if method == 'airPLS':
        df = df.reset_index('FrameCounter', drop=False)
        df = df.groupby(df.index, group_keys=False).apply(lambda df: zdFF_airPLS(df, **kwargs))
        return df
    else:
        raise NotImplementedError('Method {m} is not yet implemented.'.format(m=method))


def zdFF_airPLS(
    df, smooth_win=10, remove=200, lambd=5e4, porder=1, itermax=50
):
    """
    Low-level implementation of the airPLS Pipeline
    Handles values in df as a single datastream
    Calculates z-score dF/F signal based on fiber photometry calcium-independent
    and calcium-dependent signals
        :param df: df with raw data, containing columns 'Signal' and 'Reference'
        :param smooth_win: window for moving average smooth, integer
        :param remove: the beginning of the traces with a big slope one would like to remove, integer
        :param lambd: parameter for airPLS. The larger lambda is,
                the smoother the resulting background
        :param porder: adaptive iteratively reweighted penalized least squares for baseline fitting
        :param itermax: maximum iteration times
        :return: df with 'zdFF (airPLS)' as an additional column
    """

    # remove beginning and end of recording
    df = df[(df.FrameCounter > remove) & (df.FrameCounter < max(df.FrameCounter) - remove)]
    # df = df.drop(np.arange(1, remove)).drop(np.arange(max(df.index) - remove, max(df.index)))

    # Smooth signal
    reference = smooth_signal(df['Reference'], smooth_win)
    signal = smooth_signal(df['Signal'], smooth_win)

    # Remove slope using airPLS algorithm
    reference -= airPLS(reference, lambda_=lambd, porder=porder, itermax=itermax)
    signal -= airPLS(signal, lambda_=lambd, porder=porder, itermax=itermax)

    # Standardize signals
    # TODO: why use median and not mean?
    reference = (reference - np.median(reference)) / np.std(reference)
    signal = (signal - np.median(signal)) / np.std(signal)

    # Align reference signal to calcium signal using non-negative robust linear regression
    lin = Lasso(
        alpha=0.0001,
        precompute=True,
        max_iter=1000,
        positive=True,
        random_state=9999,
        selection="random",
    )
    n = len(reference)
    lin.fit(reference.reshape(n, 1), signal.reshape(n, 1))
    reference = lin.predict(reference.reshape(n, 1)).reshape(
        n,
    )

    df['zdFF (airPLS)'] = signal - reference
    return df


def smooth_signal(x, window_len=10, window="flat"):
    """
    smooth the data using a window with requested size.
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    The code taken from: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    :param x: np array, the input signal
    :param window_len: the dimension of the smoothing window; should be an odd integer
    :param window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
                'flat' window will produce a moving average smoothing.
    :return: array, the smoothed signal
    """

    if x.ndim != 1:
        raise (ValueError, "smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        raise (ValueError, "Input vector needs to be bigger than window size.")

    if window_len < 3:
        return x

    if window not in ["flat", "hanning", "hamming", "bartlett", "blackman"]:
        raise (
            ValueError,
            "Window is one of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'",
        )

    s = np.r_[x[window_len - 1 : 0 : -1], x, x[-2 : -window_len - 1 : -1]]

    if window == "flat":  # Moving average
        w = np.ones(window_len, "d")
    else:
        w = eval("np." + window + "(window_len)")

    y = np.convolve(w / w.sum(), s, mode="valid")

    y = y[(int(window_len / 2) - 1) : -int(window_len / 2)]
    if len(x) != len(y):
        y = y[:len(x)]
    return y


def whittaker_smooth(x, w, lambda_, differences=1):
    """
    Penalized least squares algorithm for background fitting
        :param x: input data (i.e. chromatogram of spectrum)
        :param w: binary masks (value of the mask is zero if a point belongs to peaks and one otherwise)
        :param lambda_: parameter that can be adjusted by user. The larger lambda is,
                 the smoother the resulting background
        :param differences: integer indicating the order of the difference of penalties
        :return: np array of whittaker smooth
    """
    X = np.matrix(x)
    m = X.size
    E = eye(m, format="csc")
    D = (
        E[1:] - E[:-1]
    )  # numpy.diff() does not work with sparse matrix. This is a workaround.
    W = diags(w, 0, shape=(m, m))
    A = csc_matrix(W + (lambda_ * D.T * D))
    B = csc_matrix(W * X.T)
    background = spsolve(A, B)
    return np.array(background)


def airPLS(x, lambda_=100, porder=1, itermax=15):
    """
    Adaptive iteratively reweighted penalized least squares for baseline fitting
        :param x: input data (i.e. chromatogram of spectrum)
        :param lambda_: parameter that can be adjusted by user. The larger lambda is,
                 the smoother the resulting background, z
        :param porder: adaptive iteratively reweighted penalized least squares for baseline fitting
        :param itermax: maximal amount of iterations
        :return: the fitted background vector
    """
    m = x.shape[0]
    w = np.ones(m)
    for i in range(1, itermax + 1):
        z = whittaker_smooth(x, w, lambda_, porder)
        d = x - z
        dssn = np.abs(d[d < 0].sum())
        if dssn < 0.001 * (abs(x)).sum() or i == itermax:
            if i == itermax:
                print("WARING max iteration reached!")
            break
        w[
            d >= 0
        ] = 0  # d>0 means that this point is part of a peak, so its weight is set to 0 in order to ignore it
        w[d < 0] = np.exp(i * np.abs(d[d < 0]) / dssn)
        w[0] = np.exp(i * (d[d < 0]).max() / dssn)
        w[-1] = w[0]
    return z
