"""
Numba version of bias-corrected and accelerated (BCa) bootstrap.
"""
import numpy as np
from numba import njit, boolean
import warnings
try:
    from scipy.special import ndtri, ndtr
except:
    warnings.warn('scipy not available. Numba BCa bootstrap will not work.', RuntimeWarning)
    
@njit
def resample_paired_nb(X, Y, func, output_len=1, R=int(1e6), seed=0):
    np.random.seed(seed)
    N = X.size    
    data_paired = np.vstack((X, Y)).T
    idxs_resampling = np.random.randint(low=0, high=N, size=R*N)
    data_resampled = data_paired[idxs_resampling].reshape(R, N, 2)    
    stat = func(X, Y)
    
    boot_sample = np.empty((R, output_len))
    for i, r in enumerate(data_resampled):
        x, y = r.T
        boot_sample[i] = func(x, y)
    return boot_sample

@njit
def resample_nb(X, func, output_len=1, R=int(1e6), seed=0):
    """X: array of shape (N_samples, n_vars)."""
    np.random.seed(seed)
    N = X.shape[0]
    idxs_resampling = np.random.randint(low=0, high=N, size=R*N)
    data_resampled = X[idxs_resampling].reshape(R, N, X.shape[1])    
    stat = func(X)
    
    boot_sample = np.empty((R, output_len))
    for i, r in enumerate(data_resampled):
        boot_sample[i] = func(r)
    return boot_sample

def resample(X, func, output_len=1, R=int(1e4), seed=0):
    """X: array of shape (N_samples, n_vars)."""
    np.random.seed(seed)
    N = X.shape[0]
    idxs_resampling = np.random.randint(low=0, high=N, size=R*N)
    data_resampled = X[idxs_resampling].reshape(R, N, X.shape[1])    
    stat = func(X)
    
    boot_sample = np.empty((R, output_len))
    for i, r in enumerate(data_resampled):
        boot_sample[i] = func(r)
    return boot_sample
    
@njit
def jackknife_resampling(data):
    """Performs jackknife resampling on numpy arrays.

    Jackknife resampling is a technique to generate 'n' deterministic samples
    of size 'n-1' from a measured sample of size 'n'. The i-th
    sample  is generated by removing the i-th measurement
    of the original sample.
    """
    n = data.shape[0]
    if data.ndim > 1:
        resamples = np.empty((n, n - 1) + data.shape[1:])
        base_mask = np.ones((n), dtype=boolean)
        for i in range(n): # np.delete does not support 'axis' argument in numba.
            mask_i = base_mask.copy()
            mask_i[i] = False
            resamples[i] = data[mask_i]
    else:
        resamples = np.empty((n, n - 1))
        for i in range(n):
            resamples[i] = np.delete(data, i)
    return resamples

@njit
def jackknife_stat_nb(data, statistic):
    resamples = jackknife_resampling(data)
    stats = np.array([statistic(r) for r in resamples])
    return stats

def jackknife_stat_(data, statistic):
    resamples = jackknife_resampling(data)
    stats = np.array([statistic(r) for r in resamples])
    return stats

def _percentile_of_score(a, score, axis, account_equal=False):
    """Vectorized, simplified `scipy.stats.percentileofscore`.

    Unlike `stats.percentileofscore`, the percentile returned is a fraction
    in [0, 1].
    """
    B = a.shape[axis]
    if account_equal:
        return ((a < score).sum(axis=axis) + (a <= score).sum(axis=axis)) / (2 * B)
    else:
        return (a < score).sum(axis=axis) / B
    
def CI_bca(data, statistic, alternative='two-sided', alpha=0.05, R=int(2e5), account_equal=False, use_numba=True, **kwargs):
    if alternative == 'two-sided':
        probs = np.array([alpha/2, 1 - alpha/2])
    elif alternative == 'less':
        probs = np.array([0, 1-alpha])
    elif alternative == 'greater':
        probs = np.array([alpha, 1])
    else:
        raise ValueError(f"alternative '{alternative}' not valid. Available: 'two-sided', 'less', 'greater'.")
        
    resample_func = resample_nb if use_numba else resample
    theta_hat_b = resample_func(data[:,None] if data.ndim == 1 else data,
                                statistic, R=R, **kwargs).squeeze()
    alpha_bca = _bca_interval(data, statistic, probs, theta_hat_b, account_equal, use_numba)[0]
    
    if alternative == 'two-sided':
        return  np.percentile(theta_hat_b, alpha_bca*100, axis=0)
    elif alternative == 'less':
         return np.array([-np.inf, np.percentile(theta_hat_b, alpha_bca[0]*100, axis=0)])
    elif alternative == 'greater':
        return np.array([np.percentile(theta_hat_b, alpha_bca[0]*100, axis=0), np.inf])
    
def _bca_interval(data, statistic, probs, theta_hat_b, account_equal, use_numba):
    """Bias-corrected and accelerated interval."""
    # calculate z0_hat
    theta_hat = statistic(data)
    percentile = _percentile_of_score(theta_hat_b, theta_hat, axis=-1, account_equal=account_equal)
    z0_hat = ndtri(percentile)

    # calculate a_hat
    jackknife_computer = jackknife_stat_nb if use_numba else jackknife_stat
    theta_hat_jk = jackknife_computer(data, statistic)  # jackknife resample
    n = theta_hat_jk.shape[0]
    theta_hat_jk_dot = theta_hat_jk.mean(axis=0) 

    U = (n - 1) * (theta_hat_jk_dot - theta_hat_jk)
    num = (U**3).sum(axis=0) / n**3
    den = (U**2).sum(axis=0) / n**2
    a_hat = 1/6 * num / (den**(3/2))

    # calculate alpha_1, alpha_2
    def compute_alpha(p):
        z_alpha = ndtri(p)
        num = z0_hat + z_alpha
        return ndtr(z0_hat + num/(1 - a_hat*num))
    alpha_bca = np.asarray(compute_alpha(probs[(probs != 0) & (probs != 1)]))
    if (alpha_bca > 1).any() or (alpha_bca < 0):
        warnings.warn('percentiles must be in [0, 1]. bca percentiles: {}\nForcing percentiles in [0,1]...'.format(alpha_bca), RuntimeWarning)
        alpha_bca = np.clip(alpha_bca, 0, 1)
    if alpha_bca.size == 1: # single tail interval
        alpha_bca = alpha_bca[0]
    return alpha_bca, a_hat  # return a_hat for testing