#!/usr/bin/env python3
import numpy as np
import numpy.lib.recfunctions as recf
import re
import logging
import os
import sys
import yaml
from astropy import units as u
from pathlib import Path

from ldc.lisa.orbits import Orbits
from ldc.waveform.waveform.hphc import HpHc
from ldc.lisa.projection import ProjectedStrain, to_file
import ldc.io.yml as ymlio
import ldc.io.hdf5 as h5io
from ldc.utils.logging import init_logger, close_logger


def get_selection(select, nsource):
    if not select:
        istart = 0
        iend = None
    else:
        idx, nbatch = get_range(select)
        if nsource<=1 and idx!=0:
            istart = 0
            iend = -1 # [] empty list
        elif nsource<=1 and idx==0:
            istart = 0
            iend = None
        elif nsource<nbatch:
            ns_per_batch = 1
            istart = idx if idx<nsource else 0
            iend = idx+1 if idx<nsource else 0
        else:
            ns_per_batch = nsource//nbatch
            left = nsource%nbatch
            istart = ns_per_batch*idx
            iend = ns_per_batch*(idx+1) if idx+1!=nbatch else ns_per_batch*(idx+1)+left
    return istart, iend

def get_range(select):
    x = re.search("([0-9]*):([0-9]*)", select) 
    istart, iend = x.groups()
    return int(istart), int(iend)

def get_trange(cfg):
    for k in ['t_min', 't_max', 'dt']:
        if isinstance(cfg[k], u.Quantity):
            cfg[k] = cfg[k].to(u.s).value
    return cfg["t_min"], cfg["t_max"], cfg["dt"]

            
if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--in', dest='cat', required=True,
                        help= "Path to hdf5/npy sources file")
    parser.add_argument('-c', '--config', required=True,
                        help= "Path to configuration file")
    parser.add_argument('--source-config', type=str, default="",
                        help= "Path to source configuration file")
    parser.add_argument('--approximant', type=str, default="",
                        help= "Approximant to be used, if not specified in the input catalog")
    parser.add_argument('-o', '--out', default="./strain.hdf5", help= "Output strain")
    parser.add_argument('-s', '--select', type=str,
                        help="Source selection isplit:nsplit (like 0:2 to project the first half)")
    parser.add_argument('--checkpoint', type=int, default=None, 
                        help="Save projection on disk every N source")
    parser.add_argument('-l', '--log', type=str, default="", 
                        help="Log file")
    parser.add_argument('--orbits', type=str, default="",  help="Orbits file")

    args = parser.parse_args()
    logger = init_logger(args.log)
    
    cfg = ymlio.load_config(args.config)
    t_min = None
    
    if args.source_config:
        source_cfg = ymlio.load_config(args.source_config) #parse_config(args.source_config)
        nsource = source_cfg["nsource"]
        if 't_min' in source_cfg.keys():
            t_min, t_max, dt = get_trange(source_cfg)
    else:
        cat = np.load(args.cat)
        nsource = len(cat)
    if t_min is None: 
        t_min, t_max, dt = get_trange(cfg)

    # add some margin to t_max
    t_max += 500
        
    # initialize the orbits
    if args.orbits:
        cfg["orbit_type"] = 'file'
        cfg["filename"] = args.orbits

    orbits = Orbits.type(cfg)
    nArms = orbits.number_of_arms

    istart, iend = get_selection(args.select, nsource)
    if (istart==0 and iend==0) or nsource==0:
        logger.info("Skipping %s (Less sources than jobs)"%(args.select))
        Path(args.out).touch()
        sys.exit(0) # Successful exit

    cat, units = h5io.load_array(args.cat, sl=slice(istart, iend))
    if cat.size==1:
        cat = np.array([cat])
    name = os.path.basename(args.cat)
  
    logger.info("selec %s start %s end %s"%(args.select, str(istart), str(iend)))
    logger.info("will process %d sources"%len(cat))

    # Split in indices
    if args.checkpoint is None or args.checkpoint > len(cat):
        indices = [(0,None)]
    else:
        indices = [(i,j) for i,j in zip(range(0,len(cat),args.checkpoint),
                                        list(range(args.checkpoint, len(cat),
                                              args.checkpoint))+[len(cat)])]
        
    # Initialize the projection on arm
    Proj = ProjectedStrain(orbits)
    nt = len(np.arange(t_min, t_max, dt))#int((t_max-t_min)//dt)

    yArm = np.zeros((nt, nArms))
    #source_names = ""
    source_names = name
    
    for i,j in indices:

        GWs = HpHc.type(name, source_cfg["source_type"], source_cfg["approximant"])
        GWs.set_param(cat[i:j], units=units)
        GWs = GWs.split()
        
        # Compute GW effect on links
        yArm += Proj.arm_response(t_min, t_max, dt, GWs, tt_order=cfg["travel_time_order"])
        logger.info("Projection %s:%s done"%(str(i),str(j)))

        #str_name = "-".join([Proj.source_names[0], Proj.source_names[-1]])
        #source_names += ":"+str_name
        #logger.info("source list processed: %s"%source_names)
        
        to_file(args.out, yArm, [source_names], Proj.links, t_min, t_max, dt)
        logger.info("Saved to disk")
        
    close_logger(logger)
