#!/usr/bin/env python3

"""
Genotype-free demultiplexing of pooled scRNA-seq
Jon Xu (jun.xu@uq.edu.au)
Lachlan Coin
Aug 2018
"""

import numpy as np
import pysam as ps
import pandas as pd
import statistics as stat
from scipy.stats import binom
from scipy.sparse import csr_matrix
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import sys, io, vcf, csv, math, datetime, pickle, argparse

class mixed_VCF:

    """
    Genotypes from the mixed VCF
    """

    def build_base_calls_matrix(self, file_s, filtered_vcf, barcodes, tag):
        """
        Build REF/ALT allele count matrices
        Parameters:
            file_s(str): Path to sam file (0-based positions)
            filtered_vcf(Dataframe): Information from genotype of mixed sample
            barcodes(list): cell barcodes
        """

        in_sam = ps.AlignmentFile(file_s, 'rb')
        ref_base_calls_mtx = pd.DataFrame(0, index=filtered_vcf.index, columns=barcodes)
        alt_base_calls_mtx = pd.DataFrame(0, index=filtered_vcf.index, columns=barcodes)
        print('Num Pos:', len(filtered_vcf.index), ', Num barcodes:', len(barcodes))

        for position, row in filtered_vcf.iterrows():
            for read in in_sam.fetch(row['CHROM'], row['POS']-1, row['POS']+1):
                if read.flag < 256:   # only valid reads
                    if (row['POS'] - 1) in read.get_reference_positions():
                        # if the read aligned positions cover the SNV position
                        try:
                            barcode = read.get_tag(tag)
                        except:
                            barcode = ''
                        if barcode in barcodes:
                            # read the base from the snv.POS which the read has mapped to
                            base = read.query_sequence[[item for item in read.get_aligned_pairs(True) if item[1] == (row['POS'] - 1)][0][0]]
                            if base == row['REF']:
                                ref_base_calls_mtx.loc[position, barcode] += 1
                            if base == row['ALT']:
                                alt_base_calls_mtx.loc[position, barcode] += 1

        ref_base_calls_mtx.index.name = alt_base_calls_mtx.index.name = 'SNV'

        return (ref_base_calls_mtx, alt_base_calls_mtx)


