#!/usr/bin/python3 -Wignore
"""StarbugII Matching 
usage: starbug2-match [-BGfhv] [-o output] [-p file.param] [-s KEY=VAL] table.fits ...
    -B  --band               : match in "BAND" mode (does not preserve a column for every frame)
    -C  --cascade            : match in "CASCADE" mode (left justify columns)
    -D  --dither             : match in "DITHER" mode (preserves a column for every frame)
    -f  --full               : export full catalogue
    -G  --generic            : match in "GENERIC" mode
    -h  --help               : show help message
    -o  --output  file.fits  : output matched catalogue
    -p  --param   file.param : load starbug parameter file
    -s  --set     option     : set value in parameter file at runtime (-s MATCH_THRESH=1)

    --> typical runs
       $~ starbug2-match -Gfo outfile.fits tab1.fits tab2.fits
       $~ starbug2-match -sMATCH_THRESH=0.2 -sBRIDGE_COL=F444W -Bo out.fits F*W.fits
"""

import os,sys,getopt,glob
import numpy as np
import pkg_resources
from astropy.io import fits
from astropy.table import Table, hstack, vstack
import starbug2
from starbug2 import utils
from starbug2 import matching

VERBOSE=0x01
KILLPROC=0x02

BANDMATCH   =0x04
DITHERMATCH =0x08
GENERICMATCH=0x10
CASCADEMATCH=0x20

EXPFULL = 0x100

options=0

parameter={}
pfile=None
output=None
setopt={}

def usage():
    if options& VERBOSE: quit(__doc__)
    else: quit( __doc__.split("\n")[1])

opts,args=getopt.getopt(sys.argv[1:], "BCDfGhvo:p:s:", ("band","cascade","dither","full", "generic", "help","verbose",
                                                "output=", "param=", "set="))

for opt,optarg in opts:
    if opt in ("-h", "--help"): usage()
    if opt in ("-v", "--verbose"): options|=VERBOSE
    if opt in ("-o", "--output"): output=optarg
    if opt in ("-p", "--param"): pfile=optarg

    if opt in ("-f","--full"): options|=EXPFULL


    if opt in ("-s","--set"): 
        if '=' in optarg:
            key,val=optarg.split('=')
            try: val=float(val)
            except: pass
            setopt[key]=val

        else: perror("unable to set parameter, use syntax -s KEY=VALUE\n")

    if opt in ("-B","--band"): options|=BANDMATCH
    if opt in ("-C","--cascade"): options|=CASCADEMATCH
    if opt in ("-D","--dither"): options|=DITHERMATCH
    if opt in ("-G","--generic"): options|=GENERICMATCH

if not len(args): usage()

if pfile: parameters=utils.load_params(pfile)
elif os.path.exists("./starbug.param"): parameters=utils.load_params("./starbug.param")
else: parameters=utils.load_params("%s/default.param"%pkg_resources.resource_filename("starbug2", "param/"))

if parameters: parameters.update(setopt)
else: 
    utils.perror("failed to load parameter file\n")
    quit("..quitting :(")

####################################
# FILE IO ECT
################
tables=[]
filters=[]
for fname in args:
    tab=utils.import_table(fname)
    if tab is not None:

        utils.printf("-> loading \"%s\""%fname)
        tables.append(tab)

        if (fltr:=tab.meta.get("FILTER")):
            filters.append(fltr)
        elif (fltr:=( list(set(tab.colnames)&set(starbug2.filters.keys()))).pop()):
            filters.append(fltr)
        else: 
            filters.append((fltr:=None))
        
        if fltr is not None: utils.printf(" (%s)\n"%fltr)
        else: utils.puts()

if not len(tables):
    utils.perror("No tables loaded for matching\n")
    quit()

colnames=starbug2.match_cols
colnames+=[ name for name in parameters["MATCH_COLS"].split() if name not in colnames]
dthreshold=parameters["MATCH_THRESH"]
nthreshold=parameters["NEXP_THRESH"]
snthresh=parameters["SN_THRESH"]

