#!/usr/bin/env python3

"""
Reference free AF-based demultiplexing on pooled scRNA-seq
Jon Xu (jun.xu@uq.edu.au)
Lachlan Coin
Aug 2018
"""

import numpy as np
import pysam as ps
import pandas as pd
import statistics as stat
from scipy.stats import binom
from scipy.sparse import csr_matrix
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import sys, vcf, csv, math, datetime, pickle, argparse

class SNV_data:
    """
    Stores data on each SNV
    """

    def __init__(self, chrom, pos, ref, alt):

        self.CHROM = chrom
        self.POS = pos
        self.REF = ref
        self.ALT = alt


class SNV_list(list):
    """
    List of SNV_data objects
    """

    def add_SNV(self, *args, **kwargs):
        self.append(SNV_data(*args, **kwargs))


    def build_base_calls_matrix(self, file_s, barcodes):
        """
        Build pandas DataFrame
        Parameters:
            file_s(str): Path to sam file (0-based positions)
            barcodes(list): cell barcodes
        """

        all_POS = []   # snv positions
        for entry in self:
            pos = str(entry.CHROM) + ':' + str(entry.POS)
            if pos not in all_POS:
                all_POS.append(pos)

        in_sam = ps.AlignmentFile(file_s, 'rb')
        ref_base_calls_mtx = pd.DataFrame(0, index=all_POS, columns=barcodes)
        alt_base_calls_mtx = pd.DataFrame(0, index=all_POS, columns=barcodes)
        print('Num Pos:', len(all_POS), ', Num barcodes:', len(barcodes))

        for snv in self:
            position = str(snv.CHROM) + ':' + str(snv.POS)
            for read in in_sam.fetch(snv.CHROM, snv.POS-1, snv.POS+1):
                if read.flag < 256:   # only valid reads
                    if (snv.POS - 1) in read.get_reference_positions():
                        # if the read aligned positions cover the SNV position
                        try:
                            barcode = read.get_tag('CB')
                        except:
                            barcode = ''
                        if barcode in barcodes:
                            # read the base from the snv.POS which the read has mapped to
                            base = read.query_sequence[[item for item in read.get_aligned_pairs(True) if item[1] == (snv.POS - 1)][0][0]]
                            if base == snv.REF:
                                ref_base_calls_mtx.loc[position, barcode] += 1
                            if base == snv.ALT:
                                alt_base_calls_mtx.loc[position, barcode] += 1

        ref_base_calls_mtx.index.name = alt_base_calls_mtx.index.name = 'SNV'

        return (ref_base_calls_mtx, alt_base_calls_mtx)


