"""
Utility functions for command-line relevant operations
"""
# TODO: clean up so usage of these bits and pieces is clearer
import datetime as dt
import logging
import os
import os.path
import shutil
import sys
from time import gmtime, strftime

import numpy as np
import pymagicc
import scmdata.units
from pymagicc.io import MAGICCData
from scmdata import df_append

from . import __version__
from .io import load_scmdataframe
from .iris_cube_wrappers import (
    CMIP6Input4MIPsCube,
    CMIP6OutputCube,
    MarbleCMIP5Cube,
    SCMCube,
)

logger = logging.getLogger("netcdf_scm")

# TODO: use openscm_units
_ureg = scmdata.units.ScmUnitRegistry()
"""
Unit registry for miscellaneous unit checking
"""
_ureg.add_standards()

_CUBES = {
    "Scm": SCMCube,
    "MarbleCMIP5": MarbleCMIP5Cube,
    "CMIP6Input4MIPs": CMIP6Input4MIPsCube,
    "CMIP6Output": CMIP6OutputCube,
}

_MAGICC_VARIABLE_MAP = {"tas": ("Surface Temperature", "SURFACE_TEMP")}
"""Mapping from CMOR variable names to MAGICC variables"""


def _init_logging(logger, params, out_filename=None):
    """
    Set up the root logger

    All INFO messages and greater are written to stderr.
    If an ``out_filename`` is provided, all recorded log messages are also written to
    disk.

    # TODO: make level of logging customisable

    Parameters
    ----------
    logger
        Logger to initialise

    params : list
        A list of key values to write at the start of the log

    out_filename : str
        Name of the log file which is written to disk
    """
    handlers = []
    if out_filename:
        h = logging.FileHandler(out_filename, "a")
        h.setLevel(logging.DEBUG)
        handlers.append(h)

    # Write logs to stderr
    h_stderr = logging.StreamHandler(sys.stderr)
    h_stderr.setLevel(logging.INFO)
    handlers.append(h_stderr)

    fmt = logging.Formatter(
        "{process} {asctime} {levelname}:{name}:{message}", style="{"
    )
    for h in handlers:
        if h.formatter is None:
            h.setFormatter(fmt)
        logger.addHandler(h)

    # use DEBUG as default for now
    logger.setLevel(logging.DEBUG)

    logging.captureWarnings(True)
    logger.info("netcdf-scm: %s", __version__)
    for k, v in params:
        logger.info("%s: %s", k, v)


def _get_timestamp():
    return strftime("%Y%m%d %H%M%S", gmtime())


def _make_path_if_not_exists(path_to_check):
    if not os.path.exists(path_to_check):
        logger.info("Making output directory: %s", path_to_check)
        os.makedirs(path_to_check)


def _get_scmcube_helper(drs):
    if drs == "None":
        raise NotImplementedError(
            "`drs` == 'None' is not supported yet. Please raise an issue at "
            "gitlab.com/znicholls/netcdf-scm/ with your use case if you need this "
            "feature."
        )

    return _CUBES[drs]()


def _load_scm_cube(drs, dirpath, filenames):
    scmcube = _get_scmcube_helper(drs)
    if len(filenames) == 1:
        scmcube.load_data_from_path(os.path.join(dirpath, filenames[0]))
    else:
        scmcube.load_data_in_directory(dirpath)

    return scmcube


def _set_crunch_contact_in_results(res, crunch_contact):
    for _, c in res.items():
        if "crunch_contact" in c.cube.attributes:
            logger.warning(
                "Overwriting `crunch_contact` attribute"
            )  # pragma: no cover # emergency valve
        c.cube.attributes["crunch_contact"] = crunch_contact

    return res


