#!python

"""Computes differential selection.

Written by Jesse Bloom."""


import sys
import os
import re
import math
import logging
import natsort
import pandas
import dms_tools2
import dms_tools2.utils
import dms_tools2.parseargs
import dms_tools2.diffsel


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools2.parseargs.diffselParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # set up names of output files
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = ''
    filesuffixes = {
            'log':'.log',
            'mutdiffsel':'_mutdiffsel.csv',
            'sitediffsel':'_sitediffsel.csv',            
            }
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['name'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed?
    if args['use_existing'] == 'yes' and all(map(
                os.path.isfile, files.values())):
        print("Output files already exist and '--use_existing' is 'yes', "
              "so exiting with no further action.")
        sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally statement
    try:

        assert dms_tools2.parseargs.checkName(args['name'], 'name')
        assert args['pseudocount'] > 0

        # remove expected output files if they already exist
        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        if args['chartype'] == 'codon_to_aa':
            countcharacters = dms_tools2.CODONS
            translate_to_aa = True
            countsuffix = '_codoncounts.csv'
        else:
            raise ValueError("Bad chartype {0}".format(args['chartype']))

        counts = {}
        if not args['indir']:
            args['indir'] = ''
        else:
            assert os.path.isdir(args['indir']), "No --indir {0}".format(
                    args['indir'])
        for (arg, desc) in [
                ('sel', 'selected sample'),
                ('mock', 'mock-selected sample'),
                ('err', 'error-control sample'),
                ]:
            if arg == 'err' and not args['err']:
                counts[arg] = None
                continue
            logger.info("Reading {0} counts from {1}".format(desc, args[arg]))
            fname = os.path.join(args['indir'], args[arg])
            if not os.path.isfile(fname):
                if os.path.isfile(fname + countsuffix):
                    fname = fname + countsuffix
                else:
                    raise ValueError("Missing file for --{0}:\n{1}"
                            .format(desc, fname))
            counts[arg] = pandas.read_csv(fname)

        logger.info("Computing mutdiffsel...")
        mutdiffsel = dms_tools2.diffsel.computeMutDiffSel(
                counts['sel'], counts['mock'], countcharacters,
                args['pseudocount'], translate_to_aa, counts['err'],
                args['mincount'])
        if args['excludestop'] == 'yes':
            mutdiffsel = mutdiffsel.query(
                    '(mutation != "*") & (wildtype != "*")')
        mutdiffsel = mutdiffsel.sort_values('mutdiffsel', ascending=False)
        logger.info("Mutations with largest mutdiffsel:\n{0}\n".format(
                mutdiffsel.head(10).to_string(index=False, 
                float_format='{:.2f}'.format)))
        logger.info("Writing to {0}".format(files['mutdiffsel']))
        mutdiffsel.to_csv(files['mutdiffsel'], index=False, na_rep='NaN')

        logger.info("Computing sitediffsel...")
        sitediffsel = (dms_tools2.diffsel.mutToSiteDiffSel(mutdiffsel)
                       .sort_values('abs_diffsel', ascending=False))
        logger.info("Sites with largest absolute sitediffsel:\n{0}\n"
                .format(sitediffsel.head(10).to_string(index=False,
                float_format='{:.2f}'.format)))
        logger.info("Writing to {0}".format(files['sitediffsel']))
        sitediffsel.to_csv(files['sitediffsel'], index=False, na_rep='NaN')

    except:
        logger.exception('Terminating {0} with ERROR'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
