#!python

"""Aligns and counts mutations in barcoded subamplicons."""


import os
import glob
import sys
import re
import logging
import gzip
import random
import pandas
import Bio.SeqIO
import dms_tools2.parseargs
import dms_tools2.utils



def bcInfo(bc, bcreads, retained, consensus, desc, to_csv):
    """Returns string for writing to `bcinfofile`.
   
    Creates a string summarizing the barcode.

    Args:
        `bc` (str):
            The barcode.
        `bcreads` (list)
            dict of `{'R1':r1list, 'R2':r2list}`
        `retained` (bool)
            Is the barcode retained?
        `consensus` (str or `None`)
            The consensus sequence for the barcode if created.
        `desc` (str)
            String describing the barcode and its fate.
        `to_csv` (bool)
            If `to_csv` is `True`, outputs the barcode info as a
            comma-delimited string. This string only retains the *number* of
            reads for R1 and R2, not the actual reads.

    Returns:
        A string summarizing the barcode.
    """
    if to_csv:
        num_R1 = len(bcreads['R1'])
        num_R2 = len(bcreads['R2'])
        return(','.join([bc, str(retained), desc, str(consensus),
                        str(num_R1), str(num_R2)+'\n']))
    else:
        return '\n'.join([
                'BARCODE: {0}'.format(bc),
                'RETAINED: {0}'.format(retained),
                'DESCRIPTION: {0}'.format(desc),
                'CONSENSUS: {0}'.format(consensus),
                'R1 READS:\n\t{0}'.format('\n\t'.join(bcreads['R1'])),
                'R2 READS:\n\t{0}'.format('\n\t'.join(bcreads['R1'])),
                '',
                ])


