# Kornpob Bhirombhakdi
# kbhirombhakdi@stsci.edu

from hstgrism.grismapcorr import GrismApCorr
from scipy.interpolate import interp1d
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class Contamination:
    """
    Contamination is a class handling computation for grism contamination.
    - Object = object of interest
    - Contaminate = another object in the FoV of Object and contaminating it
    - trace_object_csv = trace.csv of Object (generated by HSTGRISM flow)
    - trace_contaminate_csv = trace.csv of Contaminate
    - halfdy_object = a 1D integer array parallel to Object trace for extraction aperture halfdy width of Object (excluding trace; that is total width 2+halfdy_object + 1).
     - if halfdy_object = None, contamination region would not be computed.
     - if Object was analyzed with HSTGRISM flow, tbox.csv has halfdy_object.
    - instrument = string for available instrument list in hstgrism.grismapcorr.GrismApCorr.
     - if not None, self.compute() will grab aperture correction values (apcorr) from GrismApCorr table given instrument.
     - self.apcorr to access apcorr as a dict with key as halfdy in pixel unit and value as 1D array parallel to self.combine_df.ww_contaminate
       - Note: for ww_contaminate where it is None, we replaced it with ww_contaminate.min(). This would not cause any conflict at the end since we will place None back.
     - self.combine_df after running self.compute() will also have columns corresponding to apcorr_contaminate_in and apcorr_contaminate_out.
    - contaminate_spectrum_wavelength,contaminate_spectrum_flam, sensitivity_curve_wavelength, and sensitivity_curve_value must be supplied to continue computing Contaminate cps inside the Object aperture.
     - Each is a 1D array.
     - contaminate_spectrum and sensitivity_curve does not need to be parallel.
     - Use contaminate_spectrum_interp_rule and sensitivity_curve_interp_rule to control how to interpolate.
    - Set do_drzblot_apcorr = True, and input drzblot_apcorr_wavelength, drzblot_apcorr_value to incorporate drzblot apcorr.
     - Use drzblot_apcorr_interp_rule to control the interpolate.
     - Note: drzblot apcorr must be computed given the same halfdy_object if following HSTGRISM flow.
    - cps_contaminate_nan_to_zero = True if cps_contaminate output should be replace np.nan by zero.
    Use self.compute() to start computation after properly specified information.
    - if halfdy_object is not None, contamination region would be computed.
    Use self.save(container) to save outputs to ./savefolder/saveprefix_savesuffix.extension
    - savefolder, saveprefix is controlled by Container class (see hstgrism.container.Container)
    - savesuffix.extension is preset:
     - comtaminateTrace.csv = merge trace of Contaminate to Object trace frame, accessible by self.combine_df
    Use self.show(save,container) to show traces, extraction region of Object, and contamination region
    - if save=True, the plot is saved at ./savefolder/saveprefix_contamination.plotformat where savefolder, saveprefix, and plotformat are defined by container (see hstgrism.container.Container)
    """
    def __init__(self,trace_object_csv,trace_contaminate_csv,halfdy_object=None,instrument=None,
                 contaminate_spectrum_wavelength=None,contaminate_spectrum_flam=None,
                 sensitivity_curve_wavelength=None,sensitivity_curve_value=None,
                 do_drzblot_apcorr=False,drzblot_apcorr_ww=None,drzblot_apcorr_value=None,
                 cps_contaminate_nan_to_zero=True,
                 contaminate_spectrum_interp_rule = {'kind':'linear','bounds_error':False,'fill_value':np.nan},
                 sensitivity_curve_interp_rule = {'kind':'linear','bounds_error':False,'fill_value':0.},
                 drzblot_apcorr_interp_rule = {'kind':'nearest','bounds_error':False,'fill_value':'extrapolate'},
                ):
        self.trace_object = pd.read_csv(trace_object_csv)
        self.trace_contaminate = pd.read_csv(trace_contaminate_csv)
        self.halfdy_object = halfdy_object
        self.instrument = instrument
        self.contaminate_spectrum_wavelength = contaminate_spectrum_wavelength
        self.contaminate_spectrum_flam = contaminate_spectrum_flam
        self.sensitivity_curve_wavelength = sensitivity_curve_wavelength
        self.sensitivity_curve_value = sensitivity_curve_value
        self.do_drzblot_apcorr = do_drzblot_apcorr
        self.drzblot_apcorr_ww = drzblot_apcorr_ww
        self.drzblot_apcorr_value = drzblot_apcorr_value
        self.cps_contaminate_nan_to_zero = cps_contaminate_nan_to_zero
        self.contaminate_spectrum_interp_rule = contaminate_spectrum_interp_rule
        self.sensitivity_curve_interp_rule = sensitivity_curve_interp_rule
        self.drzblot_apcorr_interp_rule = drzblot_apcorr_interp_rule
    def compute(self):
        xg_object = np.ceil(self.trace_object.xh.values + self.trace_object.xyref[0]).astype(int)
        yg_object = np.ceil(self.trace_object.yh.values + self.trace_object.xyref[1]).astype(int)
        ww_object = self.trace_object.ww.values
        xg_contaminate = np.ceil(self.trace_contaminate.xh.values + self.trace_contaminate.xyref[0]).astype(int)
        yg_contaminate = np.ceil(self.trace_contaminate.yh.values + self.trace_contaminate.xyref[1]).astype(int)
        ww_contaminate = self.trace_contaminate.ww.values
        yg_contaminate_wrt_object = np.array([],dtype=object)
        ww_contaminate_wrt_object = np.array([],dtype=object)
        for ii,i in enumerate(xg_object):
            if i in xg_contaminate:
                m = np.argwhere(xg_contaminate == i).flatten()
                yg_contaminate_wrt_object = np.concatenate((yg_contaminate_wrt_object,yg_contaminate[m]))
                ww_contaminate_wrt_object = np.concatenate((ww_contaminate_wrt_object,ww_contaminate[m]))
            else:
                yg_contaminate_wrt_object = np.concatenate((yg_contaminate_wrt_object,[None]))
                ww_contaminate_wrt_object = np.concatenate((ww_contaminate_wrt_object,[None]))
        combine_df = {'xg_object':xg_object,'yg_object':yg_object,'ww_object':ww_object,'yg_contaminate':yg_contaminate_wrt_object,'ww_contaminate':ww_contaminate_wrt_object.astype(float)}
        combine_df = pd.DataFrame(combine_df)
        self.combine_df = combine_df  
        ##### make halfdy_contaminate_in and out #####
        sentinel_halfdy_contaminate = False
        if self.halfdy_object is not None:
            self._halfdy_contaminate()
            sentinel_halfdy_contaminate = True
        else:
            print('halfdy_object must be specified when instantiate Contamination class in order to compute halfdy_contaminate')
        ##### prepare apcorr #####
        sentinel_apcorr_contaminate = False
        if self.instrument is not None and sentinel_halfdy_contaminate:
            self._apcorr_contaminate()
            sentinel_apcorr_contaminate = True
        else:
            print('instrument and halfdy_object must be specified when instantiate Contamination class in order to compute apcorr_contaminate')
        ##### compute cps_contaminate #####
        sentinel_cps_contaminate = False
        if self.contaminate_spectrum_wavelength is not None and self.contaminate_spectrum_flam is not None and self.sensitivity_curve_wavelength is not None and self.sensitivity_curve_value is not None and sentinel_halfdy_contaminate and sentinel_apcorr_contaminate:
            self._cps_contaminate()
            sentinel_cps_contaminate = True
        else:
            print('contaminate_spectrum, sensitivity_curve, instrument, and halfdy_object must be specified when instantiate Contamination class in order to compute cps_contaminate')
    def _cps_contaminate(self):
        contaminate_spectrum_model = interp1d(self.contaminate_spectrum_wavelength,self.contaminate_spectrum_flam,**self.contaminate_spectrum_interp_rule)
        sensitivity_curve_model = interp1d(self.sensitivity_curve_wavelength,self.sensitivity_curve_value,**self.sensitivity_curve_interp_rule)
        ##### wwperpix_contaminate #####
        xhdiff = np.diff(self.combine_df.xg_object.values)
        xhdiff = np.append(xhdiff,xhdiff[-1]) # add one last element to maintain the array dimension
        wwdiff = np.diff(self.combine_df.ww_contaminate.values)
        wwdiff = np.append(wwdiff,wwdiff[-1]) # add one last element to maintain the array dimension
        wwperpix_contaminate = wwdiff/xhdiff
        ##### sensitivity_contaminate #####
        sensitivity_contaminate = sensitivity_curve_model(self.combine_df.ww_contaminate.values)
        ##### flam projecting to GRB frame #####
        flam_contaminate = contaminate_spectrum_model(self.combine_df.ww_contaminate.values)                            
        ##### drzblot_apcorr_contaminate #####
        if self.do_drzblot_apcorr:
            drzblot_apcorr_model = interp1d(self.drzblot_apcorr_ww,self.drzblot_apcorr_value,**self.drzblot_apcorr_interp_rule)
            drzblot_apcorr_contaminate = drzblot_apcorr_model(self.combine_df.ww_contaminate.values)
        else:
            drzblot_apcorr_contaminate = np.full_like(self.combine_df.ww_contaminate.values,1.,dtype=float)                           
        ##### apcorr_contaminate #####
        apcorr_contaminate = 0.5*(self.combine_df.apcorr_contaminate_out.values - self.combine_df.apcorr_contaminate_in.values) # divide by 2 for one-side                             
        ##### cps_contaminate #####
        cps_contaminate = flam_contaminate * wwperpix_contaminate * sensitivity_contaminate * drzblot_apcorr_contaminate * apcorr_contaminate
            # fix nan to zero
        if self.cps_contaminate_nan_to_zero:
            m = np.argwhere(~np.isfinite(cps_contaminate)).flatten()
            cps_contaminate[m] = 0.
        ##### output #####
        self.combine_df['wwperpix_contaminate'] = wwperpix_contaminate
        self.combine_df['sensitivity_contaminate'] = sensitivity_contaminate
        self.combine_df['flam_contaminate'] = flam_contaminate
        self.combine_df['drzblot_apcorr_contaminate'] = drzblot_apcorr_contaminate
        self.combine_df['apcorr_contaminate'] = apcorr_contaminate
        self.combine_df['cps_contaminate'] = cps_contaminate
    def _halfdy_contaminate(self):
        tx = self.combine_df.xg_object.values.copy()
        ty = self.combine_df.yg_object.values.copy()
        tyc = self.combine_df.yg_contaminate.copy()
        halfdy_contaminate_in = np.full_like(tx,None,dtype=object)
        halfdy_contaminate_out = np.full_like(tx,None,dtype=object)
        for ii,i in enumerate(tx):
            if tyc[ii] is not None:
                halfdy_contaminate_in[ii] = np.abs(tyc[ii] - ty[ii]) - self.halfdy_object[ii] - 1
                halfdy_contaminate_out[ii] = np.abs(tyc[ii] - ty[ii]) + self.halfdy_object[ii]
        self.combine_df['halfdy_contaminate_in'] = halfdy_contaminate_in.copy()
        self.combine_df['halfdy_contaminate_out'] = halfdy_contaminate_out.copy() 
    def _apcorr_contaminate(self):
        halfdy_list = np.concatenate((self.combine_df.halfdy_contaminate_in.unique(),self.combine_df.halfdy_contaminate_out.unique()))
        apcorr = {}
        for i in halfdy_list:
            if i is None:
                continue
            halfdy_contaminate = i
            apsizepix = halfdy_contaminate*2 + 1
            apsizepix = np.full_like(self.combine_df.ww_contaminate.values,apsizepix,dtype=float)
            apsizearcsec = apsizepix * GrismApCorr().table[self.instrument]['scale']
            tww = self.combine_df.ww_contaminate.values.astype(float).copy()
            tww[np.isnan(tww)] = tww[np.isfinite(tww)].min()
            apcorrobj = GrismApCorr(instrument=self.instrument,apsize=apsizearcsec,wave=tww,aptype='diameter',apunit='arcsec',waveunit='A')
            apcorrobj.compute()
            apcorr[i] = apcorrobj.data['apcorr'] 
        self.apcorr = apcorr   
        ##### prepare apcorr_contaminate_in and apcorr_contaminate_out #####
        ##### place them into combine_df #####
        tx = self.combine_df.halfdy_contaminate_in.values.copy()
        apcorr_contaminate_in = np.array([],dtype=object)
        for ii,i in enumerate(tx):
            if i is None:
                apcorr_contaminate_in = np.concatenate((apcorr_contaminate_in,np.array([None])))
            else:
                apcorr_contaminate_in = np.concatenate((apcorr_contaminate_in,np.array([self.apcorr[i][ii]])))
        tx = self.combine_df.halfdy_contaminate_out.values.copy()
        apcorr_contaminate_out = np.array([],dtype=object)
        for ii,i in enumerate(tx):
            if i is None:
                apcorr_contaminate_out = np.concatenate((apcorr_contaminate_out,np.array([None])))
            else:
                apcorr_contaminate_out = np.concatenate((apcorr_contaminate_out,np.array([self.apcorr[i][ii]])))
        self.combine_df['apcorr_contaminate_in'] = apcorr_contaminate_in.astype(float)
        self.combine_df['apcorr_contaminate_out'] = apcorr_contaminate_out.astype(float)
    ##########
    ##########
    ##########
    def save(self,container=None):
        if container is None:
            raise ValueError('container must be specified to save.')
        string = './{0}/{1}_contaminateTrace.csv'.format(container.data['savefolder'],container.data['saveprefix'])
        self.combine_df.to_csv(string)
        print('Save {0}'.format(string))
    ##########
    ##########
    ##########
    def show(self,figsize=(10,10),
             object_color='black',object_ls=':',object_marker='x',
             contaminate_color='red',contaminate_ls=':',contaminate_marker='x',
             fontsize=12,
             save=False,container=None,
            ):
        plt.figure(figsize=figsize)
        combine_df = self.combine_df
        halfdy_object = self.halfdy_object
        plt.plot(combine_df.xg_object,combine_df.yg_object,ls=object_ls,color=object_color,label='Object trace')
        plt.plot(combine_df.xg_object,combine_df.yg_object + halfdy_object,color=object_color,marker=object_marker,label='Object aperture (inclusive)')
        plt.plot(combine_df.xg_object,combine_df.yg_object - halfdy_object,color=object_color,marker=object_marker)
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate,color=contaminate_color,ls=contaminate_ls,label='Contaminate trace')
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate - combine_df.halfdy_contaminate_in,color=contaminate_color,marker=contaminate_marker,label='Contaminate aperture (inclusive)')
        plt.plot(combine_df.xg_object,combine_df.yg_contaminate - combine_df.halfdy_contaminate_out,color=contaminate_color,marker=contaminate_marker)
        plt.xlabel('pixX',fontsize=fontsize)
        plt.ylabel('pixY',fontsize=fontsize)
        plt.title('Contamination',fontsize=fontsize)
        plt.legend(loc=(1.01,0.))
        plt.tight_layout()
        if save:
            if container is None:
                raise ValueError('container must be specified to save')
            string = './{0}/{1}_contamination.{2}'.format(container.data['savefolder'],container.data['saveprefix'],container.data['plotformat'])
            plt.savefig(string,plotformat=container.data['plotformat'],bbox_inches='tight')
            print('Save {0}'.format(string))          
            