#!/usr/bin/env python

"""
combine outfrom from scrinet_gen_wf_data type files

e.g. first run waveform generation
./scrinet_gen_wf_data --grid random --npts 10 -v --n-cores 4 --output-dir train_wf_data_0
./scrinet_gen_wf_data --grid random --npts 10 -v --n-cores 4 --output-dir train_wf_data_1


then combine them to a single directory

./scrinet_combine_wf_data -v --output-dir train_wf_data --wf-dirs train_wf_data_0 train_wf_data_1

or with wildcards

./scrinet_combine_wf_data -v --output-dir train_wf_data --wf-dirs train_wf_data_*
"""

import numpy as np
import os
import argparse
import h5py
import glob

from scrinet.workflow.pipe_utils import init_logger, load_data

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-save", type=str, nargs='+',
                        default=['amp', 'phase'],
                        help="list of data to save.",
                        choices=['amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])

    parser.add_argument("--output-dir", type=str, default='train_wf_data',
                        help="directory to save data")

    parser.add_argument("--wf-dirs", type=str, nargs='+',
                        help="list of directories to combine. if one then will glob.")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info("OMP_NUM_THREADS: 1")

    if len(args.wf_dirs) == 1:
        if args.verbose:
            logger.info("length of --wf-dirs is 1. Using glob")
        args.wf_dirs = glob.glob(args.wf_dirs[0])

    if args.verbose:
        logger.info(f"wf dirs: {args.wf_dirs}")

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    if args.verbose:
        logger.info(f"data to save: {args.data_to_save}")

    # load time and coords
    x_l = []
    coords_l = []
    name0 = args.data_to_save[0]
    for wf_dir in args.wf_dirs:
        x, coords = load_data(name0, wf_dir, return_data=False)
        x_l.append(x)
        coords_l.append(coords)

    # the x grid should be the same
    if args.verbose:
        logger.info("checking x grids are the same")
    for i in range(1, len(x_l)):
        np.testing.assert_array_equal(x_l[0], x_l[i])

    ts_x = x_l[0]

    ts_coords = np.row_stack(coords_l)

    del x_l
    del coords_l

    if args.verbose:
        logger.info("saving coords")
    filename = os.path.join(args.output_dir, 'coords.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=ts_coords)

    if args.verbose:
        logger.info("saving times")
    filename = os.path.join(args.output_dir, 'times.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("times", data=ts_x)

    del ts_x
    del ts_coords

    # load data one at a time because they can be very large

    for name in args.data_to_save:
        if args.verbose:
            logger.info(f"loading {name}")
        data_l = []
        for wf_dir in args.wf_dirs:
            _, data, _ = load_data(name, wf_dir)
            data_l.append(data)

        ts_arrays_dict = {}
        ts_arrays_dict.update({name: np.row_stack(data_l)})

        del data_l

        if args.verbose:
            logger.info(f"saving {name}")
        filename = os.path.join(args.output_dir, f'{name}.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_arrays_dict[name])

        del ts_arrays_dict

    if args.verbose:
        logger.info(f"done!")

