import numpy as np
from scipy.interpolate import interp1d

def gmpe(index, M, rrup, SOF, Vs30):

    # period = np.array([0.001, 0.01, 0.02, 0.03, 0.05, 0.075, 0.1, 0.15, 0.2, 
                        # 0.25, 0.3, 0.4, 0.5, 0.75, 1, 1.5, 2, 3, 4, 5, 7.5, 10])
    # T = period[index]
    
    if Vs30 > 1200:
        Vs30 = 1200
    
    # ind_min = np.where(M < 5)  
    # ind_max = np.where(M > 7.5)
    # M_SE = M
    # M_SE[ind_min] = 5
    # M_SE[ind_max] = 7
    
    # if T <= 0.05:
    #     sigma = 1.18 + 0.035 * np.log(0.05) - 0.06 * M_SE
    # elif T >= 3:
    #     sigma = 1.18 + 0.035 * np.log(3) - 0.06 * M_SE
    # else:
    #     sigma = 1.18 + 0.035 * np.log(T) - 0.06 * M_SE
    
    sigma = 0 * M + 0.70
    
    sof = {'strike-slip':     0,
           'normal':          0,
           'normal-oblique':  0,
           'reverse':         1,
           'reverse-oblique': 1,
           'unspecified':     0}
    
    F = sof[SOF]
    
    Coef1 = np.array([[ 7.0887,  7.0887,  7.1157,  7.2087,  6.2638,  5.9051,  7.5791,  8.0190,  9.2812,  9.5804,  9.8912,  9.5342,  9.2142,  8.3517,  7.0453,  5.1307,  3.3610,  0.1784, -2.4301, -4.3570, -7.8275, -9.2857],
                      [ 0.2058,  0.2058,  0.2058,  0.2058,  0.0625,  0.1128,  0.0848,  0.1713,  0.1041,  0.0875,  0.0003,  0.0027,  0.0399,  0.0689,  0.1600,  0.2429,  0.3966,  0.7560,  0.9283,  1.1209,  1.4016,  1.5574],
                      [ 2.9935,  2.9935,  2.9935,  2.9935,  2.8664,  2.9406,  3.0190,  2.7871,  2.8611,  2.8289,  2.8423,  2.8300,  2.8560,  2.7544,  2.7339,  2.6800,  2.6837,  2.6907,  2.5782,  2.5468,  2.4478,  2.3922],
                      [-0.2287, -0.2287, -0.2287, -0.2287, -0.2418, -0.2513, -0.2516, -0.2236, -0.2229, -0.2200, -0.2284, -0.2318, -0.2337, -0.2392, -0.2398, -0.2417, -0.2450, -0.2389, -0.2514, -0.2541, -0.2593, -0.2586]])

    Coef2 = np.array([[ 9.0138,  9.0138,  9.0408,  9.1338,  7.9837,  7.7560,  9.4252,  9.6242, 11.1300, 11.3629, 11.7818, 11.6097, 11.4484, 10.9065,  9.8565,  8.3363,  6.8656,  4.1178,  1.8102,  0.0977, -3.0563, -4.4387],                             
                      [-0.0794, -0.0794, -0.0794, -0.0794, -0.1923, -0.1614, -0.1887, -0.0665, -0.1698, -0.1766, -0.2798, -0.3048, -0.2911, -0.3097, -0.2565, -0.2320, -0.1226,  0.1724,  0.3001,  0.4609,  0.6948,  0.8393],      
                      [ 2.9935,  2.9935,  2.9935,  2.9935,  2.7995,  2.8143,  2.8131,  2.4091,  2.4938,  2.3773,  2.3772,  2.3413,  2.3477,  2.2042,  2.1493,  2.0408,  2.0013,  1.9408,  1.7763,  1.7030,  1.5212,  1.4195],                       
                      [-0.2287, -0.2287, -0.2287, -0.2287, -0.2319, -0.2326, -0.2211, -0.1676, -0.1685, -0.1531, -0.1595, -0.1594, -0.1584, -0.1577, -0.1532, -0.1470, -0.1439, -0.1278, -0.1326, -0.1291, -0.1220, -0.1145]])
     
    a3    = np.array([ 0.0589,  0.0589,  0.0589,  0.0589,  0.0417,  0.0527,  0.0442,  0.0329,  0.0188,  0.0095, -0.0039, -0.0133, -0.0224, -0.0267, -0.0198, -0.0367, -0.0291, -0.0214, -0.0240, -0.0202, -0.0219, -0.0035])
    zeta  = np.array([-0.8540, -0.8540, -0.8540, -0.8540, -0.6310, -0.5910, -0.7570, -0.9110, -0.9980, -1.0420, -1.0300, -1.0190, -1.0230, -1.0560, -1.0090, -0.8980, -0.8510, -0.7610, -0.675 , -0.6290, -0.5310, -0.586])
    gamma = np.array([-0.0027, -0.0027, -0.0027, -0.0027, -0.0061, -0.0056, -0.0042, -0.0046, -0.0030, -0.0028, -0.0029, -0.0028, -0.0021, -0.0029, -0.0032, -0.0033, -0.0032, -0.0031, -0.0051, -0.0059, -0.0057, -0.0061])
    phi   = np.array([ 0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0800,  0.0600,  0.0400,  0.0200,  0.0200,  0     ,  0     ,  0     ,  0])
             
    aux1 = (M <= 6.75).astype(int)
    aux2 = ((M > 6.75) & (M <= 8.5)).astype(int)
    
    C1    = Coef1[:,index]
    C2    = Coef2[:,index]
    a1    = aux1 * C1[0] + aux2 * C2[0]
    a2    = aux1 * C1[1] + aux2 * C2[1]
    b1    = aux1 * C1[2] + aux2 * C2[2]
    b2    = aux1 * C1[3] + aux2 * C2[3]
    a3    = a3[index]
    zeta  = zeta[index]
    gamma = gamma[index]
    phi   = phi[index]
    
    lny   = a1 + a2 * M + a3 * (8.5 - M)**2 - (b1 + b2 * M) * np.log(rrup.T + 10) + zeta * np.log(Vs30) + gamma * rrup.T + phi * F
    lny   = np.squeeze(lny)

    return lny, sigma

