#!python

"""Runs a batch of ``dms2_diffsel`` programs, summarizes results."""


import os
import glob
import sys
import re
import logging
import functools
import subprocess
import multiprocessing
import multiprocessing.dummy
import natsort
import pandas
import dms_tools2.parseargs
import dms_tools2.utils
import dms_tools2.plot
import dms_tools2.diffsel


def main():
    """Main body of script."""

    parser = dms_tools2.parseargs.batch_diffselParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # set up names of output files
    dms_tools2.parseargs.checkName(args['summaryprefix'], 
            'summaryprefix')
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = '.'
    filesuffixes = {
            'log':'.log',
            }
    lineplottypes = ['positive', 'total', 'max', 'minmax']
    for pt in lineplottypes:
        filesuffixes['mean' + pt] = '_mean{0}diffsel.pdf'.format(pt)
        filesuffixes['median' + pt] = '_median{0}diffsel.pdf'.format(pt)
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['summaryprefix'], s))) for (f, s) in 
            filesuffixes.items()])

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:

        # read batchfile, strip any whitespace from strings
        logger.info("Parsing info from {0}".format(args['batchfile']))
        assert os.path.isfile(args['batchfile']), \
                "no batchfile {0}".format(args['batchfile'])
        batchruns = pandas.read_csv(args['batchfile'], na_filter=False)
        batchruns.columns = batchruns.columns.str.strip()
        colnames = set(['name', 'sel', 'mock'])
        assert set(batchruns.columns) >= colnames, ("batchfile lacks "
                "required column names: {0}".format(colnames))
        for c in batchruns.columns:
            batchruns[c] = batchruns[c].map(str).map(str.strip)
        logger.info("Read the following sample information:\n{0}\n".format(
                batchruns.to_csv(index=False)))
        assert all([dms_tools2.parseargs.checkName(name, 'name') for name
                in batchruns['name']])
        if 'group' not in batchruns.columns:
            batchruns['group'] = ''
            batchruns['outname'] = batchruns['name']
            groups = groupprefixes = sorted(set(batchruns['group']))
        else:
            assert all([dms_tools2.parseargs.checkName(group, 'group') for
                    group in batchruns['group']])
            batchruns['outname'] = batchruns['group'] + '-' + batchruns['name']
            groups = sorted(set(batchruns['group']))
            groupprefixes = [g + '-' for g in groups]
        assert (len(set(batchruns['outname'].values)) ==
                len(batchruns['outname'])), "Duplicated name"

        # add more expected output files based on group
        for (g, gprefix) in zip(groups, groupprefixes):
            for seltype in ['mutdiffsel', 'absolutesitediffsel', 
                    'positivesitediffsel', 'maxmutdiffsel']:
                filesuffixes[g + seltype + 'corr'] = '_{0}{1}corr.pdf'.format(
                        gprefix, seltype)
            for avgtype in ['mean', 'median']:
                for seltype in ['mutdiffsel', 'sitediffsel']:
                    filesuffixes[g + avgtype + seltype] = \
                            '_{0}{1}{2}.csv'.format(
                            gprefix, avgtype, seltype)
        files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
                args['summaryprefix'], s))) for (f, s) in 
                filesuffixes.items()])

        # files from individual runs
        runfiles = {'mutdiffsel': '_mutdiffsel.csv',
                    'sitediffsel': '_sitediffsel.csv',
                    }
        for filename, filesuffix in runfiles.items():
            batchruns[filename] = (args['outdir'] + '/' + batchruns['outname'] +
                                   filesuffix)
        found_all_runfiles = all(all(map(os.path.isfile, batchruns[f].values))
                                         for f in runfiles.keys())

        # do we need to proceed
        if (
                args['use_existing'] == 'yes' and
                found_all_runfiles and
                all(map(os.path.isfile, files.values()))
                ):
            logger.info("Output summary files already exist and "
                    "'--use_existing' is 'yes', so exiting with no "
                    "further action.")
            sys.exit(0)

        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        # determine how many cpus to use
        if args['ncpus'] == -1:
            ncpus = multiprocessing.cpu_count()
        elif args['ncpus'] > 0:
            ncpus = min(args['ncpus'], multiprocessing.cpu_count())
        else:
            raise ValueError("--ncpus must be -1 or > 0")
        ncpus_per_run = max(1, ncpus // len(batchruns.index))

        # run dms2_diffsel for each sample in batchfile
        logger.info("Running dms2_diffsel on all samples...")
        argslist = []
        for (i, row) in batchruns.iterrows():
            # define newargs to pass to dms2_diffsel
            newargs = ['dms2_diffsel', '--name', row['outname'], 
                    '--sel', row['sel'], '--mock', row['mock'],
                    '--ncpus', str(ncpus_per_run)]
            if 'err' in batchruns.columns:
                newargs += ['--err', row['err']]
            for (arg, val) in args.items():
                if arg in ['batchfile', 'ncpus', 'summaryprefix']:
                    continue
                elif val:
                    newargs.append('--{0}'.format(arg))
                    if isinstance(val, list):
                        newargs += list(map(str, val))
                    else:
                        newargs.append(str(val))
            argslist.append(newargs)
        pool = multiprocessing.dummy.Pool(ncpus)
        pool.imap(functools.partial(subprocess.check_output, 
                stderr=subprocess.STDOUT), argslist)
        pool.close()
        pool.join()
        logger.info("Completed runs of dms2_diffsel.\n")

        # define dms2_diffsel output files and make sure they exist 
        for filename in runfiles.keys():
            for f in batchruns[filename]:
                if not os.path.isfile(f):
                    flog = f.replace(filesuffix, '.log')
                    assert os.path.isfile(flog), "Didn't create {0}".format(flog)
                    with open(flog) as flog_f:
                        lines = flog_f.readlines()
                    raise RuntimeError("Failed to create {0}.\nHere is end of "
                            "{1}:\n{2}".format(f, flog, ''.join(lines[-25 : ])))

        # summarize for each group
        groups = batchruns['group'].unique()
        groupprefixes = ['{0}{1}'.format(g, {True:'', False:'-'}[g == ''])
                for g in groups]

        # process diffsel each group
        avgsitediffsel = {'mean':[], 'median':[]}
        for (g, gprefix) in zip(groups, groupprefixes):

            samples = batchruns.query('group == @g')
            logger.info("Analyzing the diffsel values for the "
                    "{0} samples{1}.".format(len(samples),
                    {True:'', False:' in group {0}'.format(g)}[g == '']))

            for seltype in ['mutdiffsel', 'absolutesitediffsel', 
                    'positivesitediffsel', 'maxmutdiffsel']:
                plotfile = files[g + seltype + 'corr']
                logger.info("Plotting {0} correlations to {1}".format(
                        seltype, plotfile))
                if seltype == 'mutdiffsel':
                    datatype = 'mutdiffsel'
                    infiles = samples['mutdiffsel']
                else:
                    datatype = {'absolutesitediffsel':'abs_diffsel',
                                'positivesitediffsel':'positive_diffsel',
                                'maxmutdiffsel':'max_diffsel'}[seltype]
                    infiles = samples['sitediffsel']
                dms_tools2.plot.plotCorrMatrix(samples['name'], 
                        infiles, plotfile, datatype=datatype, 
                        title=g.replace('-', ' '))

            for avgtype in ['mean', 'median']:
                f = files[g + avgtype + 'mutdiffsel']
                logger.info("Writing {0} mutdiffsel to {1}".format(avgtype, f))
                avgmutdiffsel = dms_tools2.diffsel.avgMutDiffSel(
                        samples['mutdiffsel'], avgtype)
                avgmutdiffsel.to_csv(f, index=False)
                f = files[g + avgtype + 'sitediffsel']
                logger.info("Writing {0} sitediffsel to {1}".format(avgtype, f))
                (dms_tools2.diffsel.mutToSiteDiffSel(avgmutdiffsel)
                        .sort_values('abs_diffsel', ascending=False)
                        .to_csv(f, index=False)
                        )
                avgsitediffsel[avgtype].append(f)

        # plots of diffsel for all groups
        if 'grouplabel' in batchruns.columns:
            grouplabels = batchruns['grouplabel'].unique()
            assert (len(groups) == len(grouplabels) ==
                    len(batchruns.groupby(['group', 'grouplabel']))),\
                     "not a unique pairing of `group` and `grouplabel`"
        else:
            grouplabels = [g.replace('-', ' ') for g in groups]
        for avgtype in ['mean', 'median']:
            for pt in lineplottypes:
                f = files[avgtype + pt]
                logger.info("Plotting {0} {1} diffsel to {2}".format(
                        avgtype, pt, f))
                dms_tools2.plot.plotSiteDiffSel(grouplabels,
                        avgsitediffsel[avgtype], f, pt)

    except SystemExit as e:
        if e.code != 0:
            raise

    except:
        logger.exception('Terminating {0} with ERROR'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
