#!/usr/bin/env python3
import scipy.interpolate
import h5py
import numpy as np
from scipy import signal
from lisainstrument import Instrument
from lisainstrument.containers import ForEachMOSA

from ldc.utils.logging import init_logger, close_logger
from ldc.lisa.projection import from_file
import ldc.io.yml as ymlio

import pytdi
from pytdi.michelson import X2, Y2, Z2
from pytdi import Data

import matplotlib.pyplot as plt

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--strain', default="",
                        help= "GW strain input file")
    parser.add_argument('--pipe-config', default="",
                        help= "Pipeline config file")
    parser.add_argument('-o', '--output', default="./",
                        help= "Output path")
    parser.add_argument('-m', '--measurements', default="measurements.h5",
                        help= "Output path")
    parser.add_argument('--orbits', default="orbits.h5",
                        help= "Orbits file")
    parser.add_argument('--glitch', default=None, 
                        help= "Glitch file")
    parser.add_argument('--noisefree',  action='store_true',
                        help= "Disable all sources of noise")
    parser.add_argument('--lasernoise',  action='store_true',
                        help= "Include laser noise")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")

    args = parser.parse_args()
    logger = init_logger(args.log, name='lisainstrument.instrument')

    # Read data from the file with the gravitational waves
    yArm, snames, links, tmin, tmax, dt = from_file(args.strain, nodata=False, ilink=None)

    d = ymlio.load_config(args.pipe_config)
    
    tvec = np.arange(tmin, tmax, dt)

    # Need to interpolate GW signal to use upsampling
    dt_physic = d["dt_instrument"].to('s').value/d["physic_upsampling"]
    dt_instru = d["dt_instrument"].to('s').value
    dt_out = d["dt"].to('s').value
    
    gw = {"t": np.arange(tmin, tmax+dt_physic, dt_physic)}
    for j, n in enumerate(links):
        r,s = int(n[0]), int(n[-1])
        if dt != dt_physic:
            gw[f"{r}{s}"] = scipy.interpolate.InterpolatedUnivariateSpline(tvec, yArm[:,j], k=5)(gw["t"])
        else:
            gw[f"{r}{s}"] = yArm[:,j]

    if d["physic_upsampling"] == 1:
        aafilter = None
    else:
        aafilter = ('kaiser', d["sim_kaiser_attenuation"],
                    d["sim_kaiser_passband"], d["sim_kaiser_stopband"])

    N = int((tmax-tmin)/dt_instru)
    for k,v in gw.items():
        gw[k] = gw[k][0:N*d["physic_upsampling"]]
    central_freq = 2.816E14
    i = Instrument(physics_upsampling=d["physic_upsampling"], aafilter=aafilter,
                   size=N,   dt=dt_instru, t0=tmin, 
                   glitches=args.glitch,
                   gws=gw,
                   orbits=args.orbits,
                   central_freq=central_freq,
                   backlink_asds=d["backlinknoise"],
                   testmass_asds=d["accnoise"],
                   )
    i.oms_isc_carrier_asds = ForEachMOSA(d["readoutnoise"])
    
    if args.noisefree:
        if args.lasernoise:
            i.disable_all_noises(but='laser')
        else:
            i.disable_all_noises()
    else:
        i.disable_clock_noises()
        i.modulation_asds = ForEachMOSA(0)
        i.disable_ranging_noises()
        if not args.lasernoise:
            i.laser_asds = ForEachMOSA(0)
    i.disable_dopplers()
    i.simulate() # Run simulator

    ## TDI
    logger.info("Loading data into pytdi")
    data = Data.from_instrument(i)
    data.delay_derivative = None
    logger.info("Building X")
    built = X2.build(**data.args_nodoppler)
    X2_data = built(data.measurements)/central_freq

    
    logger.info("Building Y")
    built = Y2.build(**data.args_nodoppler)
    Y2_data = built(data.measurements)/central_freq

    logger.info("Building Z")
    built = Z2.build(**data.args_nodoppler)
    Z2_data = built(data.measurements)/central_freq

    # Write output
    logger.info("Write XYZ to disk")
    with h5py.File(args.output, 'w') as hdf5:
        hdf5.create_dataset('XYZ', data=np.vstack((i.t, X2_data, Y2_data, Z2_data)).T) 
        
    close_logger(logger)



# def elliptic_filtering(data, dt_i, logger):
#     """ Apply elliptic filtering.  
#     """

#     Fs = 1/dt_i # Sampling frequency in Hz 
#     fp = 0.1 # Pass band frequency in Hz 
#     fs = 0.15 # Stop band frequency in Hz 
#     Ap = 0.1 # Pass band ripple in dB 
#     As = 100 # Stop band attenuation in dB 
    
#     # Normalized passband edge frequencies w.r.t. Nyquist rate 
#     wp = fp/(Fs/2) 
#     ws = fs/(Fs/2) 
#     N, wc = signal.ellipord(wp, ws, Ap, As) 
#     logger.info(f'Order of the filter={N}') 
#     logger.info(f'Cut-off frequency={wc}') 
    
#     # Design digital elliptic bandpass filter using signal.ellip function 
#     z, p = signal.ellip(N, Ap, As, wc, 'lowpass') 
#     logger.info(f'Numerator Coefficients: {z}') 
#     logger.info(f'Denominator Coefficients: {p}') 

#     sos = signal.ellip(N, Ap, As, wc, 'lowpass', output='sos')

#     for k in ["X", "Y", "Z"]:
#         data[k] = signal.sosfilt(sos, data[k])

#     delay = 9.9
#     return data, delay
