import pdb
import numpy as np
import time

from .util_ import procrustes, issparse, sparse_matrix, nearest_neighbors

import scipy
from scipy.sparse import linalg as slinalg
from scipy.sparse import csr_matrix, block_diag, vstack
from sklearn.neighbors import NearestNeighbors
from scipy.spatial.distance import pdist, squareform

import multiprocess as mp
from multiprocess import shared_memory

# Computes Z_s for the case when to_tear is True.
# Input Z_s is the Z_s for the case when to_tear is False.
# Output Z_s is a subset of input Z_s.
def compute_Z_s_to_tear(y, s, Z_s, C, c, k):
    n_Z_s = Z_s.shape[0]
    # C_s_U_C_Z_s = (self.C[s,:]) | np.isin(self.c, Z_s)
    C_s_U_C_Z_s = np.where(C[s,:] + C[Z_s,:].sum(axis=0))[1]
    n_ = C_s_U_C_Z_s.shape[0]
    k_ = min(k,n_-1)
    _, neigh_ind_ = nearest_neighbors(y[C_s_U_C_Z_s,:], k_, 'euclidean')
    U_ = sparse_matrix(neigh_ind_, np.ones(neigh_ind_.shape, dtype=bool))
    Utilde_ = C[np.ix_(Z_s,C_s_U_C_Z_s)].dot(U_)
    Utilde_ = vstack([Utilde_, C[s,C_s_U_C_Z_s].dot(U_)])
    n_Utildeg_Utilde_ = Utilde_.dot(Utilde_.T) 
    n_Utildeg_Utilde_.setdiag(False)
    return Z_s[n_Utildeg_Utilde_[-1,:-1].nonzero()[1]].tolist()

def procrustes_init(seq, rho, y, is_visited_view, d, Utilde, n_Utilde_Utilde,
                    C, c, intermed_param, global_opts, print_freq=1000):   
    n = Utilde.shape[1]
    # Traverse views from 2nd view
    for m in range(1,seq.shape[0]):
        if print_freq and np.mod(m, print_freq)==0:
            print('Initial alignment of %d views completed' % m, flush=True)
        s = seq[m]
        # pth view is the parent of sth view
        p = rho[s]
        Utilde_s = Utilde[s,:]

        
        # If to tear apart closed manifolds
        if global_opts['to_tear']:
            if global_opts['init_algo_align_w_parent_only']:
                Z_s = [p]
            else:
                # Compute T_s and v_s by aligning
                # the embedding of the overlap Utilde_{sp}
                # due to sth view with that of the pth view
                Utilde_s_p = Utilde_s.multiply(Utilde[p,:]).nonzero()[1]
                V_s_p = intermed_param.eval_({'view_index': s, 'data_mask': Utilde_s_p})
                V_p_s = intermed_param.eval_({'view_index': p, 'data_mask': Utilde_s_p})
                intermed_param.T[s,:,:], intermed_param.v[s,:] = procrustes(V_s_p, V_p_s)
                
                # Compute temporary global embedding of point in sth cluster
                C_s = C[s,:].indices
                y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})
                # Find more views to align sth view with
                Z_s = n_Utilde_Utilde[s,:].multiply(is_visited_view)
                Z_s_all = Z_s.nonzero()[1]
                Z_s = compute_Z_s_to_tear(y, s, Z_s_all, C, c, global_opts['k'])
        # otherwise
        else:
            # Align sth view with all the views which have
            # an overlap with sth view in the ambient space
            Z_s = n_Utilde_Utilde[s,:].multiply(is_visited_view)
            Z_s = Z_s.nonzero()[1].tolist()
            # If for some reason Z_s is empty
            if len(Z_s)==0:
                Z_s = [p]
                
        # Compute centroid mu_s
        # n_Utilde_s_Z_s[k] = #views in Z_s which contain
        # kth point if kth point is in the sth view, else zero
        n_Utilde_s_Z_s = np.zeros(n, dtype=int)
        mu_s = np.zeros((n,d))
        cov_s = csr_matrix((1,n), dtype=bool)
        for mp in Z_s:
            Utilde_s_mp = Utilde_s.multiply(Utilde[mp,:]).nonzero()[1]    
            n_Utilde_s_Z_s[Utilde_s_mp] += 1
            mu_s[Utilde_s_mp,:] += intermed_param.eval_({'view_index': mp,
                                                         'data_mask': Utilde_s_mp})

        # Compute T_s and v_s by aligning the embedding of the overlap
        # between sth view and the views in Z_s, with the centroid mu_s
        temp = n_Utilde_s_Z_s > 0
        mu_s = mu_s[temp,:] / n_Utilde_s_Z_s[temp,np.newaxis]
        V_s_Z_s = intermed_param.eval_({'view_index': s, 'data_mask': temp})

        T_s, v_s = procrustes(V_s_Z_s, mu_s)

        # Update T_s, v_
        intermed_param.T[s,:,:] = np.matmul(intermed_param.T[s,:,:], T_s)
        intermed_param.v[s,:] = np.matmul(intermed_param.v[s,:][np.newaxis,:], T_s) + v_s

        # Mark sth view as visited
        is_visited_view[s] = True

        # Compute global embedding of point in sth cluster
        C_s = C[s,:].indices
        y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})
    return y, is_visited_view
    
