#!python

"""Makes sequence logo plots.

Written by Jesse Bloom."""


import sys
import math
import os
import re
import time
import logging
import natsort
import numpy
import pandas
import phydmslib.weblogo
import dms_tools2.parseargs
import dms_tools2.utils
import dms_tools2.prefs


def omegabysiteOverlay(omegabysitefile, minP=1e-4):
    """Creates an omegabysite overlay.

    Args:
        `omegabysitefile` (str)
            Name of ``*omegabysite.txt`` file created by ``phydms``.
        `minP` (float)
            Show all P values less than this as <= this.

    Returns:
        The 2-tuple `(overlaytup, fix_limits_val)` where
        `overlaytup` is entry to be added to `overlay` list,
        and `fix_limits_val` is value to be added to 
        `fix_limits` under key `shortname`.
    """
    omegas = pandas.read_csv(omegabysitefile, comment='#',
            sep='\t')
    shortname = '$\omega_r$'
    longname = ('{0} $<1 \; \longleftarrow$ $\log_{{10}} P$ '
            '$\longrightarrow \;$ {0} $>1$'.format(shortname))
    prop_d = {}
    for r in omegas['site'].values:
        omegar = float(omegas[omegas['site'] == r]['omega'])
        p = float(omegas[omegas['site'] == r]['P'])
        if omegar < 1:
            prop_d[str(r)] = max(math.log10(minP), math.log10(p))
        else:
            prop_d[str(r)] = -max(math.log10(minP), math.log10(p))
    overlaytup = (prop_d, shortname, longname)
    ticklocs = [itick for itick in range(int(math.log10(minP)),
            1 - int(math.log10(minP)))]
    ticknames = [-abs(itick) for itick in ticklocs]
    fix_limits_val = (ticklocs, ticknames)
    return (overlaytup, fix_limits_val)


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools2.parseargs.logoplotParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # what type of data are we plotting?
    datatype = [x for x in ['prefs', 'diffsel', 'fracsurvive',
                'diffprefs', 'muteffects'] if args[x]]
    if len(datatype) != 1:
        raise ValueError("Did not recognize a valid data type")
    datatype = datatype[0]

    # define output file names
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = ''
    filesuffixes = {
            'log':'.log',
            'logo':'_{0}.pdf'.format(datatype),

            }
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['name'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed?
    if args['use_existing'] == 'yes' and all(map(
            os.path.isfile, files.values())):
        print("Output files already exist and '--use_existing' is 'yes', "
                "so exiting with no furhter action.")
        sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:
        # remove expected output files if they already exist
        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        # some checking on arguments
        assert dms_tools2.parseargs.checkName(args['name'], 'name')
        assert args['nperline'] >= 1
        assert args['numberevery'] >= 1
        assert args['stringency'] >= 0

        # read data
        logger.info("Reading {0} from file {1}...".format(datatype, 
                args[datatype]))
        assert os.path.isfile(args[datatype]), "Can't find {0}".format(
                args[datatype])
        data = pandas.read_csv(args[datatype])
        assert 'site' in data.columns, "no 'site' column"
        data['site'] = data['site'].astype(str)
        sites = data['site'].unique()
        logger.info("Read data for {0} sites.\n".format(len(sites)))
        if args['sortsites'] == 'yes':
            sites = natsort.realsorted(sites)
        elif args['sortsites'] != 'no':
            raise ValueError("invalid --sortsites")

        plotdatatype = datatype

        scalebar = False
        if datatype == 'prefs':
            if args['stringency'] != 1:
                logger.info("Re-scaling preferences by stringency "
                    "parameter {0}".format(args['stringency']))
                data = dms_tools2.prefs.rescalePrefs(data, 
                        args['stringency'])
            nosepline = True
            ylimits = None
            assert not args['scalebar'], "can't use `scalebar` with `prefs`"

        elif datatype == 'fracsurvive':
            plotdatatype = 'diffsel'
            data = (data.pivot_table(index='site', values='mutfracsurvive',
                                columns='mutation')
                        .fillna(0)
                        .reindex(sites, axis='rows')
                        .reset_index()
                        )
            args['letterheight'] *= 2 # taller letter stacks for diffsel
            nosepline = True
            ymax = (data.drop('site', axis=1)
                    .clip(0, None)
                    .sum(axis=1)
                    .max()
                    )
            ymin = -ymax / 100.0
            if args['fracsurvivemax'] is not None:
                assert args['fracsurvivemax'] > 0, 'fracsurvivemax <= 0'
                assert args['fracsurvivemax'] >= ymax, \
                        "range of data exceed `--fracsurvivemax`"
                ylimits = (ymin, args['fracsurvivemax'])
            else:
                assert ymax > 0, "no fracsurvive values > 0"
                ylimits = (ymin, 1.02 * ymax)
            if args['scalebar']:
                scalebar = (float(args['scalebar'][0]),
                        args['scalebar'][1])

        elif datatype == 'diffsel':
            data = (data.pivot_table(index='site', values='mutdiffsel',
                                columns='mutation')
                        .fillna(0)
                        .reindex(sites, axis='rows')
                        )
            if args['restrictdiffsel'] == 'positive':
                data = data.clip(0, None)
            elif args['restrictdiffsel'] == 'negative':
                data = data.clip(None, 0)
            elif args['restrictdiffsel'] != 'all':
                raise ValueError('invalid restrictdiffsel')
            data = data.reset_index()

            args['letterheight'] *= 2 # taller letter stacks for diffsel
            nosepline = {'no':True, 'yes':False}[args['sepline']]

            ymin = (data.drop('site', axis=1)
                        .clip(None, 0)
                        .sum(axis=1)
                        .min()
                        )
            ymax = (data.drop('site', axis=1)
                        .clip(0, None)
                        .sum(axis=1)
                        .max()
                        )
            if args['diffselrange']:
                diffselrange = args['diffselrange']
                if (diffselrange[0] > ymin) or (diffselrange[1] < ymax):
                    raise ValueError("Invalid diffselrange does not "
                            "fully include data range of {0} to {1}"
                            .format(ymin, ymax))
                ymin = diffselrange[0]
                ymax = diffselrange[1]
            dy = ymax - ymin
            ylimits = (ymin - 0.02 * dy, ymax + 0.02 * dy)
            if args['scalebar']:
                scalebar = (float(args['scalebar'][0]),
                            args['scalebar'][1])

        elif datatype == 'muteffects':
            plotdatatype = 'diffsel'
            data = (data.pivot_table(index='site', values='log2effect',
                                columns='mutant')
                        .fillna(0)
                        .reindex(sites, axis='rows')
                        .reset_index()
                        )

            nosepline = {'no':True, 'yes':False}[args['sepline']]
            args['letterheight'] *= 2 # taller letter stacks

            ymin = (data.drop('site', axis=1)
                        .clip(None, 0)
                        .sum(axis=1)
                        .min()
                        )
            ymax = (data.drop('site', axis=1)
                        .clip(0, None)
                        .sum(axis=1)
                        .max()
                        )

            if args['muteffectrange']:
                muteffectrange = args['muteffectrange']
                if (muteffectrange[0] > ymin) or (muteffectrange[0] < ymax):
                    raise ValueError("Invalid `muteffectrange` does not "
                            "fully include data range of {0} to {1}"
                            .format(ymin, ymax))
                ymin = muteffectrange[0]
                ymax = muteffectrange[1]

            dy = ymax - ymin
            ylimits = (ymin - 0.02 * dy, ymax + 0.02 * dy)

            if args['scalebar']:
                scalebar = (float(args['scalebar'][0]),
                            args['scalebar'][1])

        elif datatype == 'diffprefs':
            nosepline = {'no':True, 'yes':False}[args['sepline']]
            ylimits = None
            assert not args['scalebar'], "can't use `scalebar` with `diffprefs`"

        else:
            raise ValueError("Invalid datatype {0}".format(datatype))

        # check columns
        assert (args['ignore_extracols'] == 'yes') or (set(data.columns)
                <= set(['site'] + dms_tools2.AAS_WITHSTOP)), 'extra columns'
        # exclude stop codons if specified
        cols = ['site'] + dms_tools2.AAS
        assert set(cols) <= set(data.columns), 'missing amino-acids in data'
        if '*' in data.columns and args['excludestop'] == 'no':
            cols.append('*')
        data = data[cols]
        if datatype == 'prefs':
            # re-norm after excludestop, we pass stringency of 1
            # because we have already re-scaled and are just using
            # function to re-normalize
            data = dms_tools2.prefs.rescalePrefs(data, 1)
            
        # convert data from wide data frame to dict for logo plot
        data = data.set_index('site').to_dict('index')

        # read any overlays
        overlay = []
        fix_limits = {}
        for i in range(3): # loop over possibly overlays
            overlayarg = 'overlay{0}'.format(i + 1)
            if not args[overlayarg]:
                continue
            (overlayfile, shortname, longname) = args[overlayarg]
            logger.info("Reading overlay for {0} from {1}...".format(
                    shortname, overlayfile))
            if shortname == longname == 'omegabysite':
                (overlaytup, fix_limits_val) = omegabysiteOverlay(
                        overlayfile)
                overlay.append(overlaytup)
                fix_limits[overlaytup[1]] = fix_limits_val
                continue
            assert (len(shortname) < 6) or (shortname == 'wildtype'), \
                "{0} SHORTNAME too long".format(overlayarg)
            overlaydf = pandas.read_csv(overlayfile)
            assert {'site', shortname} <= set(overlaydf.columns), \
                    "No 'site' and {0} columns in {1} FILE".format(
                    shortname, overlayarg)
            overlaydf = overlaydf[['site', shortname]].drop_duplicates()
            overlaydf['site'] = overlaydf['site'].astype('str')
            assert len(overlaydf['site']) == len(set(overlaydf['site'])),\
                    "Duplicate sites in {0} FILE".format(overlayarg)
            extrasites = set(overlaydf['site']) - set(sites)
            assert not extrasites, "Extra sites in {0}:\n{1}".format(
                    overlayarg, extrasites)
            overlay.append((
                    dict(zip(overlaydf['site'], overlaydf[shortname])),
                    shortname, longname))
            logger.info("Read overlay for {0} sites.\n".format(
                    len(overlaydf['site'])))
        if args['underlay'] == 'yes':
            overlay.reverse()
        shortnames = [tup[1] for tup in overlay]
        assert len(set(shortnames)) == len(shortnames), ("Duplicate "
                "SHORTNAME in overlay arguments")

        # make logo plot
        logger.info("Making logo plot {0}...".format(files['logo']))
        phydmslib.weblogo.LogoPlot(
                    sites=sites, 
                    datatype=plotdatatype, 
                    data=data, 
                    plotfile=files['logo'],
                    nperline=args['nperline'],
                    numberevery=args['numberevery'],
                    allowunsorted=True,
                    ydatamax=1.01, # no meaning for prefs or diffsel
                    overlay=overlay,
                    fix_limits=fix_limits,
                    fixlongname=False,
                    overlay_cmap=args['overlaycolormap'],
                    ylimits=ylimits,
                    relativestackheight=args['letterheight'],
                    custom_cmap=args['colormap'],
                    map_metric=args['mapmetric'],
                    noseparator=nosepline,
                    underlay={'no':False, 'yes':True}[args['underlay']],
                    scalebar=scalebar,
                    )
        logger.info("Successfully created logo plot.\n")

    except:
        logger.exception('Terminating {0} with ERROR.'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
