#!/usr/bin/env python -W ignore

from __future__ import absolute_import, print_function, division
from six.moves import range

import os.path
import sys
from time import sleep

import math
import numpy as np
try:
    import matplotlib.mlab as mlab
    from pylab import figure, show, text, close, pause
except ImportError:
    print('You need to install the python matplotlib package')
    raise

from MDPlus.analysis import pca

# some global variables that keep track of where we are:
i      = 0
j      = 1
option = 1
c      = 0
resol  = 1
subset = 0

# some key definitions:
up     = 'i'
down   = 'k'
left   = 'j'
right  = 'l'
plus   = '+'
minus  = '-'
next   = 'o'
prev   = 'u'

helptext=('PCZPLOT v0.2\n\n'
          'Press \'1\' to plot the projection of a PC vs. snapshot number.\n'
          '    - use \'i\' and \'k\' keys to choose PC to plot.\n\n'
          'Press \'2\' to plot the histogram of a PC\'s projection.\n'
          '    - use \'j\' and \'l\' keys to choose PC to plot.\n'
          '    - use \'+\' and \'-\' keys to change the resolution.\n\n'
          'Press \'3\' to plot the track of one PC against another.\n'
          '    - use the \'a\' key to animate the plot.\n\n'
          'Press \'4\' to plot 2D histograms of PCs.\n'
          '    - use \'j\' and \'l\' keys to choose PC to plot on the x-axis.\n'
          '    - use \'i\' and \'k\' keys to choose PC to plot on the y-axis.\n'
          '    - use the \'c\' key to swap between heat map and contours.\n'
          '    - use \'+\' and \'-\' keys to change the resolution.\n\n'
          'Use \'u\' and \'o\' keys to change subset on any screen.\n'
          'Press \'q\' to quit, and \'h\' for this help text again.')

def incr(oldvalue, increment, minvalue=None, maxvalue=None, skipvalue=None):
    newvalue = oldvalue + increment
    if newvalue == skipvalue:
        newvalue = newvalue + increment
    if newvalue > maxvalue:
        newvalue = maxvalue
    if newvalue < minvalue:
        newvalue = minvalue
    if newvalue == skipvalue:
        newvalue = newvalue - increment

    return newvalue

def press(event):
    global i,j,c,option,resol,subset
    if event.key=='1':
        option=1
    elif event.key=='2':
        option=2
        resol=1
    elif event.key=='3':
        option=3
        resol=1
    elif event.key=='4':
        option=4
    elif event.key=='h':
        option=0
    elif event.key=='q':
        close('all')

    if  option==0:
        # show help text
        ax.clear()
        text(0.5,0.5,helptext, horizontalalignment='center',
             verticalalignment='center', transform = ax.transAxes)
        fig.canvas.draw()

    elif option==1:
        # Make a PC vs. snapshot plot
        if event.key==down:
            i = incr(i,-1,0,pcz.n_vecs-1)
        elif event.key==up:
            i = incr(i,1,0,pcz.n_vecs-1)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        ax.clear()
        ax.set_aspect('auto')
        proj = pcz.projs[i]
        ax.set_ylim(prange[i, 0], prange[i, 1])
        ax.set_title(os.path.basename(testfile)+'\n'+'Projection '+str(i + 1)
                     +', Subset '+subid[subset])
        ax.set_xlabel('Snapshot')
        ax.set_ylabel('Proj '+str(i + 1))
        ax.plot(proj[i1[subset]:i2[subset]])
        fig.canvas.draw()

    elif option==4:
        # show a 2D PC vs. PC plot as heat map or contours
        if i == j:
            i = incr(i, -1, 0, pcz.n_vecs-1, j)
        if event.key==left:
            i = incr(i,-1,0,pcz.n_vecs-1,j)
        elif event.key==right:
            i = incr(i,1,0,pcz.n_vecs-1,j)
        elif event.key==down:
            j = incr(j,-1,0,pcz.n_vecs-1,i)
        elif event.key==up:
            j = incr(j,1,0,pcz.n_vecs-1,i)
        elif event.key=='c':
            c=c+1
            if c>1:
                c=0
        elif event.key==plus:
            resol = incr(resol,1,1,5)
        elif event.key==minus:
            resol = incr(resol,-1,1,5)

        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)

        ax.clear()
        proj1 = pcz.projs[i][i1[subset]:i2[subset]]
        proj2 = pcz.projs[j][i1[subset]:i2[subset]]
        hrange = [prange[j,:],prange[i,:]]
        nb=15+5*resol
        H, xedges, yedges = np.histogram2d(proj2, proj1, bins=(nb, nb), range=hrange)
        extent = [yedges[0], yedges[-1], xedges[0], xedges[-1]]
        ax.set_title(os.path.basename(testfile)+'\n'+
                     'Proj '+str(i + 1)+' vs. '+str(j + 1)+' for subset '+subid[subset])
        ax.set_xlabel('Proj '+str(i + 1))
        ax.set_ylabel('Proj '+str(j + 1))
        if c==1:
            ax.imshow(H, extent=extent, interpolation='nearest',origin='lower')
        else:
            ax.contour(H, 10, extent=extent, origin='lower')
        ax.set_aspect('auto')
        fig.canvas.draw()

    elif option==2:
        # Make a histogram of a PC
        if event.key==left:
            i = incr(i,-1,0,pcz.n_vecs-1)
        elif event.key==right:
            i = incr(i,1,0,pcz.n_vecs-1)
        if event.key==minus:
            resol = incr(i,-1,1,5)
        elif event.key==plus:
            resol = incr(i,1,1,5)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        ax.clear()
        ax.set_aspect('auto')
        ax.set_xlim(prange[i,0],prange[i,1])
        proj = pcz.projs[i][i1[subset]:i2[subset]]
        ax.set_title(os.path.basename(testfile)+'\n'+
                     'Histogram of projection '+str(i + 1))
        ax.set_xlabel('Proj'+str(i + 1)+', Subset '+subid[subset])
        ax.set_ylabel('Frequency')
        nl, bins, patches = ax.hist(proj, bins=10*2**(resol-1), normed=1,
                                  histtype='stepfilled')
        sigma = math.sqrt(pcz.evals[i])
        mu=0.0
        y = mlab.normpdf( bins, mu, sigma)
        l = ax.plot(bins, y, 'r--', linewidth=1)
        ymax = max(y.max(), nl.max())
        ax.set_ylim(0,ymax*1.1)
        fig.canvas.draw()

    elif option == 3:
        # make a 2D plot of a PC vs. another
        if event.key==left:
            c=0
            i = incr(i,-1,0,pcz.n_vecs-1,j)
        elif event.key==right:
            c=0
            i = incr(i,1,0,pcz.n_vecs-1,j)
        elif event.key==down:
            c=0
            j = incr(j,-1,0,pcz.n_vecs-1,i)
        elif event.key==up:
            c=0
            j = incr(j,1,0,pcz.n_vecs-1,i)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        elif event.key=='a':
            c=1

        ax.clear()
        if j==i:
            j=j+1
        if j>=pcz.n_vecs:
            j=j-2
        proj1 = pcz.projs[i][i1[subset]:i2[subset]]
        proj2 = pcz.projs[j][i1[subset]:i2[subset]]
        nf = pcz.n_frames
        ax.set_title(os.path.basename(testfile)+'\n'+
                     'Proj '+str(i + 1)+' vs. '+str(j + 1)+' for subset '+subid[subset])
        ax.set_xlabel('Proj '+str(i + 1))
        ax.set_ylabel('Proj '+str(j + 1))
        line, = ax.plot(proj1, proj2)
        line.set_marker('*')
        line.set_markevery(nf-1)
        if c==0:
            fig.canvas.draw()
        else:
            # crude animation method, but good enough
            nfstep = nf // 200
            if nfstep < 1:
                nfstep=1
            for k in np.arange(1, nf, nfstep):
                line.set_xdata(proj1[0:k])
                line.set_ydata(proj2[0:k])
                if k > 1:
                    line.set_markevery(k-1)
                pause(5.0 * nfstep / nf)

