# AUTOGENERATED! DO NOT EDIT! File to edit: 00_Genodata.ipynb (unless otherwise specified).

__all__ = ['read_bgen', 'bgen2dask', 'pybgen_region', 'extract_bed', 'Genodata']

# Cell
import numpy as np
import pandas as pd
import dask.array as da
from bgen_reader import open_bgen
from pandas_plink import read_plink
try:
    from pybgen.parallel import ParallelPyBGEN as PyBGEN
except:
    print('Can not import ParallelPyBGEN. import PyBGEN instead')
    from pybgen import PyBGEN

# Cell
def read_bgen(file, sample_file=None,pybgen=True):
    '''the function to read genotype data'''
    if pybgen:
        bg = PyBGEN(file,probs_only=True)
        bim = []
        for i,t in enumerate(bg.iter_variant_info()):
            bim.append([int(t.chrom),t.name,0.0,t.pos,t.a1,t.a2,i])
        bim = pd.DataFrame(bim,columns=['chrom','snp','cm','pos','a0','a1','i'])
        bim.snp = 'chr'+bim[['chrom','pos','a0','a1']].astype(str).agg(':'.join, axis=1)
    else:
        bg = open_bgen(file,verbose=False)
        snp,aa0,aa1 = [],[],[]
        for c,p,alleles in zip(bg.chromosomes,bg.positions,bg.allele_ids):
            a0,a1 = alleles.split(',')
            aa0.append(a0)
            aa1.append(a1)
            snp.append(':'.join(['chr'+str(int(c)),str(p),a0,a1]))  # '05' first change to int, then change to str
        bim = pd.DataFrame({'chrom':bg.chromosomes.astype(int),'snp':snp,'pos':bg.positions,'a0':aa0,'a1':aa1})
    if sample_file is None:
        fam = None
    else:
        fam = pd.read_csv(sample_file, header=0, delim_whitespace=True, quotechar='"',skiprows=1)
        fam.columns = ['fid','iid','missing','sex']
        fam = fam
    return bim,fam,bg

# Cell
def bgen2dask(bgen,index,step=500):
    '''The function to covert bgen to dask array'''
    genos = []
    n = len(index)
    for i in range(0,n,step):
        onecode_geno = bgen.read(index[i:min(n,i+step)])  #samples x variants
        geno = onecode_geno.argmax(axis=2).astype(np.int8)
        genos.append(da.from_array(geno))
    return(da.concatenate(genos,axis=1).T)

# Cell
def pybgen_region(bgen,region,step=100):
    genos,geno=[],[]
    i = 1
    for _,v in bgen.iter_variants_in_region('0'+str(region[0]) if region[0]<10 else str(region[0]),region[1],region[2]):
        if i % step == 0:
            genos.append(da.from_array(geno))
            geno = []
        geno.append(v.argmax(axis=1).astype(np.int8))
        i += 1
    genos.append(da.from_array(geno))
    return(da.concatenate(genos,axis=0))

# Cell
def extract_bed(geno,idx,row=True,step=500,region=None):  #row = True by variants, row = False by samples
    if isinstance(geno,da.core.Array):
        if row:
            geno = geno[idx,:]
        else:
            geno = geno[:,idx]
    elif isinstance(geno,PyBGEN):
        geno = pybgen_region(geno,region,step)
    else:
        if row:
            #must be numric index
            if type(list(idx)[0]) is bool:
                pd_idx = pd.Series(idx)
                idx = list(pd_idx[pd_idx].index)
            geno = bgen2dask(geno,idx,step)
        else:
            geno = geno.read() # read all variants
            geno = geno[:,idx]
    return geno

# Cell
class Genodata:
    def __init__(self,geno_path,sample_path=None):
        self.bim,self.fam,self.bed = self.read_geno(geno_path,sample_path)

    def __repr__(self):
        return "bim:% s \n fam:% s \n bed:%s" % (self.bim, self.fam, self.bed)

    def read_geno(self,geno_file,sample_file):
        if geno_file.endswith('.bed'):
            return read_plink(geno_file[:-4], verbose=False)
        elif geno_file.endswith('.bgen'):
            if sample_file is None:
                sample_file = geno_file.replace('.bgen', '.sample')
            return read_bgen(geno_file,sample_file)
        else:
            raise ValueError('Plesae provide the genotype files with PLINK binary format or BGEN format')


    def geno_in_stat(self,stat,notin=False):
        '''The function to find an overlap region between geno data with sumstat'''
        variants = stat.SNP
        self.extractbyvariants(variants,notin)


    def geno_in_unr(self,unr):
        '''The function to find an overlap samples between geno data with unr'''
        samples = unr.IID
        self.extractbysamples(samples)

    def extractbyregion(self,region):
        bim = self.bim
        idx = (bim.chrom == region[0]) & (bim.pos >= region[1]) & (bim.pos <= region[2])
        if sum(idx) == 0:
            raise ValueError('The extraction is empty')
        #update bim,bed
        self.extractbyidx(idx,row=True,region=region)

    def extractbyvariants(self,variants,notin=False):  #variants is list or pd.Series
        idx = self.bim.snp.isin(variants)
        if notin:
            idx = idx == False
        if sum(idx) == 0:
            raise ValueError('The extraction is empty')
        #update bim,bed
        self.extractbyidx(idx,row=True)

    def extractbysamples(self,samples,notin=False): #samples is list or pd.Series
        samples = pd.Series(samples,dtype=str)
        idx = self.fam.iid.astype(str).isin(samples)
        if notin:
            idx = idx == False
        if sum(idx) == 0:
            raise ValueError('The extraction is empty')
        #update fam,bed
        self.extractbyidx(idx,row=False)

    def extractbyidx(self,idx,row=True,region=None):
        '''get subset of genodata by index
        if index is numbers, the order of genodata will be sorted by the order of index.
        if row = True, extract by variants. Otherwise, extract by samples.'''
        idx = list(idx)
        self.idx = idx
        if row:
            #update bim
            if type(idx[0]) is bool:
                self.bim = self.bim[idx]
            else:
                self.bim = self.bim.iloc[idx]
        else:
            #update fam
            if type(idx[0]) is bool:
                self.fam = self.fam[idx]
            else:
                self.fam = self.fam.iloc[idx]
        self.bed = extract_bed(self.bed,idx,row,region=region)
