#!/usr/bin/env python
#
# Copyright (C) 2017 Prayush Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Calculate the fitting factors of simulated signals with a template bank."""

import sys
import os, logging
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',\
                     level=logging.INFO, stream=sys.stdout)
import time
_itime = time.time()

import argparse
import numpy as np

import gwnrtools.analysis as DA
import gwnrtools.waveform as WF

import lal
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils

from pycbc.waveform import get_td_waveform, get_fd_waveform
from pycbc.waveform import td_approximants, fd_approximants
from pycbc.types import FrequencySeries, TimeSeries, zeros
from pycbc.filter import make_frequency_series, match, sigma
from pycbc.detector import overhead_antenna_pattern
import pycbc.psd, pycbc.strain, pycbc.scheme


__author__ = "Prayush Kumar <prayush.kumar@gmail.com>"
PROGRAM_NAME = os.path.abspath(sys.argv[0])
#########################################################################
####################       Input parsing     #####################
#########################################################################
#{{{
parser = argparse.ArgumentParser(usage = "%%prog [OPTIONS]", description="""
Takes in a sub-bank and proposal points (as XML). Computes overlaps between
the systems in those files, and stores the maximum of these overlaps (for
each proposal point) in a file match_id_part_pid.dat.
""", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

aprs = sorted(list(set(td_approximants() + fd_approximants())))

# IO related inputs
parser.add_argument("--template-file", required = True,
                    dest="bank_file_name", help="The bank file")
parser.add_argument("--signal-file", required = True,
                    dest="prop_file_name", help="The new points file")
parser.add_argument("--match-file", required = True,
                    dest="match_file_name", help="The file to store matches")
parser.add_argument("--template-batch-size", dest="bank_batch_size",
                    default=100, type=int,
                    help="No of points for which wavefors be pre-generated")
parser.add_argument("--signal-batch-size", dest="proposal_batch_size",
                    default=100, type=int,
                    help="No of points for which wavefors be pre-generated")

# Physics related inputs
parser.add_argument("--template-approximant", dest = "bank_approximant",
                    choices = aprs, default="SpinTaylorT4",
                    help="Waveform Approximant for bank templates")
parser.add_argument("--proposal-approximant", choices = aprs,
                    default="SpinTaylorT4",
                    help="Waveform Approximant for signal / proposal templates")

# Filtering reductioning inputs
parser.add_argument("--mchirp-window", default=0.1, type=float)
parser.add_argument("--tau0-window", default=0.0, type=float)
# parser.add_argument("--eccentricity-window", dest="ecc_window", default=0.0,
#                     type=float)

# add PSD options
pycbc.psd.insert_psd_option_group(parser, output=False)

# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)

# add filtering options
parser.add_argument('-f', '--low-frequency-cutoff', metavar='FREQ',
                    dest="f_min", default=15.0, type=float,
                    help='low frequency cutoff of matched filter')
parser.add_argument("-l", "--signal-length", dest="signal_length",
                    default=128, type=int, help="The length of the signal (s)")

#hardware support
pycbc.scheme.insert_processing_option_group(parser)
pycbc.fft.insert_fft_option_group(parser)

parser.add_argument("--do-not-flush", action="store_true", default=False,
                    help="""If enabled, output will be written when all
                    calculation is complete.""")

# Miscellaneous
parser.add_argument("--tolerate-waveform-failures", action="store_true",
                    default=False,
                    help="Skip signal / bank waveforms that fail to generate")
parser.add_argument("-V", "--verbose", action="store_true", default=False,
                    help="print extra debugging information")
parser.add_argument("-C", "--comment", metavar="STRING", default='',
                    help="add the optional STRING as the process:comment")

options = parser.parse_args()

pycbc.psd.verify_psd_options(options, parser)

if options.psd_estimation:
    pycbc.strain.verify_strain_options(options, parser)

#}}}



#########################################################################
####################### Functions to generate waveform ##################
#########################################################################
generate_fplus_fcross    = overhead_antenna_pattern
generate_detector_strain = WF.generate_detector_strain
outside_mchirp_window    = DA.outside_mchirp_window
outside_tau0_window      = DA.outside_tau0_window
get_sim_hash             = DA.get_sim_hash

#############################
# def get_tag(wav): return str(wav.simulation_id.column_name)

def get_tag(wav):
    try:
        return options.prop_file_name + ':{0}'.format(prop_table.index(wav))
    except ValueError:
        return options.bank_file_name + ':{0}'.format(bank_table.index(wav))

def waveform_exists(wav, waves): return get_tag(wav) in waves

def outside_ecc_window(bank, sim, w, key = 'alpha'):
    b_ecc = getattr(bank, key)
    s_ecc = getattr(sim, key)
    return abs(s_ecc - b_ecc) > w

def get_waveform(wav, approximant, f_min, dt, N):
    """This function will generate the waveform corresponding to the point
    taken as input.
    Note: If waveform generation fails, the function will return None if 
            --tolerate-waveform-failures is specified."""
    #{{{
    m1 = wav.mass1
    m2 = wav.mass2

    s1x = wav.spin1x
    s1y = wav.spin1y
    s1z = wav.spin1z
    s2x = wav.spin2x
    s2y = wav.spin2y
    s2z = wav.spin2z

    ecc = wav.alpha
    mean_per_ano = wav.alpha1
    long_asc_nodes = wav.alpha2
    coa_phase = wav.coa_phase

    inc = wav.inclination
    dist = wav.distance

    df = 1./(dt * N)
    # choose a sensible max frequency: put a threshold at Mf = 0.15
    f_max = min( 1./(2.*dt), 0.15 / ((m1+m2)*lal.MTSUN_SI) )

    if approximant in fd_approximants():
        try:
            hptild, hctild = get_fd_waveform(approximant=approximant,
                                    mass1=m1, mass2=m2,
                                    spin1x=s1x, spin1y=s1y, spin1z=s1z,
                                    spin2x=s2x, spin2y=s2y, spin2z=s2z,
                                    eccentricity=ecc,
                                    mean_per_ano=mean_per_ano,
                                    long_asc_nodes=long_asc_nodes,
                                    coa_phase=coa_phase,
                                    inclination=inc, distance=dist,
                                    f_lower=f_min, f_final=f_max, delta_f=df)
        except RuntimeError as re:
            for c in dir(wav):
                if "__" not in c and "get" not in c and "set" not in c and hasattr(wav, c):
                    print(c, getattr(wav, c))
            if options.tolerate_waveform_failures:
                return None
            raise RuntimeError(re)
        hptilde = FrequencySeries(hptild, delta_f=df,
                                  dtype=np.complex128,copy=True)
        hpref_padded = FrequencySeries(zeros(N/2 + 1), delta_f=df,
                                  dtype=np.complex128, copy=True )
        hpref_padded[0:len(hptilde)] = hptilde
        hctilde = FrequencySeries(hctild, delta_f=df,
                                  dtype=np.complex128,copy=True)
        hcref_padded = FrequencySeries(zeros(N/2 + 1), delta_f=df,
                                  dtype=np.complex128, copy=True )
        hcref_padded[0:len(hctilde)] = hctilde
        href_padded = generate_detector_strain(wav, hpref_padded, hcref_padded)
    elif approximant in td_approximants():
        try:
            hp, hc = get_td_waveform(approximant=approximant,
                            mass1=m1, mass2=m2,
                            spin1x=s1x, spin1y=s1y, spin1z=s1z,
                            spin2x=s2x, spin2y=s2y, spin2z=s2z,
                            eccentricity=ecc,
                            mean_per_ano=mean_per_ano,
                            long_asc_nodes=long_asc_nodes,
                            coa_phase=coa_phase,
                            inclination=inc, distance=dist,
                            f_lower=f_min, delta_t=dt)
        except RuntimeError as re:
            for c in dir(wav):
                if "__" not in c and "get" not in c and "set" not in c and hasattr(wav, c):
                    print(c, getattr(wav, c))
            if options.tolerate_waveform_failures:
                return None
            raise RuntimeError(re)
        hpref_padded = TimeSeries(zeros(N), delta_t=dt,dtype=hp.dtype,copy=True)
        hpref_padded[:len(hp)] = hp
        hcref_padded = TimeSeries(zeros(N), delta_t=dt,dtype=hc.dtype,copy=True)
        hcref_padded[:len(hc)] = hc
        href_padded_td = generate_detector_strain(wav,hpref_padded,hcref_padded)
        href_padded = make_frequency_series(href_padded_td)
    return href_padded
    #}}}

def append_one_match(bank, sim, mval, norm_b = -1.0, norm_s = -1.0):
    out_str = "{0}\t{1}\t{2:.12e}\t{3:.12e}\t{4:.12e}\n".format(\
        get_tag(bank), get_tag(sim), mval, norm_b, norm_s)
    if options.do_not_flush:
        output.append(out_str)
    else:
        with open(options.match_file_name, "a") as myfile:
            myfile.write(out_str)

#########################################################################
#################### Opening input/output files/tables ##################
#########################################################################
# Open the input sub-bank file and get the table
if not options.bank_file_name:
    logging.info("No bank file-name given!")
    raise IOError("No bank file-name given to " + PROGRAM_NAME)

if not os.path.exists(options.bank_file_name):
    logging.info("This bank file does not exist !")
    raise IOError("The bank file {} does not exist".format(options.bank_file_name))

if options.verbose:
    logging.info("Opening bank file %s" % options.bank_file_name)

bank_doc = ligolw_utils.load_filename(options.bank_file_name,
                  contenthandler = table.use_in(ligolw.LIGOLWContentHandler),
                  verbose = options.verbose)
try:
    bank_table = lsctables.SimInspiralTable.get_table(bank_doc)
    for row in bank_table:
        if not hasattr(row, 'simulation_id'):
            row.simulation_id = get_sim_hash()
except ValueError:
    try:
        bank_table = lsctables.SnglInspiralTable.get_table(bank_doc)
        # We need to assign unique tags to all rows in this 
        # table as we won't have those available under the 
        # `simulation_id` column. We also addtionally include
        # sensible values for `inclination`, `distance`, 
        # `latitude`, `longitude` and `polarization` fields!
        for row in bank_table:
            row.simulation_id = get_sim_hash()
            row.inclination = 0
            row.distance = 1.e6
            row.latitude = 0
            row.longitude = 0
            row.polarization = 0
    except ValueError:
        raise IOError("Only sngl/sim_inspiral tables are understood for banks")

# Open the input proposals file and get the table
if not options.prop_file_name:
    logging.info("No injection/proposal points file-name given!")
    raise ValueError("No injection/proposal points file-name given to " + PROGRAM_NAME)

if not os.path.exists(options.prop_file_name):
    logging.info("This injections/proposals file does not exist (IO ERROR)")
    raise IOError(\
        "The injection/proposal points file {} does not exist !".format(\
            options.prop_file_name))

logging.info("Opening injections/proposals file %s" % options.prop_file_name)
prop_doc = ligolw_utils.load_filename(options.prop_file_name,
                  contenthandler = table.use_in(ligolw.LIGOLWContentHandler),
                  verbose = options.verbose)
try:
    prop_table = lsctables.SimInspiralTable.get_table(prop_doc)
except ValueError:
    raise IOError(\
        "Only sim_inspiral tables are understood for injections/proposals..")

logging.info("We have {0} templates and {1} signals / proposals".format(\
    len(bank_table), len(prop_table)))
#########################################################################
#############################   Compute Overlaps   ######################
#########################################################################
# Initialize quantities for overlap calculations
f_min         = options.f_min
signal_length = options.signal_length
sample_rate   = options.sample_rate

dt            = 1. / float(sample_rate)
N             = signal_length * sample_rate
n             = N / 2 + 1
df            = 1. / float(signal_length)
if options.verbose:
    logging.info("f_min={}, sig_len={}, sample_rate={}, dt={}, N={}".format(f_min,\
        signal_length,sample_rate,dt,N))

# Get hardware context
ctx = pycbc.scheme.from_cli(options)

# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
    logging.info("Obtaining h(t) for PSD generation")
    strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
    strain = None

# GET psd
psd = pycbc.psd.from_cli(options, n, df, f_min, strain = strain)

##########################################################
### Note on algorithm to follow:-
## 0) Eliminate calculations as much as possible.
##    - use mchirp_window
##    - use self vs self
##    - if table is empty..
## 1) Generate waveform only and only if it is to be used in an overlap calc.
## 2) Once generated, store every waveform indexed by its HASH. No need for
##     separate bank and proposal dicts, as HASHes must be unique!
## 3) compute matches

##########################################################
# Storage
waveforms = {}

if options.do_not_flush:
    output = []

##########################################################
# Split templates into batches
bank_batch_size = options.bank_batch_size
if len(bank_table) == 0:
    try: os.mknod(options.match_file_name)
    except OSError: pass
    logging.info("Bank file {} is empty. Exiting!".format(options.bank_file_name))
    sys.exit(0)
elif len(bank_table) <= bank_batch_size:
    bank_batches = [bank_table]
else:
    bank_batches=[bank_table[i:i+bank_batch_size] for i in range(0,\
                                              len(bank_table), bank_batch_size)]

##########################################################
# Split injections / proposals into batches
prop_batch_size = options.proposal_batch_size
if len(prop_table) == 0:
    try: os.mknod(options.match_file_name)
    except OSError: pass
    logging.info("Proposal file {} is empty. Exiting!".format(options.prop_file_name))
    sys.exit(0)
elif len(prop_table) <= prop_batch_size:
    prop_batches = [prop_table]
else:
    prop_batches = [prop_table[i:i+prop_batch_size] for i in range(0,\
                                              len(prop_table), prop_batch_size)]

##########################################################
cnt_bank_generations  = 0
cnt_test_generations  = 0
cnt_match_evaluations = 0
with ctx:
    for i, bank_batch in enumerate(bank_batches):
        if options.verbose:
            logging.info("\t Processing bank batch {} of {} (size {})".format(i+1,\
                len(bank_batches), len(bank_batch)))
        for j, prop_batch in enumerate(prop_batches):
            if options.verbose:
                logging.info("\t\t Processing proposal batch {} of {} (size {})".format(\
                    j+1, len(prop_batches), len(prop_batch)))
            for k, pb in enumerate(bank_batch):
                for l, pp in enumerate(prop_batch):
                    ## Avoid computing match as much as possible!
                    if options.mchirp_window and \
                        outside_mchirp_window(pp, pb, options.mchirp_window):
                        if options.verbose:
                            logging.warn(\
                                "\t Skipped (o, {}) due to mchirp".format(k))
                        append_one_match(pb, pp, -1)
                        continue
                    if options.tau0_window and \
                        outside_tau0_window(pp, pb, options.tau0_window, f_min):
                        if options.verbose:
                            logging.warn(\
                                "\t Skipped (o, {}) due to tau0".format(k))
                        append_one_match(pb, pp, -1)
                        continue
                    if get_tag(pp) == get_tag(pb):
                        if options.verbose:
                            logging.warn(\
                                "\t Skipped (o, {}) due to TAG".format(k))
                        append_one_match(pb, pp, 1, 1, 1)
                        continue
                    
                    ## Now, we really need to get both of these waveforms!
                    # first the template
                    if waveform_exists(pb, waveforms): stilde = waveforms[get_tag(pb)]
                    else:
                        cnt_bank_generations += 1
                        if options.verbose:
                            logging.info(\
                                "\t Computing waves for ({}, o)".format(k))
                        stilde = get_waveform(pb, options.bank_approximant,\
                                            f_min, dt, N)
                        waveforms[get_tag(pb)] = stilde
                    # then the signal / injection / proposal
                    if waveform_exists(pp, waveforms): htilde = waveforms[get_tag(pp)]
                    else:
                        cnt_test_generations += 1
                        if options.verbose:
                            logging.info("\t Computing waves for (o, {})".format(l))
                        htilde = get_waveform(pp, options.proposal_approximant,
                                                                  f_min, dt, N)
                        waveforms[get_tag(pp)] = htilde

                    ## Compute match!
                    if stilde is not None:
                        norm_s = sigma(stilde, psd = psd, low_frequency_cutoff = f_min)
                    else: norm_s = -1
                    if htilde is not None:
                        norm_h = sigma(htilde, psd = psd, low_frequency_cutoff = f_min)
                    else: norm_h = -1
                    if stilde is not None and htilde is not None:
                        mval, _ = match(stilde, htilde, psd=psd, low_frequency_cutoff=f_min)
                    else: mval = -2
                    append_one_match(pb, pp, mval, norm_s, norm_h)
                    cnt_match_evaluations += 1

if options.do_not_flush:
    with open(options.match_file_name, "a") as myfile:
        for out_str in output: myfile.write(out_str)

if options.verbose:
    logging.info("Written results to file: {}".format(options.match_file_name))
    logging.info("Total {}+{} waves generated, {} matches evaluated.".format(\
        cnt_bank_generations, cnt_test_generations, cnt_match_evaluations))
    logging.info("Time taken: {} seconds".format(time.time() - _itime))
