#!python

"""Makes logo plots showing prefs or diffprefs.

Written by Jesse Bloom."""


import os
import time
import math
import natsort
import pandas
import phydmslib.weblogo
import phydmslib.file_io
import phydmslib.parsearguments
from phydmslib.constants import AA_TO_INDEX


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSLogoPlotParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # print some basic information
    print('\nBeginning execution of {0} in directory {1} at time {2}\n'.format(
            prog, os.getcwd(), time.asctime()))
    print("{0}\n".format(phydmslib.file_io.Versions()))
    print('Parsed the following command-line arguments:\n{0}\n'.format(
            '\n'.join(['\t%s = %s' % tup for tup in args.items()])))

    # check on outfile we will create
    assert os.path.splitext(args['outfile'])[1].lower() == '.pdf', (
            "outfile {0} does not end in '.pdf'".format(args['outfile']))
    assert (not os.path.dirname(args['outfile'])) or\
           (os.path.isdir(os.path.dirname(args['outfile']))),\
           ("outfile {0} includes ""non-existent directory"
            .format(args['outfile']))
    if os.path.isfile(args['outfile']):
        print("Removing existing outfile of {0}\n".format(args['outfile']))
        os.remove(args['outfile'])

    aas = AA_TO_INDEX.keys()

    if args['prefs']:
        datatype = 'prefs'
        print("Reading preferences from {0}".format(args['prefs']))
        prefs = phydmslib.file_io.readPrefs(args['prefs'],
                                            sites_as_strings=True)
        sites = natsort.realsorted(prefs.keys())
        print("Read preferences for {0} sites.".format(len(sites)))
        beta = args['stringency']
        if beta != 1:
            print("Re-scaling by stringency parameter {0}".format(beta))
            for r in sites:
                prefs[r] = dict([(a, prefs[r][a]**beta) for a in aas])
                prefsum = sum([prefs[r][a] for a in aas])
                prefs[r] = dict([(a, prefs[r][a] / prefsum) for a in aas])
        data = prefs
        ydatamax = None

    elif args['diffprefs']:
        datatype = 'diffprefs'
        print("Reading differential preferences from {0}".format(
                args['diffprefs']))
        diffprefs = pandas.read_csv(args['diffprefs'], sep='\t', comment='#')
        sites = natsort.realsorted(map(str, diffprefs['site'].values))
        print("Read differential preferences for {0} sites."
              .format(len(sites)))
        data = {}
        ydatamax = args['diffprefheight'] * 1.01
        for r in sites:
            data[r] = {}
            sitediffprefs = diffprefs[diffprefs['site'].astype(str) == r]
            for a in aas:
                data[r][a] = float(sitediffprefs['dpi_{0}'.format(a)])
            sumpos = sum([dpi for dpi in data[r].values() if dpi > 0])
            sumneg = sum([-dpi for dpi in data[r].values() if dpi < 0])
            if max(sumpos, sumneg) >= ydatamax:
                raise ValueError("diffprefs extend beyond --diffprefheight "
                                 "{0}. Increase --diffprefheight."
                                 .format(args['diffprefheight']))

    else:
        raise ValueError("Didn't specify either --prefs or --diffprefs")

    assert sites, "No sites specified"
    assert len(set(sites)) == len(sites), "Duplicate site numbers."

    fix_limits = {}
    if args['omegabysite']:
        print("\nWe will make an overlay with the site-specific omega "
              "values in {0}.".format(args['omegabysite']))
        omegas = pandas.read_csv(args['omegabysite'], comment='#', sep='\t')
        assert set(sites) == set(map(str, omegas['site'].values)),\
               ("sites in {0} don't match those in prefs or diffprefs"
               .format(args['omegabysite']))
        shortname = '$\omega_r$'
        longname = ('{0} $<1 \; \longleftarrow$ $\log_{{10}} P$ '
                    '$\longrightarrow \;$ {0} $>1$'.format(shortname))
        prop_d = {}
        for r in sites:
            omegar = float(omegas[omegas['site'].astype(str) == r]['omega'])
            p = float(omegas[omegas['site'].astype(str) == r]['P'])
            if omegar < 1:
                prop_d[r] = max(math.log10(args['minP']), math.log10(p))
            else:
                prop_d[r] = -max(math.log10(args['minP']), math.log10(p))
        overlay = [(prop_d, shortname, longname)]
        ticklocs = [itick for itick in range(int(math.log10(args['minP'])),
                    1 - int(math.log10(args['minP'])))]
        ticknames = [-abs(itick) for itick in ticklocs]
        fix_limits[shortname] = (ticklocs, ticknames)
    else:
        overlay = None

    # make plot
    print("\nNow making plot {0}...".format(args['outfile']))
    phydmslib.weblogo.LogoPlot(
                sites=sites,
                datatype=datatype,
                data=data,
                plotfile=args['outfile'],
                nperline=min(len(sites), args['nperline']),
                numberevery=args['numberevery'],
                allowunsorted=True,
                ydatamax=ydatamax,
                overlay=overlay,
                fix_limits=fix_limits,
                fixlongname=True,
                overlay_cmap='bwr',
                custom_cmap=args['colormap'],
                map_metric=args['mapmetric'],
            )
    assert os.path.isfile(args['outfile']), "Failed to create plot"
    print("Created plot {0}".format(args['outfile']))

    print('\nSuccessful completion of %s' % prog)


if __name__ == '__main__':
    main()  # run the script
