import math
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import cvxpy as cp

def window_op(xaxis, region, filter = 'rbf', SD = 1):
    '''
    xaxis : the entire x axis range, e.g., [100, 3000]
    region : e.g., [100,200]
    filter : can be 'rbf', 'sinc', 'logistic', 'uniform'. Uniform is just averaging filter.
    SD : for rbf kernel, the region will lie inside +/-SD
    
    Return : op array. Has the length of xaxis.
    '''

    ''' TO BE FINISHED
    if filter == 'spike' or filter == 'vanilla':
        op = np.zeros(len(xaxis))
        op[region] = 1
    elif filter == 'uniform' or filter == 'rectangle' or filter == 'average':
        op = np.ones(len(xaxis)) / len(xaxis)
    elif filter == 'gaussian' or filter == 'rbf':
        op = ... 
    # todo: others

    op = np.zeros(len(xaxis))
    region_start = math.ceil(region[0])
    region_end = math.floor(region[1])
    
    if filter == 'spike' or filter == 'vanilla':
        op[(region[0]+region[1])/2] = 1
    elif filter == 'uniform' or filter == 'rectangle' or filter == 'average':
        if region[1]-region[0] != 0:
            op[start_int: end_int] = 1 / (region[1]-region[0])
        elif region[1]-region[0] == 0:
            op[region[0]] = 1
    elif filter == 'triangle':
        if region[1]-region[0] != 0:
            d = region[1]-region[0]
            h = 2/d
            k = h/(d/2)
            op = []
            for x in range(round(region[0]),math.ceil(region[1]+1)):
                if x > region[0] and x <=  d/2+region[0]:
                    op_value1 = k*x+(h-(d/2+region[0])*k)
                    op.append(op_value1)
                elif x > d/2+region[0] and x < region[1]:
                    op_value2 = -k*x+(h+(d/2+region[0])*k)
                    op.append(op_value2)
            op[]
        elif region[1]-region[0] == 0:
            op[(region[0]+region[1])/2] = 1
    elif filter == 'gaussian' or filter == 'rbf':
        if region[1]-region[0] != 0:
            x = np.arange(region[0],region[1]+1)
            mean = np.mean(region)
            std = np.std(np.array(region))
            op = np.random.normal(mean, std, x.shape[0])
            op /= sum(op)
            fs = np.dot(np.array(op),np.array(list(range(math.ceil(region[0]),math.floor(region[1]+1)))))
            Fs.append(fs)
        elif region[1]-region[0] == 0:
            op = 1
            fs = region[0]*op
            Fs.append(fs)

    return op
    '''
    pass


def window_fs(X, regions, filter = 'rbf'):
    '''
    Convert one data to binned features.
    Break down the axis as sections. Each seection is an integral of the signal intensities in the region.
    Integration can be done by radius basis function / sinc kernel, etc.

    filter : Apply a filter operator to a continuous region. Can be 'rbf', 'sinc', 'logistic', 'uniform'. Uniform is just averaging filter.
    '''

    Fss = []
    for x in X:

        Fs = [] # the discrete features for one data sample
        for region in regions:
            op = window_op([0, len(x)], region, filter)
            F = (op*x).sum()
            Fs.append(F)

        Fss.append(Fs)

    return np.array(Fss)

def group_lasso(X_scaled, y, WIDTH, offset = 0, LAMBDA = 1, ALPHA = 0.5):
    """
    Group Lasso Feature Selection

    Parameters
    ----------
    X_scaled : X, should be rescaled;
    y : target var;
    WIDTH : sliding window's width; 
    LAMBDA : regularization coefficient; 
    ALPHA : ratio of L1 vs Group;
    """

    assert(offset < WIDTH)

    # Problem data.
    m,n = X_scaled.shape
    X_scaled_e =  np.hstack((np.ones( (len(X_scaled),1 ) ) , X_scaled )) 

    # Construct the problem.
    theta = cp.Variable(n+1)

    group_loss = cp.norm(theta[1:][:offset]) # cp.norm(np.zeros(WIDTH))
    for i in range(offset, n, WIDTH):
        # +1 for skipping bias
        group_loss = group_loss + cp.norm(theta[1:][i:i+WIDTH]) # the features are already scaled. No need for group-sepecific weights

    group_loss = group_loss + cp.norm(theta[1:][i+WIDTH:])

    objective = cp.Minimize(cp.sum_squares(X_scaled_e @ theta - y) / 2 
                            + ALPHA * LAMBDA * cp.norm(theta[1:], 1) 
                            + (1-ALPHA)*LAMBDA * group_loss
                           )
    constraints = []
    prob = cp.Problem(objective, constraints)

    # The optimal objective value is returned by `prob.solve()`.
    result = prob.solve()
    # The optimal value for x is stored in `x.value`.
    
    THETA = theta.value[1:] # skip the bias/intercept  
    # plot_feature_importance(np.abs(THETA), 'All feature coefficiences')
    
    return THETA #, biggest_gl_fs, X_gl_fs

