#!/usr/bin/env python

"""
combine outfrom from scrinet_gen_wf_data type files

e.g. first run waveform generation
./scrinet_gen_wf_data --grid random --npts 10 -v --n-cores 4 --output-dir train_wf_data_0
./scrinet_gen_wf_data --grid random --npts 10 -v --n-cores 4 --output-dir train_wf_data_1


then combine them to a single directory

./scrinet_combine_wf_data -v --output-dir train_wf_data --wf-dirs train_wf_data_0 train_wf_data_1

or with wildcards

./scrinet_combine_wf_data -v --output-dir train_wf_data --wf-dirs train_wf_data_*
"""

import numpy as np
import os
import argparse
import h5py
import glob

from scrinet.workflow.pipe_utils import init_logger, load_data

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-save", type=str, nargs='+',
                        default=['amp', 'phase'],
                        help="list of data to save.",
                        choices=['amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])

    parser.add_argument("--output-dir", type=str, default='train_wf_data',
                        help="directory to save data")

    parser.add_argument("--wf-dirs", type=str, nargs='+',
                        help="list of directories to combine. if one then will glob.")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info("OMP_NUM_THREADS: 1")

    if len(args.wf_dirs) == 1:
        if args.verbose:
            logger.info("length of --wf-dirs is 1. Using glob")
        args.wf_dirs = glob.glob(args.wf_dirs[0])

    if args.verbose:
        logger.info(f"wf dirs: {args.wf_dirs}")

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    if args.verbose:
        logger.info(f"data to save: {args.data_to_save}")

    x_l = []
    coords_l = []
    data_l = []
    data_dict_list = {}
    name0 = args.data_to_save[0]
    for wf_dir in args.wf_dirs:
        x, data, coords = load_data(name0, wf_dir)
        x_l.append(x)
        coords_l.append(coords)
        data_l.append(data)
    data_dict_list.update({name0: data_l})

    for name in args.data_to_save[1:]:
        data_l = []
        for wf_dir in args.wf_dirs:
            _, data, _ = load_data(name, wf_dir)
            data_l.append(data)
        data_dict_list.update({name: data_l})

    # amp_data_l = []
    # phase_data_l = []
    # freq_data_l = []
    # real_data_l = []
    # imag_data_l = []
    # for wf_dir in args.wf_dirs:
    #     x, data, coords = load_data("amp", wf_dir)
    #     x_l.append(x)
    #     amp_data_l.append(data)
    #     coords_l.append(coords)
    # for wf_dir in args.wf_dirs:
    #     _, data, _ = load_data("phase", wf_dir)
    #     phase_data_l.append(data)
    # # for wf_dir in args.wf_dirs:
    # #     _, data, _ = load_data("freq", wf_dir)
    # #     freq_data_l.append(data)

    # the x grid should be the same
    if args.verbose:
        logger.info("checking x grids are the same")
    for i in range(1, len(x_l)):
        np.testing.assert_array_equal(x_l[0], x_l[i])

    ts_x = x_l[0]

    # n_wfs = 0
    # for amp in amp_data_l:
    #     n, m = amp.shape
    #     n_wfs += n

    ts_arrays_dict = {}
    for name in args.data_to_save:
        ts_arrays_dict.update({name: np.row_stack(data_dict_list[name])})

    # ts_amp = np.row_stack(amp_data_l)
    # ts_phase = np.row_stack(phase_data_l)
    # ts_freq = np.row_stack(freq_data_l)

    # ndim = coords_l[0].shape[1]
    ts_coords = np.row_stack(coords_l)

    if args.verbose:
        logger.info("saving coords")
    filename = os.path.join(args.output_dir, 'coords.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=ts_coords)

    if args.verbose:
        logger.info("saving times")
    filename = os.path.join(args.output_dir, 'times.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("times", data=ts_x)

    for name in args.data_to_save:
        if args.verbose:
            logger.info(f"saving {name}")
        filename = os.path.join(args.output_dir, f'{name}.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_arrays_dict[name])

    # if args.verbose:
    #     logger.info("saving amp")
    # filename = os.path.join(args.output_dir, 'amp.h5')
    # with h5py.File(filename, "w") as f:
    #     f.create_dataset("data", data=ts_arrays_dict['amp'])

    # if args.verbose:
    #     logger.info("saving phase")
    # filename = os.path.join(args.output_dir, 'phase.h5')
    # with h5py.File(filename, "w") as f:
    #     f.create_dataset("data", data=ts_arrays_dict['phase'])

    # # if args.verbose:
    # #     logger.info("saving freq")
    # # filename = os.path.join(args.output_dir, 'freq.h5')
    # # with h5py.File(filename, "w") as f:
    # #     f.create_dataset("data", data=ts_freq)