class models:

    def __init__(self, base_calls_mtx, num, doublets):
        """
        Model class containing SNVs, matrices counts, barcodes, model allele fraction with assigned cells

        Parameters:
             base_calls_mtx(list of Dataframes): SNV-barcode matrix containing lists of base calls
             ref_bc_mtx: SNV-barcode matrix for reference allele counts
             alt_bc_mtx: SNV-barcode matrix for alternative allele counts
             all_POS(list): list of SNVs positions
             barcodes(list): list of cell barcodes
             num(int): number of total samples
             P_s_c(DataFrame): barcode/sample matrix containing the probability of seeing sample s with observation of barcode c
             lP_c_s(DataFrame): barcode/sample matrix containing the log likelihood of seeing barcode c under sample s, whose sum should increase by each iteration
             assigned(list): final lists of cell/barcode assigned to each cluster/model
             model_af(Dataframe): Dataframe of model allele frequencies P(A) for each SNV and state
        """
        self.ref_bc_mtx, self.alt_bc_mtx = base_calls_mtx[0], base_calls_mtx[1]
        self.all_POS, self.barcodes = base_calls_mtx[2].tolist(), base_calls_mtx[3].tolist()
        self.doublets = doublets
        self.num = num
        self.P_s_c = pd.DataFrame(0, index = self.barcodes, columns = range(self.num))
        self.lP_c_s = pd.DataFrame(0, index = self.barcodes, columns = range(self.num))
        self.assigned, self.reassigned = [], []
        for _ in range(self.num):
            self.assigned.append([])
            self.reassigned.append([])
        self.model_af = pd.DataFrame(0, index=self.all_POS, columns=range(self.num))
        self.pseudo = 1

        # set background alt count proportion as allele fraction for each SNVs of doublet state, with pseudo count added for 0 counts on multi-base SNPs
        N_alt = self.alt_bc_mtx.sum(axis=1) + self.pseudo
        N_ref = self.ref_bc_mtx.sum(axis=1) + self.pseudo
        self.k_alt = N_alt / (N_ref + N_alt)

        # find barcodes for state initialization, using subsetting/PCA/K-mean
        base_mtx = (self.alt_bc_mtx + self.ref_bc_mtx).toarray()
        rows, cols = [*range(base_mtx.shape[0])], [*range(base_mtx.shape[1])]
        irows, icols = np.array(rows), np.array(cols)
        nrows, ncols = len(rows), len(cols)
        mrows = mcols = 0
        while (mrows < 0.9 * nrows or mcols < 0.9 * ncols) and (nrows >= 10 or ncols >= 10):
            rbrows = np.sort(np.unique(list(map(int, np.random.beta(1,10,int(0.1*nrows))*nrows))))
            rbcols = np.sort(np.unique(list(map(int, np.random.beta(1,10,int(0.1*ncols))*ncols))))
            rows = np.count_nonzero(base_mtx, axis=1).argsort().tolist()
            cols = np.count_nonzero(base_mtx, axis=0).argsort().tolist()
            rows = [item for index, item in enumerate(rows) if index not in rbrows]
            cols = [item for index, item in enumerate(cols) if index not in rbcols]
            irows, icols = irows[rows], icols[cols]
            nrows, ncols = len(rows), len(cols)
            base_mtx = base_mtx[rows][:,cols]
            mrows = min(np.count_nonzero(base_mtx, axis=0))
            mcols = min(np.count_nonzero(base_mtx, axis=1))
        alt_subset = self.alt_bc_mtx[irows][:, icols].todense()
        ref_subset = self.ref_bc_mtx[irows][:, icols].todense()
        alt_prop = (alt_subset + 0.01) / (alt_subset + ref_subset + 0.02)
        alt_pca = StandardScaler().fit_transform(alt_prop.T)
        pca = PCA(n_components=min(nrows, ncols, 20))
        pca_alt = pca.fit_transform(alt_pca)
        if pca_alt.shape[0] < self.num:
            print('not enough informative cells to support model initialization')
        else:
            kmeans = KMeans(n_clusters=self.num, random_state=0).fit(pca_alt)
            # intialise allele frequency for model states
            self.initial = []
            for n in range(self.num):
                self.initial.append([])
                for index, col in enumerate(icols):
                    if kmeans.labels_[index] == n:
                        self.initial[n].append(self.barcodes[col])
                barcode_alt = np.array(self.alt_bc_mtx[:, icols[kmeans.labels_==n]].sum(axis=1))
                barcode_ref = np.array(self.ref_bc_mtx[:, icols[kmeans.labels_==n]].sum(axis=1))
                self.model_af.loc[:, n] = (barcode_alt + self.k_alt) / (barcode_alt + barcode_ref + self.pseudo)


    def run_EM(self):

        # commencing E-M
        iterations = 0
        self.sum_log_likelihood = [1,2]
        while self.sum_log_likelihood[-2] != self.sum_log_likelihood[-1]:
            iterations += 1
            progress = 'Iteration ' + str(iterations) + '   ' + str(datetime.datetime.now()) + '\n'
            with open('scSplit.log', 'a') as myfile: myfile.write(progress)
            self.calculate_cell_likelihood()
            self.calculate_model_af()
            self.sum_log_likelihood.append(self.lP_c_m)


    def calculate_cell_likelihood(self):
        """
        Calculate cell|sample likelihood P(c|s) and derive sample|cell probability P(s|c)
        P(c|s_v) = P(N(A),N(R)|s) = P(g_A|s)^N(A) * (1-P(g_A|s))^N(R)
        log(P(c|s)) = sum_v{(N(A)_c,v*log(P(g_A|s)) + N(R)_c,v*log(1-P(g_A|s)))}
        P(s_n|c) = P(c|s_n) / [P(c|s_1) + P(c|s_2) + ... + P(c|s_n)]
        log(P(s1|c) = log{1/[1+P(c|s2)/P(c|s1)]} = -log[1+P(c|s2)/P(c|s1)] = -log[1+2^(logP(c|s2)-logP(c|s1))]
        """

        # calculate likelihood P(c|s) based on allele probability
        for n in range(self.num):
            matcalc = self.alt_bc_mtx.T.multiply(self.model_af.loc[:, n].apply(np.log2)).T \
                    + self.ref_bc_mtx.T.multiply((1 - self.model_af.loc[:, n]).apply(np.log2)).T
            self.lP_c_s.loc[:, n] = matcalc.sum(axis=0).tolist()[0]

        # transform to cell sample probability P(s|c) using Baysian rule
        for i in range(self.num):
            denom = 0
            for j in range(self.num):
                denom += 2 ** (self.lP_c_s.loc[:, j] - self.lP_c_s.loc[:, i])
            self.P_s_c.loc[:, i] = 1 / denom

        # calculate model likelihood: logP(x|theta) = log{Sigma_yP(x,y|theta)}
        self.lP_c_m = ((self.lP_c_s.subtract(self.lP_c_s.min(axis=1), axis=0).pow(2).sum(axis=1) + 1).apply(np.log2) + self.lP_c_s.min(axis=1)).sum(axis=0)


    def calculate_model_af(self):
        """
        Update the model allele fraction by distributing the alt and total counts of each barcode on a certain snv to the model based on P(s|c)
        """

        self.model_af = pd.DataFrame((self.alt_bc_mtx.dot(self.P_s_c) + self.k_alt) / ((self.alt_bc_mtx + self.ref_bc_mtx).dot(self.P_s_c) + self.pseudo),
                                        index = self.all_POS, columns = range(self.num))


    def assign_cells(self):
        """
            Final assignment of cells according to P(s|c) >= 0.9
        """

        for n in range(self.num):
            self.assigned[n] = sorted(self.P_s_c.loc[self.P_s_c[n] >= 0.99].index.values.tolist())


    def define_doublet(self):
        """
            Locate the doublet state
        """
        cross_state = pd.DataFrame(0, index = range(self.num), columns = range(self.num))
        for i in range(self.num):
            for j in range(self.num):
                index = []
                # transform barcode assignments to indices
                for item in self.assigned[j]:
                    index.append(self.barcodes.index(item))
                matcalc = self.alt_bc_mtx.T.multiply(self.model_af.loc[:, i].apply(np.log2)).T \
                        + self.ref_bc_mtx.T.multiply((1 - self.model_af.loc[:, i]).apply(np.log2)).T
                cross_state.loc[j, i] = matcalc[:, index].sum()
        result = cross_state.sum(axis=0).tolist()
        self.doublet = result.index(max(result))


    def refine_doublets(self, doublets):
        """
            Find falsely assigned doublets
        """
        N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=self.all_POS, columns=range(self.num)), pd.DataFrame(0, index=self.all_POS, columns=range(self.num))
        found = []
        self.reassigned = self.assigned.copy()
        if doublets > 0: # if user has set expectation on doublet proportion, otherwise go with default doublet detection
            for n in range(self.num):
                bc_idx = [i for i, e in enumerate(self.barcodes) if e in self.assigned[n]]
                # REF/ALT alleles counts from cells assigned to state n
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[:, bc_idx].sum(axis=1), self.alt_bc_mtx[:, bc_idx].sum(axis=1)
            # get total non zero variants per state
            rps = self.alt_bc_mtx.T.dot(1 - self.model_af) * (self.P_s_c >= 0.99)
            rpc = rps.drop(self.doublet, axis=1)
            lack = len(self.barcodes) * doublets - len(self.assigned[self.doublet])
            if lack > 0:
                found = rpc.index[rpc.sum(axis=1).argsort()[int(len(rpc.index) - lack):len(rpc.index)]]
                for n in range(self.num):
                    if n != self.doublet:
                        self.reassigned[n] = [x for x in self.assigned[n] if x not in found]
                    else:
                        self.reassigned[n] += found.tolist()


    def distinguishing_alleles(self, pos=[]):
        """
            Locate the distinguishing alleles
            N_ref_mtx, N_alt_mtx: SNV-state matrix for ref/alt counts in each state
        """
        
        # build SNV-state matrices for ref and alt counts
        self.dist_variants, ncols = [], self.num - 1
        if len(pos) != 0:
            snv = [self.all_POS[i] for i in pos]
            N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=snv, columns=range(self.num)), pd.DataFrame(0, index=snv, columns=range(self.num))
        else:
            N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=self.all_POS, columns=range(self.num)), pd.DataFrame(0, index=self.all_POS, columns=range(self.num))

        for n in range(self.num):
            bc_idx = [i for i, e in enumerate(self.barcodes) if e in self.reassigned[n]]
            # REF/ALT alleles counts from cells assigned to state n
            if len(pos) == 0:
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[:, bc_idx].sum(axis=1), self.alt_bc_mtx[:, bc_idx].sum(axis=1)
            else:
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[pos][:, bc_idx].sum(axis=1), self.alt_bc_mtx[pos][:, bc_idx].sum(axis=1)

        # judge N(A) or N(R) for each cluster
        alt_or_ref = ((N_alt_mtx >= 10) * 1 - ((N_ref_mtx >= 10) & (N_alt_mtx == 0)) * 1).drop(self.doublet, axis=1).astype(np.int8)
        alt_or_ref[alt_or_ref == 0], alt_or_ref[alt_or_ref == -1] = float('NaN'), 0     # formatting data for further analysis
        alt_or_ref = alt_or_ref.ix[[x for x in alt_or_ref.index if x[0] not in ['X','Y','MT']]]
        submatrix = alt_or_ref.copy()
        
        # find unique alleles for each column window and then merge
        while len(self.dist_variants) < self.num - 1:                          
            found, todo = [], []
            while len(found) < ncols - 1:
                selected = []
                # informative alleles for the selected clusters with no NAs
                informative_sub = submatrix[(submatrix.var(axis=1) > 0) & (submatrix.count(axis=1) == ncols)]
                if informative_sub.index.values.size > 0:
                    patt = informative_sub.astype(str).values.sum(axis=1)
                    unq = np.unique(patt, return_inverse=True)                  
                    for j in range(len(unq[0])):
                        # first informative alleles for each unique pattern in the cluster screen which has maximum non-NA values in original information matrix
                        selected.append(alt_or_ref.loc[informative_sub.iloc[[i for i, x in enumerate(unq[1]) if x == j]].index].count(axis=1).idxmax())
                    subt = submatrix.reindex(np.unique(selected)).iloc[:, 0:ncols]
                    subt[subt.isna()]=-1
                    d = np.linalg.svd(subt, full_matrices=False)[1]
                    if sum(d>1e-10) >= ncols:
                        found = [subt[subt.sum(axis=1) == min(subt.sum(axis=1))].index[0]]
                        while len(found) < ncols:            # Gram-Schmidt Process
                            svd = np.linalg.svd(subt.reindex(found).transpose(), full_matrices=False)
                            U, V = svd[0], svd[2]
                            if len(found) == 1: D=np.asmatrix(svd[1])
                            else: D = np.diag(svd[1])
                            colU, Vinv, Dinv = U.shape[1], np.linalg.solve(V, np.diag([1]*len(V))), np.linalg.solve(D, np.diag([1]*len(D)))
                            proj = np.asmatrix(np.zeros((colU,subt.shape[0])))
                            for j in range(subt.shape[0]):
                                for i in range(colU):
                                    proj[i, j] = np.dot(U[:,i], subt.iloc[j])
                            W = np.dot(np.dot(Vinv, Dinv), proj)
                            R = np.dot(subt.reindex(found).transpose(), W)
                            diff = (subt.transpose() - R).transpose()
                            subt1 = subt.reindex(diff[diff.var(axis=1) > (0.5 * max(diff.var(axis=1)))].index)
                            found += [subt1[subt1.sum(axis=1) == min(subt1.sum(axis=1))].index[0]]
                ncols -= 1
                submatrix = submatrix.iloc[:, 0:ncols]
            self.dist_variants += found
            if len(found) == 0:
                with open('scSplit_dist_variants.txt', 'w') as myfile:
                    myfile.write('Not all clusters can be distinguished. \n')
                break
            # find indistinguishable clusters to form new submatrix
            for i in range(1, len(alt_or_ref.columns)):
                for j in range(i):
                    if (alt_or_ref.reindex(self.dist_variants).iloc[:, [i, j]].var(axis=1)).sum() == 0:
                        todo.extend([i, j])
            submatrix = alt_or_ref.iloc[:, sorted(set(todo))]
            ncols = len(submatrix.columns)

        self.dist_variants = list(set(self.dist_variants))
        self.dist_matrix = alt_or_ref.reindex(self.dist_variants)
        self.pa_matrix = alt_or_ref