#################
# SN RATIO CUTS #
#################
if snthresh>0:
    utils.puts("SN Ratio Cuts")
    for i,(tab,fltr) in enumerate(zip(tables, filters)):

        if fltr:
            mask = ((tab[fltr]/tab["e%s"%fltr])<snthresh)
            utils.printf("-> %s: Removing %d sources\n"%(fltr, sum(mask)))
            tables[i].remove_rows(mask)
        else:
            utils.perror("Unable to determine filter of \"%s\"\n"%args[i])


if options & BANDMATCH:
    #filters=[]
    tomatch={ starbug2.NIRCAM:[], starbug2.MIRI:[] }
    fname=output if output else "out.fits"
    _colnames=["RA","DEC","flag"]

    for i,(tab,fltr) in enumerate(zip(tables,filters)):
        tomatch[starbug2.filters[fltr].instr].append(tab)
        _colnames+=([fltr,"e%s"%fltr])
    
    if tomatch[starbug2.NIRCAM] and tomatch[starbug2.MIRI]:
        utils.printf("Detected NIRCam to MIRI matching\n")

        nircam_matched=matching.band_match(tomatch[starbug2.NIRCAM], colnames=_colnames)
        utils.printf("\n--> %s\n"%(_fname:="%s-nircam.fits"%os.path.splitext(fname)[0]))
        utils.export_table(nircam_matched, fname=_fname)

        miri_matched=matching.band_match(tomatch[starbug2.MIRI], colnames=_colnames)
        utils.printf("\n--> %s\n"%(_fname:="%s-miri.fits"%os.path.splitext(fname)[0]))
        utils.export_table(miri_matched, fname=_fname)

        load=utils.loading(len(miri_matched), msg="Combining NIRCAM-MIRI(%.2g\")"%dthreshold)
        if (bridgecol:=parameters.get("BRIDGE_COL")):
            mask= np.isnan(nircam_matched[bridgecol])
            utils.printf("-> bridging catalogues with %s\n"%bridgecol)
            #if not (bridgecol:=parameters.get("BRIDGE_COL")):
            #bridgecol= sorted([f for f in filters if starbug2.filters[f].instr==starbug2.NIRCAM],key=lambda f: list(starbug2.filters.keys()).index(f))[-1]
        else: mask=np.full(len(nircam_matched), False)


        matched,_=matching.generic_match((nircam_matched[~mask],miri_matched), threshold=dthreshold, add_src=True, load=load)
        matched.remove_column("NUM")
        matched=vstack((matched, nircam_matched[mask]))
    else:
        matched=matching.band_match(tables, colnames=_colnames)
        
    utils.printf("--> %s\n"%fname)
    utils.export_table(matched,fname=fname)

else:
    if options & DITHERMATCH: av,full=matching.dither_match(tables, threshold=dthreshold, colnames=colnames)
    elif options & CASCADEMATCH: av,full=matching.cascade_match(tables, threshold=dthreshold, colnames=colnames)
    else:#elif options & GENERICMATCH: 
        options|=EXPFULL
        av,full=matching.generic_match(tables,threshold=dthreshold, add_src=True, average=True, load=options&VERBOSE)


    dtypes=[]
    for name in full.colnames:
        if name=="Catalogue_Number": dtypes.append(str)
        elif name=="flag": dtypes.append(np.uint16)
        else: dtypes.append(float)
    full=Table(full,dtype=dtypes).filled(np.nan) ## fill empty values with null

    if av: 
        av.meta.update(tables[0].meta)
        if nthreshold!=-1:
            mask=av["NUM"]>=nthreshold
            av=av[mask]

    if output is None:
        output=utils.combine_fnames( [ name for name in args] , ntrys=100)
    dname,fname,ext=utils.split_fname(output)

    utils.printf("-> %s/%s*\n"%(dname,fname))
    if options&EXPFULL: utils.export_table(full,fname="%s/%sfull.fits"%(dname,fname))
    if av: utils.export_table(av,"%s/%smatch.fits"%(dname,fname))
