#!python

"""Phylogenetic inference using deep mutational scanning data.

Written by Jesse Bloom."""


import os
import re
import logging
import random
import time
import operator
import multiprocessing
import warnings
import numpy
import scipy.stats
import statsmodels.sandbox.stats.multicomp
import Bio.Phylo
import phydmslib
from phydmslib.constants import INDEX_TO_AA, N_AA
import phydmslib.file_io
import phydmslib.parsearguments
import phydmslib.models
import phydmslib.treelikelihood
warnings.simplefilter('always')
warnings.simplefilter('ignore', ImportWarning)


def treelikForDiffPrefsBySite(site, tl, prefslist, prior, modeltype,
                              othermodeltype, nograd):
    """Returns `TreeLikelihood` for fitting diffprefs for a site.

    The returned value can be passed to `fitDiffPrefsBySite`."""
    sitealignment = [(head, seq[3 * (site - 1): 3 * site]) for
                     (head, seq) in tl.alignment]
    treeliks = {}
    for (mtype, mname) in [(modeltype, 'primary'),
                           (othermodeltype, 'secondary')]:
        if isinstance(tl.model, phydmslib.models.GammaDistributedOmegaModel):
            basemodel = tl.model.basemodel
            assert type(basemodel) in [phydmslib.models.ExpCM,
                                       phydmslib.models.ExpCM_empirical_phi]
            sitemodel = mtype([prefslist[site - 1]], prior, basemodel.kappa,
                              basemodel.omega, basemodel.phi,
                              origbeta=basemodel.beta,
                              freeparams=['zeta', 'omega'])
            sitemodel = (phydmslib.models
                         .GammaDistributedOmegaModel(sitemodel, tl.model.ncats,
                                                     tl.model.alpha_omega,
                                                     tl.model.beta_omega,
                                                     freeparams=[]))
        else:
            assert type(tl.model) in [phydmslib.models.ExpCM,
                                      phydmslib.models.ExpCM_empirical_phi]
            sitemodel = mtype([prefslist[site - 1]],
                              prior, tl.model.kappa, tl.model.omega,
                              tl.model.phi, origbeta=tl.model.beta,
                              freeparams=['zeta'])
        assert sitemodel.freeparams == ['zeta']
        treeliks[mname] = phydmslib.treelikelihood.TreeLikelihood(
                tl.tree, sitealignment, sitemodel,
                branchScale=tl.model.branchScale)
    return (site, treeliks, nograd)


def fitDiffPrefsBySite(argtup):
    """Fits diffprefs for a site, use in a multiprocessing pool.

    `argtup` is return value of `treelikForDiffPrefsBySite`.

    The return value is a tuple of half the absolute total diffpref,
    then a tab-separated string describing the results for the site."""
    (site, treeliks, nograd) = argtup
    try:
        treelik = treeliks['primary']
        treelik.maximizeLikelihood(optimize_brlen=False, approx_grad=nograd)
    except RuntimeError as err:
        warnings.warn("Optimization failed with primary model; "
                      "trying secondary model.\n"
                      "Error message:\n{0}".format(str(err)))
        treelik = treeliks['secondary']
        treelik.maximizeLikelihood(optimize_brlen=False, approx_grad=nograd)
    diffprefs = treelik.model.pi[0] - treelik.model.origpi[0]
    resultstr = ([str(site)] +
                 ['{0:.4f}'.format(diffprefs[a]) for a in range(N_AA)])
    halfabssum = 0.5 * numpy.absolute(diffprefs).sum()
    resultstr.append('{0:.4f}'.format(halfabssum))
    return (halfabssum, '\t'.join(resultstr))


