#!python

"""Comprehensive model comparison and selection-detection with ``phydms``.

Written by Jesse Bloom and Sarah Hilton."""


import os
import re
import time
import copy
import logging
import multiprocessing
import subprocess
import signal
import glob
import phydmslib
import phydmslib.file_io
import phydmslib.parsearguments


def RunCmds(cmds):
    """Runs the command line arguments in *cmds* using *subprocess*."""
    try:
        p = subprocess.Popen(cmds, stdout=subprocess.PIPE,
                             stderr=subprocess.STDOUT)
        pid = p.pid
        (stdout, stderr) = p.communicate()
    except Exception:
        os.kill(pid, signal.SIGTERM)


def createModelComparisonFile(models, outprefix, modelcomparisonfile):
    """Creates `modelcomparisonfile` in Markdown format.

    `models` is a list of the prefixes used for each run.

    *modelcomparison.csv*
        Columns include 'model', 'variable', 'value'
        'variable' includes all of the optimized parameters, the final log
        likelihood and the number of optimized parameters.
    *modelcomparison.md*
        Columns include 'model', 'Loglikelihood', 'deltaAIC', and all non-phi
        optimized parameters.
    """

    # modelcomparison.csv
    columns = ['Model', 'deltaAIC', 'LogLikelihood', 'nParams', 'ParamValues']
    stats = dict([(col, []) for col in columns])
    for modelName in models:
        stats['Model'].append(modelName)
        with open(outprefix + modelName + '_loglikelihood.txt') as f:
            stats['LogLikelihood'].append(float(f.read().split()[-1]))
        with open(outprefix + modelName + '_modelparams.txt') as f:
            params = [(line.split('=')) for line in f]
        stats['nParams'].append(len(params))
        paramvalues = ['{0}={1:.2f}'.format(name.strip(), float(value))
                       for (name, value) in sorted(params) if 'phi' not in
                       name]
        stats['ParamValues'].append(', '.join(paramvalues))
        stats['deltaAIC'].append(2 * (stats['nParams'][-1] -
                                      stats['LogLikelihood'][-1]))
    minAIC = min(stats['deltaAIC'])
    stats['deltaAIC'] = [x - minAIC for x in stats['deltaAIC']]
    nrows = len(stats['Model'])
    indexorder = [tup[1] for tup in sorted(zip(stats['deltaAIC'],
                                           range(nrows)))]
    for col in ['deltaAIC', 'LogLikelihood']:
        stats[col] = ['{0:.2f}'.format(x) for x in stats[col]]
    stats['nParams'] = [str(x) for x in stats['nParams']]
    colwidths = [max(map(len, stats[col] + [col])) for col in columns]
    formatline = ('| ' + ' | '.join(['{' + str(i) + ':' + str(w) + '}'
                                     for (i, w) in enumerate(colwidths)])
                  + ' |')
    sepline = '|-' + '-|-'.join(['-' * w for w in colwidths]) + '-|'
    text = [formatline.format(*columns), sepline]
    for i in indexorder:
        line = [stats[col][i] for col in columns]
        text.append(formatline.format(*line))
    with open(modelcomparisonfile, 'w') as f:
        f.write('\n'.join(text))


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSComprehensiveParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # create output directory if needed
    outdir = os.path.dirname(args['outprefix'])
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)

    # setup files
    # file names slightly different depending on
    # whether outprefix is directory or file
    if args['outprefix'][-1] == '/':
        logfile = "{0}log.log".format(args['outprefix'])
    else:
        logfile = "{0}.log".format(args['outprefix'])
        args['outprefix'] = '{0}_'.format(args['outprefix'])
    modelcomparisonfile = '{0}modelcomparison.md'.format(args['outprefix'])
    raxmlStandardOutputFile = '{0}raxml_output.txt'.format(args['outprefix'])

    # Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    logging.captureWarnings(True)
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                        level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
                .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
                .format('\n'.join(['\t{0} = {1}'.format(key, args[key])
                                   for key in args.keys()])))

    # setup models
    filesuffixlist = ['_log.log', '_tree.newick', '_loglikelihood.txt',
                      '_modelparams.txt']
    filesuffixes = {}  # keyed by model, values are list of suffixes
    models = {}
    additionalcmds = ["--brlen", args["brlen"]]
    if args['omegabysite']:
        additionalcmds.append('--omegabysite')
        filesuffixlist = copy.deepcopy(filesuffixlist) + ['_omegabysite.txt']

    # set up the YNGKP models
    models = {'YNGKP_M0': ('YNGKP_M0', additionalcmds)}
    filesuffixes['YNGKP_M0'] = filesuffixlist
    models['YNGKP_M5'] = ('YNGKP_M5', additionalcmds)
    filesuffixes['YNGKP_M5'] = filesuffixlist

    # set up the ExpCM
    additionalcmds = copy.deepcopy(additionalcmds)
    if args['diffprefsbysite']:
        additionalcmds = (copy.deepcopy(additionalcmds)
                          + ['--diffprefsbysite'])
        filesuffixlist = (copy.deepcopy(filesuffixlist)
                          + ['_diffprefsbysite.txt'])
    for prefsfile in args['prefsfiles']:
        if re.search('\s', prefsfile):
            raise ValueError("There is a space in the preferences file name:\
                    {0}".format(prefsfile))
        prefsfilebase = os.path.splitext(os.path.basename(prefsfile))[0]
        modelname = 'ExpCM_{0}'.format(prefsfilebase)
        assert modelname not in filesuffixes, "Duplicate preferences file"\
            " base name {0} for {1}; make names unique even after "\
            "removing directory and extension".format(modelname, prefsfile)
        filesuffixes[modelname] = filesuffixlist
        models[modelname] = ('ExpCM_{0}'.format(prefsfile), additionalcmds)
        if args["gammaomega"]:
            gammaomegamodelname = '{0}_gammaomega'.format(modelname)
            models[gammaomegamodelname] = ('ExpCM_{0}'.format(prefsfile),
                                           additionalcmds + ['--gammaomega'])
            filesuffixes[gammaomegamodelname] = filesuffixlist
        if args["gammabeta"]:
            gammabetamodelname = '{0}_gammabeta'.format(modelname)
            models[gammabetamodelname] = ('ExpCM_{0}'.format(prefsfile),
                                          additionalcmds + ['--gammabeta'])
            filesuffixes[gammabetamodelname] = filesuffixlist
        if not args['noavgprefs']:
            avgmodelname = 'averaged_{0}'.format(modelname)
            models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                    additionalcmds + ['--avgprefs'], 0)
            filesuffixes[avgmodelname] = filesuffixlist
            if args["gammaomega"]:
                avgmodelname = 'averaged_{0}_gammaomega'.format(modelname)
                models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                        additionalcmds + ['--gammaomega']
                                        + ['--avgprefs'], 0)
                filesuffixes[avgmodelname] = filesuffixlist
            if args["gammabeta"]:
                avgmodelname = 'averaged_{0}_gammabeta'.format(modelname)
                models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                        additionalcmds + ['--gammabeta']
                                        + ['--avgprefs'], 0)
                filesuffixes[avgmodelname] = filesuffixlist
        if args['randprefs']:
            randmodelname = 'randomized_{0}'.format(modelname)
            models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                     additionalcmds + ['--randprefs'], 0)
            filesuffixes[randmodelname] = filesuffixlist
            if args["gammaomega"]:
                randmodelname = 'randomized_{0}_gammaomega'.format(modelname)
                models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                         additionalcmds + ['--gammaomega']
                                         + ['--randprefs'], 0)
                filesuffixes[randmodelname] = filesuffixlist
            if args["gammabeta"]:
                randmodelname = 'randomized_{0}_gammabeta'.format(modelname)
                models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                                         additionalcmds + ['--gammabeta']
                                         + ['--randprefs'], 0)
                filesuffixes[randmodelname] = filesuffixlist

    # check alignment
    logger.info('Checking that the alignment {0} is valid...'
                .format(args['alignment']))
    alignment = phydmslib.file_io.ReadCodonAlignment(args['alignment'],
                                                     checknewickvalid=True)
    logger.info('Valid alignment specifying {0} sequences of length {1}.\n'
                .format(len(alignment), len(alignment[0][1])))

    # read or build a tree
    if args["tree"]:
        logger.info("Reading tree from {0}".format(args['tree']))
    else:
        logger.info("Tree not specified.")
        try:
            logger.info("Inferring tree with RAxML using command {0}"
                        .format(args['raxml']))

            # remove pre-existing RAxML files
            raxmlOutputName = os.path.splitext(os.path.basename(
                    args["alignment"]))[0]
            raxmlOutputFiles = []
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if os.path.isfile(raxmlFile):
                    raxmlOutputFiles.append(raxmlFile)
                    os.remove(raxmlFile)
            if len(raxmlOutputFiles) > 0:
                logger.info('Removed the following RAxML files:\n{0}\n'
                            .format('\n'.join(['\t{0}'.format(fname) for
                                               fname in raxmlOutputFiles])))

            # run RAxmL
            raxmlCMD = [args['raxml'], '-s', args['alignment'], '-n',
                        raxmlOutputName, '-m', 'GTRCAT', '-p1', '-T', '2']
            with open(raxmlStandardOutputFile, 'w') as f:
                subprocess.check_call(raxmlCMD, stdout=f)

            # move RAxML tree to output directory and remove all other files
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if "bestTree" in raxmlFile:
                    args["tree"] = args["outprefix"] + 'RAxML_tree.newick'
                    os.rename(raxmlFile, args['tree'])
                    logger.info("RAxML inferred tree is now named {0}".format(
                            args['tree']))
                else:
                    os.remove(raxmlFile)
        except OSError:
            raise ValueError("The raxml command of {0} is not valid. Is raxml "
                             "installed at this path?".format(args['raxml']))

    # get number of available CPUs and assign to each model
    if args['ncpus'] == -1:
        try:
            args['ncpus'] = multiprocessing.cpu_count()
        except Exception:
            raise RuntimeError("Encountered a problem trying to dynamically"
                               " determine the number of available CPUs. "
                               "Please manually specify the number of "
                               "desired CPUs with '--ncpus' and try again.")
            logger.info('Will use all %d available CPUs.\n' % args['ncpus'])
    assert args['ncpus'] >= 1, "Failed to specify valid number of CPUs"

    # YNGKP models get one CPU
    # ExpCM get more than one if excess over number of models
    expcm_modelnames = [modelname for modelname in models.keys()
                        if 'ExpCM' in modelname]
    yngkp_modelnames = [modelname for modelname in models.keys()
                        if 'YNGKP' in modelname]
    assert len(models.keys())\
           == (len(expcm_modelnames) + len(yngkp_modelnames)),\
           ("not ExpCM or YNGKP:\n{0}".format(str(models.keys())))
    ncpus_per_model = dict([(modelname, 1) for modelname in yngkp_modelnames])
    nperexpcm = max(1, (args['ncpus'] - 2) // len(expcm_modelnames))
    for modelname in expcm_modelnames:
        ncpus_per_model[modelname] = nperexpcm
    for modelname in models.keys():
        mtup = models[modelname]
        models[modelname] = (mtup[0],
                             mtup[1] + ['--ncpus',
                                        str(ncpus_per_model[modelname])])

    pool = {}  # holds process for model name
    started = {}  # holds whether process started for model name
    completed = {}  # holds whether process completed for model name
    outprefixes = {}  # holds outprefix for model name

    # rest of execution in try / finally
    try:

        # remove existing output files
        outfiles = [modelcomparisonfile]
        removed = []
        for modelname in models.keys():
            for suffix in filesuffixes[modelname]:
                fname = "{0}{1}{2}".format(args['outprefix'], modelname,
                                           suffix)
                outfiles.append(fname)
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
                removed.append(fname)
        if removed:
            logger.info('Removed the following existing files that have names'
                        " that match the names of output files that will be "
                        'created: {0}\n'.format(', '.join(removed)))

        # now run the models
        for modelname in list(models.keys()):
            (model, additionalcmds) = models[modelname]
            outprefix = "{0}{1}".format(args['outprefix'], modelname)
            cmds = ['phydms', args['alignment'], args["tree"], model,
                    outprefix] + additionalcmds
            logger.info('Starting analysis to optimize tree in {0} using '
                        'model {1}. The command is: {2}\n'
                        .format(args["tree"], modelname, ' '.join(cmds)))
            pool[modelname] = multiprocessing.Process(target=RunCmds,
                                                      args=(cmds,))
            outprefixes[modelname] = outprefix
            completed[modelname] = False
            started[modelname] = False
        while not all(completed.values()):
            nrunning = list(started.values()).count(True) - \
                    list(completed.values()).count(True)
            if nrunning < args['ncpus']:
                for (modelname, p) in pool.items():
                    if not started[modelname]:
                        p.start()
                        started[modelname] = True
                        break
            for (modelname, p) in pool.items():
                if started[modelname] and (not completed[modelname]) and \
                        (not p.is_alive()):  # process just completed
                    completed[modelname] = True
                    logger.info('Analysis completed for {0}'.format(modelname))
                    for fname in [outprefixes[modelname] + suffix for suffix
                                  in filesuffixes[modelname]]:
                        if not os.path.isfile(fname):
                            raise RuntimeError(
                                "phydms failed to created expected output "
                                "file {0}.".format(fname))
                        logger.info("Found expected output file {0}"
                                    .format(fname))
                    logger.info('Analysis successful for {0}\n'
                                .format(modelname))
            time.sleep(1)

        createModelComparisonFile(models.keys(), args["outprefix"],
                                  modelcomparisonfile)

        # make sure all expected output files are there
        for fname in outfiles:
            if not os.path.isfile(fname):
                raise RuntimeError("Cannot find expected output file {0}"
                                   .format(fname))

    except Exception:
        logger.exception('Terminating {0} at {1} with ERROR'
                         .format(prog, time.asctime()))
    else:
        logger.info('Successful completion of {0}'.format(prog))
    finally:
        logging.shutdown()
        for p in pool.values():
            if p.is_alive():
                p.terminate()


if __name__ == '__main__':
    main()  # run the script
