#!/usr/bin/env python3
import numpy as np
import re
import os
import sys
from scipy.interpolate import make_interp_spline
from astropy import units as u

import ldc.io.yml as ymlio
from ldc.lisa.projection import ProjectedStrain, from_file, to_file
from ldc.utils.logging import init_logger, close_logger
from ldc.common.series import TimeSeries, FrequencySeries

def get_trange(cfg):
    for k in ['t_min', 't_max', 'dt']:
        if isinstance(cfg[k], u.Quantity):
            cfg[k] = cfg[k].to(u.s).value
    return cfg["t_min"], cfg["t_max"], cfg["dt"]


if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--in', dest='hdf5', nargs="+", help= "Path to hdf5 sources file(s)")
    parser.add_argument('-c', '--config', required=True, help= "Path to configuration file")
    parser.add_argument('-o', '--out', default="./strain.hdf5", help= "Output strain")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")
    parser.add_argument('--fft-upsampling', action='store_true',
                        help="Upsampling in Fourier domain using 0-padding")
    
    args = parser.parse_args()
    logger = init_logger(args.log)		

    ### load instrument config : TODO use a better format
    cfg = ymlio.load_config(args.config)
    t_min, t_max, dt = get_trange(cfg)
    # add some margin to t_max
    t_max += 500
    t1 = np.arange(t_min, t_max, dt)
    order = cfg['interp_order']

    logger.info("Merging %s strains"%(len(args.hdf5)))
    
    merged = np.zeros((len(t1), 6))
    lst = []
    links = []
    for hdf5file in args.hdf5:
        logger.info(f"processing {hdf5file}")
        if os.path.getsize(hdf5file)==0:
            continue

        yArm, source_names, links, t_min_, t_max_, dt_ = from_file(hdf5file)
        if dt==dt_:
            merged = merged + yArm
        else:
            logger.info(f"Interpolating from {hdf5file}: original dt={dt_}, target dt={dt}")
            t2 = np.arange(t_min_, t_max_, dt_)
            if args.fft_upsampling:
                logger.info(f"Padding FFT with 0")
                for ilink in range(yArm.shape[1]):
                    y = TimeSeries(yArm[:,ilink], dt=dt_, t0=t_min)
                    yfd = y.ts.fft()
                    n = int(np.rint(len(t1)/2))
                    yfd2 = FrequencySeries(np.zeros((n), dtype=yfd.data.dtype), df=yfd.df)
                    yfd2.data[0:len(yfd)] = yfd.data
                    merged[:,ilink] += yfd2.ts.ifft(dt=dt).data[0:len(t1)]
            else:
                for ilink in range(yArm.shape[1]):
                    bs = make_interp_spline(t2, yArm[:,ilink], k=order)
                    merged[:,ilink] += bs(t1)
                
        lst += source_names


    to_file(args.out, merged, lst, links, t_min, t_max, dt)
    logger.info("Saved to disk")

