import numpy as np
from util import pyfghutil
import multiprocessing as mp
from scipy.fft import ifft

class FGHMatrixObj:
    def __init__(self, N, dtype):
        self.D = len(N)
        self.N = N
        self.Npts = np.prod(N)
        self.mat = np.zeros((self.Npts,self.Npts),dtype=dtype)

    def getValueByPt(self,pt1,pt2):
        return self.mat(pt1,pt2)

    def getValueByIdx(self,idx1,idx2):
        pt1 = pyfghutil.IndexToPoint(self.D, self.N, idx1)
        pt2 = pyfghutil.IndexToPoint(self.D, self.N, idx2)
        return self.mat(pt1,pt2)

    def setValueByPt(self,pt1,pt2,val):
        self.mat[pt1,pt2] = val
        return

    def setValueByIdx(self,idx1,idx2,val):
        pt1 = pyfghutil.IndexToPoint(self.D, self.N, idx1)
        pt2 = pyfghutil.IndexToPoint(self.D, self.N, idx2)
        self.mat[pt1,pt2] = val
        return



# A function to calculate the BMatrix
# B(j,l) = ((4*pi)/(L*N) * sum(p=1,n)(p*sin(2*pi*p*(l-j)/N))
# where n = (N-1)/2
# LValue = L = length of the dimension
# NValue = N = number of points in the dimension
# Each dimension has its own set of B matrices

def bmatrixgen(N, L):
    n = (N - 1) // 2
    B = np.zeros((N, N), dtype=float)
    a = np.zeros(N, dtype=complex)
    for i in range(N):
        a[i] = 2 * np.pi * (1j) * (i - n) / L
    aifft = ifft(a, n=N)
    for k in range(N):
        aifft[k] = aifft[k] * np.exp(-2 * np.pi * (1j) * n * k / N)
    for j in range(N):
        for t in range(N):
            B[j, t] = np.real(aifft[(N + j - t) % N])
    return B

# A function to calculate the CMatrix
# C(j,l) = ((-8*pi*pi)/(L*L*N) * sum(p=1,n)(p*p*cos(2*pi*p*(l-j)/N))
# where n = (N-1)/2
# LValue = L = length of the dimension
# NValue = N = number of points in the dimension
# Each dimension has its own set of C matrices

def cmatrixgen(N, L):
    n = (N - 1) // 2
    C = np.zeros((N, N), dtype=float)
    a = np.zeros(N, dtype=complex)
    for i in range(N):
        a[i] = -4 * np.pi * np.pi * (i - n) * (i - n) / (L * L)
    aifft = ifft(a, n=N)
    for k in range(N):
        aifft[k] = aifft[k] * np.exp(-2 * np.pi * (1j) * n * k / N)
    for j in range(N):
        for t in range(N):
            C[j, t] = np.real(aifft[(N + j - t) % N])
    return C

# A function to calculate the individual values for the TMatrix
# Right now, only "approximation 2" is tested.  The other approximations will be for a
# future release.

