#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
from time import time as clock
from os.path import join, exists, abspath, basename


from astropy.table import Table
from astropy.io import fits
from astropy.stats import sigma_clip

import scipy.optimize as optimize

from halophot.halo_tools import *

from argparse import ArgumentParser
import warnings ## eeek!
warnings.filterwarnings("ignore")

import matplotlib as mpl

mpl.style.use('seaborn-colorblind')

#To make sure we have always the same matplotlib settings
#(the ones in comments are the ipython notebook settings)

mpl.rcParams['figure.figsize']=(8.0,6.0)    #(6.0,4.0)
mpl.rcParams['font.size']=18               #10 
mpl.rcParams['savefig.dpi']= 200             #72 
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12


'''-----------------------------------------------------------------
halo

This executable Python script allows you to detrend any single object using halo
photometry.

An example call is 

halo ktwo200007768-c04_lpd-targ.fits.gz --data-dir ../data/ --name atlas -c 4 --do-plot
-----------------------------------------------------------------'''

if __name__ == '__main__':
    ap = ArgumentParser(description='halophot: K2 halo photometry with total variation.')
    ap.add_argument('fname', type=str, help='Input target pixel file name.')
    ap.add_argument('--data-dir', default='', type=str)
    ap.add_argument('--name', default='test',type=str,help='Target name')
    ap.add_argument('-c', '--campaign', metavar='C',default=4, type=int, 
        help='Campaign number')
    ap.add_argument('--objective', default='tv',type=str,help='Objective function: can be tv, tv_o2, l2v, or l3v')
    ap.add_argument('--bitmask', default='default',type=str,help='Quality bitmask: can be default, hard, or hardest')
    ap.add_argument('-sub',  type=int,default=1, help='Subsampling parameter')
    ap.add_argument('--lag',  type=int,default=1, help='Lag parameter')
    ap.add_argument('-maxiter',  type=int,default=151, help='Maximum # iterations')
    ap.add_argument('--splits', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
        sep=','), help='List of time values for kernel splits')
    ap.add_argument('--rr', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
        sep=','), help='rmin, rmax (pix)')
    ap.add_argument('--quiet', action='store_true', default=False, 
        help='suppress messages')
    ap.add_argument('--save-dir', default='.', 
        help='The directory to save the output file in')
    ap.add_argument('--do-plot', action = 'store_true', default = False, \
                    help = 'produce plots')
    ap.add_argument('--do-split', action = 'store_true', default = False, \
                    help = 'produce plots')
    ap.add_argument('--random-init', action = 'store_true', default = False, \
                    help = 'initialize search with random seed')
    ap.add_argument('--minflux', type=float,default=100., help='Minimum flux to include')
    ap.add_argument('--thresh', type=float,default=0.8, help='What fraction of saturation to throw away')
    ap.add_argument('--analytic', action = 'store_true', default = True, \
                    help = 'use analytic derivatives; orders of magnitude faster')
    ap.add_argument('--sigclip', action = 'store_true', default = True, \
                    help = 'sigma-clip the final light curve')
    ap.add_argument('--deathstar', action = 'store_true', default = False, \
                    help = 'remove background star pixels')

    args = ap.parse_args()

    csplits = {j:None for j in range(16)}
    csplits[4] = [550,2200]

    if args.splits is None:
        if args.campaign in csplits.keys():
            splits = csplits[args.campaign]
        else:
            splits = None
    else:
        splits = args.splits

    if not exists(args.save_dir):
        print("Error: the save directory {:s} doesn't exist".format(args.save_dir))

    ### first load your data
    fname = args.data_dir + args.fname
    tpf, ts = read_tpf(fname)

    if args.campaign == 13:
        # m1 = np.logical_or(ts['cadence']<140911,ts['cadence']>140922)
        # m2 = np.logical_and(m1,ts['cadence']<144619)
        # m3 = np.logical_or(m2,ts['cadence']>144654)
        # m4 = np.logical_and(m3,ts['cadence']<144715)
        # m = np.logical_or(m4,ts['cadence']>144726)
        # tpf,ts = tpf[m,:,:], ts[m]
        m1 = np.logical_or(ts['time']<2988.2553329814764,ts['time']>2988.494)
        #m11 = np.logical_and(m1,ts['time']<3001.9834)
        #m12 = np.logical_or(m11,ts['time']>3001.9849)
        m2 = np.logical_and(m1,ts['time']<3064.016165412577)
        m3 = np.logical_or(m2,ts['time']>3064.75)
        m4 = np.logical_and(m3,ts['time']<3065.9776255118923)
        m = np.logical_or(m4,ts['time']>3066.2225)
        tpf,ts = tpf[m,:,:], ts[m]

    if args.campaign == 10:
        m = ts['time']>2760
        tpf, ts = tpf[m,:,:], ts[m]

    if args.campaign == 7:
        m = ts['time']>2470
        tpf, ts = tpf[m,:,:], ts[m]

    if args.campaign == 12:
        m = (ts['time']>2937.75)*(ts['time']<2938.8)
        tpf, ts = tpf[~m,:,:], ts[~m]


    print('Loaded %s!'  % args.fname)

    start = clock()

    # get annulus if necessary
    if args.rr is not None:
        rmin, rmax = args.rr
        print('Getting annulus from',rmin,'to',rmax)
        tpf = get_annulus(tpf,rmin,rmax)
        print('Using',np.sum(np.isfinite(tpf[0,:,:])),'pixels')

    # destroy background stars
    if args.deathstar:
        print('Removing background stars')
        tpf = remove_stars(tpf)


    if args.do_split:
        print('First doing one run to establish weights')

        tpf, newts, weights, weightmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, 
            maxiter=args.maxiter,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,analytic=args.analytic,
            sigclip=args.sigclip,objective=args.objective,lag=args.lag,bitmask=args.bitmask)

        'Splitting at',splits
        # do first segment
        tpf1, ts1, w1, wm1, pv1 = do_lc(tpf, ts, (None,splits[0]), args.sub, 
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,analytic=args.analytic,
            sigclip=args.sigclip,objective=args.objective,lag=args.lag,bitmask=args.bitmask)

        # do others
        tpf2, ts2, w2, wm2, pv2 = do_lc(tpf, ts, (splits[0],splits[1]), args.sub, 
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,sigclip=args.sigclip,objective=args.objective,lag=args.lag,bitmask=args.bitmask)

        tpf3, ts3, w3, wmap, pixelvector = do_lc(tpf, ts, (splits[1],None), args.sub, 
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,analytic=args.analytic,lag=args.lag,
            sigclip=args.sigclip,objective=args.objective,bitmask=args.bitmask)

        ## now stitch these

        newts = stitch([ts1,ts2,ts3])
    else:
        print('Not splitting')
        tpf, newts, weights, wmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, 
            maxiter=args.maxiter,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,analytic=args.analytic,
            sigclip=args.sigclip,objective=args.objective,lag=args.lag,bitmask=args.bitmask)

    weightmap = wmap['weightmap']
    print_time(clock()-start)

    time, opt_lc = newts['time'][:], newts['corr_flux'][:]

    # if args.sigclip:
    #     good = ~sigma_clip(opt_lc-savgol_filter(opt_lc,51,1),sigma=6.0).mask

    #     print('Clipping %d bad points' % np.sum(~good))
    #     pixelvector, newts = pixelvector[:,good], newts[good]
    #     opt_lc = np.dot(weights,pixelvector)


    tv1 = diff_1(opt_lc[np.isfinite(opt_lc)]/np.nanmedian(opt_lc))/float(np.size(opt_lc[np.isfinite(opt_lc)]))


    print('Total variation per point (first order): %f ' % tv1)

    ### save your new light curve!

    norm = np.size(weightmap)
    # weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))
    finite = np.isfinite(newts['corr_flux'])
    scf = savgol_filter(newts['corr_flux'][finite],201,1)
    poly = np.poly1d(np.polyfit(newts['time'][finite],newts['corr_flux'][finite],15))
    scf = poly(newts['time'][finite])

    smooth = newts['corr_flux'][finite]-scf+np.nanmedian(scf)
    smooth /= np.nanmedian(smooth)

    newts['trend'] = np.nan*np.zeros_like(newts['corr_flux'])
    newts['trend'][finite] = scf

    newts['whitened'] = np.nan*np.zeros_like(newts['corr_flux'])
    newts['whitened'][finite] = smooth

    hdu = fits.PrimaryHDU(weightmap.T) # can't save a masked array yet so just using pixelmap
    cols = [fits.Column(name=key,format="D",array=newts[key]) for key in newts.keys()]
    tab = fits.BinTableHDU.from_columns(cols)

    hdul = fits.HDUList([hdu, tab])
    hdul.writeto('%s/%s_halo_lc_%s.fits' % (args.save_dir,args.name,args.objective),overwrite=True)

    # newts.write('%s/%shalo_lc.fits' % (args.save_dir,args.name),overwrite=True)
    print('Saved halo-corrected light curve to %s/%s_halo_lc_%s.fits' % (args.save_dir,args.name,args.objective))

    weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))

    if args.do_plot:

        formal_name = translate_greek(args.name).replace('_',' ')
        image = np.nansum(tpf,axis=0)

        plt.figure(0)

        plot_all(newts,image,weightmap,formal_name=formal_name+' - C%02d' % args.campaign,save_file='%s/%s_all_%s.png' % (args.save_dir,args.name,args.objective))
        print('Everything saved to %s/%s_all_%s.png' % (args.save_dir,args.name,args.objective))
