import numpy as np
import lisaglitch # import the LISAGlitch module
from lisaglitch.lpfdistribution import SampleLPFpar, Poisson
import ldc.io.yml as ymlio
from ldc.utils.logging import init_logger, close_logger


if __name__ == "__main__":
    
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--glitch', default="", help= "Glitch output file")
    parser.add_argument('--config', default="", help= "Pipeline config file")
    parser.add_argument('--glitch-config', default="", help= "Glitch config file")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")
    args = parser.parse_args()
    logger = init_logger(args.log, name='lisaglitch.glitch')
    
    cfg = ymlio.load_config(args.glitch_config) # source config file
    pipe_cfg = ymlio.load_config(args.config) # pipeline config file
    t0 = cfg["t_min"].to("s").value
    t_max = cfg["t_max"].to("s").value
    dt = pipe_cfg["dt_instrument"].to('s').value/pipe_cfg["physic_upsampling"] 
    size = cfg["t_max"].to("s").value / dt
    
    glitch_type = cfg["glitch_type"]
    if glitch_type == 'short':
        t_rise, t_fall = cfg["glitch_deltat"]
        inj_points = cfg["glitch_inj_point"]
        t_injs = cfg["glitch_t_inj"]
        t_levels = cfg["glitch_level"]
        for inj_point, t_inj, level in zip(inj_points, t_injs, t_levels):
            g = lisaglitch.OneSidedDoubleExpGlitch(inj_point=inj_point, t_inj=t_inj.to("s").value,
                                                   dt=dt, size=size, t0=t0, 
                                                   t_rise=t_rise, t_fall=t_fall, level=level)
            g.write(path=args.glitch)

    elif glitch_type == 'long':
        run, glitch = cfg["glitch_class"].split("/")
        path = cfg["glitch_lib"]
        inj_point = cfg["glitch_inj_point"]
        t_inj = cfg["glitch_t_inj"]
        g = lisaglitch.LPFLibraryModelGlitch(inj_point=inj_point, t_inj=t_inj.to("s").value,
                                             dt=dt, size=size, t0=t0, 
                                             path=path, run=int(run), glitch=int(glitch))
        g.write(path=args.glitch)
    
    elif glitch_type == 'poisson':

        path = cfg["glitch_training"]
        max_glitch = cfg["max_glitch"]
        seed = cfg["glitch_seed"]
        intervals = cfg["glitch_interval"]

        lpf_param = SampleLPFpar("cpu", cfg["min_beta"], cfg["max_beta"],
                                 cfg['min_amp'], cfg['max_amp'])
        lpf_param.define_model()
        lpf_param.load_model(path)

        # List of the injection points to sample from
        inj_points = ['tm_12', 'tm_23', 'tm_31', 'tm_13', 'tm_32', 'tm_21']

        poisson_sampler = Poisson()
        poisson_sampler.estimate_lambda(intervals)
        np.random.seed(seed) # setting the random seed
        timesarr = t0 + poisson_sampler.glitch_times_array(max_glitch)

        timesarr = timesarr[timesarr<size*dt]
        number_samples = len(timesarr)

        beta, amp = lpf_param.generate(2*number_samples)
        good = beta<cfg['beta_range'][1]
        good &= beta>cfg['beta_range'][0]
        good &= amp<cfg['amp_range'][1]
        good &= amp>cfg['amp_range'][0]
        logger.info(f"Before/After cut in ampl and beta: {2*number_samples}/{good.sum()}.")
        beta = beta[good][0:number_samples]
        amp = amp[good][0:number_samples]
        
        sign = np.random.random(size=number_samples) < 0.5  # Random choice of the glitch amplitude sign.
        amp[sign] *= -1
        
        for j in range(number_samples):

            g = lisaglitch.ShapeletGlitch(inj_point=np.random.choice(inj_points),
                                          t0=t0, size=size, dt=dt, t_inj=timesarr[j],
                                          beta=beta[j], amp=amp[j])
            g.write(path=args.glitch)
            if j%100==0:
                logger.info(f"wrote {j} over {number_samples}")
                #g.plot(tmin=t0, tmax=t_max)
        
    