def I2014_REV(To, M, Rrup, Rhyp, Ztor, Zhyp, Vs30):
    
    SOF = 'reverse'
    
    if To == 0:
        To = 0.001
        
    period = np.array([0.001, 0.01, 0.02, 0.03, 0.05, 0.075, 0.1, 0.15, 0.2, 
                        0.25, 0.3, 0.4, 0.5, 0.75, 1, 1.5, 2, 3, 4, 5, 7.5, 10])
    
    T_lo   = np.max(period[np.where(period <= To)[0]])
    T_hi   = np.min(period[np.where(period >= To)[0]])
    index  = np.where(np.abs(period - T_lo) < 1e-6)[0] #Identify the period
    M      = np.array(M)
    
    aux = np.where(M > 8.5)[0]
    if np.size(aux) != 0:
        M[aux] = np.NaN
    
    if T_lo == T_hi:
        lny, sigma = gmpe(index, M, Rrup, SOF, Vs30)
    else:
        lny_lo, sigma_lo = gmpe(index,     M, Rrup, SOF, Vs30)
        lny_hi, sigma_hi = gmpe(index + 1, M, Rrup, SOF, Vs30)
        x       = np.log(np.array([T_lo, T_hi]))
        Y_sa    = np.vstack([lny_lo, lny_hi])
        Y_sigma = np.vstack([sigma_lo, sigma_hi])
        
        n     = np.shape(Y_sa)[1]
        lny   = []
        sigma = []
        for i in range(n):
            lny.append(interp1d(x, Y_sa.T[i])(np.log(To)))
            sigma.append(interp1d(x, Y_sigma.T[i])(np.log(To)))
        lny   = np.array(lny)
        sigma = np.array(sigma)

    tau = np.full(np.size(M), np.NaN)
    phi = np.full(np.size(M), np.NaN)
    
    return lny, sigma, tau, phi