#!python

"""Runs a batch of ``dms2_prefs`` programs, summarizes results."""


import os
import glob
import sys
import re
import logging
import functools
import subprocess
import multiprocessing
import multiprocessing.dummy
import pandas
import dms_tools2.parseargs
import dms_tools2.utils
import dms_tools2.plot
import dms_tools2.prefs


def main():
    """Main body of script."""

    parser = dms_tools2.parseargs.batch_prefsParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # set up names of output files
    dms_tools2.parseargs.checkName(args['summaryprefix'], 'summaryprefix')
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = '.'
    filesuffixes = {
            'log':'.log',
            }
    if not args['no_corr']:
        filesuffixes['corr'] = '_prefscorr.pdf'
    if not args['no_avg']:
        filesuffixes['avgprefs'] = '_avgprefs.csv'
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['summaryprefix'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed
    if args['use_existing'] == 'yes' and all(map(
                os.path.isfile, files.values())):
            print("Output summary files already exist and '--use_existing' "
                    "is 'yes', so exiting with no further action.")
            sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:

        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        # read batchfile, strip any whitespace from strings
        logger.info("Parsing info from {0}".format(args['batchfile']))
        assert os.path.isfile(args['batchfile']), "no batchfile"
        batchruns = pandas.read_csv(args['batchfile'], na_filter=False)
        batchruns.columns = batchruns.columns.str.strip()
        colnames = set(['name', 'pre', 'post'])
        assert set(batchruns.columns) >= colnames, ("batchfile lacks "
                "required column names: {0}".format(colnames))
        for c in batchruns.columns:
            batchruns[c] = batchruns[c].map(str).map(str.strip)
        logger.info("Read the following sample information:\n{0}\n".format(
                batchruns.to_csv(index=False)))
        assert all([dms_tools2.parseargs.checkName(name, 'name') for 
                name in batchruns['name']])
        assert len(batchruns['name']) == len(set(batchruns['name'].values)),\
                "Duplicated name"

        # determine how many cpus to use
        if args['ncpus'] == -1:
            ncpus = multiprocessing.cpu_count()
        elif args['ncpus'] > 0:
            ncpus = min(args['ncpus'], multiprocessing.cpu_count())
        else:
            raise ValueError("--ncpus must be -1 or > 0")
        ncpus_per_run = max(1, ncpus // len(batchruns.index))

        # run dms2_prefs for each sample in batchfile
        logger.info("Running dms2_prefs on all samples...")
        argslist = []
        if 'err' in batchruns.columns:
            error_model = 'same'
        elif {'errpre', 'errpost'} <= set(batchruns.columns):
            error_model = 'different'
        else:
            assert 'errpre' not in batchruns.columns, "errpre but not errpost"
            assert 'errpost' not in batchruns.columns, "errpost but not errpre"
            error_model = 'none'
        for (i, row) in batchruns.iterrows():
            # define newargs to pass to dms2_prefs
            newargs = ['dms2_prefs', '--name', row['name'], 
                    '--pre', row['pre'], '--post', row['post'],
                    '--ncpus', str(ncpus_per_run)]
            if error_model == 'same':
                newargs += ['--err', row['err'], row['err']]
            elif error_model == 'different':
                newargs += ['--err', row['errpre'], row['errpost']]
            for (arg, val) in args.items():
                if arg in ['batchfile', 'ncpus', 'summaryprefix',
                           'no_avg', 'no_corr']:
                    continue
                elif val:
                    newargs.append('--{0}'.format(arg))
                    if isinstance(val, list):
                        newargs += list(map(str, val))
                    else:
                        newargs.append(str(val))
            argslist.append(newargs)
        pool = multiprocessing.dummy.Pool(ncpus)
        pool.imap(functools.partial(subprocess.check_output, 
                stderr=subprocess.STDOUT), argslist)
        pool.close()
        pool.join()
        logger.info("Completed runs of dms2_prefs.\n")

        # define dms2_prefs output files and make sure they exist 
        for (filename, filesuffix) in [
                ('prefs', '_prefs.csv')
                ]:
            batchruns[filename] = (args['outdir'] + '/' + batchruns['name'] +
                    filesuffix)
            for f in batchruns[filename]:
                if not os.path.isfile(f):
                    flog = f.replace(filesuffix, '.log')
                    assert os.path.isfile(flog), "Didn't create {0}".format(flog)
                    with open(flog) as flog_f:
                        lines = flog_f.readlines()
                    raise RuntimeError("Failed to create {0}.\nHere is end of "
                            "{1}:\n{2}".format(f, flog, ''.join(lines[-25 : ])))

        if not args['no_corr']:
            logger.info("Plotting correlations to {0}\n".format(files['corr']))
            dms_tools2.plot.plotCorrMatrix(batchruns['name'], 
                    batchruns['prefs'], files['corr'], datatype='prefs')

        if not args['no_avg']:
            logger.info("Writing average preferences to {0}\n".format(
                   files['avgprefs']))
            avgprefs = dms_tools2.prefs.avgPrefs(batchruns['prefs'])
            avgprefs.to_csv(files['avgprefs'], index=False)

    except:
        logger.exception('Terminating {0} with ERROR'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
