#!python

"""Test different diversifying pressure models with ``phydms``.

Written by Jesse Boom and Sarah Hilton.
"""


import os
import re
import time
import logging
import multiprocessing
import subprocess
import glob
import phydmslib
import phydmslib.file_io
import phydmslib.parsearguments
import pandas as pd
import random


def randomizeDivpressure(divpressureFile, numberRandomizations, outprefix):
    """
    This function randomizes a diversfiying pressure list the number of
    times specfied by the user. Each randomization is written to a new
    diversifying pressures files in the directory 'randomizedFiles'.
    The function returns a dictionary of file names keyed by the random seed.
    """
    divpressurefilebase = (os.path
                           .splitext(os.path.basename(divpressureFile))[0])
    outdir = os.path.dirname(outprefix+"randomizedFiles/")
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)
    randomFiles = {}
    divpressureValues = phydmslib.file_io.readDivPressure(divpressureFile)
    sites = list(divpressureValues.keys())
    divpressures = list(divpressureValues.values())
    for seed in range(int(numberRandomizations)):
        random.seed(seed)
        random.shuffle(divpressures)
        randomDivpressureFile = '{0}/{1}_random_{2}.csv'\
                                .format(outdir, divpressurefilebase, seed)
        with open(randomDivpressureFile, 'w') as f:
            f.write("#SITE,VALUE\n")
            f.write('\n'.join('{0},{1}'.format(site, dp) for (site, dp)
                              in zip(sites, divpressures)))
        randomFiles[seed] = randomDivpressureFile
    return randomFiles


def createModelComparisonFile(models, outprefix):
    """
    Creates two summary files, a .csv and a .md

    modelcomparison.csv
        *Preferences*
            The prefs file used for `ExpCM`
        *Diversifying Pressure Set*
            The diversifying pressure file
        *Diversifying Pressure Type*
            "True" if the values in the Diversifying Pressure Set
            "None" if `ExpCM` was run without diversifying pressure
            "Random" if the true values were randomized
        *Diversifying Pressure ID*
            Combination of the Diversifying Pressure Set and the Diversifying
            Pressure Type. Unique identifier for each run of `phydms`
        *variable*
            All of the optimized parameters and the final log likelihood
        *Value*
            The value of the variable
    *modelcomparison.md*
        Columns include 'Diversifying Pressure Set', 'Loglikelihood',
        and all optimized parameters.
        Only the non-randomized `phydms` runs are included.
    """

    # modelcomparison.csv
    final = pd.DataFrame({"Preferences": [], "DiversifyingPressureSet": [],
                          "DiversifyingPressureType": [],
                          "DiversifyingPressureID": [],
                          "variable": [], "value": []})
    for modelID in models.keys():
        outFilePrefix = "ExpCM_" + '_'.join(str(item) for item in
                                            modelID if item != "None")
        df = pd.read_csv(outprefix + outFilePrefix + "_modelparams.txt",
                         sep=" = ", header=None, engine='python')
        df.columns = ['variable', 'value']  # rename the columns
        df["Preferences"] = [modelID[0] for x in range(len(df))]
        df["DiversifyingPressureSet"] = [modelID[1] for x in range(len(df))]
        df["DiversifyingPressureType"] = [modelID[2] for x in range(len(df))]
        df["DiversifyingPressureID"] = [modelID[1] + "_" + str(modelID[3])
                                        for x in range(len(df))]
        final = pd.concat([final, df], sort=True)
        with open(outprefix + outFilePrefix + "_loglikelihood.txt") as f:
            temp = {"Preferences": [], "DiversifyingPressureSet": [],
                    "DiversifyingPressureType": [],
                    "DiversifyingPressureID": [],
                    "variable": [], "value": []}
            lines = f.readlines()
            temp["Preferences"] = [modelID[0]]
            temp["DiversifyingPressureSet"] = [modelID[1]]
            temp["DiversifyingPressureType"] = [modelID[2]]
            temp["DiversifyingPressureID"] = ["{0}_{1}".format(modelID[1],
                                                               modelID[3])]
            temp["variable"].append("LogLikelihood")
            temp["value"].append(lines[0].split(" = ")[-1])
        final = pd.concat([final, pd.DataFrame(temp)])
    final.to_csv(outprefix + "modelcomparison.csv", index=False)
    final.reset_index(drop=True, inplace=True)

    # modelcomparison.md
    # spivot the dataframe so the parameters are columns
    mdDF = (final[final["DiversifyingPressureType"] != "Random"]
            .pivot(index="DiversifyingPressureSet",
                   columns='variable', values='value'))
    mdDF = mdDF[[x for x in mdDF.columns.values]]
    with open(outprefix + "modelcomparison.md", "w") as f:
        f.write("|".join(["DiversifyingPressureSet"] +
                         list(mdDF.columns.values)) + "\n")
        f.write("".join(["---|" for x in range(len(mdDF.columns.values))]
                + ["---"]) + "\n")
        for index, row in mdDF.iterrows():
            f.write("|".join([str(x) for x in [index] + list(row)]) + "\n")