def _get_outfile_dir_flat_dir(dpath, drs, dst):
    scmcube = _get_scmcube_helper(drs)
    outfile_dir = dpath.replace(scmcube.process_path(dpath)["root_dir"], dst)
    _make_path_if_not_exists(outfile_dir)
    duplicate_dir = os.path.join(dst, "flat")
    _make_path_if_not_exists(duplicate_dir)

    return outfile_dir, duplicate_dir


def _convert_units(openscmdf, target_units_specs):
    for variable in openscmdf["variable"].unique():
        if variable in target_units_specs["variable"].tolist():
            target_unit = target_units_specs[
                target_units_specs["variable"] == variable
            ]["unit"].values[0]
            current_unit = openscmdf.filter(variable=variable)["unit"].values[0]

            logger.info(
                "Converting units of %s from %s to %s",
                variable,
                current_unit,
                target_unit,
            )

            target_length = _ureg(target_unit).dimensionality["[length]"]
            current_length = _ureg(current_unit).dimensionality["[length]"]
            if np.equal(target_length - current_length, 2):
                openscmdf = _take_area_sum(openscmdf, current_unit)

            openscmdf = openscmdf.convert_unit(target_unit, variable=variable)

    return openscmdf


def _find_dirs_meeting_func(src, check_func):
    matching_dirs = []
    failures = False
    logger.info("Finding directories with files")
    for dirpath, _, filenames in os.walk(src):
        logger.debug("Entering %s", dirpath)
        if filenames:
            try:
                if check_func(dirpath):
                    matching_dirs.append((dirpath, filenames))
            except Exception as e:  # pylint:disable=broad-except
                logger.error(
                    "Directory checking failed on %s with error %s", dirpath, e
                )
                failures = True

    logger.info("Found %s directories with files", len(matching_dirs))
    return matching_dirs, failures


def _skip_file(out_file, force, duplicate_dir):
    if not force and os.path.isfile(out_file):
        logger.info("Skipped (already exists, not overwriting) %s", out_file)
        return True

    if os.path.isfile(out_file):
        os.remove(out_file)

    duplicate_file = os.path.join(duplicate_dir, os.path.basename(out_file))
    if os.path.isfile(duplicate_file):
        os.remove(duplicate_file)

    return False


def _get_path_bits(inpath, drs):
    helper = _get_scmcube_helper(drs)
    return helper.process_path(os.path.dirname(inpath))


def _get_id_in_path(path_id, fullpath, drs):
    # TODO: change to _get_path_bits(fullpath, drs)[path_id]
    helper = _get_scmcube_helper(drs)
    return helper.process_path(os.path.dirname(fullpath))[path_id]


def _get_timestamp_str(fullpath, drs):
    # TODO: rename to _get_timestamp_str_from_path
    # TODO: change to _get_id_in_path("timestamp_str", fullpath, drs)
    helper = _get_scmcube_helper(drs)
    filename_bits = helper._get_timestamp_bits_from_filename(  # pylint:disable=protected-access
        os.path.basename(fullpath)
    )
    return filename_bits["timestamp_str"]


def _get_meta(inscmdf, meta_col, expected_unique=True):
    # TODO: remove and instead use scmdataframe.get_unique_meta(raise=True) or whatever it is
    vals = inscmdf[meta_col].unique()
    if expected_unique:
        if len(vals) != 1:
            raise AssertionError("{} is not unique: {}".format(meta_col, vals))
        return vals[0]

    return vals


