# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_Liftover.ipynb (unless otherwise specified).

__all__ = ['Liftover']

# Cell
import gzip
from liftover import get_lifter

# Cell
class Liftover:
    def __init__(self,fr='hg19',to='hg38'):
        self.fr,self.to = fr,to
        self.chainmap = get_lifter(fr, to)
        #self.fasta

    def variants_liftover(self,chrom,pos):
        if len(chrom) == 1:
            chrom = [chrom]*len(pos)
        lchr,lpos = [],[]
        for c,p in zip(chrom,pos):
            new_c,new_p = self.chrpos_liftover(c,p)
            lchr.append(new_c)
            lpos.append(new_p)
        return lchr,lpos

    def chrpos_liftover(self,chrom,pos):
        try:
            if str(chrom) in ['X','chrX','23']:
                new_c,new_p,_ = self.chainmap['X'][pos][0]
                return 23,new_p
            elif str(chrom) in ['Y','chrY','24']:
                new_c,new_p,_ = self.chainmap['Y'][pos][0]
                return 24,new_p
            elif str(chrom) in ['M','chrM','25','MT']:
                new_c,new_p,_ = self.chainmap['M'][pos][0]
                return 25,new_p
            else:
                new_c,new_p,_ = self.chainmap[int(chrom)][pos][0]
            return int(new_c[3:]),new_p
        except:
            return 0,0

        #The function to liftover bim
    def bim_liftover(self,bim):
        new_bim = bim.copy()
        lchr,lpos = self.variants_liftover(bim.chrom,bim.pos)
        new_bim.chrom =lchr
        new_bim.pos = lpos
        new_bim.snp = 'chr'+new_bim[['chrom','pos','a0','a1']].astype(str).agg(':'.join, axis=1)
        return new_bim


    def sumstat_liftover(self,ss):
        new_ss = ss.copy()
        lchr,lpos = self.variants_liftover(ss.CHR,ss.POS)
        new_ss.CHR =lchr
        new_ss.POS = lpos
        new_ss.SNP = 'chr'+new_ss[['CHR','POS','REF','ALT']].astype(str).agg(':'.join, axis=1)
        return new_ss

    def vcf_liftover(self,vcf,vcf_out=None,remove_missing = True):
        if vcf_out is None:
            vcf_out = vcf[:-7]+'_'+self.fr+'To'+self.to+vcf[-7:]
        count_fail,total= 0,0
        with gzip.open(vcf, 'rt') as ifile:
            with gzip.open(vcf_out,'wt') as ofile:
                for line in ifile:
                    if line.startswith("#"):
                        ofile.write(line)
                    else:
                        variant = [x for x in line.split('\t')]
                        new_c,new_p = self.chrpos_liftover(variant[0],int(variant[1]))
                        total +=1
                        if new_c == 0:
                            count_fail +=1
                            if remove_missing:
                                continue
                        variant[0] = str(new_c)
                        variant[1] = str(new_p)
                        variant[2] = 'chr'+':'.join(variant[:2]+variant[3:5])
                        ofile.write('\t'.join(variant))
            ofile.close()
        ifile.close()
        if remove_missing:
            print("Total number SNPs ",total,". Removing SNPs failed to liftover ", count_fail)
        else:
            print("Total number SNPs ",total,". The number of SNPs failed to liftover ", count_fail,". Their chr and pos is replaced with 0, 0")



    def region_liftover(self,region):
        imp_cs,imp_start = self.chrpos_liftover(region[0],region[1])
        imp_ce,imp_end = self.chrpos_liftover(region[0],region[2])
        if imp_cs !=imp_ce:
            raise ValueError('After liftover, the region is not in the same chromosome anymore.')
        return imp_cs,imp_start,imp_end

    def df_liftover(self):
        pass
