#!/usr/bin/env python3
from astropy import units as u
import ldc.io.hdf5 as hdf5io
import ldc.io.npz as npzio

# missing units in sky catalogs
# -> only non blind file.

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-file', type=str, default="",
                        help="Input hdf5 file without units")
    parser.add_argument('--output-file', type=str, default="",
                        help="Output hdf5 file with units")
    args = parser.parse_args()

    configs = ["header", "instru/config", "obs/config",
               "sky/vgb/config", "sky/dgb/config","sky/igb/config",
               "sky/mbhb/config"]
    units_obs = dict({'dt':'s', 'initial_position':'rad', 'initial_rotation':'rad',
                 'nominal_arm_length':'m', 't_max':'s', 't_min':'s'})

    u_gb = {'EclipticLatitude':         'rad',
            'EclipticLongitude':        'rad',
            'Amplitude':                '1',
            'Frequency':                'Hz',
            'FrequencyDerivative':      'Hz2',
            'Inclination':              'rad',
            'Polarization':             'rad',
            'InitialPhase':             'rad'}
    u_mbhb = {'EclipticLatitude':       'rad',
              'EclipticLongitude':      'rad',
              'PolarAngleOfSpin1':      'rad',
              'PolarAngleOfSpin2':      'rad',
              'Spin1':                  '1',
              'Spin2':                  '1',
              'Mass1':                  'Msun',
              'Mass2':                  'Msun',
              'CoalescenceTime':        's',
              'PhaseAtCoalescence':     'rad',
              'InitialPolarAngleL':     'rad',
              'InitialAzimuthalAngleL': 'rad',
              'Cadence':                's',
              'Redshift':               '1',
              'Distance':               'Gpc',
              'ObservationDuration':    's'}

    cats = ["sky/vgb/cat", "sky/dgb/cat", "sky/igb/cat", "sky/mbhb/cat"]
    units_cats = [u_gb, u_gb, u_gb, u_mbhb]
    
    
    tdi = ["obs/tdi"]
    
    for c in configs:
        cfg = hdf5io.load_config(args.input_file, name=c)
        if c=="obs/config":
            for k,v in cfg.items():
                if k in units_obs:
                    cfg[k] *= u.Unit(units_obs[k])
        for k,v in cfg.items():
            if v is None:
                cfg[k] = "None"
        hdf5io.save_config(args.output_file, cfg, name=c, mode='a')

    for c, units in zip(cats, units_cats):
        print(c)
        arr, attr = hdf5io.load_array(args.input_file, name=c)
        if c in ["sky/dgb/cat", "sky/igb/cat", "sky/vgb/cat"]:
            for k,v in units.items():
                attr[k] = v
        else:
            print(attr)
        hdf5io.save_array(args.output_file, arr, name=c, mode='a', **attr)

    for c in tdi:
        arr, attr = hdf5io.load_array(args.input_file, name=c)
        hdf5io.save_array(args.output_file, arr, name=c, mode='a', **attr)

        