def _write_ascii_file(  # pylint:disable=too-many-arguments
    openscmdf,
    metadata,
    header,
    outfile_dir,
    duplicate_dir,
    fnames,
    force,
    out_format,
    drs,
    prefix=None,
):
    if out_format in ("mag-files",):
        _write_mag_file(
            openscmdf,
            metadata,
            header,
            outfile_dir,
            duplicate_dir,
            fnames,
            force,
            prefix,
            drs,
        )
    elif out_format in (
        "mag-files-average-year-start-year",
        "mag-files-average-year-mid-year",
        "mag-files-average-year-end-year",
        "mag-files-point-start-year",
        "mag-files-point-mid-year",
        "mag-files-point-end-year",
    ):
        _write_mag_file_with_operation(
            openscmdf,
            metadata,
            header,
            outfile_dir,
            duplicate_dir,
            fnames,
            force,
            out_format,
            prefix,
            drs,
        )
    elif out_format in ("magicc-input-files",):
        _write_magicc_input_file(
            openscmdf,
            metadata,
            header,
            outfile_dir,
            duplicate_dir,
            fnames,
            force,
            prefix,
        )
    elif out_format in (
        "magicc-input-files-average-year-start-year",
        "magicc-input-files-average-year-mid-year",
        "magicc-input-files-average-year-end-year",
        "magicc-input-files-point-start-year",
        "magicc-input-files-point-mid-year",
        "magicc-input-files-point-end-year",
    ):
        _write_magicc_input_file_with_operation(
            openscmdf,
            metadata,
            header,
            outfile_dir,
            duplicate_dir,
            fnames,
            force,
            out_format,
            prefix,
        )
    else:
        raise AssertionError("how did we get here?")  # pragma: no cover


def _write_mag_file(  # pylint:disable=too-many-arguments,too-many-locals
    openscmdf, metadata, header, outfile_dir, duplicate_dir, fnames, force, prefix, drs
):
    ts = openscmdf.timeseries()

    src_time_points = ts.columns

    time_id = "{:04d}{:02d}-{:04d}{:02d}".format(
        src_time_points[0].year,
        src_time_points[0].month,
        src_time_points[-1].year,
        src_time_points[-1].month,
    )
    old_time_id = _get_timestamp_str(fnames[0], drs)

    out_file_base = fnames[0].replace(old_time_id, time_id)
    if prefix is not None:
        out_file_base = "{}_{}".format(prefix, out_file_base)

    out_file = os.path.join(outfile_dir, out_file_base)
    out_file = "{}.MAG".format(os.path.splitext(out_file)[0])

    if _skip_file(out_file, force, duplicate_dir):
        return

    writer = MAGICCData(openscmdf)
    writer["todo"] = "SET"
    _check_timesteps_are_monthly(writer)

    writer.metadata = metadata
    writer.metadata["timeseriestype"] = "MONTHLY"
    writer.metadata["header"] = header

    logger.info("Writing file to %s", out_file)
    writer.write(out_file, magicc_version=7)

    duplicate_file = os.path.join(duplicate_dir, os.path.basename(out_file))
    logger.info("Duplicating file as %s", duplicate_file)
    shutil.copyfile(out_file, duplicate_file)


def _check_timesteps_are_monthly(scmdf):
    time_steps = scmdf["time"][1:].values - scmdf["time"][:-1].values
    step_upper = np.timedelta64(32, "D")  # pylint:disable=too-many-function-args
    step_lower = np.timedelta64(28, "D")  # pylint:disable=too-many-function-args
    if any((time_steps > step_upper) | (time_steps < step_lower)):
        raise ValueError(
            "Please raise an issue at "
            "gitlab.com/znicholls/netcdf-scm/issues "
            "to discuss how to handle non-monthly data"
        )