# Ngoc-Diep Ho, Paul Van Dooren, On the pseudo-inverse of the Laplacian of a bipartite graph
def compute_Lpinv_BT(Utilde, B):
    M, n = Utilde.shape
    B_ = Utilde.copy().transpose().astype('int')
    D_1 = np.asarray(B_.sum(axis=1))
    D_2 = np.asarray(B_.sum(axis=0))
    D_1_inv_sqrt = np.sqrt(1/D_1)
    D_2_inv_sqrt = np.sqrt(1/D_2)
    B_tilde = B_.multiply(D_2_inv_sqrt).multiply(D_1_inv_sqrt)
    # TODO: U12 is dense of size nxM
    U12,SS,VT = scipy.linalg.svd(B_tilde.todense(), full_matrices=False)
    # U12,SS,VT = slinalg.svds(B_tilde, k=M, solver='propack')
    V = VT.T
    mask = np.abs(SS-1)<1e-6
    m_1 = np.sum(mask)
    Sigma = np.expand_dims(SS[m_1:], 1)
    Sigma_1 = 1/(1-Sigma**2)
    Sigma_2 = Sigma*Sigma_1
    U1 = U12[:,:m_1]
    U2 = U12[:,m_1:]
    V1 = V[:,:m_1]
    V2 = V[:,m_1:]
    
    B_n = B - B.mean(axis=1)
    B_n = np.asarray(B_n)
    B1T = D_1_inv_sqrt * (B_n[:, :n].T)
    B2T = D_2_inv_sqrt.T * (B_n[:, n:].T)
    
    U1TB1T = np.matmul(U1.T, B1T)
    U2TB1T = np.matmul(U2.T, B1T)
    V1TB2T = np.matmul(V1.T, B2T)
    V2TB2T = np.matmul(V2.T, B2T)
    
    temp1 = -0.75*np.matmul(U1,U1TB1T)-0.25*np.matmul(U1,V1TB2T) +\
            np.matmul(U2, ((Sigma_1-1))*(U2TB1T)) + np.matmul(U2, Sigma_2*(V2TB2T)) + B1T
    temp1 = temp1 * D_1_inv_sqrt
    
    temp2 = -0.25*np.matmul(V1, U1TB1T) + 0.25*np.matmul(V1,V1TB2T) +\
            np.matmul(V2, Sigma_2*(U2TB1T)) + np.matmul(V2, Sigma_1*(V2TB2T))
    temp2 = temp2 * D_2_inv_sqrt.T 
    
    temp = np.concatenate((temp1, temp2), axis=0)
    temp = temp - np.mean(temp, axis=0, keepdims=True)
    return temp

def compute_CC(D, B, Lpinv_BT):
    CC = D - B.dot(Lpinv_BT)
    return 0.5*(CC + CC.T)

