#!/usr/bin/env python
#
# Script to run doppler.fit() on a spectrum

from __future__ import print_function

import os
import time
import numpy as np
import doppler
from astropy.io import fits
from astropy.table import Table
from argparse import ArgumentParser
from dlnpyutils import utils as dln
import subprocess
import traceback
import importlib as imp
try:
    import __builtin__ as builtins # Python 2
except ImportError:
    import builtins # Python 3

# Main command-line program
if __name__ == "__main__":
    parser = ArgumentParser(description='Run Doppler fitting on spectra')
    parser.add_argument('files', type=str, nargs='+', help='Spectrum FITS files or list')
    parser.add_argument('--outfile', type=str, nargs=1, default='', help='Output filename')
    parser.add_argument('--payne', action='store_true', help='Fit a Payne model')
    parser.add_argument('--fitpars', type=str, nargs=1, default='', help='Payne labels to fit (e.g. TEFF,LOGG,FE_H')    
    parser.add_argument('--fixpars', type=str, nargs=1, default='', help='Payne labels to hold fixed (e.g. TEFF:5500,LOGG:2.3')
    parser.add_argument('--figfile', type=str, nargs=1, default='', help='Figure filename')    
    parser.add_argument('-d','--outdir', type=str, nargs=1, default='', help='Output directory')        
    parser.add_argument('-l','--list', action='store_true', help='Input is a list of FITS files')
    parser.add_argument('-j','--joint', action='store_true', help='Joint fit all the spectra')
    parser.add_argument('--snrcut', type=float, default=10.0, help='S/N threshold to fit spectrum separately')    
    parser.add_argument('-p','--plot', action='store_true', help='Save the plots')
    parser.add_argument('-c','--corner', action='store_true', help='Make corner plot with MCMC results')    
    parser.add_argument('-m','--mcmc', action='store_true', help='Run MCMC when fitting spectra individually')
    parser.add_argument('-r','--reader', type=str, nargs=1, default='', help='The spectral reader to use')
    parser.add_argument('-v','--verbose', action='store_true', help='Verbose output')
    parser.add_argument('-t','--timestamp', action='store_true', help='Add timestamp to Verbose output')    
    parser.add_argument('-nth','--nthreads', type=int, nargs=1, default=0, help='Number of threads to use')
    parser.add_argument('--notweak', action='store_true', help='Do not tweak the continuum using the model')
    parser.add_argument('--tpoly', action='store_true', help='Use low-order polynomial for tweaking')
    parser.add_argument('--tpolyorder', type=int, nargs=1, default=3, help='Polynomial order to use for tweaking')
    args = parser.parse_args()

    t0 = time.time()
    files = args.files
    inpoutfile = dln.first_el(args.outfile)
    payne = args.payne
    # Payne labels to fit
    fitpars = dln.first_el(args.fitpars)
    if fitpars != '':
        fitparams = fitpars.split(',')
        fitparams = list(np.char.array(fitparams).upper())
    else:
        fitparams = None
    # Payne labels to hold fixed
    fixpars = dln.first_el(args.fixpars)
    if fixpars != '':
        fixpars = fixpars.split(',')
        fixparams = {}
        for k,val in enumerate(fixpars):
            if val.find(':') != -1:
                arr = val.split(':')
            elif val.find('=') != -1:
                arr = val.split('=')
            else:
                raise ValueError('Use format key=value or key:value')
            fixparams[str(arr[0]).upper()] = float(arr[1])
    else:
        fixparams = {}
    inpfigfile = dln.first_el(args.figfile)
    outdir = dln.first_el(args.outdir)
    if outdir == '':
        outdir = None
    else:
        if os.path.exists(outdir) is False:
            os.mkdir(outdir)
    joint = args.joint
    snrcut = args.snrcut
    verbose = args.verbose
    timestamp = args.timestamp    
    mcmc = args.mcmc
    corner = args.corner
    reader = dln.first_el(args.reader)
    if reader == '': reader = None
    saveplot = args.plot
    inlist = dln.first_el(args.list)
    nthreads = dln.first_el(args.nthreads)
    if nthreads==0: nthreads = None
    notweak = args.notweak
    tpoly = args.tpoly
    tpolyorder = dln.first_el(args.tpolyorder)

    # Timestamp requested, set up logger
    if timestamp and verbose:
        logger = dln.basiclogger()
        logger.handlers[0].setFormatter(logging.Formatter("%(asctime)s [%(levelname)-5.5s]  %(message)s"))
        logger.handlers[0].setStream(sys.stdout)
        builtins.logger = logger   # make it available globally across all modules
    
    # Load files from a list
    if inlist is True:
        # Check that file exists
        if os.path.exists(files[0]) is False:
            raise ValueError(files[0]+' NOT FOUND')
        # Read in the list
        listfile = files[0]
        files = dln.readlines(listfile)
        # If the filenames are relative, add the list directory
        listdir = os.path.dirname(listfile)
        if listdir != '':
            fdir = [os.path.dirname(f) for f in files]
            rel, = np.where(np.char.array(fdir)=='')
            if len(rel)>0:
                for i in range(len(rel)):
                    files[rel[i]] = listdir+'/'+files[rel[i]]
    nfiles = len(files)

    # Outfile and figfile can ONLY be used with a SINGLE file or JOINT fitting
    if (inpoutfile!='') & (nfiles>1) & (joint is False):
        raise ValueError('--outfile can only be used with a SINGLE input file or JOINT fitting')
    if (inpfigfile!='') & (nfiles>1) & (joint is False):
        raise ValueError('--figfile can only be used with a SINGLE input file or JOINT fitting')


    # Register spectral reader
    if reader is not None:
        if os.path.exists(reader):
            readerfile = reader
            reader = os.path.basename(readerfile)[:-3]
            spec = imp.util.spec_from_file_location(reader, readerfile)
            readermodule = imp.util.module_from_spec(spec)
            spec.loader.exec_module(readermodule)
            readerfunc = getattr(readermodule,reader)
            doppler.reader.register(reader,readerfunc)
    
    # Fitting individual spectra
    #---------------------------
    if (joint is False) | (nfiles==1):
    
        if (verbose is True) & (nfiles>1):
            print('--- Running Doppler Fit on %d spectra ---' % nfiles)
        
        # Loop over the files
        for i,f in enumerate(files):
            # Check that the file exists
            if os.path.exists(f) is False:
                print(f+' NOT FOUND')
                continue

            try:
            
                # Load the spectrum
                spec = doppler.read(f,format=reader)
    
                if (verbose is True):
                    if (nfiles>1):
                        if (i>0): print('')
                        print('Spectrum %3d:  %s  S/N=%6.1f ' % (i+1,f,spec.snr))
                    else:
                        print('%s  S/N=%6.1f ' % (f,spec.snr))
                    
                # Save the figure
                figfile = None
                if (nfiles==1) & (inpfigfile!=''):
                    figfile = inpfigfile
                if (inpfigfile=='') & (saveplot is True):
                    fdir,base,ext = doppler.utils.splitfilename(f)
                    figfile = base+'_dopfit.png'
                    if outdir is not None: figfile = outdir+'/'+figfile
                    if (outdir is None) & (fdir != ''): figfile = fdir+'/'+figfile 
                # Save a corner plot
                cornername = None
                if corner is True:
                    fdir,base,ext = doppler.utils.splitfilename(f)
                    cornername = base+'_doppler_corner.png'
                    if outdir is not None: cornername = outdir+'/'+cornername
                    if (outdir is None) & (fdir != ''): cornername = fdir+'/'+cornername
                    
                # Run Doppler
                out, model, specm = doppler.fit(spec,payne=payne,fitparams=fitparams,fixparams=fixparams,
                                                mcmc=mcmc,figfile=figfile,verbose=verbose,
                                                cornername=cornername,nthreads=nthreads,notweak=notweak,
                                                tpoly=tpoly,tpolyorder=tpolyorder)
                
                # Save the output
                if inpoutfile!='':
                    outfile = inpoutfile
                else:
                    fdir,base,ext = doppler.utils.splitfilename(f)
                    outfile = base+'_doppler.fits'
                    if outdir is not None: outfile = outdir+'/'+outfile
                    if (outdir is None) & (fdir != ''): outfile = fdir+'/'+outfile
                if verbose is True:
                    print('Writing output to '+outfile)
                if os.path.exists(outfile): os.remove(outfile)
                Table(out).write(outfile)
                # append best model
                hdulist = fits.open(outfile)
                hdu = fits.PrimaryHDU(model.flux)
                hdulist.append(hdu)
                hdulist.writeto(outfile,overwrite=True)
                hdulist.close()

                # Add SPECM??

            except Exception as e:
                if verbose is True:
                    print('Doppler failed on '+f+' '+str(e))
                    traceback.print_exc()


                
    # Joint fitting
    #--------------
    else:

        # Default output filename
        if inpoutfile=='':
            fdir,base,ext = doppler.utils.splitfilename(files[0])
            outfile = base+'_doppler.fits'
            if outdir is not None: outfile = outdir+'/'+outfile
            if (outdir is None) & (fdir != ''): outfile = fdir+'/'+outfile
        else:
            outfile = inpoutfile
            
        if verbose is True:
            print('--- Running Doppler Jointfit on '+str(len(files))+' spectra ---')
            print('')
            
        # Load the spectra
        if verbose is True: print('Loading the spectra')
        speclist = []
        for i,f in enumerate(files):
            # Check that the file exists
            if os.path.exists(f) is False:
                raise ValueError(f+' NOT FOUND')

            # Load the spectrum
            spec = doppler.read(f,format=reader)

            if verbose is True:
                print('Spectrum %3d:  %s  S/N=%6.1f ' % (i+1,f,spec.snr))
                
            # Append to the list
            speclist.append(spec)
    
        # Run Doppler jointfit()
        if verbose is True: print('')
        sumstr, final, model, specmlist = doppler.jointfit(speclist,payne=payne,fitparams=fitparams,fixparams=fixparams,
                                                           mcmc=mcmc,saveplot=saveplot,snrcut=snrcut,
                                                           verbose=verbose,outdir=outdir,nthreads=nthreads,
                                                           notweak=notweak,tpoly=tpoly,tpolyorder=tpolyorder)
        
        # Save the output
        if verbose is True:
            print('Writing output to '+outfile)
        if os.path.exists(outfile): os.remove(outfile)
        # Summary average values
        Table(sumstr).write(outfile)
        # append other items
        hdulist = fits.open(outfile)
        # append final values for each spectrum
        hdu = fits.table_to_hdu(Table(final))
        hdulist.append(hdu)
        # append best model
        # there's a model for each spectrum, each one gets an extension
        for i,m in enumerate(model):
            hdu = fits.PrimaryHDU(m.flux)
            hdulist.append(hdu)
        hdulist.writeto(outfile,overwrite=True)
        hdulist.close()
