import argparse
import logging
import pdb
import sys
from pathlib import Path

import pandas as pd
import pyadcirc.io as pyio
import xarray as xa

# Initialize Logging
logger = logging.getLogger("adcirc_post_process")

def merge_output(
    output_dir: str,
    stations: bool = True,
    globs: bool = False,
    minmax: bool = True,
    nodals: bool = True,
    partmesh: bool = True,
):
    """Merge ADCIRC output"""
    ds = xa.Dataset()

    if stations:
        station_idxs = [61, 62, 71, 72, 91]
        station_files = [f"{output_dir}/fort.{x}.nc" for x in station_idxs]
        for i, sf in enumerate(station_files):
            logger.info(f"Reading station data {sf}")
            if i != 0:
                station_data = xa.open_dataset(sf)
                ds = xa.merge([ds, station_data], compat="override")
            else:
                ds = xa.open_dataset(sf)

        d_vars = list(ds.data_vars.keys())
        new_names = [(x, f"{x}-station") for x in d_vars if x != "station_name"]
        ds = ds.rename(dict(new_names))

    if globs:
        glob_idxs = [63, 64, 73, 74, 93]
        global_files = [f"{output_dir}/fort.{x}.nc" for x in glob_idxs]
        for gf in global_files:
            logger.info(f"Reading global data {gf}")
            global_data = xa.open_dataset(gf)
            ds = xa.merge([ds, global_data])

    if minmax:
        minmax = ["maxele", "maxvel", "maxwvel", "minpr"]
        minmax_files = [f"{output_dir}/{x}.63.nc" for x in minmax]
        for mf in minmax_files:
            logger.info(f"Reading min/max data {mf}")
            minmax_data = xa.open_dataset(mf)
            minmax_data = minmax_data.drop("time")
            ds = xa.merge([ds, minmax_data])

    if nodals:
        # Load f13 nodal attribute data
        ds = pyio.read_fort13(f"{output_dir}/fort.13", ds)

    if partmesh:
        # Load partition mesh data
        ds["partition"] = (
            ["node"],
            pd.read_csv(f"{output_dir}/partmesh.txt", header=None)
            .to_numpy()
            .reshape(-1),
        )

    return ds


if __name__ == "__main__":
    # Parse command line options
    parser = argparse.ArgumentParser()
    parser.add_argument("output_dir", type=str)
    parser.add_argument("output_file", type=str)
    parser.add_argument(
        "--stations", action=argparse.BooleanOptionalAction, default=True
    )
    parser.add_argument("--globs", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--minmax", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--nodals", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument(
        "--partmesh", action=argparse.BooleanOptionalAction, default=True
    )
    args = parser.parse_args()

    logformat = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
    logging.basicConfig(
        level=logging.DEBUG,
        stream=sys.stdout,
        format=logformat,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    logger.info(f"Starting to merge output at {args.output_dir}")
    output = merge_output(
        args.output_dir,
        args.stations,
        args.globs,
        args.minmax,
        args.nodals,
        args.partmesh,
    )
    logger.info(f"Done merging output")
    logger.info(f"Writing output netcdf file")
    output.to_netcdf(args.output_file)
    logger.info(f"Done writing output fileWriting output netcdf file")
