# AUTOGENERATED! DO NOT EDIT! File to edit: 03_LDmatrix.ipynb (unless otherwise specified).

__all__ = ['geno_corr', 'dask_corr', 'dict2mat', 'dask_corr_pair', 'dict2mat_pair']

# Cell
import torch
import numpy as np
import pandas as pd
import dask.array as da

# Cell
def geno_corr(x,y=None,step=100):
    '''functions to calculate LD matrix'''
    if y is None:
        dd = dask_corr(x,step)
        return(dict2mat(dd))
    else:
        dd = dask_corr_pair(x,y,step)
        return(dict2mat_pair(dd))

def dask_corr(genos,step=100):
    #sample by snps (normalized)
    nsample = genos.shape[0]
    nsnp = genos.shape[1]
    da_corr = {}
    for i in range(0,nsnp,step):
        da_corr[i] = {}
        geno_i = genos[:,i:min(i+step,nsnp)].compute().astype(np.float64)
        geno_i = (geno_i - np.nanmean(geno_i,axis=0)[None,:])/np.nanstd(geno_i,axis=0)[None,:]
        geno_i = torch.from_numpy(geno_i)
        geno_i[torch.isnan(geno_i)] = 0
        chunk_i = da.from_array((torch.matmul(geno_i.T,geno_i)/nsample).numpy())
        da_corr[i][i]=chunk_i
        for j in range(i+step,nsnp,step):
            geno_j = genos[:,j:min(j+step,nsnp)].compute().astype(np.float64)
            geno_j = (geno_j - np.nanmean(geno_j,axis=0)[None,:])/np.nanstd(geno_j,axis=0)[None,:]
            geno_j = torch.from_numpy(geno_j)
            geno_j[torch.isnan(geno_j)] = 0
            cor_ij = da.from_array((torch.matmul(geno_i.T,geno_j)/nsample).numpy())
            da_corr[i][j]=cor_ij
    return da_corr

def dict2mat(dd):
    da_mat=[]
    for i in dd.keys():
        rowi = []
        for j in dd.keys():
            if i>j:
                rowi.append(dd[j][i].T)
            else:
                rowi.append(dd[i][j])
        rowi = da.concatenate(rowi,axis=1)
        da_mat.append(rowi)
    return(da.concatenate(da_mat,axis=0))

def dask_corr_pair(genos,pgenos,step=100):
    #sample by snps (normalized)
    nsample = genos.shape[0]
    nsnp = genos.shape[1]
    psample = pgenos.shape[0]
    psnp = pgenos.shape[1]
    if nsample != psample: print("error: sample not match")
    da_corr = {}
    for i in range(0,nsnp,step):
        da_corr[i] = {}
        geno_i = genos[:,i:min(i+step,nsnp)].compute().astype(np.float64)
        geno_i = (geno_i - np.nanmean(geno_i,axis=0)[None,:])/np.nanstd(geno_i,axis=0)[None,:]
        geno_i = torch.from_numpy(geno_i)
        geno_i[torch.isnan(geno_i)] = 0
        for j in range(0,psnp,step):
            geno_j = pgenos[:,j:min(j+step,psnp)].compute().astype(np.float64)
            geno_j = (geno_j - np.nanmean(geno_j,axis=0)[None,:])/np.nanstd(geno_j,axis=0)[None,:]
            geno_j = torch.from_numpy(geno_j)
            geno_j[torch.isnan(geno_j)] = 0
            cor_ij = da.from_array((torch.matmul(geno_i.T,geno_j)/nsample).numpy())
            da_corr[i][j]=cor_ij
    return da_corr

def dict2mat_pair(dd):
    da_mat=[]
    for i in dd.keys():
        rowi = []
        for j in dd[0].keys():
            rowi.append(dd[i][j])
        rowi = da.concatenate(rowi,axis=1)
        da_mat.append(rowi)
    return(da.concatenate(da_mat,axis=0))