import argparse
from pcazip._version import __version__

parser = argparse.ArgumentParser()

parser.add_argument('-V', '--version', action='version', version=__version__)

mandatory = parser.add_argument_group('mandatory argument')
mandatory.add_argument('-i', '--input', type=str, 
                       help='pcz format file to plot')

optional = parser.add_argument_group('mutually exclusive optional arguments')
group = optional.add_mutually_exclusive_group()
group.add_argument('-n', '--index', type=str, 
                   help='index file defining groups in the pcz file')
group.add_argument('-r', '--replicates', type=int, 
                   help='number of replicates included in the pcz file')

args = parser.parse_args()
testfile = args.input

pcz = pca.load(testfile)
if args.index:
    subfile = args.index
    i1,i2 = np.loadtxt(subfile,usecols=(1,2),unpack=True)
    i2 = i2 + 1
    subid=[]
    with open(subfile,'r') as f:
        for line in f:
            subid.append(line.split()[0])

    nsub = len(i1)
elif args.replicates:
    chunksize = pcz.n_frames // args.replicates
    i1 = np.zeros(args.replicates + 1)
    i2 = np.zeros(args.replicates + 1)
    subid=[]
    for i in range(args.replicates):
        i1[i] = i * chunksize
        i2[i] = (i + 1) * chunksize - 1
        subid.append('rep {}'.format(i+1))
    i1[-1] = 0
    i2[-1] = pcz.n_frames - 1
    subid.append('All')
    nsub = len(i1)
else:
    i1 = [(0)]
    i2 = [(pcz.n_frames)]
    subid = ['All']
    nsub = 1

evals = pcz.evals
prange = np.zeros((pcz.n_vecs, 2))
for k in range(pcz.n_vecs):
    proj = pcz.projs[k]
    mar = proj.max() - proj.min() * 0.05
    prange[k,0]=proj.min() - mar
    prange[k,1]=proj.max() + mar

fig = figure()
ax = fig.add_subplot(111)
# begin by showing basic info over a plot of the eigenvalues
info = ('PCZPLOT v0.2\n\n'
        'File : '+testfile+'\n'
        +str(pcz.n_atoms)+' atoms, '+str(pcz.n_vecs)+ ' PCs, '
        +str(pcz.n_frames)+' frames\n'
        +str(nsub)+' subsets\n\n'
        '(press \'h\' for help from any screen)')
text(0.5,0.5,info, horizontalalignment='center',
     verticalalignment='center', transform = ax.transAxes)
ax.plot(evals)
ax.set_title(os.path.basename(testfile)+'\n'+'Eigenvalues')
ax.set_xlabel('PC')
ax.set_ylabel('Eigenvalue')
fig.canvas.mpl_connect('key_press_event', press)

show()
