#!/usr/bin/env python

import time
import os
import sys

import argparse

from glue import gpstime
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from glue.ligolw import ilwd

from pycbc.waveform import parameters as default_parameters

import gwnrtools.analysis as DA

__author__ = "Prayush Kumar <prayush.kumar@gmail.com>"
PROGRAM_NAME = os.path.abspath(sys.argv[0])

#########################################################################
#########################################################################

####################       Functions     #####################
def get_sim_hash(N=1, num_digits=10):
      return ilwd.ilwdchar(":%s:0"%DA.get_unique_hex_tag(N = N,
                                                         num_digits = num_digits))

####################       Input parsing     #####################
parser = argparse.ArgumentParser(usage = "%%prog [OPTIONS]", description="""
Takes in an injection or template bank file, which primarily consists of
a SimInspiral or SnglInspiral table. This script replaces the file with
another that stores all data in a table of the other kind, ie if the input
is a SnglInspiral table, then the output file has a SimInspiral table, and
vice versa.
""", formatter_class=argparse.RawTextHelpFormatter)

# IO related inputs
parser.add_argument("--input", dest="input",
                    help="Input file with Sim/SnglInspiralTable")
parser.add_argument("--output", dest="output",
                    help="Output file with Sngl/SimInspiralTable")
parser.add_argument("--add-hashes", action="store_true",
                    help="Add a unique hex hash to each row",
                    default=False)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="print extra debugging information",
                    default=False)

options = parser.parse_args()

################ Read in input XML ##################
input_doc = ligolw_utils.load_filename(options.input,
                  contenthandler=table.use_in(ligolw.LIGOLWContentHandler),
                  verbose=options.verbose)
try:
    input_table = lsctables.SimInspiralTable.get_table(input_doc)
    input_table_type = lsctables.SimInspiralTable
    output_table_type = lsctables.SnglInspiralTable
    output_row_type = lsctables.SnglInspiral
except ValueError:
    try:
        input_table = lsctables.SnglInspiralTable.get_table(input_doc)
        input_table_type = lsctables.SnglInspiralTable
        output_table_type = lsctables.SimInspiralTable
        output_row_type = lsctables.SimInspiral
    except ValueError:
        raise IOError("Only sngl/sim_inspiral tables are understood in XMLs..")

################ Define columns and their mappings ##################
sensible_defaults = {
    'alpha' : 0.0,
    'alpha1' : 0.0,
    'alpha2' : 0.0,
    'latitude' : 0.0,
    'longitude' : 0.0,
    'polarization' : 0.0
}
columns_needed_in_sim_inspiral = [
    'mass1', 'mass2', 'mchirp', 'eta',
    'spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z',
    'alpha', 'alpha1', 'alpha2', # for eccentricity, mean anomaly, and long ascending node
    'distance', 'coa_phase', 'inclination', 'polarization', 'latitude', 'longitude',
    'simulation_id', 'process_id'
]
columns_needed_in_sngl_inspiral = [
    'mass1', 'mass2', 'mchirp', 'eta',
    'spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z',
    'alpha', 'alpha1', 'alpha2', # for eccentricity, mean anomaly, and long ascending node
    'coa_phase', 'process_id',
    'event_id'
]

# Column mappings
map_column_in_sim_to_sngl_inspiral = {
    'simulation_id' : 'event_id'
}
map_column_in_sngl_to_sim_inspiral = {}
for c in map_column_in_sim_to_sngl_inspiral:
    map_column_in_sngl_to_sim_inspiral[map_column_in_sim_to_sngl_inspiral[c]] = c

# Assign for current case
if output_table_type is lsctables.SimInspiralTable:
    columns_needed = columns_needed_in_sim_inspiral
    map_columns = map_column_in_sngl_to_sim_inspiral
    hash_field = "simulation_id"
else:
    columns_needed = columns_needed_in_sngl_inspiral
    map_columns = map_column_in_sim_to_sngl_inspiral
    hash_field = "event_id"

################ Create output document ##################
new_points_doc = ligolw.Document()
new_points_doc.appendChild(ligolw.LIGO_LW())

out_proc_id = ligolw_process.register_to_xmldoc(new_points_doc,
    PROGRAM_NAME, options.__dict__).process_id

new_points_table = lsctables.New(output_table_type, columns=columns_needed)
new_points_doc.childNodes[0].appendChild(new_points_table)

############## Copy points over to the new table ##############
# Loop over input table preserving ordering
for idx in range(len(input_table)):
    old_row = input_table[idx]
    new_row = output_row_type()
    unavailable_columns = []
    for col in input_table.columnnames:
        if col in columns_needed:
            setattr(new_row, col, getattr(old_row, col))
        else:
            unavailable_columns.append(col)
    for col in unavailable_columns:
        if col in map_columns:
            setattr(new_row, map_columns[col], getattr(old_row, col))
    for col in columns_needed:
        if not hasattr(new_row, col):
            if hasattr(default_parameters, col) and getattr(default_parameters, col).default != None:
                setattr(new_row, col, getattr(default_parameters, col).default)
            elif col in sensible_defaults:
                setattr(new_row, col, sensible_defaults[col])
    new_points_table.append(new_row)

# Now, loop over the new table and make sure hashes are unique
if options.add_hashes:
    if getattr(new_points_table[0], hash_field) == getattr(new_points_table[1], hash_field):
        for row in new_points_table:
            setattr(row, hash_field, get_sim_hash())

############## Write the new sample points to XML #############
print("Writing {} new points to {}".format(len(new_points_table), options.output))
sys.stdout.flush()

new_points_proctable = table.get_table(new_points_doc, lsctables.ProcessTable.tableName)
new_points_proctable[0].end_time = gpstime.GpsSecondsFromPyUTC(time.time())
ligolw_utils.write_filename(new_points_doc, options.output)