def build_ortho_optim(d, Utilde, intermed_param):
    M,n = Utilde.shape
    B_row_inds = []
    B_col_inds = []
    B_vals = []
    D = []
    for i in range(M):
        X_ = intermed_param.eval_({'view_index': i,
                                   'data_mask': Utilde[i,:].indices})
        D.append(np.matmul(X_.T,X_))
        row_inds = list(range(d*i,d*(i+1)))
        col_inds = Utilde[i,:].indices.tolist()
        B_row_inds += (row_inds + np.repeat(row_inds, len(col_inds)).tolist())
        B_col_inds += (np.repeat([n+i], d).tolist() + np.tile(col_inds, d).tolist())
        B_vals += (np.sum(-X_.T, axis=1).tolist() + X_.T.flatten().tolist())
    
    D = block_diag(D, format='csr')
    B = csr_matrix((B_vals, (B_row_inds, B_col_inds)), shape=(M*d,n+M))

    print('Computing Pseudoinverse of a matrix of L of size', n+M, 'multiplied with B', flush=True)
    Lpinv_BT = compute_Lpinv_BT(Utilde, B)

    CC = compute_CC(D, B, Lpinv_BT)
    return CC, Lpinv_BT
    

def compute_alignment_err(d, Utilde, intermed_param):
    CC, Lpinv_BT = build_ortho_optim(d, Utilde, intermed_param)
    M,n = Utilde.shape
    CC_mask = np.tile(np.eye(d, dtype=bool), (M,M))
    err = np.sum(CC[CC_mask])
    return err

# Kunal N Chaudhury, Yuehaw Khoo, and Amit Singer, Global registration
# of multiple point clouds using semidefinite programming
def spectral_alignment(y, is_visited_view, d, Utilde,
                      C, intermed_param, global_opts, 
                      seq_of_intermed_views_in_cluster):
    CC, Lpinv_BT = build_ortho_optim(d, Utilde, intermed_param)
    M,n = Utilde.shape
    print('Computing eigh(C,k=d)', flush=True)
    np.random.seed(42)
    v0 = np.random.uniform(0,1,CC.shape[0])
    # To find smallest eigenvalues, using shift-inverted algo with mode=normal and which='LM'
    #W_,V_ = scipy.sparse.linalg.eigsh(CC, k=d, v0=v0, sigma=0.0)
    # or just pass which='SM' without using sigma
    W_,V_ = scipy.sparse.linalg.eigsh(CC, k=d, v0=v0, which='SM')
    print('Done.', flush=True)
    Wstar = np.sqrt(M)*V_.T
    
    Tstar = np.zeros((d, M*d))
    for i in range(M):
        U_,S_,VT_ = scipy.linalg.svd(Wstar[:,d*i:d*(i+1)])
        temp_ = np.matmul(U_,VT_)
        if (global_opts['init_algo_name'] != 'spectral') and (np.linalg.det(temp_) < 0): # remove reflection
            VT_[-1,:] *= -1
            Tstar[:,i*d:(i+1)*d] = np.matmul(U_, VT_)
        else:
            Tstar[:,i*d:(i+1)*d] = temp_
    
    Zstar = Tstar.dot(Lpinv_BT.transpose())
    
    n_clusters = len(seq_of_intermed_views_in_cluster)
    for i in range(n_clusters):
        seq = seq_of_intermed_views_in_cluster[i]
        s0 = seq[0]
        T0T = Tstar[:,s0*d:(s0+1)*d]
        v0 = Zstar[:,n+s0][np.newaxis,:]
        v0TOT = np.matmul(v0, T0T)
        is_visited_view[s0] = 1
        for m in range(1, seq.shape[0]):
            s = seq[m]
            T_s = np.matmul(Tstar[:,s*d:(s+1)*d].T, T0T)
            v_s = np.matmul(Zstar[:,n+s][np.newaxis,:], T0T) - v0TOT
            #T_s = Tstar[:,s*d:(s+1)*d].T
            #v_s = Zstar[:,n+s][np.newaxis,:]
            intermed_param.T[s,:,:] = np.matmul(intermed_param.T[s,:,:], T_s)
            intermed_param.v[s,:] = np.matmul(intermed_param.v[s,:], T_s) + v_s
            C_s = C[s,:].indices
            y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})
            is_visited_view[s] = 1
    
    return y, Zstar[:,:n].T, is_visited_view