def _write_mag_file_with_operation(  # pylint:disable=too-many-arguments
    openscmdf,
    metadata,
    header,
    outfile_dir,
    duplicate_dir,
    fnames,
    force,
    out_format,
    prefix,
    drs,
):  # pylint:disable=too-many-locals
    if len(fnames) > 1:
        raise AssertionError(
            "more than one file to wrangle?"
        )  # pragma: no cover # emergency valve

    ts = openscmdf.timeseries()

    src_time_points = ts.columns
    original_years = ts.columns.map(lambda x: x.year).unique()

    time_id = "{:04d}-{:04d}".format(src_time_points[0].year, src_time_points[-1].year)
    old_time_id = _get_timestamp_str(fnames[0], drs)

    out_file_base = fnames[0].replace(old_time_id, time_id)
    if prefix is not None:
        out_file_base = "{}_{}".format(prefix, out_file_base)

    out_file = os.path.join(outfile_dir, out_file_base)
    out_file = "{}.MAG".format(os.path.splitext(out_file)[0])

    if _skip_file(out_file, force, duplicate_dir):
        return

    writer = MAGICCData(_do_timeseriestype_operation(openscmdf, out_format)).filter(
        year=original_years
    )

    writer["todo"] = "SET"
    writer.metadata = metadata
    writer.metadata["timeseriestype"] = (
        out_format.replace("mag-files-", "").replace("-", "_").upper()
    )

    writer.metadata["header"] = header

    logger.info("Writing file to %s", out_file)
    writer.write(out_file, magicc_version=7)

    duplicate_file = os.path.join(duplicate_dir, os.path.basename(out_file))
    logger.info("Duplicating file as %s", duplicate_file)
    shutil.copyfile(out_file, duplicate_file)


def _do_timeseriestype_operation(openscmdf, out_format):
    if out_format.endswith("average-year-start-year"):
        out = openscmdf.time_mean("AS")

    elif out_format.endswith("average-year-mid-year"):
        out = openscmdf.time_mean("AC")

    elif out_format.endswith("average-year-end-year"):
        out = openscmdf.time_mean("A")

    elif out_format.endswith("point-start-year"):
        out = openscmdf.resample("AS")

    elif out_format.endswith("point-mid-year"):
        out_time_points = [
            dt.datetime(y, 7, 1)
            for y in range(
                openscmdf["time"].min().year, openscmdf["time"].max().year + 1
            )
        ]
        out = openscmdf.interpolate(target_times=out_time_points)

    elif out_format.endswith("point-end-year"):
        out = openscmdf.resample("A")

    else:  # pragma: no cover # emergency valve
        raise NotImplementedError("Do not recognise out_format: {}".format(out_format))

    if out.timeseries().shape[1] == 1:
        error_msg = "We cannot yet write `{}` if the output data will have only one timestep".format(
            out_format
        )
        raise ValueError(error_msg)

    return out


def _write_magicc_input_file(  # pylint:disable=too-many-arguments
    openscmdf, metadata, header, outfile_dir, duplicate_dir, fnames, force, prefix
):
    if len(fnames) > 1:
        raise AssertionError(
            "more than one file to wrangle?"
        )  # pragma: no cover # emergency valve

    _write_magicc_input_files(
        openscmdf,
        outfile_dir,
        duplicate_dir,
        force,
        metadata,
        header,
        "MONTHLY",
        prefix,
    )


def _write_magicc_input_file_with_operation(  # pylint:disable=too-many-arguments
    openscmdf,
    metadata,
    header,
    outfile_dir,
    duplicate_dir,
    fnames,
    force,
    out_format,
    prefix,
):
    if len(fnames) > 1:
        raise AssertionError(
            "more than one file to wrangle?"
        )  # pragma: no cover # emergency valve

    ts = openscmdf.timeseries()

    original_years = ts.columns.map(lambda x: x.year).unique()

    openscmdf = _do_timeseriestype_operation(openscmdf, out_format).filter(
        year=original_years
    )

    _write_magicc_input_files(
        openscmdf,
        outfile_dir,
        duplicate_dir,
        force,
        metadata,
        header,
        out_format.replace("magicc-input-files-", "").replace("-", "_").upper(),
        prefix,
    )