def Tab(d, NValue, LValue, mu, c_matrix_insert, dimensionCounterArray, approximation, b_matrix_insert, GMat):
    # Deltacounter is used to makes sure that the value being calculated is in the diagonal of the matrix
    Deltacounter = 0

    # Total is return value
    def delta(x, y):
        if (x == y):
            return float(1)
        else:
            return float(0)

    total = 0.0
    if (approximation == 4):
        for T in range(d):
            # Check if the T counter is not equal to the index of the C value being tested, and if so it checks for if the dimension's corrosponding x and y values equal each other.
            Deltacounter = 0
            for Ccounter in range(d):
                if (Ccounter != T):
                    if (dimensionCounterArray[Ccounter * 2] == dimensionCounterArray[(Ccounter * 2) + 1]):
                        Deltacounter += 1
                else:
                    pass
            # If deltacounter equals the dimensions - 1, add the formula to the total for that C value
            if (Deltacounter == (d - 1)):
                # print(dimensionCounterArray)
                try:
                    total += (-1.0 / 2.0) * (GMat[T][T]) * (
                    c_matrix_insert[T][dimensionCounterArray[(T * 2) + 1], dimensionCounterArray[T * 2]])
                except:
                    print("Trying to access: " + str(dimensionCounterArray[(T * 2) + 1]) + ", " + str(
                        dimensionCounterArray[T * 2]))
                    print(c_matrix_insert[T])
                    print("TAB error")
    elif (approximation == 3):
        # Perform all of the "C" calcuations first
        for C in range(d):
            Deltacounter = 0
            for Ccounter in range(d):
                if (Ccounter != C):
                    if (dimensionCounterArray[Ccounter * 2] == dimensionCounterArray[(Ccounter * 2) + 1]):
                        Deltacounter += 1
            if (Deltacounter == d - 1):
                total += float((GMat[C][C])) * (
                c_matrix_insert[C][dimensionCounterArray[(C * 2) + 1], dimensionCounterArray[C * 2]])
                # Perform all of the "B" calculations second
        for B in range(d):
            if (dimensionCounterArray[B * 2] == dimensionCounterArray[(B * 2) + 1]):
                temptotal = 1.0
                for BSecond in range(d):
                    if (BSecond != B):
                        temptotal *= b_matrix_insert[BSecond][
                            dimensionCounterArray[(BSecond * 2) + 1], dimensionCounterArray[BSecond * 2]]
                # This will be the first set of numbers and then the flipped ones
                GRange = [*range(d)]
                GRange.remove(B)
                Gx = min(GRange)
                Gy = max(GRange)
                total += float((GMat[Gx][Gy])) * temptotal * 2
        # Outside of summation multiplication of -hbar^2/2
        total *= (-1.0 * 1.0 ** 2) / (2.0)
    elif (approximation == 2):
        t = int(dimensionCounterArray[0])
        j = int(dimensionCounterArray[1])
        u = int(dimensionCounterArray[2])
        k = int(dimensionCounterArray[3])
        v = int(dimensionCounterArray[4])
        l = int(dimensionCounterArray[5])
        # BMatrix calls are backwards
        # So BMatrix[0] would be B1, but it's actually B3
        sums = 0.0
        for p in range(NValue[0]):
            sums += b_matrix_insert[2][j, p] * b_matrix_insert[2][p, t] * GMat[p][k][l][0][0]
        total += -0.5 * sums * delta(k, u) * delta(l, v)

        sums = 0.0
        for p in range(NValue[1]):
            sums += b_matrix_insert[1][k, p] * b_matrix_insert[1][p, u] * GMat[j][p][l][1][1]
        total += -0.5 * sums * delta(j, t) * delta(l, v)

        sums = 0.0
        for p in range(NValue[2]):
            sums += b_matrix_insert[0][l, p] * b_matrix_insert[0][p, v] * GMat[j][k][p][2][2]
        total += -0.5 * sums * delta(j, t) * delta(k, u)

        total += -0.5 * (b_matrix_insert[2][j, t] * b_matrix_insert[1][k, u] * GMat[t][k][l][0][1]) * delta(v, l)
        total += -0.5 * (b_matrix_insert[2][j, t] * b_matrix_insert[0][l, v] * GMat[t][k][l][0][2]) * delta(k, u)
        total += -0.5 * (b_matrix_insert[1][k, u] * b_matrix_insert[2][j, t] * GMat[j][u][l][1][0]) * delta(v, l)
        total += -0.5 * (b_matrix_insert[1][k, u] * b_matrix_insert[0][l, v] * GMat[j][u][l][1][2]) * delta(j, t)
        total += -0.5 * (b_matrix_insert[0][l, v] * b_matrix_insert[2][j, t] * GMat[j][k][v][2][0]) * delta(k, u)
        total += -0.5 * (b_matrix_insert[0][l, v] * b_matrix_insert[1][k, u] * GMat[j][k][v][2][1]) * delta(j, t)



    else:
        print("This current approximation is incorrect or not supported: " + str(approximation))
        exit()

    return (total)


# A function that splits the Tmatrix calculation into blocks to be calculated in parallel.
# Each block uses the Tab function above to calculate individual matrix elements.