def RunCmds(cmds):
    """Runs the command line arguments in *cmds* using *subprocess*."""
    try:
        p = subprocess.Popen(cmds, stdout=subprocess.PIPE,
                             stderr=subprocess.STDOUT)
        pid = p.pid
        (stdout, stderr) = p.communicate()
    except Exception:
        os.kill(pid)


def main():
    """
    Main body of script.

    *modelID* tuple
        Each run of `phydms` is considered one model by this script.
        Each model has unique ID which describes the preferences and the
        diversifying pressures used in the `ExpCM`.
        The ID is a tuple of the following form:
            (pref file, diversifying pressure (name of file or "None"),
            diversifying pressure type ("None", "True", or "Random"), and the
            random seend (integer or "None"))
    *models* dictionary keyed by modelID
        each value is a tuple of the following form:
            (pref file, list of `phydms` commands,
            number of empirical parameters, modelprefix)

    The branch lengths are optimized using `ExpCM` without diversifying
    pressure on either the tree specfied by the user or the tree inferred
    under the *GTRCAT* with `RAxML`. The `ExpCM` with diversifying pressures
    are run with this optimized tree and the branch lengths are scaled but not
    individually optimized.
    """

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSTestdivpressureParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # create output directory if needed
    outdir = os.path.dirname(args['outprefix'])
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)

    # setup files
    # file names depend on whether outprefix is directory or file
    if args['outprefix'][-1] == '/':
        logfile = "{0}log.log".format(args['outprefix'])
    else:
        logfile = "{0}.log".format(args['outprefix'])
        args['outprefix'] = '{0}_'.format(args['outprefix'])
    raxmlStandardOutputFile = '{0}ramxl_output.txt'.format(args['outprefix'])

    # Set up to log everything to logfile.
    # Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    logging.captureWarnings(True)
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                        level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
                .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
                .format('\n'.join(['\t{0} = {1}'
                        .format(key, args[key]) for key in args.keys()])))

    # Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                        level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
                .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
                .format('\n'.join(['\t{0} = {1}'
                        .format(key, args[key]) for key in args.keys()])))

    # setup models
    filesuffixlist = ['_log.log', '_tree.newick', '_loglikelihood.txt',
                      '_modelparams.txt']
    filesuffixes = {}  # keyed by model, values are list of suffixes
    models = {}

    # set up the ExpCM without diversifying pressure
    expcmadditionalcmds = []
    nempiricalExpCM = 3
    if re.search('\s', args['prefsfile']):
        raise ValueError("There is a space in the preferences file name:\
                {0}".format(args['prefsfile']))
    prefsfilebase = os.path.splitext(os.path.basename(args['prefsfile']))[0]
    modelID = (prefsfilebase, "None", "None", "None")
    assert modelID not in filesuffixes, "Duplicate model ID"
    filesuffixes[modelID] = filesuffixlist
    models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
                       expcmadditionalcmds,
                       nempiricalExpCM,
                       "ExpCM_" + '_'.join(str(item) for item in
                                           modelID if item != "None"))

    # set up ExpCM with diversifying pressure
    for divpressure in args["divpressure"]:
        divpressurefilebase = (os.path
                               .splitext(os.path.basename(divpressure))[0])
        modelID = (prefsfilebase, divpressurefilebase, "True", "None")
        assert modelID not in filesuffixes, "Duplicate model ID (divpressure)"
        filesuffixes[modelID] = filesuffixlist
        models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
                           expcmadditionalcmds + ["--divpressure",
                                                  divpressure],
                           nempiricalExpCM,
                           "ExpCM_" + '_'.join(str(item) for item
                                               in modelID if item != "None"))
        # set up ExpCM with diversifying pressure randomizations
        if args["randomizations"]:
            randomFiles = randomizeDivpressure(divpressure,
                                               args["randomizations"],
                                               args["outprefix"])
            for r in randomFiles.keys():
                modelID = (prefsfilebase, divpressurefilebase,
                           "Random", r)
                assert modelID not in filesuffixes, ("Duplicate model ID "
                                                     "(divpressure)")
                filesuffixes[modelID] = filesuffixlist
                models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
                                   expcmadditionalcmds + ["--divpressure",
                                                          randomFiles[r]],
                                   nempiricalExpCM,
                                   "ExpCM_" + '_'.join(str(item) for item in
                                                       modelID if item !=
                                                       "None"))

    # check alignment
    logger.info('Checking that the alignment {0} is valid...'
                .format(args['alignment']))
    alignment = phydmslib.file_io.ReadCodonAlignment(args['alignment'],
                                                     checknewickvalid=True)
    assert len(set([align[1] for align in alignment])) \
           == len([align[1] for align in alignment]),\
           ("Remove duplicate sequences from {0}.".format(args["alignment"]))
    logger.info('Valid alignment specifying {0} sequences of length {1}.\n'
                .format(len(alignment), len(alignment[0][1])))

    # read or build a tree
    if args["tree"]:
        logger.info("Reading tree from {0}".format(args['tree']))
    else:
        logger.info("Tree not specified.")
        try:
            logger.info("Inferring tree with RAxML using command {0}"
                        .format(args['raxml']))

            # remove pre-existing RAxML files
            raxmlOutputName = os.path.splitext(os.path.basename(
                    args["alignment"]))[0]
            raxmlOutputFiles = [n for n in
                                glob.glob("RAxML_*{0}".format(raxmlOutputName))
                                if os.path.isfile(n)]
            raxmlOutputFiles = []
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if os.path.isfile(raxmlFile):
                    raxmlOutputFiles.append(raxmlFile)
                    os.remove(raxmlFile)
            if len(raxmlOutputFiles) > 0:
                logger.info('Removed the following RAxML files:\n{0}\n'
                            .format('\n'.join(['\t{0}'.format(fname) for
                                               fname in raxmlOutputFiles])))

            # run RAxmL
            raxmlCMD = [args['raxml'], '-s', args['alignment'], '-n',
                        raxmlOutputName, '-m', 'GTRCAT', '-p1', '-T', '2']
            with open(raxmlStandardOutputFile, 'w') as f:
                subprocess.check_call(raxmlCMD, stdout=f)

            # move RAxML tree to output directory and remove all other files
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if "bestTree" in raxmlFile:
                    args["tree"] = args["outprefix"] + 'RAxML_tree.newick'
                    os.rename(raxmlFile, args['tree'])
                    logger.info("RAxML inferred tree is now named {0}".format(
                            args['tree']))
                else:
                    os.remove(raxmlFile)
        except OSError:
            raise ValueError("The raxml command of {0} is not valid."
                             " Is raxml installed at this path?"
                             .format(args['raxml']))

    # setting up cpus
    if args['ncpus'] == -1:
        try:
            args['ncpus'] = multiprocessing.cpu_count()
        except Exception:
            raise RuntimeError("Encountered a problem trying to dynamically"
                               "determine the number of available CPUs. "
                               "Please manually specify the number of "
                               "desired CPUs with '--ncpus' and try again.")
        logger.info('Will use all %d available CPUs.\n' % args['ncpus'])
    assert args['ncpus'] >= 1, "Failed to specify valid number of CPUs"

    # Each model gets at least 1 CPU
    ncpus_per_model = {}
    nperexpcm = max(1, (args['ncpus'] - 2) // len(models))
    for modelID in models.keys():
        ncpus_per_model[modelID] = nperexpcm
        mtup = models[modelID]
        assert len(mtup) == 4  # should be 3-tuple
        models[modelID] = (mtup[0],
                           mtup[1] + ['--ncpus',
                                      str(ncpus_per_model[modelID])],
                           mtup[2], mtup[3])

    pool = {}  # holds process for model name
    started = {}  # holds whether process started for model name
    completed = {}  # holds whether process completed for model name
    outprefixes = {}  # holds outprefix for model name

    # rest of execution in try / finally
    try:

        # remove existing output files
        outfiles = []
        removed = []
        for modelID in models.keys():
            for suffix in filesuffixes[modelID]:
                fname = "{0}{1}{2}".format(args['outprefix'],
                                           models[modelID][3], suffix)
                outfiles.append(fname)
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
                removed.append(fname)
        if removed:
            logger.info('Removed the following existing files that have names'
                        " that match the names of output files that will be "
                        'created: {0}\n'.format(', '.join(removed)))

    # first run the `ExpCM` model without diversiying pressure to optimize
    # the branch lengths
        noDivPressure = [x for x in list(models.keys()) if x[1] == "None"]
        assert len(noDivPressure) == 1, "More than one ExpCM without\
                diversifying pressure was specified"
        noDivPressure = noDivPressure[0]
        (model, additionalcmds, nempirical, modelname) = models[noDivPressure]
        outprefix = "{0}{1}".format(args['outprefix'], modelname)
        cmds = ['phydms', args['alignment'], args["tree"], model,
                outprefix] + additionalcmds + ["--brlen", "optimize"]
        logger.info('Starting analysis. Optimizing the branch lengths. The '
                    'command is: %s\n' % (' '.join(cmds)))
        outprefixes[noDivPressure] = outprefix
        subprocessOutput = '{0}subprocess_output.txt'.format(args['outprefix'])
        with open(subprocessOutput, 'w') as f:
            subprocess.check_call(cmds, stdout=f)
        if os.path.isfile(subprocessOutput):
            os.remove(subprocessOutput)
        for fname in [outprefixes[noDivPressure] + suffix
                      for suffix in filesuffixes[noDivPressure]]:
            if not os.path.isfile(fname):
                raise RuntimeError("phydms failed to created"
                                   " expected output file {0}.".format(fname))
        logger.info('Analysis successful for {0} without diversifying '
                    'pressure.\n'.format(modelID[0]))
        treeFile = '{0}_tree.newick'.format(outprefix)
        assert os.path.isfile(treeFile)

    # now run the other models
        divPressureModels = [x for x in list(models.keys()) if x[1] != "None"]
        for modelID in divPressureModels:
            (model, additionalcmds, nempirical, modelname) = models[modelID]
            outprefix = "{0}{1}".format(args['outprefix'], modelname)
            cmds = ['phydms', args['alignment'], treeFile, model,
                    outprefix] + additionalcmds + ["--brlen", "scale"]
            pool[modelID] = multiprocessing.Process(target=RunCmds,
                                                    args=(cmds,))
            outprefixes[modelID] = outprefix
            completed[modelID] = False
            started[modelID] = False
        while not all(completed.values()):
            nrunning = list(started.values()).count(True) - \
                    list(completed.values()).count(True)
            if nrunning < args['ncpus']:
                for (modelID, p) in pool.items():
                    if not started[modelID]:
                        p.start()
                        started[modelID] = True
                        break
            for (modelID, p) in pool.items():
                if started[modelID] and (not completed[modelID]) and \
                        (not p.is_alive()):  # process just completed
                    completed[modelID] = True
                    for fname in [outprefixes[modelID] + suffix for suffix
                                  in filesuffixes[modelID]]:
                        if not os.path.isfile(fname):
                            raise RuntimeError("phydms failed to created"
                                               " expected output file {0}."
                                               .format(fname))
                    if modelID[1] != "None":
                        if modelID[2] != "True":
                            logger.info('Analysis successful for {0} with '
                                        'diversifying pressure {1} and random '
                                        'seed {2}\n'
                                        .format(modelID[0], modelID[1],
                                                modelID[3]))
                        else:
                            logger.info('Analysis successful for {0} with '
                                        'diversifying pressure {1}\n'
                                        .format(modelID[0], modelID[1]))
                    else:
                        logger.info('Analysis successful for {0} without '
                                    'diversifying pressure.\n'
                                    .format(modelID[0]))
            time.sleep(1)

    # make sure all expected output files are there
        for fname in outfiles:
            if not os.path.isfile(fname):
                raise RuntimeError("Cannot find expected output file {0}"
                                   .format(fname))
        createModelComparisonFile(models, args["outprefix"])
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
        if args["randomizations"]:
            for f in glob.glob(args["outprefix"] + "randomizedFiles/*"):
                os.remove(f)
            os.rmdir(args["outprefix"] + "randomizedFiles/")

    except Exception:
        logger.exception('Terminating {0} at {1} with ERROR'
                         .format(prog, time.asctime()))
    else:
        logger.info('Successful completion of {0}'.format(prog))
    finally:
        logging.shutdown()
        for p in pool.values():
            if p.is_alive():
                p.terminate()


if __name__ == '__main__':
    main()  # run the script