def _write_magicc_input_files(  # pylint:disable=too-many-arguments,too-many-locals
    openscmdf,
    outfile_dir,
    duplicate_dir,
    force,
    metadata,
    header,
    timeseriestype,
    prefix,
):
    try:
        var_to_write = openscmdf["variable"].unique()[0]
        variable_abbreviations = {
            "filename": var_to_write,
            "magicc_name": _MAGICC_VARIABLE_MAP[var_to_write][0],
            "magicc_internal_name": _MAGICC_VARIABLE_MAP[var_to_write][1],
        }
    except KeyError:
        raise KeyError(
            "I don't know which MAGICC variable to use for input `{}`".format(
                var_to_write
            )
        )

    region_filters = {
        "FOURBOX": [
            "World|Northern Hemisphere|Land",
            "World|Southern Hemisphere|Land",
            "World|Northern Hemisphere|Ocean",
            "World|Southern Hemisphere|Ocean",
        ],
        "GLOBAL": ["World"],
    }
    for region_key, regions_to_keep in region_filters.items():
        out_file_base = (
            ("{}_{}_{}_{}_{}_{}.IN")
            .format(
                variable_abbreviations["filename"],
                openscmdf["scenario"].unique()[0],
                openscmdf["climate_model"].unique()[0],
                openscmdf["member_id"].unique()[0],
                region_key,
                variable_abbreviations["magicc_internal_name"],
            )
            .upper()
        )
        if prefix is not None:
            out_file_base = "{}_{}".format(prefix, out_file_base)

        out_file = os.path.join(outfile_dir, out_file_base,)
        duplicate_file = os.path.join(duplicate_dir, os.path.basename(out_file))

        if _skip_file(out_file, force, duplicate_dir):
            return

        writer = MAGICCData(openscmdf).filter(region=regions_to_keep)
        writer["todo"] = "SET"
        writer["variable"] = variable_abbreviations["magicc_name"]
        writer.metadata = metadata
        writer.metadata["header"] = header
        writer.metadata["timeseriestype"] = timeseriestype

        logger.info("Writing file to %s", out_file)
        writer.write(out_file, magicc_version=7)
        logger.info("Duplicating file as %s", duplicate_file)
        shutil.copyfile(out_file, duplicate_file)


def _get_openscmdf_metadata_header(
    fnames, dpath, target_units_specs, wrangle_contact, out_format
):
    if len(fnames) > 1:
        raise AssertionError(
            "more than one file to wrangle?"
        )  # pragma: no cover # emergency valve

    openscmdf = df_append([load_scmdataframe(os.path.join(dpath, f)) for f in fnames])
    if openscmdf.timeseries().shape[1] == 1:
        error_msg = "We cannot yet write `{}` if the output data has only one timestep".format(
            out_format
        )
        raise ValueError(error_msg)

    if target_units_specs is not None:
        openscmdf = _convert_units(openscmdf, target_units_specs)

    metadata = openscmdf.metadata
    header = _get_openscmdf_header(
        wrangle_contact, metadata["crunch_netcdf_scm_version"]
    )

    return openscmdf, metadata, header


def _get_openscmdf_header(contact, netcdf_scm_version):
    header = (
        "Date: {}\n"
        "Contact: {}\n"
        "Source data crunched with: NetCDF-SCM v{}\n"
        "File written with: pymagicc v{} (more info at "
        "github.com/openclimatedata/pymagicc)\n".format(
            dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            contact,
            netcdf_scm_version,
            pymagicc.__version__,
        )
    )

    return header


def _take_area_sum(openscmdf, current_unit):
    converted_ts = []

    for region, df in openscmdf.timeseries().groupby("region"):
        rkey = SCMCube._convert_region_to_area_key(  # pylint:disable=protected-access
            region
        )
        for k, v in openscmdf.metadata.items():
            if "{} (".format(rkey) in k:
                unit = k.split("(")[-1].split(")")[0]
                conv_factor = v * _ureg(unit)

                converted_region = df * v
                converted_region = converted_region.reset_index()
                converted_region["unit"] = str(
                    (1 * _ureg(current_unit) * conv_factor).units
                )
                converted_ts.append(converted_region)
                break

    converted_ts = df_append(converted_ts)
    converted_ts.metadata = openscmdf.metadata
    return converted_ts