def treeliksForOmegaBySite(site, tl, prefslist, fixsyn, nograd):
    """Returns `TreeLikelihoods` for fitting omega for a site.

    The returned value can be passed to `fitOmegaBySite`."""
    if isinstance(tl.model, phydmslib.models.DistributionModel):
        assert isinstance(tl.model,
                          phydmslib.models.GammaDistributedOmegaModel)
        modeltotype = tl.model.basemodel
    else:
        modeltotype = tl.model
    sitealignment = [(head, seq[3 * (site - 1): 3 * site]) for
                     (head, seq) in tl.alignment]
    freeparams = [] if fixsyn else ['mu']
    treeliks = {}
    for (name, addparam) in [('fix', []), ('fit', ['omega'])]:
        if isinstance(modeltotype, phydmslib.models.ExpCM):
            assert len(prefslist) == tl.nsites
            sitemodel = (phydmslib.models
                         .ExpCM([prefslist[site - 1]],
                                kappa=tl.model.kappa, omega=1.0,
                                beta=tl.model.beta,
                                mu=1.0, phi=modeltotype.phi,
                                freeparams=(freeparams + addparam)))
        elif isinstance(modeltotype, phydmslib.models.YNGKP_M0):
            assert prefslist is None
            sitemodel = (phydmslib.models
                         .YNGKP_M0(modeltotype.e_pw,
                                   1, kappa=tl.model.kappa,
                                   omega=1.0, mu=1.0,
                                   freeparams=(freeparams + addparam)))
        else:
            raise ValueError("Invalid modeltotype {0}".format(modeltotype))
        treeliks[name] = phydmslib.treelikelihood.TreeLikelihood(
                tl.tree, sitealignment, sitemodel,
                branchScale=tl.model.branchScale)
    return (site, treeliks, nograd)