class models:
    """
    Main Class including all model parameters
    """

    def __init__(self, base_calls_mtx, num):
        """
        Model class containing SNVs, matrices counts, barcodes, model allele fraction with assigned cells

        Parameters:
             base_calls_mtx(list of Dataframes): SNV-barcode matrix containing lists of base calls
             num(int): number of total samples
             ref_bc_mtx: SNV-barcode matrix for reference allele counts
             alt_bc_mtx: SNV-barcode matrix for alternative allele counts
             all_POS(list): list of SNVs positions
             barcodes(list): list of cell barcodes
             P_s_c(DataFrame): barcode/sample matrix containing the probability of seeing sample s with observation of barcode c
             lP_c_s(DataFrame): barcode/sample matrix containing the log likelihood of seeing barcode c under sample s, whose sum should increase by each iteration
             assigned(list): lists of cell/barcode assigned to each cluster/model
             reassigned(list): final assignment adjusted by doublets proportion
             model_af(Dataframe): Dataframe of model allele frequencies P(A) for each SNV and state
        """
        self.ref_bc_mtx, self.alt_bc_mtx = base_calls_mtx[0], base_calls_mtx[1]
        self.all_POS, self.barcodes = base_calls_mtx[2].tolist(), base_calls_mtx[3].tolist()
        self.num = num
        self.P_s_c = pd.DataFrame(0, index = self.barcodes, columns = range(self.num))
        self.lP_c_s = pd.DataFrame(0, index = self.barcodes, columns = range(self.num))
        self.lP_s = [np.log2(1/self.num)] * self.num
        self.assigned, self.reassigned = [], []
        for _ in range(self.num):
            self.assigned.append([])
            self.reassigned.append([])
        self.model_af = pd.DataFrame(0, index=self.all_POS, columns=range(self.num))
        self.pseudo = 1

        # set background alt count proportion as allele fraction for each SNVs of doublet state, with pseudo count added for 0 counts on multi-base SNPs
        N_alt = self.alt_bc_mtx.sum(axis=1) + self.pseudo
        N_ref = self.ref_bc_mtx.sum(axis=1) + self.pseudo
        self.k_alt = N_alt / (N_ref + N_alt)

        # find barcodes for state initialization, using subsetting/PCA/K-mean
        base_mtx = (self.alt_bc_mtx + self.ref_bc_mtx).toarray()
        rows, cols = [*range(base_mtx.shape[0])], [*range(base_mtx.shape[1])]
        irows, icols = np.array(rows), np.array(cols)
        nrows, ncols = len(rows), len(cols)
        mrows = mcols = 0
        minimum = int(self.num / 0.9 + 0.5) + 1
        while (mrows < 0.9 * nrows or mcols < 0.9 * ncols) and ncols >= minimum:
            rbrows = np.sort(np.unique(list(map(int, np.random.beta(1,10,int(0.1*nrows + 0.5))*nrows))))
            rbcols = np.sort(np.unique(list(map(int, np.random.beta(1,10,int(0.1*ncols + 0.5))*ncols))))
            if len(rbrows) == 0 and len(rbcols) == 0:   # if no more rows or cols be removed
                break
            rows = np.count_nonzero(base_mtx, axis=1).argsort().tolist()
            cols = np.count_nonzero(base_mtx, axis=0).argsort().tolist()
            rows = [item for index, item in enumerate(rows) if index not in rbrows]
            cols = [item for index, item in enumerate(cols) if index not in rbcols]
            irows, icols = irows[rows], icols[cols]
            nrows, ncols = len(rows), len(cols)
            base_mtx = base_mtx[rows][:,cols]
            mrows = min(np.count_nonzero(base_mtx, axis=0))
            mcols = min(np.count_nonzero(base_mtx, axis=1))
        alt_subset = self.alt_bc_mtx[irows][:, icols].todense()
        ref_subset = self.ref_bc_mtx[irows][:, icols].todense()
        alt_prop = (alt_subset + 0.01) / (alt_subset + ref_subset + 0.02)
        alt_pca = StandardScaler().fit_transform(alt_prop.T)
        pca = PCA(n_components=min(nrows, ncols, 20))
        pca_alt = pca.fit_transform(alt_pca)
        if pca_alt.shape[0] < self.num:
            with open('scSplit.log', 'a') as logfile: logfile.write('Not enough informative cells to support model initialization. \n')
        else:
            kmeans = KMeans(n_clusters=self.num, random_state=0).fit(pca_alt)
            # intialise allele frequency for model states
            self.initial = []
            for n in range(self.num):
                self.initial.append([])
                for index, col in enumerate(icols):
                    if kmeans.labels_[index] == n:
                        self.initial[n].append(self.barcodes[col])
                barcode_alt = np.array(self.alt_bc_mtx[:, icols[kmeans.labels_==n]].sum(axis=1))
                barcode_ref = np.array(self.ref_bc_mtx[:, icols[kmeans.labels_==n]].sum(axis=1))
                self.model_af.loc[:, n] = (barcode_alt + self.k_alt) / (barcode_alt + barcode_ref + self.pseudo)


    def run_EM(self):
        """
        Expectation-Maximization iterations
        """

        # commencing E-M
        iterations = 0
        self.sum_log_likelihood = [1,2]
        while (iterations < 1000) & (self.sum_log_likelihood[-2] != self.sum_log_likelihood[-1]):
            iterations += 1
            with open('scSplit.log', 'a') as logfile: logfile.write('E-M Iteration ' + str(iterations) + '   ' + str(datetime.datetime.now()) + '\n')
            self.calculate_cell_likelihood()
            self.calculate_model_af()
            self.sum_log_likelihood.append(self.lP_c_m)


    def calculate_cell_likelihood(self):
        """
        Calculate cell|sample likelihood P(c|s) and derive sample|cell probability P(s|c)
        P(c|s_v) = P(N(A),N(R)|s) = P(g_A|s)^N(A) * (1-P(g_A|s))^N(R)
        log(P(c|s)) = sum_v{(N(A)_c,v*log(P(g_A|s)) + N(R)_c,v*log(1-P(g_A|s)))}
        P(s_n|c) = P(c|s_n) / [P(c|s_1) + P(c|s_2) + ... + P(c|s_n)]
        e.g. log(P(s1|c) = log{1/[1+P(c|s2)/P(c|s1)]} = -log[1+P(c|s2)/P(c|s1)] = -log[1+2^(logP(c|s2)-logP(c|s1))]
        """

        # calculate likelihood P(c|s) based on allele probability
        for n in range(self.num):
            matcalc = self.alt_bc_mtx.T.multiply(self.model_af.loc[:, n].apply(np.log2)).T \
                    + self.ref_bc_mtx.T.multiply((1 - self.model_af.loc[:, n]).apply(np.log2)).T
            self.lP_c_s.loc[:, n] = matcalc.sum(axis=0).tolist()[0]

        # transform to cell sample probability P(s|c) using Baysian rule
        for i in range(self.num):
            denom = 0
            for j in range(self.num):
                denom += 2 ** (self.lP_c_s.loc[:, j] + self.lP_s[j] - self.lP_c_s.loc[:, i] - self.lP_s[i])
            self.P_s_c.loc[:, i] = 1 / denom

        # calculate model likelihood: logP(x|theta) = log{Sigma_y[P(x,y|theta)]}
        self.lP_c_m = ((self.lP_c_s.subtract(self.lP_c_s.min(axis=1), axis=0).pow(2).sum(axis=1) + 1).apply(np.log2) + self.lP_c_s.min(axis=1)).sum(axis=0)


    def calculate_model_af(self):
        """
        Update the model allele fraction by distributing the alt and total counts of each barcode on a certain snv to the model based on P(s|c)
        """

        self.model_af = pd.DataFrame((self.alt_bc_mtx.dot(self.P_s_c) + self.k_alt) / ((self.alt_bc_mtx + self.ref_bc_mtx).dot(self.P_s_c) + self.pseudo),
                                        index = self.all_POS, columns = range(self.num))
        self.lP_s = np.log2(self.P_s_c.sum(axis=0))


    def assign_cells(self):
        """
        Final assignment of cells according to P(s|c) >= 0.9
        """

        for n in range(self.num):
            self.assigned[n] = sorted(self.P_s_c.loc[self.P_s_c[n] >= 0.99].index.values.tolist())


    def define_doublet(self):
        """
        Locate the doublet state
        """
        cross_state = pd.DataFrame(0, index = range(self.num), columns = range(self.num))
        for i in range(self.num):
            for j in range(self.num):
                index = []
                # transform barcode assignments to indices
                for item in self.assigned[j]:
                    index.append(self.barcodes.index(item))
                matcalc = self.alt_bc_mtx.T.multiply(self.model_af.loc[:, i].apply(np.log2)).T \
                        + self.ref_bc_mtx.T.multiply((1 - self.model_af.loc[:, i]).apply(np.log2)).T
                cross_state.loc[j, i] = matcalc[:, index].sum()
        result = cross_state.sum(axis=0).tolist()
        self.doublet = result.index(max(result))


    def refine_doublets(self, doublets):
        """
        Optimize assignments based on doublet expectations
        """
        N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=self.all_POS, columns=range(self.num)), pd.DataFrame(0, index=self.all_POS, columns=range(self.num))
        found = []
        self.reassigned = self.assigned.copy()
        if doublets == 0:
            self.doublet = -1
        elif doublets > 0: # if user has set expectation on doublet proportion, otherwise go with default doublet detection
            for n in range(self.num):
                bc_idx = [i for i, e in enumerate(self.barcodes) if e in self.assigned[n]]
                # REF/ALT alleles counts from cells assigned to state n
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[:, bc_idx].sum(axis=1), self.alt_bc_mtx[:, bc_idx].sum(axis=1)
            # get total non zero variants per state
            rps = self.alt_bc_mtx.T.dot(1 - self.model_af) * (self.P_s_c >= 0.99)
            rpc = rps.drop(self.doublet, axis=1)
            lack = len(self.barcodes) * doublets - len(self.assigned[self.doublet])
            if lack > 0:
                found = rpc.index[rpc.sum(axis=1).argsort()[int(len(rpc.index) - lack):len(rpc.index)]]
                for n in range(self.num):
                    if n != self.doublet:
                        self.reassigned[n] = [x for x in self.assigned[n] if x not in found]
                    else:
                        self.reassigned[n] += found.tolist()


    def distinguishing_alleles(self, pos=[]):
        """
        Locate the distinguishing alleles
        N_ref_mtx, N_alt_mtx: SNV-state matrix for ref/alt counts in each state
        """
        
        # build SNV-state matrices for ref and alt counts
        self.dist_variants, ncols = [], self.num - 1
        if len(pos) != 0:
            snv = [self.all_POS[i] for i in pos]
            N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=snv, columns=range(self.num)), pd.DataFrame(0, index=snv, columns=range(self.num))
        else:
            N_ref_mtx, N_alt_mtx = pd.DataFrame(0, index=self.all_POS, columns=range(self.num)), pd.DataFrame(0, index=self.all_POS, columns=range(self.num))

        for n in range(self.num):
            bc_idx = [i for i, e in enumerate(self.barcodes) if e in self.reassigned[n]]
            # REF/ALT alleles counts from cells assigned to state n
            if len(pos) == 0:
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[:, bc_idx].sum(axis=1), self.alt_bc_mtx[:, bc_idx].sum(axis=1)
            else:
                N_ref_mtx.loc[:, n], N_alt_mtx.loc[:, n] = self.ref_bc_mtx[pos][:, bc_idx].sum(axis=1), self.alt_bc_mtx[pos][:, bc_idx].sum(axis=1)

        # judge N(A) or N(R) for each cluster
        if self.doublet == 0:
            alt_or_ref = ((N_alt_mtx >= 10) * 1 - ((N_ref_mtx >= 10) & (N_alt_mtx == 0)) * 1).astype(np.int8)
        else:
            alt_or_ref = ((N_alt_mtx >= 10) * 1 - ((N_ref_mtx >= 10) & (N_alt_mtx == 0)) * 1).drop(self.doublet, axis=1).astype(np.int8)
        alt_or_ref[alt_or_ref == 0], alt_or_ref[alt_or_ref == -1] = float('NaN'), 0     # formatting data for further analysis
        alt_or_ref = alt_or_ref.ix[[x for x in alt_or_ref.index if x[0] not in ['X','Y','MT']]]
        submatrix = alt_or_ref.copy()
        
        # find unique alleles for each column window and then merge
        while len(self.dist_variants) < self.num - 1:                          
            found, todo = [], []
            while len(found) < ncols - 1:
                selected = []
                # informative alleles for the selected clusters with no NAs
                informative_sub = submatrix[(submatrix.var(axis=1) > 0) & (submatrix.count(axis=1) == ncols)]
                if informative_sub.index.values.size > 0:
                    patt = informative_sub.astype(str).values.sum(axis=1)
                    unq = np.unique(patt, return_inverse=True)                  
                    for j in range(len(unq[0])):
                        # first informative alleles for each unique pattern in the cluster screen which has maximum non-NA values in original information matrix
                        selected.append(alt_or_ref.loc[informative_sub.iloc[[i for i, x in enumerate(unq[1]) if x == j]].index].count(axis=1).idxmax())
                    subt = submatrix.reindex(np.unique(selected)).iloc[:, 0:ncols]
                    subt[subt.isna()]=-1
                    d = np.linalg.svd(subt, full_matrices=False)[1]
                    if sum(d>1e-10) >= ncols:
                        found = [subt[subt.sum(axis=1) == min(subt.sum(axis=1))].index[0]]
                        while len(found) < ncols:            # Gram-Schmidt Process
                            svd = np.linalg.svd(subt.reindex(found).transpose(), full_matrices=False)
                            U, V = svd[0], svd[2]
                            if len(found) == 1: D=np.asmatrix(svd[1])
                            else: D = np.diag(svd[1])
                            colU, Vinv, Dinv = U.shape[1], np.linalg.solve(V, np.diag([1]*len(V))), np.linalg.solve(D, np.diag([1]*len(D)))
                            proj = np.asmatrix(np.zeros((colU,subt.shape[0])))
                            for j in range(subt.shape[0]):
                                for i in range(colU):
                                    proj[i, j] = np.dot(U[:,i], subt.iloc[j])
                            W = np.dot(np.dot(Vinv, Dinv), proj)
                            R = np.dot(subt.reindex(found).transpose(), W)
                            diff = (subt.transpose() - R).transpose()
                            subt1 = subt.reindex(diff[diff.var(axis=1) > (0.5 * max(diff.var(axis=1)))].index)
                            found += [subt1[subt1.sum(axis=1) == min(subt1.sum(axis=1))].index[0]]
                ncols -= 1
                submatrix = submatrix.iloc[:, 0:ncols]
            self.dist_variants += found
            if len(found) == 0:
                with open('scSplit_dist_variants.txt', 'w') as logfile:
                    logfile.write('Not all clusters can be distinguished. \n')
                break
            # find indistinguishable clusters to form new submatrix
            for i in range(1, len(alt_or_ref.columns)):
                for j in range(i):
                    if (alt_or_ref.reindex(self.dist_variants).iloc[:, [i, j]].var(axis=1)).sum() == 0:
                        todo.extend([i, j])
            submatrix = alt_or_ref.iloc[:, sorted(set(todo))]
            ncols = len(submatrix.columns)

        self.dist_variants = list(set(self.dist_variants))
        self.dist_matrix = alt_or_ref.reindex(self.dist_variants)
        self.pa_matrix = alt_or_ref


