#! /usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Resampling of orbit files.
"""

import logging
import h5py
import numpy
import scipy.interpolate
import ldc.io.yml as ymlio
from ldc.utils.logging import init_logger, close_logger

def upsample_orbits(path, output, dt, size, t0=0, logger=None):
    """Upsample orbit file.

    A new orbit file is created from the resampling of the original orbit file.
    Note that the new sampling must be identical for TCB and TPS (for t and tau).

    An exception is raised if the new time vector lies beyond the original time vector.

    Args:
        path: path to original orbit file
        output: path to resampled orbit file
        dt: target sampling time (TCB and TPS) [s]
        size: target number of samples (TCB and TPS)
    """
    logger.info("Opening orbit file '%s'", path)
    oldf = h5py.File(path, 'r')

    original_dt = oldf.attrs['dt']
    original_dtau = oldf.attrs['dtau']
    if original_dt is None or original_dtau is None:
        raise ValueError("Cannot resample irregularly sampled orbit file")
    logger.info("Original orbit file sampled at dt=%s s, dtau=%s s", original_dt, original_dtau)

    logger.info("Computing new time vector (t0=%s, dt=%s, size=%s)", t0, dt, size)
    dt = float(dt)
    size = int(size)
    t0 = float(t0)
    new_t = t0 + numpy.arange(size) * dt

    logger.info("Creating resampled orbit file '%s'", output)
    newf = h5py.File(output, 'x')

    logger.info("Writing metadata")
    for k in oldf.attrs:
        newf.attrs[f'original_{k}'] = oldf.attrs[k]
    newf.attrs['generator'] = 'upsampleorbits'
    newf.attrs['dt'] = dt
    newf.attrs['dtau'] = dt
    newf.attrs['t0'] = t0
    newf.attrs['tau0'] = t0
    newf.attrs['tsize'] = size
    newf.attrs['tausize'] = size
    newf.attrs['tduration'] = size * dt
    newf.attrs['tauduration'] = size * dt
    newf.attrs['t'] = str(new_t)
    newf.attrs['tau'] = str(new_t)

    logger.info("Writing time vectors")
    newf['tcb/t'] = new_t
    newf['tps/tau'] = new_t

    def resample(dname, column):
        t_dname = dname.split('/', 1)[0]
        t_column = 't' if t_dname == 'tcb' else 'tau'
        newf[dname][column] = scipy.interpolate.InterpolatedUnivariateSpline(
            oldf[t_dname][t_column][:], oldf[dname][column], k=5, ext='raise')(new_t)

    for sc_index in ['1', '2', '3']:

        dname = f'tcb/sc_{sc_index}'
        logger.info("Resampling dataset '%s'", dname)
        columns = ['x', 'y', 'z', 'vx', 'vy', 'vz', 'tau', 'd_tau']
        dtype = numpy.dtype({'names': columns, 'formats': [numpy.float64] * len(columns)})
        newf.create_dataset(dname, (size,), dtype=dtype)
        for column in columns:
            resample(dname, column)

        dname = f'tps/sc_{sc_index}'
        logger.info("Resampling dataset '%s'", dname)
        columns = ['t']
        dtype = numpy.dtype({'names': columns, 'formats': [numpy.float64] * len(columns)})
        newf.create_dataset(dname, (size,), dtype=dtype)
        for column in columns:
            resample(dname, column)

    for link_index in ['12', '23', '31', '13', '32', '21']:

        dname = f'tcb/l_{link_index}'
        logger.info("Resampling dataset '%s'", dname)
        columns = ['tt', 'ppr', 'd_tt', 'd_ppr']#,'nx','ny']
        dtype = numpy.dtype({'names': columns, 'formats': [numpy.float64] * len(columns)})
        newf.create_dataset(dname, (size,), dtype=dtype)
        for column in columns:
            resample(dname, column)

        dname = f'tps/l_{link_index}'
        logger.info("Resampling dataset '%s'", dname)
        columns = ['ppr', 'd_ppr']
        dtype = numpy.dtype({'names': columns, 'formats': [numpy.float64] * len(columns)})
        newf.create_dataset(dname, (size,), dtype=dtype)
        for column in columns:
            resample(dname, column)

    logger.info("Closing original and resampled orbit files")
    newf.close()
    oldf.close()


if __name__ == '__main__':

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', '--output', default="",
                        help= "Output path")
    parser.add_argument('-i', '--input', dest='input_orbits', default="",
                        help= "Output path")
    parser.add_argument('-c', '--config', default="",
                        help= "Pipeline config file")
    parser.add_argument('-l', '--log', type=str, default="", help="Log file")
    args = parser.parse_args()
    logger = init_logger(args.log)

    d = ymlio.load_config(args.config)
    dt = d["dt_instrument"].to("s").value / d["physic_upsampling"]
    size = int( (d["t_max"].to("s").value - d["t_min"].to("s").value) / dt) + 1500
    upsample_orbits(args.input_orbits, args.output, dt, size, logger=logger)
