#!/usr/bin/env python
""""
ACTION
 - Reads in ile output.  Outputs flat ascii file, compatible with tools that parse lalinference output
EXAMPLES
  gsiscp ldas-jobs.ligo.caltech.edu:~pankow/param_est/data/zero_noise_mdc/unpin_single/zero_noise_tref_unpinned.xml.gz 
  python convert_output_format_ile2inference zero_noise_tref_unpinned.xml.gz  | more



"""


import sys
from optparse import OptionParser
import numpy as np
from ligo.lw import utils, table, lsctables, ligolw
try:
    import h5py
except:
    print(" - no h5py - ")


# Contenthandlers : argh
#   - http://software.ligo.org/docs/glue/
lsctables.use_in(ligolw.LIGOLWContentHandler)

def mc(m1,m2):
    return np.power(m1*m2, 3./5.)/np.power(m1+m2, 1./5.)
def eta(m1,m2):
    return m1*m2/(np.power(m1+m2, 2))


optp = OptionParser()
optp.add_option("--fref",default=20,type=float,help="Reference frequency. Depending on approximant and age of implementation, may be ignored")
optp.add_option("--export-extra-spins",action='store_true',help="Reference frequency. Depending on approximant and age of implementation, may be ignored")
optp.add_option("--export-tides",action='store_true',help="Include tidal parameters")
optp.add_option("--export-eos",action='store_true',help="Include EOS index parameter")
optp.add_option("--export-cosmology",action='store_true',help="Include source frame masses and redshift")
optp.add_option("--export-weights",action='store_true',help="Include a field 'weights' equal to L p/ps")
optp.add_option("--export-eccentricity", action="store_true", help="Include eccentricity")
optp.add_option("--with-cosmology",default="Planck15",help="Specific cosmology to use")
optp.add_option("--use-interpolated-cosmology",action='store_true',help="Specific cosmology to use")
optp.add_option("--convention",default="RIFT",help="RIFT|LI")
opts, args = optp.parse_args()
if opts.export_cosmology:
    import astropy, astropy.cosmology
    from astropy.cosmology import default_cosmology
    import astropy.units as u
    default_cosmology.set(opts.with_cosmology)
    cosmo = default_cosmology.get()

    # Set up fast cosmology
    def cosmo_func(dist_Mpc):
        return astropy.cosmology.z_at_value(cosmo.luminosity_distance, dist_Mpc*u.Mpc)
    if opts.use_interpolated_cosmology:
        from scipy import interpolate
        zvals = np.arange(0,20,0.0025)   # Note hardcoded maximum redshift
        dvals = np.zeros(len(zvals))
        for indx in np.arange(len(zvals)):
            dvals[indx]  = float(cosmo.luminosity_distance(zvals[indx])/u.Mpc)
        # Interpolate. should use monotonic spline code, but ...
        interp_cosmo_func = interpolate.interp1d(dvals,zvals)
        cosmo_func = interp_cosmo_func
        

if opts.export_extra_spins:
    import RIFT.lalsimutils as lalsimutils


