#!/usr/bin/env python3
import numpy as np
import h5py
import scipy.interpolate

from ldc.utils.logging import init_logger, close_logger
import ldc.io.yml as ymlio
import ldc.io.hdf5 as h5io

from ldc.common.series import FrequencySeries, TDI, TimeSeries

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--tdi', default="", help= "TDI input file")
    parser.add_argument('--galaxy', default=None, help= "TDI of galaxy to be included as background noise")
    parser.add_argument('--pipe-config', default="", help= "Pipeline config file")
    parser.add_argument('--config', default="", help= "Dataset config file")
    parser.add_argument('-o', '--output', default="./", help= "Output path")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")

    args = parser.parse_args()
    logger = init_logger(args.log)

    cfg_pipe = ymlio.load_config(args.pipe_config)
    cfg = ymlio.load_config(args.config)

    gaps_tmin, gaps_tmax =  cfg_pipe["gaps_delta_t"]
    gaps_tmin = int(gaps_tmin.to("s").value)
    gaps_tmax = int(gaps_tmax.to("s").value)
    gaps_tmean = (gaps_tmin+gaps_tmax)/2.
    gaps_duration = int(cfg_pipe["gaps_duration"].to("s").value)
    seed = cfg["gaps_seed"]
    tdi = TDI.load(args.tdi)
    #XYZ, attr = h5io.load_array(args.tdi)

    t_min = tdi.t[0]
    t_max = tdi.t[-1]
    dt = tdi.t[1]-tdi.t[0]

    if args.galaxy:
        AET, attr = h5io.load_array(args.galaxy, name="tdi")
        A = FrequencySeries(AET[0,:], df=attr['df'], kmin=0)
        E = FrequencySeries(AET[1,:], df=attr['df'], kmin=0)
        T = FrequencySeries(AET[2,:], df=attr['df'], kmin=0)
        dt_gb = 15
        galaxy = TDI(dict(zip(["A", "E", "T"], [A, E, T])))
        galaxy.AET2XYZ()
        for j,k in enumerate(["X", "Y", "Z"]):
            Ng = galaxy[k].ts.ifft(dt=dt_gb) # time domain X
            Ng = Ng.interp(t=tdi.t)
            #scipy.interpolate.InterpolatedUnivariateSpline(np.arange(0,len(Ng)*dt, dt), Ng, k=1)(XYZ[:,0])
            tdi[k].data += Ng # add background noise
            
    ngaps = int((t_max-t_min)/gaps_tmean)
    logger.info(f"Will add {ngaps} gaps")
    
    np.random.seed(seed) # setting the random seed

    tinj_list = np.random.random_integers(gaps_tmin, gaps_tmax, ngaps)
    
    logger.info(f"will inject gap after {tinj_list}")
    
    mask = np.zeros((len(tdi.t)), dtype=bool)
    istart = 0
    for tinj in tinj_list:
        istart += int(tinj/dt)
        istop = istart+int(gaps_duration/dt)
        mask[istart:istop] = True

    for j,k in enumerate(["X", "Y", "Z"]):
        tdi[k][mask] = np.nan # apply gaps
        
    # Write output
    logger.info("Write XYZ to disk")
    tdi.save(args.output)
        
    close_logger(logger)