def TBlockCalc(dimensions, NValue, LValue, mu, c_matrix, approximation, blockX, blockY, b_matrix, gmatrix):
    # Blocks will be 0 index
    blockHolder = np.zeros((NValue[0], NValue[0]), float)
    # The 0Start variables will always be 0 at the beginning to act as loop variables that correspond to the blockHolder size
    alpha0start = 0
    beta0start = 0
    for alpha in range(0 + NValue[0] * blockX, NValue[0] + NValue[0] * blockX):
        for beta in range(0 + NValue[0] * blockY, NValue[0] + NValue[0] * blockY):
            counter = pyfghutil.AlphaAndBetaToCounter(alpha, beta, dimensions, NValue)
            if (approximation > 2):
                counter1 = int(counter[0])
                counter2 = int(counter[2])
                counter3 = int(counter[4])
                blockHolder[alpha0start, beta0start] = Tab(dimensions, NValue, LValue, mu, c_matrix, counter,
                                                           approximation, b_matrix,
                                                           gmatrix[counter1][counter2][counter3])
            else:
                blockHolder[alpha0start, beta0start] = Tab(dimensions, NValue, LValue, mu, c_matrix, counter,
                                                           approximation, b_matrix, gmatrix)
            beta0start += 1
        alpha0start += 1
        beta0start = 0
    return blockHolder


# The function to calculate a TMatrix using the dataObject class from input
def TMatrixCalc(dataObject, GMatrix):
    # Establish variables needed
    NValue = []
    LValue = []
    if (int(dataObject.N1) > 0):
        NValue.append(int(dataObject.N1))
        LValue.append(float(dataObject.L1))
    if (int(dataObject.N2) > 0):
        NValue.append(int(dataObject.N2))
        LValue.append(float(dataObject.L2))
    if (int(dataObject.N3) > 0):
        NValue.append(int(dataObject.N3))
        LValue.append(float(dataObject.L3))
    D = len(NValue)
    pes = dataObject.PES
    dimensionCounterArray = np.zeros(D * 2, int)
    mu = []
    Tapprox = 2

    # Create the TMatrix and the TFlagMatrix
    # The alpha and beta values are used to create the TMatrix in the correct position
#    tmatrix = np.zeros((np.prod(NValue), np.prod(NValue)), float)
    tmatrix = FGHMatrixObj(NValue, float)
    tflag = np.zeros((np.prod(NValue), np.prod(NValue)), int)
    alpha = 0
    beta = 0

    # Create the C_Matrix
    c_matrix = []
    for x in reversed(range(len(NValue))):
        c_matrix.append(cmatrixgen(NValue[x], LValue[x]))

    # Create the B_Matrix if necessary
    if (Tapprox < 4):
        b_matrix = []
        for x in reversed(range(len(NValue))):
            b_matrix.append(bmatrixgen(NValue[x], LValue[x]))
    else:
        b_matrix = None

    blockCoords = []
    blocks = []
    paramz = []
    totalwidth = int(np.prod(NValue))
    repeatamount = int(totalwidth // NValue[0])
    for x in range(repeatamount):
        for y in range(repeatamount):
            blockCoords.append((x, y))
    if (Tapprox == 4):
        for coords in blockCoords:
            paramz.append((D, NValue, LValue, mu, c_matrix, Tapprox, coords[0], coords[1], None, GMatrix))
    elif (Tapprox < 4):
        for coords in blockCoords:
            paramz.append((D, NValue, LValue, mu, c_matrix, Tapprox, coords[0], coords[1], b_matrix, GMatrix))
    else:
        # Occurs when invalid or unsupported T Approximation is used
        pass
    # Pool and run
    p = mp.Pool(dataObject.cores_amount)
    #    print("Pool go T")
    blocks = p.starmap(TBlockCalc, paramz)
    #    print("Pool's done T")
    p.close()

    precalc = 0
    for i in range(len(blockCoords)):
        block = blocks[i]
        x = blockCoords[i][0]
        y = blockCoords[i][1]

        for a in range(0 + NValue[precalc] * x,(NValue[precalc] + NValue[precalc] * x)):
            for b in range(0 + NValue[precalc] * y),(NValue[precalc] + NValue[precalc] * y):
                tmatrix.setValueByPt(a,b,block[a,b])

#        tmatrix[(0 + NValue[precalc] * x):(NValue[precalc] + NValue[precalc] * x),
#        (0 + NValue[precalc] * y):(NValue[precalc] + NValue[precalc] * y)] = block

    return tmatrix




