#!python

"""Runs a batch of ``dms2_bcsubamp`` programs, summarizes results."""


import os
import glob
import sys
import re
import logging
import functools
import subprocess
import multiprocessing
import multiprocessing.dummy
import pandas
import dms_tools2.parseargs
import dms_tools2.utils
import dms_tools2.plot


def main():
    """Main body of script."""

    parser = dms_tools2.parseargs.batch_bcsubampParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # set up names of output files
    assert dms_tools2.parseargs.checkName(args['summaryprefix'], 
            'summaryprefix')
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = '.'
    filesuffixes = {
            'log':'.log',
            'readstats':'_readstats.pdf',
            'bcstats':'_bcstats.pdf',
            'readsperbc':'_readsperbc.pdf',
            'depth':'_depth.pdf',
            'mutfreq':'_mutfreq.pdf',
            'codonmuttypes':'_codonmuttypes.pdf',
            'codonmuttypes_csv':'_codonmuttypes.csv',
            'codonntchanges':'_codonntchanges.pdf',
            'singlentchanges':'_singlentchanges.pdf',
            'cumulmutcounts':'_cumulmutcounts.pdf',
            }
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['summaryprefix'], s))) for (f, s) in filesuffixes.items()])

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    intentional_exit = False  # are we intentionally exiting program?
    try:

        # read batchfile, strip any whitespace from strings
        logger.info("Parsing sample info from {0}".format(args['batchfile']))
        assert os.path.isfile(args['batchfile']), "no batchfile"
        batchruns = pandas.read_csv(args['batchfile'], na_filter=False)
        batchruns.columns = batchruns.columns.str.strip()
        colnames = set(['name', 'R1'])
        assert set(batchruns.columns) >= colnames, ("batchfile lacks "
                "required column names of: {0}".format(', '.join(colnames)))
        for c in batchruns.columns:
            batchruns[c] = batchruns[c].map(str).map(str.strip)
        logger.info("Read the following sample information:\n{0}\n".format(
                batchruns.to_csv(index=False)))
        assert all([dms_tools2.parseargs.checkName(name, 'name') for name 
                in batchruns['name']])
        assert len(batchruns['name']) == len(set(batchruns['name'].values)),\
                "Duplicated name"

        # files created for each run
        runfiles = {'counts': '_{0}counts.csv'.format(args['chartype']),
                    'readstats': '_readstats.csv',
                    'readsperbc': '_readsperbc.csv',
                    'bcstats': '_bcstats.csv',
                    }
        for filename, filesuffix in runfiles.items():
            batchruns[filename] = (args['outdir'] + '/' + batchruns['name'] +
                                   filesuffix)
        found_all_runfiles = all(all(map(os.path.isfile, batchruns[f].values))
                                 for f in runfiles.keys())

        # do we need to proceed
        if (args['use_existing'] == 'yes' and
                all(map(os.path.isfile, files.values())) and
                found_all_runfiles):
            logger.info("Output files already exist and '--use_existing' "
                        "is 'yes', so exiting with no further action.")
            intentional_exit = True
            sys.exit(0)

        for fname, fpath in files.items():
            if os.path.isfile(fpath) and fname != 'log':
                logger.info("Removing existing file {0}".format(fpath))
                os.remove(fpath)

        # determine how many cpus to use
        if args['ncpus'] == -1:
            ncpus = multiprocessing.cpu_count()
        elif args['ncpus'] > 0:
            ncpus = min(args['ncpus'], multiprocessing.cpu_count())
        else:
            raise ValueError("--ncpus must be -1 or > 0")

        # run dms2_bcsubamp for each sample in batchfile
        logger.info("Running dms2_bcsubamp on all samples using "
                "{0} CPUs...".format(ncpus))
        argslist = []
        for (i, row) in batchruns.iterrows():
            # define newargs to pass to dms2_bcsubamp
            newargs = ['dms2_bcsubamp', '--name', row['name'], 
                       '--R1', row['R1']]
            for (arg, val) in args.items():
                if arg in ['batchfile', 'ncpus', 'summaryprefix']:
                    continue
                elif arg in row:
                    assert val is None, ("`{0}` specified in `batchfile` and"
                            " as command-line argument ({1}).".format(arg, val))
                    newargs.append('--{0}'.format(arg))
                    val = row[arg]
                    if arg in ['R1trim', 'R2trim']:
                        val = row[arg].split()
                    if isinstance(val, list):
                        newargs += list(map(str, val))
                    else:
                        newargs.append(str(val))
                elif val or isinstance(val, (int, float)):
                    newargs.append('--{0}'.format(arg))
                    if isinstance(val, list):
                        newargs += list(map(str, val))
                    elif (arg != 'bcinfo') & (arg != 'bcinfo_csv'):
                        newargs.append(str(val))
            argslist.append(newargs)
        pool = multiprocessing.dummy.Pool(ncpus)
        pool.imap(functools.partial(subprocess.check_output, 
                stderr=subprocess.STDOUT), argslist)
        pool.close()
        pool.join()
        logger.info("Completed runs of dms2_bcsubamp.\n")

        # define dms2_bcsubamp output files and make sure they exist 
        logfiles = args['outdir'] + '/' + batchruns['name'] + '.log'
        for filename in runfiles.keys():
            assert all(map(os.path.isfile, batchruns[filename].values)), (
                    "Did not create all these files:\n{0}\n\nLook in following "
                    "log files for details of what went wrong:\n{1}".format(
                    '\n'.join(batchruns[filename].values), 
                    '\n'.join(logfiles.values)))

        logger.info("Plotting read stats to {0}".format(files['readstats']))
        dms_tools2.plot.plotReadStats(batchruns['name'], 
                batchruns['readstats'], files['readstats'])

        logger.info("Plotting barcode stats to {0}".format(files['bcstats']))
        dms_tools2.plot.plotBCStats(batchruns['name'], 
                batchruns['bcstats'], files['bcstats'])

        logger.info("Plotting reads per barcode to {0}".format(
                files['readsperbc']))
        dms_tools2.plot.plotReadsPerBC(batchruns['name'],
                batchruns['readsperbc'], files['readsperbc'])

        logger.info("Plotting count depth to {0}".format(files['depth']))
        dms_tools2.plot.plotDepth(batchruns['name'], 
                batchruns['counts'], files['depth'])

        logger.info("Plotting mutation frequencies to {0}".format(
                files['mutfreq']))
        dms_tools2.plot.plotMutFreq(batchruns['name'], 
                batchruns['counts'], files['mutfreq'])

        logger.info("Plotting average frequencies of codon mutation "
                "types to {0}, writing the data to {1}".format(
                files['codonmuttypes'], files['codonmuttypes_csv']))
        dms_tools2.plot.plotCodonMutTypes(batchruns['name'],
                batchruns['counts'], files['codonmuttypes'],
                classification='aachange', 
                csvfile=files['codonmuttypes_csv'])

        logger.info("Plotting average frequencies of nucleotide changes per "
                "codon mutation to {0}".format(files['codonntchanges']))
        dms_tools2.plot.plotCodonMutTypes(batchruns['name'],
                batchruns['counts'], files['codonntchanges'],
                classification='n_ntchanges')

        logger.info("Plotting frequencies of nucleotide changes in 1-"
                "nucleotide mutations to {0}".format(files['singlentchanges']))
        dms_tools2.plot.plotCodonMutTypes(batchruns['name'],
                batchruns['counts'], files['singlentchanges'],
                classification='singlentchanges')

        logger.info("Plotting fraction of mutations seen less than some "
                "number of times to {0}".format(files['cumulmutcounts']))
        dms_tools2.plot.plotCumulMutCounts(batchruns['name'],
                batchruns['counts'], files['cumulmutcounts'], 'codon')

    except:
        if not intentional_exit:
            logger.exception('Terminating {0} with ERROR'.format(prog))
            for fname, fpath in files.items():
                if os.path.isfile(fpath) and fname != 'log':
                    logger.exception("Deleting file {0}".format(fpath))
                    os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