def procrustes_final(y, d, Utilde, C, intermed_param, n_Utilde_Utilde, n_Utildeg_Utildeg,
                     seq_of_intermed_views_in_cluster, parents_of_intermed_views_in_cluster,
                     cluster_of_intermed_view, global_opts):
    M,n = Utilde.shape
    # Traverse over intermediate views in a random order
    seq = np.random.permutation(M)
    is_first_view_in_cluster = np.zeros(M, dtype=bool)
    for i in range(len(seq_of_intermed_views_in_cluster)):
        is_first_view_in_cluster[seq_of_intermed_views_in_cluster[i][0]] = True

    # For a given seq, refine the global embedding
    for it1 in range(global_opts['refine_algo_max_internal_iter']):
        for s in seq.tolist():
            # Never refine s_0th intermediate view
            if is_first_view_in_cluster[s]:
                C_s = C[s,:].indices
                y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})
                continue

            Utilde_s = Utilde[s,:]

            # If to tear apart closed manifolds
            if global_opts['to_tear']:
                # Find more views to align sth view with
                Z_s = n_Utilde_Utilde[s,:].multiply(n_Utildeg_Utildeg[s,:])
            # otherwise
            else:
                # Align sth view with all the views which have
                # an overlap with sth view in the ambient space
                Z_s = n_Utilde_Utilde[s,:]

            Z_s = Z_s.nonzero()[1].tolist()

            if len(Z_s) == 0:
                Z_s = parents_of_intermed_views_in_cluster[cluster_of_intermed_view[s]][s]
                Z_s = [Z_s]

            # Compute centroid mu_s
            # n_Utilde_s_Z_s[k] = #views in Z_s which contain
            # kth point if kth point is in the sth view, else zero
            n_Utilde_s_Z_s = np.zeros(n, dtype=int)
            mu_s = np.zeros((n,d))
            for mp in Z_s:
                Utilde_s_mp = Utilde_s.multiply(Utilde[mp,:]).nonzero()[1]
                n_Utilde_s_Z_s[Utilde_s_mp] += 1
                mu_s[Utilde_s_mp,:] += intermed_param.eval_({'view_index': mp, 'data_mask': Utilde_s_mp})

            temp = n_Utilde_s_Z_s > 0
            mu_s = mu_s[temp,:] / n_Utilde_s_Z_s[temp,np.newaxis]

            # Compute T_s and v_s by aligning the embedding of the overlap
            # between sth view and the views in Z_s, with the centroid mu_s
            V_s_Z_s = intermed_param.eval_({'view_index': s, 'data_mask': temp})
            
            T_s, v_s = procrustes(V_s_Z_s, mu_s)

            # Update T_s, v_s
            intermed_param.T[s,:,:] = np.matmul(intermed_param.T[s,:,:], T_s)
            intermed_param.v[s,:] = np.matmul(intermed_param.v[s,:][np.newaxis,:], T_s) + v_s

            # Compute global embedding of points in sth cluster
            C_s = C[s,:].indices
            y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})
    return y

