#!/usr/bin/env python
#
# Script to run chronos on a photometry catalog

from __future__ import print_function

import os
import time
import numpy as np
import chronos
from astropy.io import fits
from astropy.table import Table
from argparse import ArgumentParser
from dlnpyutils import utils as dln
import subprocess
import traceback
import importlib as imp
try:
    import __builtin__ as builtins # Python 2
except ImportError:
    import builtins # Python 3

# Main command-line program
if __name__ == "__main__":
    parser = ArgumentParser(description='Run Chronos on a photometry catalog')
    parser.add_argument('files', type=str, nargs='+', help='Photometry catalog FITS files or list')
    parser.add_argument('--catnames', type=str, nargs=1, default='', help='List of photometry column names')
    parser.add_argument('--isonames', type=str, nargs=1, default='', help='List of isochrone photometry column names')
    parser.add_argument('--outfile', type=str, nargs=1, default='', help='Output filename')
    parser.add_argument('--initpar', type=str, nargs=1, default='', help='Initial estimates for [age,metal,ext,distmod]')
    parser.add_argument('--caterrnames', type=str, nargs=1, default='', help='List of photometric uncertainty column names')
    parser.add_argument('--fixed', type=str, nargs=1, default='', help='Parameters to hold fixed (e.g. AGE:5e9,METAL:-0.75)')
    parser.add_argument('-m','--mcmc', action='store_true', help='Run MCMC')
    parser.add_argument('--reject', action='store_true', help='Reject outliers')
    parser.add_argument('--nsigrej', type=float, nargs=1, default=3.0, help='Outlier rejection Nsigma')
    parser.add_argument('--extdict', type=str, nargs=1, default='', help='Extinction coefficients (e.g. [SDSS_UMAG:1.57465,SDSS_GMAG:1.22651])')
    parser.add_argument('--figfile', type=str, nargs=1, default='', help='Figure filename')    
    parser.add_argument('-d','--outdir', type=str, nargs=1, default='', help='Output directory')        
    parser.add_argument('-l','--list', action='store_true', help='Input is a list of FITS files')
    parser.add_argument('-p','--plot', action='store_true', help='Save the plots')
    parser.add_argument('-c','--corner', action='store_true', help='Make corner plot with MCMC results')  
    parser.add_argument('-v','--verbose', action='store_true', help='Verbose output')
    parser.add_argument('-t','--timestamp', action='store_true', help='Add timestamp to Verbose output')    
    args = parser.parse_args()

    t0 = time.time()
    files = args.files
    # catnames
    catnames = dln.first_el(args.catnames)
    if catnames != '':
        inpcatnames = catnames
        # remove any leading or training []        
        if inpcatnames.startswith('['): inpcatnames=inpcatnames[1:]
        if inpcatnames.endswith(']'): inpcatnames=inpcatnames[:-1]        
        catnames = np.array(inpcatnames.split(','))
    else:
        catnames = None
    # isonames
    isonames = dln.first_el(args.isonames)
    if isonames != '':
        inpisonames = isonames
        # remove any leading or training []        
        if inpisonames.startswith('['): inpisonames=inpisonames[1:]
        if inpisonames.endswith(']'): inpisonames=inpisonames[:-1]        
        isonames = np.array(inpisonames.split(','))
    else:
        isonames = None
    inpoutfile = dln.first_el(args.outfile)
    # caterrnames
    caterrnames = dln.first_el(args.caterrnames)
    if caterrnames != '':
        inpcaterrnames = caterrnames
        # remove any leading or training []        
        if inpcaterrnames.startswith('['): inpcaterrnames=inpcaterrnames[1:]
        if inpcaterrnames.endswith(']'): inpcaterrnames=inpcaterrnames[:-1]        
        caterrnames = np.array(inpcaterrnames.split(','))
    else:
        caterrnames = None
    # Parameters to hold fixed
    fixed = dln.first_el(args.fixed)
    if fixed != '':
        inpfixed = fixed.split(',')
        fixed = {}
        for k,val in enumerate(inpfixed):
            if val.find(':') != -1:
                arr = val.split(':')
            elif val.find('=') != -1:
                arr = val.split('=')
            else:
                raise ValueError('Use format key=value or key:value')
            fixed[str(arr[0]).upper()] = float(arr[1])
    else:
        fixed = None
    mcmc = args.mcmc
    reject = args.reject
    nsigrej = dln.first_el(args.nsigrej)
    # Extinction coefficients
    extdict = dln.first_el(args.extdict)
    if extdict != '':
        inpextdict = extdict.split(',')
        extdict = {}
        for k,val in enumerate(inpextdict):
            if val.find(':') != -1:
                arr = val.split(':')
            elif val.find('=') != -1:
                arr = val.split('=')
            else:
                raise ValueError('Use format key=value or key:value')
            extdict[str(arr[0]).upper()] = float(arr[1])
    else:
        extdict = None
    # Initpar
    initpar = dln.first_el(args.initpar)
    if initpar != '':
        inpinitpar = initpar
        # remove any leading or training []        
        if inpinitpar.startswith('['): inpinitpar=inpinitpar[1:]
        if inpinitpar.endswith(']'): inpinitpar=inpinitpar[:-1]        
        initpar = np.array(inpinitpar.split(','),float)
    inpfigfile = dln.first_el(args.figfile)
    outdir = dln.first_el(args.outdir)
    if outdir == '':
        outdir = None
    else:
        if os.path.exists(outdir) is False:
            os.mkdir(outdir)        
    verbose = args.verbose
    timestamp = args.timestamp    
    saveplot = args.plot
    corner = args.corner
    inlist = dln.first_el(args.list)

    # Timestamp requested, set up logger
    if timestamp and verbose:
        logger = dln.basiclogger()
        logger.handlers[0].setFormatter(logging.Formatter("%(asctime)s [%(levelname)-5.5s]  %(message)s"))
        logger.handlers[0].setStream(sys.stdout)
        builtins.logger = logger   # make it available globally across all modules
    
    # Load files from a list
    if inlist is True:
        # Check that file exists
        if os.path.exists(files[0]) is False:
            raise ValueError(files[0]+' NOT FOUND')
        # Read in the list
        listfile = files[0]
        files = dln.readlines(listfile)
        # If the filenames are relative, add the list directory
        listdir = os.path.dirname(listfile)
        if listdir != '':
            fdir = [os.path.dirname(f) for f in files]
            rel, = np.where(np.char.array(fdir)=='')
            if len(rel)>0:
                for i in range(len(rel)):
                    files[rel[i]] = listdir+'/'+files[rel[i]]
    nfiles = len(files)
    
    if (verbose is True) & (nfiles>1):
        print('--- Running Chronos on %d photometry catalogs ---' % nfiles)
        
    # Loop over the files
    for i,f in enumerate(files):
        # Check that the file exists
        if os.path.exists(f) is False:
            print(f+' NOT FOUND')
            continue

        try:            
            # Load the catalog
            cat = Table.read(f)
    
            if (verbose is True):
                if (nfiles>1):
                    if (i>0): print('')
                    print('Image %3d:  %s  ' % (i+1,f))
                else:
                    print('%s  ' % (f))
                    
            # Save the figure
            figfile = None
            if (inpfigfile!=''):
                figfile = inpfigfile
            if (inpfigfile=='') & (saveplot is True):
                fdir,base,ext = chronos.utils.splitfilename(f)
                figfile = base+'.png'
                if outdir is not None: figfile = outdir+'/'+figfile
                if (outdir is None) & (fdir != ''): figfile = fdir+'/'+figfile 

            # Need to add inputs for grid search arrays (age, metal, ext, distmod).
                
            # Run Prometheus
            out, bestiso = chronos.chronos.fit(cat,catnames,isonames,caterrnames=caterrnames,
                                               initpar=initpar,fixed=fixed,extdict=extdict,cornername=corner,
                                               mcmc=mcmc,reject=reject,nsigrej=nsigrej,figfile=figfile,
                                               verbose=verbose)
                
            # Save the output
            if inpoutfile!='':
                outfile = inpoutfile
            else:
                fdir,base,ext = chronos.utils.splitfilename(f)
                outfile = base+'_chronos.fits'
                if outdir is not None: outfile = outdir+'/'+outfile
                if (outdir is None) & (fdir != ''): outfile = fdir+'/'+outfile
            if verbose is True:
                print('Writing output to '+outfile)
            if os.path.exists(outfile): os.remove(outfile)
            Table(out).write(outfile)
            # append best model
            hdulist = fits.open(outfile)
            hdu = fits.PrimaryHDU(bestiso.data)
            hdulist.append(hdu)
            hdulist.writeto(outfile,overwrite=True)
            hdulist.close()

        except Exception as e:
            if verbose is True:
                print('Chronos failed on '+f+' '+str(e))
                traceback.print_exc()