def main():
    """Main body of script."""

    parser = dms_tools2.parseargs.bcsubampParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    random.seed(1)

    # set up names of output files
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = ''
    filesuffixes = {
            'log':'.log',
            'counts':'_{0}counts.csv'.format(args['chartype']),
            'readstats':'_readstats.csv',
            'readsperbc':'_readsperbc.csv',
            'bcstats':'_bcstats.csv',
            }
    if args['bcinfo']:
        if args['bcinfo_csv']:
            filesuffixes['bcinfo'] = '_bcinfo.csv.gz'
        else:
            filesuffixes['bcinfo'] = '_bcinfo.txt.gz'
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['name'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed?
    if args['use_existing'] == 'yes' and all(map(
                os.path.isfile, files.values())):
        print("Output files already exist and '--use_existing' is 'yes', "
              "so exiting with no further action.")
        sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:

        assert dms_tools2.parseargs.checkName(args['name'], 'name')

        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        assert not (args['purgeread'] and args['purgebc']), ("It does "
                "not make sense to use both --purgeread and --purgebc "
                "as they subsample the data in different ways.")

        # read refseq
        refseq = [s for s in Bio.SeqIO.parse(args['refseq'], 'fasta')]
        assert len(refseq) == 1, "refseq does not specify one sequence" 
        refseq = str(refseq[0].seq).upper()
        if args['chartype'] == 'codon':
            assert re.search('^[{0}]+$'.format(''.join(dms_tools2.NTS)), 
                    refseq), "refseq does not contain only DNA nts"
            assert len(refseq) % 3 == 0, "refseq length not multiple of 3"
            logger.info('Read refseq of {0} codons from {1}'.format(
                    len(refseq), args['refseq']))
        else:
            raise ValueError("Invalid chartype")

        bclen1 = args['bclen']
        assert bclen1 >= 0, 'bclen not > 0'
        if args['bclen2'] is None:
            bclen2 = bclen1
        else:
            bclen2 = args['bclen2']
            assert bclen2 >= 0, 'bclen2 not > 0'

        assert 1 >= args['minconcur'] > 0.5
        assert args['minreads'] > 0
        assert 1 >= args['minfraccall'] > 0
        assert args['maxmuts'] >= 0

        # check validity of alignspecs
        alignspecs = []
        for s in args['alignspecs']:
            (refseqstart, refseqend, r1start, r2start) = map(int, s.split(','))
            assert r1start >= 1, "R1START must be >= 1"
            assert r1start >= bclen1, ("alignspecs has R1START {0}, which "
                    "doen't trim barcode of length {1}".format(r1start, bclen1))
            assert r2start >= 1, "R2START must be >= 1"
            assert r2start >= bclen2, ("alignspecs has R2START {0}, which "
                    "doen't trim barcode of length {1}".format(r2start, bclen2))
            assert refseqend > refseqstart, "REFSEQEND <= REFSEQSTART"
            assert refseqstart >= 0, "REFSEQSTART < 1"
            assert refseqend <= len(refseq), "REFSEQEND > len(refseq)"
            maxN = (refseqend - refseqstart + 1) * (1 - args['minfraccall'])
            # subtract bclen because barcode trimmed before aligning
            alignspecs.append((refseqstart, refseqend, 
                    r1start - bclen1, r2start - bclen2, maxN))

        # set up R1 and R2 trims based on alignspecs
        trims_d = {}
        maxtrim = {}
        for (r, bclenr) in [('R1', bclen1), ('R2', bclen2)]:
            trimarg = args['{0}trim'.format(r)]
            if trimarg is None:
                trimarg = [None]
            elif isinstance(trimarg, int):
                trimarg = [trimarg]
            if len(trimarg) == 1:
                trims_d[r] = [trimarg[0]] * len(alignspecs)
            else:
                trims_d[r] = trimarg
            assert len(trims_d[r]) == len(alignspecs), ("--{0}trim must be "
                    "one value or same length as --alignspecs".format(r))
            assert all([(t is None) or (t > bclenr) for t in trims_d[r]]), \
                    ("--{0}trim must all be greater than bclen of {1}"
                    .format(r, bclenr))
            if all([t is None for t in trims_d[r]]):
                maxtrim[r] = None
            else:
                maxtrim[r] = max(trims_d[r])
                # we remove barcodes from reads below, so adjust trims
                trims_d[r] = [t - bclenr for t in trims_d[r]]
        trims = [(trims_d['R1'][i], trims_d['R2'][i]) for i in 
                range(len(alignspecs))]

        # check on read files
        if not args['fastqdir']:
            args['fastqdir'] = ''
        else:
            assert os.path.isdir(args['fastqdir']), "Invalid --fastqdir"
        r1files = []
        for f in args['R1']:
            r1files += sorted(glob.glob(os.path.join(args['fastqdir'], f)))
        assert r1files and all(map(os.path.isfile, r1files)), "Missing R1 files"
        if not args['R2']:
            r2files = []
            for r1 in r1files:
                assert r1.count('_R1') == 1, ("Can't guess R2 file for R1 "
                        "file {0}".format(r1))
                r2files.append(r1.replace('_R1', '_R2'))
        else:
            r2files = []
            for f in args['R2']:
                r2files += sorted(glob.glob(os.path.join(args['fastqdir'], f)))
            assert len(r1files) == len(r2files), "R1 and R2 not same length"
        assert all(map(os.path.isfile, r2files)), "Missing R2 files"
        logger.info("Reads are in these FASTQ pairs:\n\t{0}\n".format(
                '\n\t'.join(['{0} and {1}'.format(r1, r2) for (r1, r2) in
                zip(r1files, r2files)])))

        # collect reads by barcode while iterating over reads
        logger.info("Now parsing read pairs...")
        nreads = {
                'total':0,
                'fail filter':0,
                'low Q barcode':0,
            }
        if args['purgeread']:
            nreads['purged'] = 0
            logger.info("Purging read pairs with probability {0:.3f} to "
                    "subsample the data.".format(args['purgeread']))
        minqchar = chr(args['minq'] + 33) # character for Q score cutoff

        barcodes = {} # barcodes key {'R1':r1list, 'R2':r2list}

        for read_tup in dms_tools2.utils.iteratePairedFASTQ(r1files, r2files,
                    maxtrim['R1'], maxtrim['R2']):

            nreads['total'] += 1
            if nreads['total'] % 5e5 == 0:
                logger.info("Reads parsed so far: {0}".format(nreads['total']))

            if args['purgeread']:
                if random.random() < args['purgeread']:
                    nreads['purged'] += 1
                    continue

            (name, r1, r2, q1, q2, failfilter) = read_tup

            if failfilter:
                nreads['fail filter'] += 1
                continue

            r1 = dms_tools2.utils.lowQtoN(r1, q1, minqchar)
            r2 = dms_tools2.utils.lowQtoN(r2, q2, minqchar)

            barcode = r1[ : bclen1] + r2[ : bclen2]
            if 'N' in barcode:
                nreads['low Q barcode'] += 1
                continue

            if barcode in barcodes:
                barcodes[barcode]['R1'].append(r1[bclen1 : ])
                barcodes[barcode]['R2'].append(r2[bclen2 : ])
            else:
                barcodes[barcode] = {'R1':[r1[bclen1 : ]], 
                                     'R2':[r2[bclen2 : ]]}

        logger.info('Parsed {0} reads, found {1} unique barcodes.'.format(
                nreads['total'], len(barcodes)))
        readstats = pandas.DataFrame(nreads, index=[0])
        logger.info("Summary stats on reads:\n{0}".format(
                readstats.to_string(index=False)))
        logger.info("Writing these stats to {0}\n".format(files['readstats']))
        readstats.to_csv(files['readstats'], index=False)

        # collect stats on reads per barcode
        readsperbc = {}
        for bcreads in barcodes.values():
            nforbc = len(bcreads['R1'])
            assert nforbc == len(bcreads['R2'])
            if nforbc in readsperbc:
                readsperbc[nforbc] += 1
            else:
                readsperbc[nforbc] = 1
        readsperbcstats = pandas.DataFrame(sorted(readsperbc.items()),
                columns=['number of reads', 'number of barcodes']
                ).set_index('number of reads')
        logger.info("Number of reads per barcode:\n{0}".format(
                readsperbcstats.to_string()))
        logger.info("Writing these stats to {0}\n".format(files['readsperbc']))
        readsperbcstats.to_csv(files['readsperbc'])

        # now loop over barcodes and build / align subamplicons
        nbcs = {
                'total':len(barcodes),
                'too few reads':0,
                'not alignable':0,
                'aligned':0,
               }

        if args['purgebc']:
            logger.info('Purging barcodes with probability {0:.3f}'
                    'to subsample the data.'.format(args['purgebc']))
            nbcs['purged'] = 0
            for barcode in list(barcodes.iterkeys()):
                if random.random() < args['purgefrac']:
                    nbcs['purged'] += 1
                    npurged += 1
                    del barcodes[barcode]
            logger.info('Purged {0} of {1} barcodes ({2:.1f}%%).\n'.format(
                    nbcs['purged'], nbcs['total'], 
                    nbcs['purged'] / float(nbcs['total']) * 100))

        logger.info('Examining the {0} barcodes to build and align '
                'subamplicons...'.format(len(barcodes)))

        counts = {} # dictionary to hold codon counts
        if args['chartype'] == 'codon':
            nsites = len(refseq) // 3
            counts['site'] = list(range(1, nsites + 1))
            counts['wildtype'] = [refseq[3 * i : 3 * i + 3] 
                    for i in range(nsites)]
            for codon in dms_tools2.CODONS:
                counts[codon] = [0] * nsites
        else:
            raise ValueError("Invalid chartype")

        if args['bcinfo']:
            bcinfofile = gzip.open(files['bcinfo'], 'wt')

        if args['bcinfo_csv']:
            bcinfofile.write("Barcode,Retained,Description,Consensus,"
                             "R1_Count,R2_Count\n")

        for (ibc, (bc, bcreads)) in enumerate(barcodes.items()):

            if (ibc + 1) % 2e5 == 0:
                logger.info("Barcodes examined so far: {0}".format(
                        ibc + 1))

            if len(bcreads['R1']) < args['minreads']:
                nbcs['too few reads'] += 1
                if args['bcinfo']:
                    bcinfofile.write(bcInfo(bc, bcreads, retained=False, 
                            consensus=None, desc='too few reads',
                            to_csv=args['bcinfo_csv']))
                continue

            consensus = {}
            for (itup, r) in enumerate(['R1', 'R2']):
                consensus[r] = dms_tools2.utils.buildReadConsensus(
                        bcreads[r],
                        args['minreads'], args['minconcur'])

            for ((r1trim, r2trim), (refseqstart, refseqend,
                    r1start, r2start, maxN)) \
                    in zip(trims, alignspecs):

                if r1trim is None:
                    r1trimconsensus = consensus['R1']
                else:
                    r1trimconsensus = consensus['R1'][ : r1trim]
                if r2trim is None:
                    r2trimconsensus = consensus['R2'][ : r2trim]
                else:
                    r2trimconsensus = consensus['R2'][ : r2trim]

                subamplicon = dms_tools2.utils.alignSubamplicon(refseq,
                        r1trimconsensus[r1start - 1 : ], 
                        r2trimconsensus[r2start - 1 : ],
                        refseqstart, refseqend, args['maxmuts'],
                        maxN, args['chartype'])

                if subamplicon:
                    if args['bcinfo']:
                        bcinfofile.write(bcInfo(bc, bcreads, retained=True,
                            consensus=subamplicon, desc='aligned at '
                            'position {0}'.format(refseqstart),
                            to_csv=args['bcinfo_csv']))
                    nbcs['aligned'] += 1
                    dms_tools2.utils.incrementCounts(refseqstart,
                            subamplicon, args['chartype'], counts)
                    break

            else: # read did not align
                nbcs['not alignable'] += 1
                if args['bcinfo']:
                    bcinfofile.write(bcInfo(bc, bcreads, retained=False,
                            consensus=None, desc='could not align',
                            to_csv=args['bcinfo_csv']))

        if args['bcinfo']:
            bcinfofile.close()

        bcstats = pandas.DataFrame(nbcs, index=[0])
        logger.info("Examined all barcodes. Summary stats:\n{0}".format(
                bcstats.to_string(index=False)))
        logger.info("Writing these stats to {0}\n".format(files['bcstats']))
        bcstats.to_csv(files['bcstats'], index=False)

        counts = pandas.DataFrame.from_dict(counts).set_index('site')[
                ['wildtype'] + dms_tools2.CODONS]
        if args['sitemask']:
            logger.info('Filtering to only sites listed in sitemask {0}'
                    .format(args['sitemask']))
            assert os.path.isfile(args['sitemask']), \
                    'no file {0}'.format(args['sitemask'])
            sitemask = pandas.read_csv(args['sitemask'])
            assert 'site' in sitemask.columns, 'no `site` column in sitemask'
            sitestokeep = sitemask['site'].unique()
            norig = len(counts)
            counts = counts.query('site in @sitestokeep')
            logger.info('Filtered from {0} to {1} sites.'
                    .format(norig, len(counts)))
        logger.info("Writing the counts of each {0} identity at each "
                "site to {1}\n".format(args['chartype'], files['counts']))
        counts.to_csv(files['counts'])

    except:
        logger.exception('Terminating {0} with ERROR'.format(prog))
        try:
            bcinfofile.close()
        except:
            pass
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()


if __name__ == '__main__':
    main() # run the script