class scSplit():

    def __init__(self):
        parser = argparse.ArgumentParser(
            description='Genotype-free demultiplexing of pooled single-cell RNA-Seq',
            usage='''scSplit <command> [<args>]

Commands:
   count     Generate REF/ALT count matrices from pooled BAM file
   run       Demultiplex the scRNA-Seq using REF/ALT count matrices
   genotype  Generate genotype information in VCF format
''')
        parser.add_argument('command', help='Subcommand to run')

        args = parser.parse_args(sys.argv[1:2])
        if hasattr(self, args.command):
            getattr(self, args.command)()


    def count(self):    # Generate REF/ALT count matrices from pooled BAM file

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Generate REF/ALT count matrices from pooled BAM file',
            usage='''scSplit count [<args>]

Options:
   -v, --vcf        Genotype of mixed BAM in VCF format
   -i, --bam        Mixed sample BAM file
   -b, --bar        Barcodes whitelist
   -r, --ref        REF count CSV output
   -a, --alt        ALT count CSV output
''')
        parser.add_argument('-v', '--vcf', required=True,  help='Genotype of mixed BAM in VCF format')
        parser.add_argument('-i', '--bam', required=True, help='Mixed sample BAM file')
        parser.add_argument('-b', '--bar', required=True,  help='Barcodes whitelist')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV output')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV output')
        args = parser.parse_args(sys.argv[2:])
        
        dist_alleles = []
        epsilon = 0.01

        all_SNVs = SNV_list()  # list of SNV_data objects
        for record in vcf.Reader(open(args.vcf, 'r')):
            # only keep SNVs with heterozygous genotypes, and ignore SNV with multiple bases (e.g. GGGT/GGAT)
            if (record.samples[0]['GL'][1] > np.log10(1-epsilon)) & (len(record.REF) == 1) & (len(record.ALT) == 1):
                all_SNVs.add_SNV(record.CHROM, record.POS, record.REF, record.ALT[0])
        
        barcodes = []   # list of cell barcodes
        for line in open(args.barcodes, 'r'):
            barcodes.append(line.strip())

        base_calls_mtx = all_SNVs.build_base_calls_matrix(args.bam, barcodes)
        base_calls_mtx[0].to_csv('{}'.format(args.ref))
        base_calls_mtx[1].to_csv('{}'.format(args.alt))


    def run(self):      # Demultiplex the scRNA-Seq using REF/ALT count matrices

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Demultiplex the scRNA-Seq using REF/ALT count matrices',
            usage='''scSplit run [<args>]

Options:
   -r, --ref    REF count CSV input
   -a, --alt    ALT count CSV input
   -n, --num    Number of mixed samples
   -d, --dbl    Doublet proportion [0.05]
   -v, --vcf    VCF file for filtering distinguishing variants
''')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV input')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV input')
        parser.add_argument('-n', '--num', required=True,  help='Number of mixed samples')
        parser.add_argument('-d', '--dbl', required=False, help='Doublet proportion [0.05]')
        parser.add_argument('-v', '--vcf', required=False, help='VCF file for filtering distinguishing variants')
        args = parser.parse_args(sys.argv[2:])

        try:
            doublets = float(args.doublets)
        except:
            doublets = -1

        if doublets == 0:
            num = int(args.num)
        else:
            num = int(args.num) + 1  # additional doublet state

        progress = 'Starting data collection: ' + str(datetime.datetime.now()) + '\n'
        with open('scSplit.log', 'a') as myfile: myfile.write(progress)

        # Read in existing matrix from the csv files
        ref = pd.read_csv(args.ref, header=0, index_col=0)
        alt = pd.read_csv(args.alt, header=0, index_col=0)
        ref_s, alt_s = csr_matrix(ref.values), csr_matrix(alt.values)
        base_calls_mtx = [ref_s, alt_s, ref.index, ref.columns]
        progress = 'Allele counts matrices uploaded: ' + str(datetime.datetime.now()) + '\n'
        with open('scSplit.log', 'a') as myfile: myfile.write(progress)

        max_likelihood = -1e10
        for i in range(30):
            with open('scSplit.log', 'a') as myfile: myfile.write('round ' + str(i) + '\n')

            model = models(base_calls_mtx, int(num), doublets)
            if model.model_af.sum().sum() > 0:
                model.run_EM()
                model.assign_cells()
                if model.lP_c_s.max(axis=1).sum() > max_likelihood:
                    max_likelihood = model.lP_c_s.max(axis=1).sum()
                    initial, assigned, af, p_s_c = model.initial, model.assigned, model.model_af, model.P_s_c
        model.assigned, model.initial, model.model_af, model.P_s_c = assigned, initial, af, p_s_c

        model.define_doublet()
        model.refine_doublets(doublets)

        pos = []
        try:
            for record in vcf.Reader(open(args.vcf, 'r')):
                # only keep high R2 variants
                try:
                    if float(record.INFO['R2'][0]) > 0.9:
                        pos.append(model.all_POS.index(str(record.CHROM)+':'+str(record.POS)))
                except:
                    continue
        except:
            pass
            
        model.distinguishing_alleles(pos)

        # generate outputs
        with open('scSplit_model', 'wb') as f:
            pickle.dump(model, f, pickle.HIGHEST_PROTOCOL)
        for n in range(int(num)):
            with open('scSplit_barcodes_{}.csv'.format(n), 'w') as myfile:
                for item in model.reassigned[n]:
                    myfile.write(str(item) + '\n')
        model.P_s_c.to_csv('scSplit_P_s_c.csv')
        with open('scSplit_doublet.txt', 'w') as myfile:
            if doublets == 0:
                myfile.write('No doublet cluster expected. \n')
            else:
                myfile.write('Cluster ' + str(model.doublet) + ' is doublet. \n')
        with open('scSplit_dist_variants.txt', 'w') as myfile:
            for item in model.dist_variants:
                myfile.write(str(item) + '\n')
        model.dist_matrix.to_csv('scSplit_dist_matrix.csv')
        model.pa_matrix.to_csv('scSplit_PA_matrix.csv')


    def genotype(self):     # Generate genotype information in VCF format

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Generate genotype information in VCF format',
            usage='''scSplit genotype [<args>]

Options:
   -r, --ref    REF count CSV input
   -a, --alt    ALT count CSV input
   -p, --psc    generated P(S|C)
''')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV input')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV input')
        parser.add_argument('-p', '--psc', required=True, help='generated P(S|C)')
        args = parser.parse_args(sys.argv[2:])

        ref = pd.read_csv(args.ref, header=0, index_col=0)
        alt = pd.read_csv(args.alt, header=0, index_col=0)
        ref_s = csr_matrix(ref.values)
        alt_s = csr_matrix(alt.values)
        all_POS = ref.index

        # get cell assignment
        P_s_c = pd.read_csv(args.psc, header=0, index_col=0)
        A_s_c = ((P_s_c >= 0.9) * 1).astype('float64')
        num = len(P_s_c.columns)

        err = 0.01  # error rate assumption
        # binomial simulation for genotype likelihood P(D|AA,RA,RR) with the alt count vs total count condition and (err, 0.5, 1-err) as allele probability
        lp_d_rr = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), err), index=all_POS, columns=range(num)).apply(np.log10)
        lp_d_ra = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), 0.5), index=all_POS, columns=range(num)).apply(np.log10)
        lp_d_aa = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), 1-err), index=all_POS, columns=range(num)).apply(np.log10)

        vcf_content = pd.DataFrame(index = all_POS, columns = range(-9, num))  # -9~-1: meta data, 0: doublet state, 1:num: samples
        names = vcf_content.columns.tolist()
        names[0] = '#CHROM'
        names[1] = 'POS'
        names[2] = 'ID'
        names[3] = 'REF'
        names[4] = 'ALT'
        names[5] = 'QUAL'
        names[6] = 'FILTER'
        names[7] = 'INFO'
        names[8] = 'FORMAT'
        vcf_content.columns = names
        vcf_content.loc[:,'#CHROM'] = [item.split(':')[0] for item in all_POS]
        vcf_content.loc[:,'POS'] = [item.split(':')[1] for item in all_POS]
        vcf_content.loc[:,'ID'] = all_POS
        vcf_content.loc[:,'REF'] = '.'
        vcf_content.loc[:,'ALT'] = '.'
        vcf_content.loc[:,'QUAL'] = '.'
        vcf_content.loc[:,'FILTER'] = '.'
        vcf_content.loc[:,'INFO'] = '.'
        vcf_content.loc[:,'FORMAT'] = 'GP:GL'

        # round to three decimal points
        GL_RR = 10 ** lp_d_rr.astype(float)
        GL_RA = 10 ** lp_d_ra.astype(float)
        GL_AA = 10 ** lp_d_aa.astype(float)
        GL_nom = GL_RR + GL_RA + GL_AA
        GP_RR = round(GL_RR / GL_nom, 3).astype(str)
        GP_RA = round(GL_RA / GL_nom, 3).astype(str)
        GP_AA = round(GL_AA / GL_nom, 3).astype(str)
        lGL_RR = round(lp_d_rr.astype(float), 3).astype(str)
        lGL_RA = round(lp_d_ra.astype(float), 3).astype(str)
        lGL_AA = round(lp_d_aa.astype(float), 3).astype(str)

        for n in range(num):
            vcf_content.loc[:, n] = GP_RR.loc[:, n] + ',' + GP_RA.loc[:, n] + ',' + GP_AA.loc[:, n] + \
                            ':' + lGL_RR.loc[:, n] + ',' + lGL_RA.loc[:, n] + ',' + lGL_AA.loc[:, n]

        header = '##fileformat=VCFv4.2\n##fileDate=' + str(datetime.datetime.today()).split(' ')[0] + \
                '\n##source=sc_split\n##reference=hg19.fa\n##contig=<ID=1,length=249250621>\n' + \
                '##contig=<ID=10,length=135534747>\n##contig=<ID=11,length=135006516>\n##contig=<ID=12,length=133851895>\n' + \
                '##contig=<ID=13,length=115169878>\n##contig=<ID=14,length=107349540>\n##contig=<ID=15,length=102531392>\n' + \
                '##contig=<ID=16,length=90354753>\n##contig=<ID=17,length=81195210>\n##contig=<ID=18,length=78077248>\n' + \
                '##contig=<ID=19,length=59128983>\n##contig=<ID=2,length=243199373>\n##contig=<ID=20,length=63025520>\n' + \
                '##contig=<ID=21,length=48129895>\n##contig=<ID=22,length=51304566>\n##contig=<ID=3,length=198022430>\n' + \
                '##contig=<ID=4,length=191154276>\n##contig=<ID=5,length=180915260>\n##contig=<ID=6,length=171115067>\n' + \
                '##contig=<ID=7,length=159138663>\n##contig=<ID=8,length=146364022>\n##contig=<ID=9,length=141213431>\n' + \
                '##contig=<ID=MT,length=16569>\n##contig=<ID=X,length=155270560>\n##contig=<ID=Y,length=59373566>\n' + \
                '##FILTER=<ID=PASS,Description="All filters passed">\n##INFO=<ID=AN,Number=1,Type=Integer,Description="Total Allele Count">\n' + \
                '##INFO=<ID=AC,Number=A,Type=Integer,Description="Alternate Allele Count">\n' + \
                '##INFO=<ID=AF,Number=A,Type=Float,Description="Estimated Alternate Allele Frequency">\n' + \
                '##FORMAT=<ID=GL,Number=3,Type=Float,Description="Genotype Likelihood for RR/RA/AA">\n'

        with open('sc_split.vcf', 'w+') as myfile:
            myfile.write(header)
            vcf_content.to_csv(myfile, index=False, sep='\t', quoting=csv.QUOTE_NONE, escapechar='"')

if __name__ == '__main__':
    scSplit()