class scSplit():
    """
    Class for commands
    """
    def __init__(self):
        parser = argparse.ArgumentParser(
            description='Genotype-free demultiplexing of pooled single-cell RNA-Seq',
            usage='''scSplit <command> [<args>]

Commands:
   count     Generate REF/ALT count matrices from pooled BAM file
   run       Demultiplex the scRNA-Seq using REF/ALT count matrices
   genotype  Generate genotype information in VCF format
''')
        parser.add_argument('command', help='Subcommand to run')

        args = parser.parse_args(sys.argv[1:2])
        if hasattr(self, args.command):
            getattr(self, args.command)()


    def count(self):    # Generate REF/ALT count matrices from pooled BAM file

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Generate REF/ALT count matrices from pooled BAM file',
            usage='''scSplit count [<args>]

Options:
   -v, --vcf        Genotype of mixed BAM in VCF format
   -i, --bam        Mixed sample BAM file
   -b, --bar        Barcodes whitelist
   -c, --com        (Optional) Common SNVs
   -r, --ref        REF count CSV output
   -a, --alt        ALT count CSV output
''')
        parser.add_argument('-v', '--vcf', required=True,  help='Genotype of mixed BAM in VCF format')
        parser.add_argument('-i', '--bam', required=True, help='Mixed sample BAM file')
        parser.add_argument('-b', '--bar', required=True,  help='Barcodes whitelist')
        parser.add_argument('-t', '--tag', required=False,  default='CB', help='Tag for barcode (default: "CB")')
        parser.add_argument('-c', '--com', required=False,  help='(Optional) Common SNVs')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV output')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV output')
        args = parser.parse_args(sys.argv[2:])
        
        epsilon = 0.01

        def read_vcf(path):
            with open(path, 'r') as f:
                lines = [l for l in f if not l.startswith('##')]
            return pd.read_csv(
                io.StringIO(''.join(lines)),
                dtype={'#CHROM': str, 'POS': int, 'ID': str, 'REF': str, 'ALT': str,
                    'QUAL': str, 'FILTER': str, 'INFO': str},
                sep='\t'
            ).rename(columns={'#CHROM': 'CHROM'})

        mixed_vcf = read_vcf(args.vcf)
        try:
            whitelist = pd.read_csv(args.com, header=None)
            filterlist = 'c'
        except:
            filterlist = 'n'

        # only keep SNVs with heterozygous genotypes, and ignore SNV with multiple bases (e.g. GGGT/GGAT)

        mixed_vcf = mixed_vcf.set_index(mixed_vcf['CHROM'] + ':' + mixed_vcf['POS'].apply(str))
        item = mixed_vcf.iloc[0].loc['FORMAT'].split(':').index('GL')
        filtered_vcf = mixed_vcf[(mixed_vcf['REF'].apply(len)==1) & (mixed_vcf['ALT'].apply(len)==1)]
        if filterlist == 'c':
            snv_list = sorted(list(set(whitelist[0].tolist()).intersection(filtered_vcf.index)))
        else:
            snv_list = sorted(filtered_vcf.index)
        filtered_vcf = filtered_vcf.loc[snv_list]
        snv_idx = [indexes for indexes, RA in enumerate([float(x.split(':')[item].split(',')[1]) for x in filtered_vcf.iloc[:,-1]]) if RA > np.log10(1-epsilon)]
        filtered_vcf = filtered_vcf.iloc[snv_idx][['CHROM','POS','REF','ALT']]

        barcodes = []   # list of cell barcodes
        for line in open(args.bar, 'r'):
            barcodes.append(line.strip())

        base_calls_mtx = mixed_VCF().build_base_calls_matrix(args.bam, filtered_vcf, barcodes, args.tag)
        
        if (base_calls_mtx[0] + base_calls_mtx[1]).sum().sum() == 0:
            raise ValueError('Empty matrices!')
        else:
            base_calls_mtx[0].to_csv('{}'.format(args.ref))
            base_calls_mtx[1].to_csv('{}'.format(args.alt))


    def run(self):      # Demultiplex the scRNA-Seq using REF/ALT count matrices

        # core E-M and cell assignment functions
        def core(num):
            max_likelihood = -1e10
            repeats = args.ems
            for i in range(repeats):
                with open('scSplit.log', 'a') as logfile: logfile.write('Repeat ' + str(i+1) + ' in demultiplexing ' + str(num) + \
                    ' subpopulations (including doublets if expected)' + '\n')
                with open('scSplit.log', 'a') as logfile: logfile.write('Building model... \n')
                model = models(base_calls_mtx, num)
                if model.model_af.sum().sum() > 0:
                    model.run_EM()
                    model.assign_cells()
                    if model.lP_c_s.max(axis=1).sum() > max_likelihood:
                        max_likelihood = model.lP_c_s.max(axis=1).sum()
                        initial, assigned, af, p_s_c = model.initial, model.assigned, model.model_af, model.P_s_c
            model.assigned, model.initial, model.model_af, model.P_s_c = assigned, initial, af, p_s_c
            return(model)

        # function to find elbow point of a curve
        def elbow(points):
            points = [-x for x in points]
            p1, p2, m, e = np.array([0, points[0]]), np.array([len(points)-1, points[-1]]), 0, 0
            for i in range(1, (len(points)-1)):
                p3 = np.array([i, points[i]])
                d = np.linalg.norm(np.cross(p2-p1, p1-p3))/np.linalg.norm(p2-p1)
                if d > m:
                    e = i
            return(e)

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Demultiplex the scRNA-Seq using REF/ALT count matrices',
            usage='''scSplit run [<args>]

Options:
   -r, --ref    REF count CSV input
   -a, --alt    ALT count CSV input
   -n, --num    Number of mixed samples
   -s, --sub    (optional) Maximum number of subpopulations, default: 10
   -e, --ems    (Optional) Number of EM repeats to avoid local maximum, default: 30
   -d, --dbl    (Optional) Doublet proportion
   -v, --vcf    (Optional) VCF file for filtering distinguishing variants
''')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV input')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV input')
        parser.add_argument('-n', '--num', type=int, required=True,  help='Number of mixed samples')
        parser.add_argument('-s', '--sub', type=int, required=False, default='10', help='(Optional) Largest possible number of subpopulations, default: 10')
        parser.add_argument('-e', '--ems', type=int, required=False, default='30', help='(Optional) Number of EM repeats to avoid local maximum, default: 30')
        parser.add_argument('-d', '--dbl', type=float, required=False, help='(Optional) Doublet proportion')
        parser.add_argument('-v', '--vcf', required=False, help='(Optional) VCF file for filtering distinguishing variants')
        args = parser.parse_args(sys.argv[2:])

        try:
            doublets = float(args.dbl)
        except:
            doublets = -1

        if doublets == 0:
            num = args.num
        else:
            num = args.num + 1  # additional doublet state

        with open('scSplit.log', 'a') as logfile: logfile.write('\n' + ' '.join(sys.argv) + '\n')
        with open('scSplit.log', 'a') as logfile: logfile.write('Loading allele count matrices: ' + str(datetime.datetime.now()) + '\n')

        # Read in existing matrix from the csv files
        ref = pd.read_csv(args.ref, header=0, index_col=0)
        alt = pd.read_csv(args.alt, header=0, index_col=0)
        ref_s, alt_s = csr_matrix(ref.values), csr_matrix(alt.values)
        base_calls_mtx = [ref_s, alt_s, ref.index, ref.columns]
        with open('scSplit.log', 'a') as logfile: logfile.write('Allele counts matrices loaded: ' + str(datetime.datetime.now()) + '\n')

        if args.num == 0:
            all_models, LL = [], []
            for sub in range(2, args.sub + 1):
                model = core(sub)
                all_models.append(model)
                LL.append(model.lP_c_s.max(axis=1).sum())
            elb = elbow(LL)
            model = all_models[elb]
            num = elb + 2
        else:
            model = core(num)

        model.define_doublet()
        model.refine_doublets(doublets)

        pos = []
        try:
            for record in vcf.Reader(open(args.vcf, 'r')):
                # only keep high R2 variants
                try:
                    if float(record.INFO['R2'][0]) > 0.9:
                        pos.append(model.all_POS.index(str(record.CHROM)+':'+str(record.POS)))
                except:
                    continue
        except:
            pass
            
        model.distinguishing_alleles(pos)

        # generate outputs
        # with open('scSplit_model', 'wb') as f:
        #    pickle.dump(model, f, pickle.HIGHEST_PROTOCOL)
        assignments = pd.DataFrame()
        for n in range(num):
            assignment = pd.DataFrame()
            assignment['Barcode'] = model.reassigned[n]
            if n != model.doublet:
                assignment['Cluster'] = 'SNG-' + str(n)
                assignments = assignments.append(assignment)
        if doublets != 0:   # if doublet cluster is expected
            assignment = pd.DataFrame()
            assignment['Barcode'] = model.reassigned[model.doublet]
            assignment['Cluster'] = 'DBL-' + str(model.doublet)
            assignments = assignments.append(assignment)
        assignments.to_csv('scSplit_result.csv', sep='\t', index=False)
        model.P_s_c.to_csv('scSplit_P_s_c.csv')
        with open('scSplit_dist_variants.txt', 'w') as logfile:
            for item in model.dist_variants:
                logfile.write(str(item) + '\n')
        model.dist_matrix.to_csv('scSplit_dist_matrix.csv', float_format='%.0f')
        model.pa_matrix.to_csv('scSplit_PA_matrix.csv', float_format='%.0f')


    def genotype(self):     # Generate genotype information in VCF format

        # Process command line arguments
        parser = argparse.ArgumentParser(description='Generate genotype information in VCF format',
            usage='''scSplit genotype [<args>]

Options:
   -r, --ref    REF count CSV input
   -a, --alt    ALT count CSV input
   -p, --psc    generated P(S|C)
''')
        parser.add_argument('-r', '--ref', required=True,  help='REF count CSV input')
        parser.add_argument('-a', '--alt', required=True,  help='ALT count CSV input')
        parser.add_argument('-p', '--psc', required=True, help='generated P(S|C)')
        args = parser.parse_args(sys.argv[2:])

        ref = pd.read_csv(args.ref, header=0, index_col=0)
        alt = pd.read_csv(args.alt, header=0, index_col=0)
        ref_s = csr_matrix(ref.values)
        alt_s = csr_matrix(alt.values)
        all_POS = ref.index

        # get cell assignment
        P_s_c = pd.read_csv(args.psc, header=0, index_col=0)
        A_s_c = ((P_s_c >= 0.9) * 1).astype('float64')
        num = len(P_s_c.columns)

        err = 0.01  # error rate assumption
        # binomial simulation for genotype likelihood P(D|AA,RA,RR) with the alt count vs total count condition and (err, 0.5, 1-err) as allele probability
        lp_d_rr = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), err), index=all_POS, columns=range(num)).apply(np.log10)
        lp_d_ra = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), 0.5), index=all_POS, columns=range(num)).apply(np.log10)
        lp_d_aa = pd.DataFrame(binom.pmf(pd.DataFrame(alt_s.dot(A_s_c)), pd.DataFrame((alt_s + ref_s).dot(A_s_c)), 1-err), index=all_POS, columns=range(num)).apply(np.log10)

        vcf_content = pd.DataFrame(index = all_POS, columns = range(-9, num))  # -9~-1: meta data, 0: doublet state, 1:num: samples
        names = vcf_content.columns.tolist()
        names[0] = '#CHROM'
        names[1] = 'POS'
        names[2] = 'ID'
        names[3] = 'REF'
        names[4] = 'ALT'
        names[5] = 'QUAL'
        names[6] = 'FILTER'
        names[7] = 'INFO'
        names[8] = 'FORMAT'
        vcf_content.columns = names
        vcf_content.loc[:,'#CHROM'] = [item.split(':')[0] for item in all_POS]
        vcf_content.loc[:,'POS'] = [item.split(':')[1] for item in all_POS]
        vcf_content.loc[:,'ID'] = all_POS
        vcf_content.loc[:,'REF'] = '.'
        vcf_content.loc[:,'ALT'] = '.'
        vcf_content.loc[:,'QUAL'] = '.'
        vcf_content.loc[:,'FILTER'] = '.'
        vcf_content.loc[:,'INFO'] = '.'
        vcf_content.loc[:,'FORMAT'] = 'GP:GL'

        # round to three decimal points
        GL_RR = 10 ** lp_d_rr.astype(float)
        GL_RA = 10 ** lp_d_ra.astype(float)
        GL_AA = 10 ** lp_d_aa.astype(float)
        GL_nom = GL_RR + GL_RA + GL_AA
        GP_RR = round(GL_RR / GL_nom, 3).astype(str)
        GP_RA = round(GL_RA / GL_nom, 3).astype(str)
        GP_AA = round(GL_AA / GL_nom, 3).astype(str)
        lGL_RR = round(lp_d_rr.astype(float), 3).astype(str)
        lGL_RA = round(lp_d_ra.astype(float), 3).astype(str)
        lGL_AA = round(lp_d_aa.astype(float), 3).astype(str)

        for n in range(num):
            vcf_content.loc[:, n] = GP_RR.loc[:, n] + ',' + GP_RA.loc[:, n] + ',' + GP_AA.loc[:, n] + \
                            ':' + lGL_RR.loc[:, n] + ',' + lGL_RA.loc[:, n] + ',' + lGL_AA.loc[:, n]

        header = '##fileformat=VCFv4.2\n##fileDate=' + str(datetime.datetime.today()).split(' ')[0] + \
                '\n##source=sc_split\n##reference=hg19.fa\n##contig=<ID=1,length=249250621>\n' + \
                '##contig=<ID=10,length=135534747>\n##contig=<ID=11,length=135006516>\n##contig=<ID=12,length=133851895>\n' + \
                '##contig=<ID=13,length=115169878>\n##contig=<ID=14,length=107349540>\n##contig=<ID=15,length=102531392>\n' + \
                '##contig=<ID=16,length=90354753>\n##contig=<ID=17,length=81195210>\n##contig=<ID=18,length=78077248>\n' + \
                '##contig=<ID=19,length=59128983>\n##contig=<ID=2,length=243199373>\n##contig=<ID=20,length=63025520>\n' + \
                '##contig=<ID=21,length=48129895>\n##contig=<ID=22,length=51304566>\n##contig=<ID=3,length=198022430>\n' + \
                '##contig=<ID=4,length=191154276>\n##contig=<ID=5,length=180915260>\n##contig=<ID=6,length=171115067>\n' + \
                '##contig=<ID=7,length=159138663>\n##contig=<ID=8,length=146364022>\n##contig=<ID=9,length=141213431>\n' + \
                '##contig=<ID=MT,length=16569>\n##contig=<ID=X,length=155270560>\n##contig=<ID=Y,length=59373566>\n' + \
                '##FILTER=<ID=PASS,Description="All filters passed">\n##INFO=<ID=AN,Number=1,Type=Integer,Description="Total Allele Count">\n' + \
                '##INFO=<ID=AC,Number=A,Type=Integer,Description="Alternate Allele Count">\n' + \
                '##INFO=<ID=AF,Number=A,Type=Float,Description="Estimated Alternate Allele Frequency">\n' + \
                '##FORMAT=<ID=GL,Number=3,Type=Float,Description="Genotype Likelihood for RR/RA/AA">\n'

        with open('sc_split.vcf', 'w+') as logfile:
            logfile.write(header)
            vcf_content.to_csv(logfile, index=False, sep='\t', quoting=csv.QUOTE_NONE, escapechar='"')

if __name__ == '__main__':
    scSplit()