def fitOmegaBySite(argtup):
    """Fits omega for a site, for use in multiprocessing pool.

    `argtup` is return value of `treeliksForOmegaBySite`.

    Returns tuple `(pval, omega, resultstr)` where `pval` is P-value,
    `omega` is omega, and `resultstr` is tab-separated line with
    site, omega, P, dLnL.
    """
    (site, treeliks, nograd) = argtup
    if treeliks['fix'].model.freeparams:  # can be no params if using fixsyn
        treeliks['fix'].maximizeLikelihood(optimize_brlen=False,
                                           approx_grad=nograd)
    # starting fit from mu of fix should ensure things don't get worse
    treeliks['fit'].updateParams({'mu': treeliks['fix'].model.mu})
    treeliks['fit'].maximizeLikelihood(optimize_brlen=False,
                                       approx_grad=nograd)
    omega = treeliks['fit'].model.omega
    dLnL = treeliks['fit'].loglik - treeliks['fix'].loglik
    if dLnL < -0.01:
        msg = '\nsite {0}'.format(site)
        for name in ['fix', 'fit']:
            msg += '\nloglik_{0} {1}'.format(name, treeliks[name].loglik)
            msg += '\nmu_{0} {1}'.format(name, treeliks[name].model.mu)
            msg += '\nomega_{0} {1}'.format(name, treeliks[name].model.omega)
        raise RuntimeError("loglik higher for fit omega:{0}".format(msg))
    p = scipy.stats.chi2.sf(2.0 * max(0, dLnL), df=1)
    return (p, omega, '{0}\t{1:.3f}\t{2:.3g}\t{3:.3f}'.format(
            site, omega, p, dLnL))


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # create output directory if needed
    outdir = os.path.dirname(args['outprefix'])
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)

    # output files, remove if they already exist
    underscore = '' if args['outprefix'][-1] == '/' else '_'
    logfile = '{0}{1}log.log'.format(args['outprefix'], underscore)
    modelparamsfile = '{0}{1}modelparams.txt'.format(args['outprefix'],
                                                     underscore)
    loglikelihoodfile = '{0}{1}loglikelihood.txt'.format(args['outprefix'],
                                                         underscore)
    treefile = '{0}{1}tree.newick'.format(args['outprefix'], underscore)
    omegafile = '{0}{1}omegabysite.txt'.format(args['outprefix'], underscore)
    diffprefsfile = '{0}{1}diffprefsbysite.txt'.format(args['outprefix'],
                                                       underscore)
    to_remove = [modelparamsfile, loglikelihoodfile, treefile, omegafile,
                 logfile, diffprefsfile]
    for f in to_remove:
        if os.path.isfile(f):
            os.remove(f)

    # Set up to log everything to logfile.
    logging.shutdown()
    logging.captureWarnings(True)
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                        level=logging.INFO)
    logger = logging.getLogger(prog)
    warning_logger = logging.getLogger("py.warnings")
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    warning_logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # begin execution
    try:
        # print some basic information
        logger.info('Beginning execution of {0} in directory {1}\n'.format(
                prog, os.getcwd()))
        logger.info('Progress is being logged to {0}\n'.format(logfile))
        logger.info("{0}\n".format(versionstring))
        logger.info('Parsed the following arguments:\n{0}\n'.format(
                    '\n'.join(['\t{0} = {1}'
                              .format(*tup) for tup in args.items()])))
        logger.info('Random number seed: {0}\n'.format(args['seed']))
        random.seed(args['seed'])

        # read alignment
        logger.info('Reading alignment from {0}'.format(args['alignment']))
        alignment = phydmslib.file_io.ReadCodonAlignment(args['alignment'],
                                                         checknewickvalid=True)
        logger.info(('Read {0} aligned sequences from {1}, each consisting '
                     'of {2} codons.\n')
                    .format(len(alignment), args['alignment'],
                            len(alignment[0][1]) // 3))
        seqnames = set([head for (head, seq) in alignment])

        # process the substitution model
        yngkp_match = re.compile('^YNGKP_(?P<modelvariant>M\d+)$')
        if (isinstance(args['model'],
                       str) and yngkp_match.search(args['model'])):
            for argname in ['randprefs', 'avgprefs',
                            'divpressure', 'diffprefsbysite']:
                assert not args[argname], ("'--{0}' incompatible with YNKGP"
                                           .format(argname))
            assert not args['gammaomega'], ("Can't use --gammaomega with "
                                            "YNGKP model; use 'model' of "
                                            "YNGKP_M5 to achieve the same "
                                            "result.")
            assert not args['gammabeta'], ("Can't use --gammabeta with YNGKP "
                                           "model.")
            modelvariant = (yngkp_match
                            .search(args['model'])
                            .group('modelvariant'))
            logger.info('Codon substitution model with be {0} version '
                        '(Yang, Nielsen, Goldman, & Pederson. '
                        'Genetics. 155:431-449).'.format(modelvariant))
            e_pw = numpy.ones((3, phydmslib.constants.N_NT), dtype='float')
            for p in range(3):
                for (w, nt) in phydmslib.constants.INDEX_TO_NT.items():
                    e_pw[p][w] = sum([list(seq)[p::3].count(nt) for (head, seq)
                                      in alignment])
            e_pw = e_pw / e_pw.sum(axis=1)[0]
            nsites = int(len(alignment[0][1]) / 3)
            model = phydmslib.models.YNGKP_M0(e_pw, nsites,
                                              freeparams=['mu', 'omega',
                                                          'kappa'])
            if modelvariant == 'M0':
                pass
            elif modelvariant == 'M5':
                logger.info('For this {0} model, omega will be drawn from '
                            '{1} gamma-distributed categories.'
                            .format(args['model'], args['ncats']))
                model = phydmslib.models.GammaDistributedOmegaModel(
                        model, args['ncats'])
            else:
                raise ValueError("Invalid variant {0} in {1}".format(
                        modelvariant, args['model']))
            prefslist = None
        elif (isinstance(args['model'], tuple) and len(args['model']) == 2 and
                args['model'][0] == 'ExpCM'):
            prefsfile = args['model'][1]
            logger.info(('The model will be an ExpCM informed by site-specific'
                         ' amino-acid preferences in {0}').format(prefsfile))
            for (argname, desc) in [('avgprefs', 'averaged'),
                                    ('randprefs', 'randomized')]:
                if args[argname]:
                    logger.info('Preferences will be {0} across sites.'
                                .format(desc))
            prefs = phydmslib.file_io.readPrefs(prefsfile,
                                                minpref=args['minpref'],
                                                avgprefs=args['avgprefs'],
                                                randprefs=args['randprefs'],
                                                seed=args['seed'])
            sites = sorted(prefs.keys())
            prefslist = [prefs[r] for r in sites]  # convert from dict to list
            assert (len(prefs) ==
                    len(alignment[0][1]) // 3), ("The number of preferences "
                                                 "in {0} does not match the "
                                                 "number of codon sites in "
                                                 "the alignment"
                                                 .format(prefsfile))
            logger.info('Successfully read site-specific amino-acid '
                        'preferences for all {0} sites.\n'.format(len(prefs)))
            freeparams = ['mu', 'kappa', 'omega', 'beta']
            if args['fitphi']:
                assert not args['divpressure'], (
                        "Can't use --divpressure and --fitphi")
                freeparams.append('eta')
                logger.info('Nucleotide frequency parameters phi will '
                            'be optimized by maximum likelihood.\n')
                model = phydmslib.models.ExpCM(prefslist,
                                               freeparams=freeparams)
            else:
                g = numpy.ndarray(phydmslib.constants.N_NT, dtype='float')
                for (w, nt) in phydmslib.constants.INDEX_TO_NT.items():
                    g[w] = sum([seq.count(nt) for (head, seq) in alignment])
                assert ((len(alignment) * len(prefs) * 3)
                        == (g.sum() + sum([seq.count('-') for (head, seq) in
                                           alignment]))), ("Alignment contains"
                                                           " invalid "
                                                           "nucleotide "
                                                           "characters")
                g /= g.sum()
                logger.info('Nucleotide frequency parameters phi will be set '
                            'so stationary state matches alignment nucleotide '
                            'frequencies of {0}\n'
                            .format(', '.join(['{0} = {1:.3f}'
                                    .format(nt, g[w]) for (w, nt) in
                                    phydmslib.constants.INDEX_TO_NT.items()])))
                if not args['divpressure']:
                    model = (phydmslib.models
                             .ExpCM_empirical_phi(prefslist, g,
                                                  freeparams=freeparams))
                else:
                    for otherarg in ['omegabysite', 'diffprefsbysite',
                                     'gammomega', 'gammabeta']:
                        if otherarg in args and args[otherarg]:
                            raise ValueError("Can't use --divpressure and "
                                             "--{0}".format(otherarg))
                    freeparams.append('omega2')
                    divpressure = phydmslib.file_io.readDivPressure(
                            args['divpressure'])
                    assert set(prefs.keys()) == set(divpressure.keys()), (
                            "The sites in {0} are different from {1}."
                            .format(args['divpressure'], prefsfile))
                    logger.info(('Read diversifying pressure from {0} '
                                 'for all sites.').format(args['divpressure']))
                    divPressureSites = list(divpressure.keys())
                    divpressure = numpy.array([divpressure[x] for
                                               x in sorted(divPressureSites)])
                    model = phydmslib.models.ExpCM_empirical_phi_divpressure(
                            prefslist, g, divpressure, freeparams=freeparams)
            if args['gammaomega']:
                assert not args['gammabeta'], ("Can't use --gammabeta with"
                                               "--gammomega")
                logger.info(('Omega will be drawn from {0} gamma-distributed '
                            'categories.\n').format(args['ncats']))
                model = phydmslib.models.GammaDistributedOmegaModel(
                        model, args['ncats'])
            if args['gammabeta']:
                assert not args['gammaomega'], ("Can't use --gammabeta with"
                                                "--gammomega")
                logger.info(('Beta will be drawn from {0} gamma-distributed '
                            'categories.\n').format(args['ncats']))
                model = phydmslib.models.GammaDistributedBetaModel(
                        model, args['ncats'])
        else:
            raise ValueError("Invalid model of {0}".format(args['model']))
        if args['initparams']:
            print("Initializing model parameters from {0}".format(
                    args['initparams']))
            with open(args['initparams']) as f:
                paramvalues = {}
                for line in f:
                    (param, value) = line.split('=')
                    param = param.strip()
                    if param in model.freeparams:
                        paramvalues[param.strip()] = float(value)
            assert paramvalues, "No values to initialize."
            print("Initializing the following values:\n\t{0}\n"
                  .format('\n\t'.join(['{0} = {1:.5f}'.format(param, value)
                                       for (param, value) in
                                       sorted(paramvalues.items())])))
            model.updateParams(paramvalues)

        # read tree
        logger.info("Reading tree from {0}".format(args['tree']))
        tree = Bio.Phylo.read(args['tree'], 'newick')
        tipnames = set([clade.name for clade in tree.get_terminals()])
        assert len(tipnames) == tree.count_terminals(), "non-unique tip names?"
        assert tipnames == seqnames, (
            "Names in alignment do not match those in tree.\nSequences in "
            "alignment but NOT in tree:\n\t{0}\n Sequences in tree but NOT in "
            "alignment:\n\t{1}".format('\n\t'.join(seqnames - tipnames),
                                       '\n\t'.join(tipnames - seqnames)))
        logger.info('Tree has {0} tips.'.format(len(tipnames)))
        tree.root_at_midpoint()
        assert tree.is_bifurcating(), "Tree is not bifurcating: cannot handle"
        nadjusted = 0
        for node in tree.get_terminals() + tree.get_nonterminals():
            if (node.branch_length is None) and (node == tree.root):
                node.branch_length = args['minbrlen']
            elif node.branch_length < args['minbrlen']:
                nadjusted += 1
                node.branch_length = args['minbrlen']
        logger.info('Adjusted {0} branch lengths up to minbrlen {1}\n'.format(
                nadjusted, args['minbrlen']))

        # set up tree likelihood
        logger.info('Initializing TreeLikelihood..')
        tl = phydmslib.treelikelihood.TreeLikelihood(tree, alignment, model)
        logger.info('TreeLikelihood initialized.')

        # maximize likelihood
        printfunc = None
        if args['opt_details']:
            printfunc = logger.info
        logger.info('Maximizing log likelihood (initially {0:.2f}).'.format(
                tl.loglik))
        if args['brlen'] == 'scale':
            optimize_brlen = False
            logger.info("Branch lengths will be scaled but not optimized "
                        "individually.")
        elif args['brlen'] == 'optimize':
            logger.info("Branch lengths will be optimized individually.")
            optimize_brlen = True
        else:
            raise ValueError("Invalid brlen {0}".format(args['brlen']))
        if args['profile']:
            import cProfile
            import pstats
            pstatsfile = '{0}_pstats'.format(args['outprefix'])
            logger.info('Maximizing with cProfile (probably slower).')
            logger.info('Profile stats will be in to {0}'.format(pstatsfile))
            maxresult = []

            def wrapper(maxresult):  # wrapper to get return value frm cProfile
                maxresult.append(tl.maximizeLikelihood(
                        optimize_brlen=optimize_brlen,
                        approx_grad=args['nograd'],
                        printfunc=printfunc))
            cProfile.runctx('wrapper(maxresult)', globals(), locals(),
                            pstatsfile)
            maxresult = maxresult[0]
            for psort in ['cumulative', 'tottime']:
                fname = '{0}_{1}.txt'.format(pstatsfile, psort)
                logger.info('Writing profile stats sorted by {0} to {1}'
                            .format(psort, fname))
                f = open(fname, 'w')
                s = pstats.Stats(pstatsfile, stream=f)
                s.strip_dirs()
                s.sort_stats(psort)
                s.print_stats()
                f.close()
        else:
            maxresult = tl.maximizeLikelihood(optimize_brlen=optimize_brlen,
                                              approx_grad=args['nograd'],
                                              printfunc=printfunc)
        logger.info('Maximization complete:\n\t{0}'.format(
                maxresult.replace('\n', '\n\t')))
        logger.info('Optimized log likelihood is {0:.2f}.'.format(tl.loglik))
        logger.info('Writing log likelihood to {0}'.format(loglikelihoodfile))
        with open(loglikelihoodfile, 'w') as f:
            f.write('log likelihood = {0:.2f}'.format(tl.loglik))
        params = ('\n\t'.join(['{0} = {1:6g}'
                  .format(p, pvalue) for (p, pvalue) in
                          sorted(tl.model.paramsReport.items())]))
        logger.info('Model parameters after optimization:\n\t{0}'
                    .format(params))
        logger.info('Writing model parameters to {0}'.format(modelparamsfile))
        with open(modelparamsfile, 'w') as f:
            f.write(params.replace('\t', ''))
        logger.info('Writing the optimized tree to {0}\n'.format(treefile))
        Bio.Phylo.write(tl.tree, treefile, 'newick')

        # get number of cpus for multiprocessing
        if args['ncpus'] == -1:
            ncpus = multiprocessing.cpu_count()
        else:
            ncpus = args['ncpus']
        assert ncpus >= 1, "{0} CPUs specified".format(ncpus)

        # optimize a different omega for each site
        if args['omegabysite']:
            logger.info("\nFitting a different omega to each site to "
                        "detect diversifying selection.")
            if args['omegabysite_fixsyn']:
                fixsynstr = 'Synonymous rate will be fixed across sites.'
            else:
                fixsynstr = 'Will fit different synonymous rate for each site.'
            logger.info(fixsynstr)
            logger.info("Fitting with {0} CPUs...".format(ncpus))
            # We use a multiprocessing pool if more than 1 CPU.
            # To save memory, use pool sizes equal to number of CPUs
            omegaresults = []
            for firstsite in range(1, tl.nsites + 1, ncpus):
                currentsites = range(firstsite,
                                     min(tl.nsites + 1,
                                         firstsite + ncpus))
                if ncpus > 1:
                    pool = multiprocessing.Pool(ncpus)
                    mapfunc = pool.imap
                else:
                    mapfunc = map
                treelks = [treeliksForOmegaBySite(site,
                                                  tl,
                                                  prefslist,
                                                  args['omegabysite_fixsyn'],
                                                  args['nograd'])
                           for site in currentsites]
                omegaresults += mapfunc(fitOmegaBySite, treelks)
                if ncpus > 1:
                    pool.close()
                    pool.join()
            pvalues = numpy.array([tup[0] for tup in omegaresults])
            omegas = numpy.array([tup[1] for tup in omegaresults])
            omega_strs = [tup[2] for tup in omegaresults]
            # compute Q-values separately for each sign of omega
            qvalues = numpy.ones(pvalues.shape)
            for op in [operator.ge, operator.le]:
                pvalues_sign = numpy.where(op(omegas, 1), pvalues, 1)
                qvalues_sign = (statsmodels.sandbox.stats.multicomp
                                .multipletests(pvalues_sign, method='fdr_bh')
                                [1])
                qvalues = numpy.minimum(qvalues, qvalues_sign)
            sortindex = numpy.argsort(qvalues)
            qvalues = numpy.take(qvalues, sortindex)
            omega_strs = numpy.take(omega_strs, sortindex)
            omegaresults = ['{0}\t{1:.3g}'.format(omega_str, q) for
                            (omega_str, q) in zip(omega_strs, qvalues)]
            logger.info("Completed fitting the site-specific omega values.")
            logger.info("Writing results to {0}\n".format(omegafile))
            with open(omegafile, 'w') as f:
                f.write('# Omega fit to each site after fixing tree and '
                        'all other parameters.\n# Fits compared to null model '
                        'of omega = 1.\n# P-values NOT corrected for multiple '
                        'testing, so consider Q-values too.\n# Q-values '
                        'computed separately for omega > and < 1 # {0}\n#\n'
                        'site\tomega\tP\tdLnL\tQ\n{1}'
                        .format(fixsynstr, '\n'.join(omegaresults)))

        # optimize differential preferences for each site
        if args['diffprefsbysite']:
            logger.info("\nFitting differential preferences for each site to "
                        "detect differential selection.")
            if args['fitprefsmethod'] == 1:
                modeltype = phydmslib.models.ExpCM_fitprefs
                othermodeltype = phydmslib.models.ExpCM_fitprefs2
            elif args['fitprefsmethod'] == 2:
                modeltype = phydmslib.models.ExpCM_fitprefs2
                othermodeltype = phydmslib.models.ExpCM_fitprefs
            else:
                raise ValueError("Invalid fitprefsmethod")
            logger.info("For the fitting, using {0} implementation "
                        "as specified by fitprefsmethod = {1}"
                        .format(modeltype.__name__, args['fitprefsmethod']))
            logger.info("Fitting with {0} CPUs...".format(ncpus))
            #  We use a multiprocessing pool if more than 1 CPU.
            #  To save memory, use pool sizes equal to number of CPUs
            diffprefresults = []
            for firstsite in range(1, tl.nsites + 1, ncpus):
                currentsites = range(firstsite, min(tl.nsites + 1,
                                     firstsite + ncpus))
                if ncpus > 1:
                    pool = multiprocessing.Pool(ncpus)
                    mapfunc = pool.imap
                else:
                    mapfunc = map
                diffprefresults += mapfunc(
                    fitDiffPrefsBySite, [
                        treelikForDiffPrefsBySite(site, tl, prefslist,
                                                  args['diffprefsprior'],
                                                  modeltype,
                                                  othermodeltype,
                                                  args['nograd'])
                        for site in currentsites])
                if ncpus > 1:
                    pool.close()
                    pool.join()
            logger.info("Completed fitting site-specific "
                        "differential preferences.")
            logger.info("Writing results to {0}\n".format(diffprefsfile))
            with open(diffprefsfile, 'w') as f:
                f.write(('# Differential preferences fit to each site.\n'
                         '# Regularizing prior: {0}\n#\n'
                         .format(args['diffprefsprior']))
                        + 'site\t'
                        + ('\t'.join(['dpi_{0}'
                           .format(INDEX_TO_AA[a]) for a in range(N_AA)]))
                        + '\thalf_sum_abs_dpi\n')
                f.write('\n'.join([tup[1] for tup in sorted(diffprefresults,
                        reverse=True)]))

    except Exception:
        logger.exception('Terminating {0} at {1} with '
                         'ERROR'.format(prog, time.asctime()))
        raise
    else:
        logger.info('Successful completion of {0}'.format(prog))
    finally:
        logging.shutdown()


if __name__ == '__main__':
    main()  # run the script
