#!/usr/bin/env python3
import numpy as np
from scipy import signal
import ldc.io.yml as ymlio
from ldc.utils.logging import init_logger, close_logger
from ldc.common.series import TimeSeries, TDI
import h5py
import scipy

def kaiser_filtering(tdi, d, logger):
    """ Apply Kaiser filtering
    """
    dt_i = d["dt_instrument"].to('s').value
    attenuation = d["tdi_kaiser_attenuation"]
    freq1 = d['tdi_kaiser_passband']
    freq2 = d['tdi_kaiser_stopband']
    cutoff = d['tdi_kaiser_cutoff']
    df = 1/dt_i
    numtaps, beta = signal.kaiserord(attenuation, (freq2 - freq1) / (df/2.))
    filter_coeff = signal.firwin(numtaps, d['tdi_kaiser_cutoff'], window=('kaiser', beta))


    for k in ["X", "Y", "Z"]:
        # tdi[k].data = signal.lfilter(filter_coeff, 1, tdi[k])
        tdi[k].data = signal.filtfilt(filter_coeff, 1, tdi[k])

    return d["tdi_kaiser_delay"]

def read_tdi(filename, time_origin):
    data = dict()
    with h5py.File(args.hdf5, "r") as f:
        if "X" in f.keys(): # from lisanode
            for k in ["X", "Y", "Z"]:
                data[k] = TimeSeries(f[k][:,1], t0=f["X"][0,0]+time_origin, dt=f["X"][1,0]-f["X"][0,0])
        else: # from lisa instrument
            for i, k in enumerate(["X", "Y", "Z"]):
                data[k] = TimeSeries(f["XYZ"][:,i+1], t0=f["XYZ"][0,0], dt=f["XYZ"][1,0]-f["XYZ"][0,0])
    return TDI(data)
    
            
if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--in', dest='hdf5', help= "Path to hdf5 TDI file")
    parser.add_argument('-c', '--config', required=True, help= "Path to configuration file")
    parser.add_argument('--source-config', help= "Path to additional configuration file")
    parser.add_argument('-o', '--out', default="./strain.hdf5", help= "Output TDI")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")
    parser.add_argument('--noisefree',  action='store_true', help= "Disable all sources of noise")
    parser.add_argument('--lasernoise',  action='store_true', help= "Include laser noise")
    parser.add_argument('--time-origin',  type=float, default=0,
                        help= "Shift time vector (needed to postprocess lisanode output)")
    parser.add_argument('--dt',  type=float, default=-1,
                        help= "Target dt")

    args = parser.parse_args()
    logger = init_logger(args.log)		

    cfg = ymlio.load_config(args.config)
    order = cfg["interp_order"]
    aafilter_deltat = cfg["sim_kaiser_delay"]

    tdi = read_tdi(args.hdf5, args.time_origin)
    
    ## Apply delays and filters
    time_delay = 0
    if cfg["physic_upsampling"] != 1:
        time_delay += aafilter_deltat
    if args.lasernoise or not args.noisefree:
        delay = kaiser_filtering(tdi, cfg, logger) 
        time_delay += delay

    cfg2 = cfg
    if args.source_config:
        cfg2 = ymlio.load_config(args.source_config)
        
    tmin = cfg2["t_min"].to('s').value #tdi.t[0] 
    tmax = cfg2["t_max"].to('s').value #tdi.t[-1]
    dt_out = cfg2["dt"].to('s').value if args.dt<0 else args.dt
    logger.info(f"using target trange {tmin}, {tmax}, {dt_out}")
    tvec1 = np.arange(tmin, tmax, dt_out)
    tdi = tdi.assign_coords(t=(tdi.t-time_delay))
    tdi_downsampled = {}
    for i,k in enumerate(["X", "Y", "Z"]):
        tdi_downsampled[k] = tdi[k].interp(t=tvec1)
        tdi_downsampled[k].attrs["dt"] = dt_out
        tdi_downsampled[k].attrs["t0"] = tvec1[0]
        #data[k] = scipy.interpolate.InterpolatedUnivariateSpline(tvec0-time_delay, data[k], k=1)(tvec1)
        
    tdi = TDI(tdi_downsampled)
    tdi.save(args.out)    
    #with h5py.File(args.out, 'w') as hdf5:
    #    hdf5.create_dataset('XYZ', data=np.vstack((tvec1, data["X"], data["Y"], data["Z"])).T) 
        