def rgd_final(y, d, Utilde, C, intermed_param,
             n_Utilde_Utilde, n_Utildeg_Utildeg,
             parents_of_intermed_views_in_cluster,
             cluster_of_intermed_view,
             global_opts):
    CC, Lpinv_BT = build_ortho_optim(d, Utilde, intermed_param)
    M,n = Utilde.shape
    n_proc = min(M,global_opts['n_proc'])
    barrier = mp.Barrier(n_proc)

    def update(alpha, max_iter, shm_name_O, O_shape, O_dtype,
               shm_name_CC, CC_shape, CC_dtype, barrier):
        ###########################################
        # Parallel Updates
        ###########################################
        def target_proc(p_num, chunk_sz, barrier):
            existing_shm_O = shared_memory.SharedMemory(name=shm_name_O)
            O = np.ndarray(O_shape, dtype=O_dtype, buffer=existing_shm_O.buf)
            existing_shm_CC = shared_memory.SharedMemory(name=shm_name_CC)
            CC = np.ndarray(CC_shape, dtype=CC_dtype, buffer=existing_shm_CC.buf)

            def unique_qr(A):
                Q, R = np.linalg.qr(A)
                signs = 2 * (np.diag(R) >= 0) - 1
                Q = Q * signs[np.newaxis, :]
                R = R * signs[:, np.newaxis]
                return Q, R
            
            start_ind = p_num*chunk_sz
            if p_num == (n_proc-1):
                end_ind = M
            else:
                end_ind = (p_num+1)*chunk_sz
            for _ in range(max_iter):
                for i in range(start_ind, end_ind):
                    xi_ = 2*np.matmul(O, CC[:,i*d:(i+1)*d])
                    temp0 = O[:,i*d:(i+1)*d]
                    temp1 = np.matmul(xi_,temp0.T)
                    skew_temp1 = 0.5*(temp1-temp1.T)
                    Q_,R_ = unique_qr(temp0 - alpha*np.matmul(skew_temp1,temp0))
                    O[:,i*d:(i+1)*d] = Q_
                barrier.wait()
            
            existing_shm_O.close()
            existing_shm_CC.close()
        
        
        proc = []
        chunk_sz = int(M/n_proc)
        for p_num in range(n_proc):
            proc.append(mp.Process(target=target_proc,
                                   args=(p_num,chunk_sz, barrier),
                                   daemon=True))
            proc[-1].start()

        for p_num in range(n_proc):
            proc[p_num].join()
        ###########################################
        
        # Sequential version of above
        # for i in range(M):
        #     temp0 = O[:,i*d:(i+1)*d]
        #     temp1 = skew(np.matmul(xi[:,i*d:(i+1)*d],temp0.T))
        #     Q_,R_ = unique_qr(temp0 - t*np.matmul(temp1,temp0))
        #     O[:,i*d:(i+1)*d] = Q_

    alpha = global_opts['refine_algo_alpha']
    max_iter = global_opts['refine_algo_max_internal_iter']
    Tstar = np.zeros((d,M*d))
    for s in range(M):
        Tstar[:,s*d:(s+1)*d] = np.eye(d)
    
    print('Descent starts', flush=True)
    shm_Tstar = shared_memory.SharedMemory(create=True, size=Tstar.nbytes)
    np_Tstar = np.ndarray(Tstar.shape, dtype=Tstar.dtype, buffer=shm_Tstar.buf)
    np_Tstar[:] = Tstar[:]
    shm_CC = shared_memory.SharedMemory(create=True, size=CC.nbytes)
    np_CC = np.ndarray(CC.shape, dtype=CC.dtype, buffer=shm_CC.buf)
    np_CC[:] = CC[:]
    
    update(alpha, max_iter, shm_Tstar.name, Tstar.shape, Tstar.dtype,
           shm_CC.name, CC.shape, CC.dtype, barrier)
    
    Tstar[:] = np_Tstar[:]
    
    del np_Tstar
    shm_Tstar.close()
    shm_Tstar.unlink()
    del np_CC
    shm_CC.close()
    shm_CC.unlink()
    
    Zstar = Tstar.dot(Lpinv_BT.transpose())
    
    for s in range(M):
        T_s = Tstar[:,s*d:(s+1)*d].T
        v_s = Zstar[:,n+s]
        intermed_param.T[s,:,:] = np.matmul(intermed_param.T[s,:,:], T_s)
        intermed_param.v[s,:] = np.matmul(intermed_param.v[s,:][np.newaxis,:], T_s) + v_s
        C_s = C[s,:].indices
        y[C_s,:] = intermed_param.eval_({'view_index': s, 'data_mask': C_s})

    return y