# Add LI-style export
if opts.convention == 'LI':
  import RIFT.lalsimutils as lalsimutils
  print("# m1 m2 a1x a1y a1z a2x a2y a2z mc eta  ra dec time phiorb incl psi  distance Npts lnL p ps neff  mtotal q chi_eff chi_p",end=' ')
  if opts.export_extra_spins:
      print( 'theta_jn phi_jl tilt1 tilt2 costilt1 costilt2 phi12 a1 a2 psiJ',end=' ')
  if opts.export_tides:
      print( "lambda1 lambda2 lam_tilde",end=' ')
  if opts.export_eos:
      print( "eos_indx",end=' ')
  if opts.export_cosmology:
      print( " m1_source m2_source mc_source mtotal_source redshift ",end=' ')
  if opts.export_weights:
      print( " weights ", )
  if opts.export_eccentricity:
      print( "eccentricity ",)
  else:
      print('')
  for fname in args:
    points = lsctables.SimInspiralTable.get_table(utils.load_filename(fname,contenthandler=ligolw.LIGOLWContentHandler), lsctables.SimInspiralTable.tableName)

    like = [row.alpha1 for row in points]  # hardcoded name
    p = [row.alpha2 for row in points]
    ps = [row.alpha3 for row in points]
    Nmax = np.max([int(row.simulation_id) for row in points])+1    # Nmax. Assumes NOT mixed samples.

    if opts.export_eccentricity:
        ecc = [row.alpha4 for row in points]
    
    wt = np.exp(like)*p/ps
    neff_here = np.sum(wt)/np.max(wt)  # neff for this file.  Assumes NOT mixed samples: dangerous


    for indx in np.arange(len(points)):
        pt = points[indx]
        if not(hasattr(pt,'spin1x')):  # no spins were provided. That means zero spin. Initialize to avoid an error
            pt.spin1x = pt.spin1y=pt.spin1z = 0
            pt.spin2x = pt.spin2y=pt.spin2z = 0
        
        # Compute derived quantities
        P=lalsimutils.ChooseWaveformParams()
        P.m1 =pt.mass1
        P.m2 =pt.mass2
        P.s1x = pt.spin1x
        P.s1y = pt.spin1y
        P.s1z = pt.spin1z
        P.s2x = pt.spin2x
        P.s2y = pt.spin2y
        P.s2z = pt.spin2z
        P.fmin=opts.fref
        if hasattr(pt, 'alpha'):
            P.eos_table_index = pt.alpha
        try:
            P.fmin = pt.f_lower  # should use this first
        except:
            True
        chieff_here =P.extract_param('xi')
        chip_here = P.extract_param('chi_p')
        mc_here = mc(pt.mass1,pt.mass2)
        eta_here = eta(pt.mass1,pt.mass2)
        mtot_here = pt.mass1 + pt.mass2
        
        print( pt.mass1, pt.mass2, pt.spin1x, pt.spin1y, pt.spin1z, pt.spin2x, pt.spin2y, pt.spin2z, mc_here, eta_here, \
            pt.longitude, \
            pt.latitude, \
            pt.geocent_end_time + 1e-9* pt.geocent_end_time_ns, \
            pt.coa_phase,  \
            pt.inclination, \
            pt.polarization, \
            pt.distance, \
            Nmax, like[indx], p[indx],ps[indx], neff_here, \
            mtot_here, pt.mass2/pt.mass1, \
            chieff_here,  \
            chip_here,end=' ')
        if opts.export_extra_spins:
            P.incl = pt.inclination  # need inclination to calculate theta_jn
            P.phiref=pt.coa_phase  # need coa_phase to calculate theta_jn ... this determines viewing angle
            thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, psiJ = P.extract_system_frame()
            print( thetaJN, phiJL, theta1, theta2, np.cos(theta1), np.cos(theta2), phi12, chi1, chi2, psiJ,end=' ')
        if opts.export_tides:
            P.lambda1 = pt.alpha5
            P.lambda2 = pt.alpha6
            lam_tilde = P.extract_param("LambdaTilde")
            print( pt.alpha5, pt.alpha6, lam_tilde,end=' ')
        if opts.export_eos:
            eos_indx = P.eos_table_index
            print(eos_indx, end=' ')
        if opts.export_cosmology:
#            z = astropy.cosmology( default_cosmology.luminosity_distance, pt.distance *u.Mpc)
            #z =astropy.cosmology.z_at_value(cosmo.luminosity_distance, pt.distance*u.Mpc)
            z = cosmo_func(pt.distance)
            m1_source = pt.mass1/(1+z)
            m2_source = pt.mass2/(1+z)
            print( m1_source, m2_source, mc_here/(1+z), mtot_here/(1+z), float(z), end=' ')
        if opts.export_weights:
            print(wt[indx],end=' ')
        if opts.export_eccentricity:
            print(ecc[indx])
        else:
            print('')


    sys.exit(0)

#

print( "# m1 m2 a1x a1y a1z a2x a2y a2z mc eta indx  Npts ra dec tref phiorb incl psi  dist p ps lnL mtotal q ",end=' ')
if opts.export_extra_spins:
    print( 'thetaJN phi_jl tilt1 tilt2 phi12 a1 a2 psiJ',end=' ')
if opts.export_tides:
    print( "lambda1 lambda2",end=' ')
if opts.export_eos:
      print( "eos_indx",end=' ')
if opts.export_cosmology:
    print( " m1_source m2_source redshift ",end=' ')
if opts.export_weights:
    print( " weights ",end=' ')
if opts.export_eccentricity:
    print( "eccentricity ",)