def group_lasso_cv(X_scaled, y, MAXF, WIDTHS, LAMBDAS, ALPHAS, cv_size = 0.2, verbose = False):
    '''
    Optimize hyper-parameters by grid search.

    Parameters
    ----------
    MAXF : max features to be selected. We compare each iteration's ACC with the same number of features. 
    WIDTHS : a list of window width / group size. 
    LAMBDAS : a list of lambdas (regularization).
    ALPHAS : a list of alphas.
    cv_size : cross validation set size. Default 20%.
    '''

    SCORES = []
    HPARAMS = [] # hyper-parameter values
    
    FSIS=[]
    THETAS = []

    pbar = tqdm(total=len(WIDTHS)*len(ALPHAS)*len(LAMBDAS)) # np.sum(WIDTHS)
    
    for w in WIDTHS:
        for offset in [int(w/2)]: # range(w)
            for alpha in ALPHAS:
                for lam in LAMBDAS:
                    
                    train_X,test_X, train_y, test_y = train_test_split(X_scaled, y,
                                                   test_size = cv_size, stratify=y)
                    
                    hparam = 'Window Size: ' + str(w) + ', offset = ' + str(offset) + ', alpha = ' + str(alpha) + ', lambda = ' + str(lam) 
                    HPARAMS.append(hparam)

                    if verbose:
                        print('=== ' + hparam + ' ===')
                    
                    THETA = group_lasso(train_X, train_y, 
                                    w, offset,
                                    LAMBDA = lam, ALPHA = alpha)
                    
                    biggest_gl_fs = (np.argsort(np.abs(THETA))[-MAXF:])[::-1]
                    # biggest_gl_fs = X_scaled[:,MAXF]

                    FSIS.append(list(biggest_gl_fs))
                    
                    if verbose:
                        print('Selected Feature Indices: ', biggest_gl_fs)

                    THETAS.append(THETA)
                                       
                    # No selected features
                    if (len(biggest_gl_fs) <= 0):
                        SCORES.append(0)
                    else:
                        reg = LinearRegression().fit(test_X[:,biggest_gl_fs], test_y)
                        score = reg.score(test_X[:,biggest_gl_fs], test_y)
                        SCORES.append(score)
                    
                    if verbose:
                        print('R2 = ', SCORES[-1])

                    pbar.update(1)
                    
    pbar.close()

    assert (len(set([len(HPARAMS), len(FSIS), len(THETAS), len(SCORES)])) == 1)
    return HPARAMS, FSIS, THETAS, SCORES

def select_features_from_group_lasso_cv(HPARAMS, FSIS, THETAS, SCORES, MAXF = 50, THRESH = 1.0):
    '''
    This is a further processing that selects MAXF most common important features.

    Parameters
    ----------
    HPARAMS, FSIS, THETAS, SCORES : returned by group_lasso_cv() 
    THRESH : coef_ abs minimum threshold 
    '''

    CAT_FS = []
    IDX = []
    FS_HPARAMS = []

    plt.figure(figsize = (16, math.ceil(MAXF/2)))

    idxx = 0
    for idx, score in enumerate(SCORES):
        # only keep whose score >= THRESH
        if (score >= THRESH):
            IDX.append(idx)
            CAT_FS += FSIS[idx]
            FS_HPARAMS.append(HPARAMS[idx])
            plt.plot(THETAS[idx] + idxx*0.1, label = str(HPARAMS[idx]))
            idxx += 1

    print('top-' + str(MAXF) + ' common features and their frequencies: ', Counter(CAT_FS).most_common(MAXF))

    plt.yticks([])
    if (idxx <= 10):
        plt.legend()
    plt.show()

    COMMON_FSI = []
    for f in Counter(CAT_FS).most_common(MAXF):
        COMMON_FSI.append(f[0])
               
    return np.array(COMMON_FSI)