#!python

"""Computes amino-acid preferences."""


import os
import re
import sys
import time
import logging
import multiprocessing
import natsort
import pandas
from dms_tools2 import CODONS, AAS, AAS_WITHSTOP, CODON_TO_AA
import dms_tools2.utils
import dms_tools2.parseargs
import dms_tools2.prefs



def main():
    """Main body of script."""

    parser = dms_tools2.parseargs.prefsParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # set up names of output files
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = ''
    filesuffixes = {
            'log':'.log',
            'prefs':'_prefs.csv',
            }
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['name'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed?
    if args['use_existing'] == 'yes' and all(map(
                os.path.isfile, files.values())):
        print("Output files already exist and '--use_existing' is 'yes', "
              "so exiting with no further action.")
        sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:

        assert dms_tools2.parseargs.checkName(args['name'], 'name')

        # remove expected output files if they already exist
        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        # read in the counts files
        if not args['indir']:
            args['indir'] = ''
        else:
            assert os.path.isdir(args['indir']), "No --indir {0}".format(
                    args['indir'])
        if args['chartype']:
            countsuffix = '_codoncounts.csv'
        else:
            raise ValueError("Invalid chartype")
        counts = {}
        for ctype in ['pre', 'post']:
            fname = os.path.join(args['indir'], args[ctype])
            if not os.path.isfile(fname):
                if os.path.isfile(fname + countsuffix):
                    fname = fname + countsuffix
                else:
                    raise ValueError("Missing file for --{0}".format(ctype))
            logger.info("Reading {0}-selection counts from {1}".format(
                    ctype, fname))
            counts[ctype] = pandas.read_csv(fname)
        if args['err']:
            ferr = {}
            for (i, ctype) in enumerate(['pre', 'post']):
                fname = os.path.join(args['indir'], args['err'][i])
                if not os.path.isfile(fname):
                    if os.path.isfile(fname + countsuffix):
                        fname = fname + countsuffix
                    else:
                        raise ValueError("Missing file {0} for --err".format(
                                i + 1))
                ferr[ctype] = fname
            if len(set(map(os.path.realpath, ferr.values()))) == 1:
                error_model = 'same'
                logger.info("Reading error-control counts from {0}"
                        .format(ferr['pre']))
                counts['err'] = pandas.read_csv(ferr['pre'])
            else:
                error_model = 'different'
                for (ctype, f) in ferr.items():
                    logger.info("Reading {0}-selection error-control "
                            "counts from {0}".format(ctype, f))
                    counts['err{0}'.format(ctype)] = pandas.read_csv(f)
        else:
            error_model = 'none'

        # get sites and wildtype identities, sorted by site
        sites = wts = None
        for c in counts.values():
            (csites, cwts) = zip(*natsort.realsorted(zip(
                    c['site'].values, c['wildtype'].values)))
            if sites == wts == None:
                sites = csites
                wts = cwts
            else:
                assert sites == csites, "different sets of sites"
                assert wts == cwts, "different wildtype identities"
        assert len(sites) == len(set(sites)), "non-unique sites"
        logger.info("Read counts for {0} sites.".format(len(sites)))
        logger.info("Here are sites and wildtype identities:\n\t{0}\n".format(
                '\n\t'.join(['{0}\t{1}'.format(r, wt) for (r, wt) in zip(
                sites, wts)])))
        if args['excludestop'] == 'yes' and args['chartype'] == 'codon_to_aa':
            if CODON_TO_AA[wts[-1]] == '*':
                sites = sites[ : -1]
                wts = wts[ : -1]
                logger.info("Excluding the last site as it a stop codon.\n")
            assert '*' not in wts, "wildtype of '*' for `--excludestop yes`"

        # get list of characters
        if args['chartype'] == 'codon_to_aa':
            if args['excludestop'] == 'yes':
                charlist = AAS
            elif args['excludestop'] == 'no':
                charlist = AAS_WITHSTOP
            else:
                raise ValueError("Invalid --excludestop")
        else:
            raise ValueError("Invalid chartype")

        if args['method'] == 'ratio':
            logger.info("Computing preferences as normalized enrichment "
                    "ratios...")
            if args['chartype'] == 'codon_to_aa':
                pre = dms_tools2.utils.codonToAACounts(counts['pre'])
                post = dms_tools2.utils.codonToAACounts(counts['post'])
                if error_model == 'same':
                    errpre = errpost = dms_tools2.utils.codonToAACounts(
                            counts['err'])
                elif error_model == 'different':
                    errpre = dms_tools2.utils.codonToAACounts(
                            counts['errpre'])
                    errpost = dms_tools2.utils.codonToAACounts(
                            counts['errpost'])
                else:
                    assert error_model == 'none'
                    errpre = errpost = None
                wts = [CODON_TO_AA[wt] for wt in wts]
            else:
                raise ValueError("invalid chartype")
            prefs = dms_tools2.prefs.inferPrefsByRatio(charlist, sites,
                    wts, pre, post, errpre, errpost, int(args['pseudocount']))
            logger.info("Writing preferences to {0}".format(files['prefs']))
            prefs.to_csv(files['prefs'], index=False)

        elif args['method'] == 'bayesian':
            logger.info("Setting up for Bayesian inference of the prefs")

            # compute mutation rates for priors
            if args['chartype'] == 'codon_to_aa':
                for ctype in list(counts.keys()):
                    counts[ctype] = dms_tools2.utils.annotateCodonCounts(
                            counts[ctype])
                avgmu = counts['pre'][['mutfreq{0}nt'.format(nmuts + 1) 
                        for nmuts in range(3)]].mean().values.sum()
                if error_model == 'none':
                    logger.info("Average mutation rate:\n\t{0}".format(avgmu))
                elif error_model == 'same':
                    avgepsilon = counts['err'][['mutfreq{0}nt'.format(nmuts + 1) 
                            for nmuts in range(3)]].mean().values
                    avgmu -= avgepsilon.sum()
                    logger.info("Avg mutation and error rates:\n\t{0}\n\t{1}"
                            .format(avgmu, avgepsilon))
                elif error_model == 'different':
                    avgepsilon = counts['errpre'][['mutfreq{0}nt'.format(
                            nmuts + 1) for nmuts in range(3)]].mean().values
                    avgmu -= avgepsilon.sum()
                    avgrho = counts['errpost'][['mutfreq{0}nt'.format(
                            nmuts + 1) for nmuts in range(3)]].mean().values
                    logger.info("Average mutation, pre-selection, and post-"
                            "selection error rates:\n\t{0}\n\t{1}\n\t{2}"
                            .format(avgmu, avgepsilon, avgrho))
                else:
                    raise ValueError("invalid error_model")
                assert 0 < avgmu < 1, "Invalid avg mut rate {0}".format(avgmu)
            else:
                raise ValueError("Invalid chartype")

            logger.info("Compiling ``pystan`` model...")
            pystan_error_model = ({
                    'none':dms_tools2.prefs.StanModelNoneErr,
                    'same':dms_tools2.prefs.StanModelSameErr,
                    'different':dms_tools2.prefs.StanModelDifferentErr,
                    }[error_model])()
            logger.info("Completed compiling ``pystan`` model.\n")

            # begin inferring prefs in a multiprocessing pool
            if args['ncpus'] == -1:
                ncpus = multiprocessing.cpu_count()
            else:
                ncpus = min(args['ncpus'], multiprocessing.cpu_count())
            assert ncpus > 0
            pool = multiprocessing.Pool(ncpus)
            results = {}
            retry = {}
            logged = {}
            pi_means = {}

            # number of MCMC iterations
            niter = {'none':2500,
                     'same':10000,
                     'different':20000}[error_model]

            logger.info("Beginning MCMC runs...")
            # run inference for each site
            for (r, wt) in zip(sites, wts):

                # build counts and priors to pass to `inferSitePrefs`
                rcounts = {}
                priors = {}
                if args['chartype'] == 'codon_to_aa':
                    wtaa = CODON_TO_AA[wt]
                    aacounts = {}
                    for (ctype, df) in counts.items():
                        aacounts[ctype] = dms_tools2.utils.codonToAACounts(
                                df[df['site'] == r])
                        aacounts[ctype]['wildtype'] = [wtaa]
                    rcounts = dict([(ctype, dict(aa[charlist].ix[0]))
                            for (ctype, aa) in aacounts.items()])
                    priors = {}
                    for prior in ['mur_prior_params', 
                                'epsilonr_prior_params',
                                'rhor_prior_params']:
                        priors[prior] = dict([(aa, 0.0) for aa in charlist])
                        priors[prior][wtaa] = 1.0
                    avgmu_percodon = avgmu / float(len(CODONS))
                    nchars_with_m = dict([(nnt, 0) for nnt in range(4)])
                    for x in CODONS:
                        nnt = sum([xi == wti for (xi, wti) in zip(x, wt)])
                        nchars_with_m[nnt] += 1
                    for x in CODONS:
                        if x == wt:
                            continue
                        aa = CODON_TO_AA[x]
                        if aa == '*' and args['excludestop'] == 'yes':
                            continue
                        priors['mur_prior_params'][aa] += avgmu_percodon
                        priors['mur_prior_params'][wtaa] -= avgmu_percodon
                        nnt = sum([xi == wti for (xi, wti) in zip(x, wt)])
                        if error_model in ['different', 'same']:
                            y = avgepsilon[nnt - 1] / nchars_with_m[nnt]
                            priors['epsilonr_prior_params'][aa] += y
                            priors['epsilonr_prior_params'][wtaa] -= y
                        if error_model == 'different':
                            y = avgrho[nnt - 1] / nchars_with_m[nnt]
                            priors['rhor_prior_params'][aa] += y
                            priors['rhor_prior_params'][wtaa] -= y
                    wt = wtaa # set wildtype to amino acid for below
                else:
                    raise ValueError("Invalid chartype")

                # scale priors by concentration parameters
                assert all([c > 0 for c in args['conc']])
                (cpi, cmu, cerr) = args['conc']
                for x in charlist:
                    priors['mur_prior_params'][x] *= len(charlist) * cmu
                    priors['epsilonr_prior_params'][x] *= len(charlist) * cerr
                    priors['rhor_prior_params'][x] *= len(charlist) * cerr
                priors['pir_prior_params'] = dict([(x, cpi) for x in charlist])

                # start MCMC for site
                logged[r] = False
                argslist = [charlist, wt, pystan_error_model, rcounts, priors,
                        1, niter]
                results[r] = pool.apply_async(dms_tools2.prefs.inferSitePrefs,
                        tuple(argslist))
                argslist[6] += 1 # different seed for second try
                retry[r] = tuple(argslist)

            while not(all([logged[r] for r in sites])):
                time.sleep(1)
                for r in sites:
                    if results[r].ready() and not logged[r]:
                        logger.info("Getting results for site {0}...".format(r))
                        (converged, pi, pi95, logstring) = results[r].get()
                        if not converged and r in retry:
                            logger.warning("Problems for site {0}, re-trying. "
                                    "Here is message from prior attempt:\n{1}\n"
                                    .format(r, logstring))
                            del retry[r]
                        elif not converged:
                            raise RuntimeError("Failed for site {0}:\n{1}"
                                    .format(r, logstring))
                        else:
                            logged[r] = True
                            logger.info("Finished for site {0}:\n{1}\n"
                                    .format(r, logstring))
                            pi_means[r] = pi
                            assert abs(1 - sum(pi_means[r].values())) < 1e-4
            pool.terminate()
            logger.info("Finished inferring the preferences.\n")

            # build up prefs and write to file
            prefs_d = dict([(key, []) for key in ['site'] + charlist])
            for r in sites:
                prefs_d['site'].append(r)
                for c in charlist:
                    prefs_d[c].append(pi_means[r][c])
            prefs = pandas.DataFrame(prefs_d)[['site'] + charlist]
            logger.info("Writing preferences to {0}".format(files['prefs']))
            prefs.to_csv(files['prefs'], index=False)

    except:
        logger.exception('Terminating {0} with ERROR'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()


if __name__ == '__main__':
    main() # run the script