else:
    print('')
for fname in args:
 if ".hdf5" in fname:
     if opts.export_eos_index:
         raise Exception(" Not implemented for hdf5 export")
     # Load manually, to avoid problems with lnL, p, ps 
     f = h5py.File(fname, 'r')
     arr = f["waveform_parameters"]
     for indx in np.arange(len(arr)):
         line = arr[indx]
         print( line[0], line[1], line[2], line[3], line[4], line[5], line[6], line[7],  mc(line[0],line[1]), eta(line[0],line[0]), \
            0, \
            0, \
            line[12], \
            line[11], \
            line[13], \
            line[10],  \
            line[14], \
            line[9], \
            line[8], \
            line[-2],line[-1],line[-3], line[0]+line[1], line[1]/line[0],end=' ')
         if opts.export_extra_spins:
            P = lalsimutils.ChooseWaveformParams(m1=line[0],m2=line[1], s1x=line[2], s1y=line[3], s1z=line[4], s2x=line[5], s2y=line[6],s2z=line[7])
            P.fmin = line[-5]  # should use this first
            thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, psiJ = P.extract_system_frame()
            print( thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, psiJ,end=' ')
         if opts.export_tides:
            print( line[-7], line[-6],end=' ')
         print('')
     f.close()
 else:
    points = lsctables.SimInspiralTable.get_table(utils.load_filename(fname,contenthandler=ligolw.LIGOLWContentHandler))
    like = [row.alpha1 for row in points]  # hardcoded name
    p = [row.alpha2 for row in points]
    ps = [row.alpha3 for row in points]
    Nmax = np.max([int(row.simulation_id) for row in points])+1
    sim_id = np.array([int(row.simulation_id) for row in points])+1

    if opts.export_eccentricity:
        ecc = [row.alpha4 for row in points]
    
    wt = np.exp(like)*p/ps

    for indx in np.arange(len(points)):
        pt = points[indx]
        if not(hasattr(pt,'spin1x')):  # no spins were provided. That means zero spin. Initialize to avoid an error
            pt.spin1x = pt.spin1y=pt.spin1z = 0
            pt.spin2x = pt.spin2y=pt.spin2z = 0
        print( pt.mass1, pt.mass2, pt.spin1x, pt.spin1y, pt.spin1z, pt.spin2x, pt.spin2y, pt.spin2z, mc(pt.mass1,pt.mass2), eta(pt.mass1,pt.mass2), \
            sim_id[indx], \
            Nmax, \
            pt.longitude, \
            pt.latitude, \
            pt.geocent_end_time + 1e-9* pt.geocent_end_time_ns, \
            pt.coa_phase,  \
            pt.inclination, \
            pt.polarization, \
            pt.distance, \
            p[indx],ps[indx],like[indx], pt.mass1+pt.mass2, pt.mass2/pt.mass1,end=' ')
        if opts.export_extra_spins:
            P = lalsimutils.ChooseWaveformParams(m1=pt.mass1,m2=pt.mass2, s1x=pt.spin1x, s1y=pt.spin1y, s1z=pt.spin1z, s2x=pt.spin2x, s2y=pt.spin2y,s2z=pt.spin2z)
            P.fmin=opts.fref
            try:
                P.fmin = pt.f_lower  # should use this first
            except:
                True
            thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, psiJ = P.extract_system_frame()
            print( thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, psiJ,end=' ')
        if opts.export_tides:
            print(pt.alpha5, pt.alpha6,end=' ')
        if opts.export_eos:
            print(pt.alpha, end=' ')
        if opts.export_cosmology:
#            z = astropy.cosmology( default_cosmology.luminosity_distance, pt.distance *u.Mpc)
            z =cosmo_func(pt.distance)#astropy.cosmology.z_at_value(cosmo.luminosity_distance, pt.distance*u.Mpc)
            m1_source = pt.mass1/(1+z)
            m2_source = pt.mass2/(1+z)
            print( m1_source, m2_source, z,end=' ')
        if opts.export_weights:
            print(wt[indx],end=' ')
        if opts.export_eccentricity:
            print(ecc[indx],)
        else:
            print('')
#    print pt.geocent_end_time + 1e-9* pt.geocent_end_time_ns, pt.coa_phase,  pt.inclination, pt.polarization, pt.longitude,pt.latitude, pt.distance, ind like[indx]
