from netCDF4 import Dataset, num2date, date2num, date2index
from wwttoolbox.nc.nc_utils import (
    convert_nc_dtype_to_simple_dtype,
    real_datetimes_to_datetimes,
)
from datetime import datetime
import numpy as np
from .nc_mask_utils import fit_mask_to_shape, get_combined_mask


class NCWETTool:
    """Similar to the NCTool class, this one can be used to interact with NetCDF files.
    However, this class is specifically designed to work with the WET files that are generated by the GOTM model.
    This class can be used to query the 'output.nc' file and alter/query the 'restart.nc' file.
    """

    def __init__(self, path: str, mode: str) -> None:
        """Initializes the NCWETTool object with a file path and mode.

        Parameters:
        - path (str): The file path to the netCDF file to be opened.
        - mode (str): The mode in which to open the file. Valid modes are 'r' for read and 'a' for append/alter.

        Raises:
        - ValueError: If an invalid mode is provided.
        """
        self.TIME = "time"
        self.DEPTH = "depth"
        self.LAYER_THICKNESS = "h"

        self.path: str = path
        self.nc: Dataset = None
        self.mode: str = None
        self.masks: dict = None
        self.__set_mode(mode)

    def __set_mode(self, mode: str) -> None:
        """Validates the mode and sets it if valid."""
        if mode == "r" or mode == "a":
            self.mode = mode
        else:
            raise ValueError(f"Invalid mode '{mode}'. Please use 'r' or 'a'.")

    def __enter__(self):
        """Opens the netCDF file in the specified mode."""
        self.nc = Dataset(self.path, self.mode)
        return self

    def __exit__(self, *args):
        """Closes the netCDF file."""
        if self.nc:
            self.nc.close()

    ###################################
    # Dimensions
    ###################################

    def get_all_dimensions(self) -> list[str]:
        """Returns a list of all the real dimensions in the netCDF file.

        Returns:
        - list[str]: A list of all the real dimensions in the netCDF file.
        """
        return list(self.nc.dimensions.keys())

    def get_queryable_dimensions(self) -> list[str]:
        """Returns a list of all the queryable dimensions in the netCDF file.

        Following dimensions are queryable:
        - time
        - depth (which is a pseudo dimension)

        Returns:
        - list[str]: A list of all the queryable dimensions in the netCDF file.
        """
        return [self.TIME, self.DEPTH]

    def get_variable_dimensions(self, variable_name: str) -> list[str]:
        """Returns a list of all the real dimensions of the specified variable.

        Parameters:
        - variable_name (str): The variable whose dimensions are to be queried.

        Returns:
        - list[str]: A list of all the real dimensions of the specified variable.
        """
        if variable_name not in self.nc.variables:
            raise ValueError(f"Variable '{variable_name}' does not exist in the file.")

        return list(self.nc.variables[variable_name].dimensions)

    def get_variable_queryable_dimensions(self, variable: str) -> list[str]:
        """Returns a list of all the queryable dimensions of the specified variable.

        See also: get_queryable_dimensions

        Parameters:
        - variable (str): The variable whose dimensions are to be queried.

        Returns:
        - list[str]: A list of all the queryable dimensions of the specified variable.
        """
        dimensions = self.get_variable_dimensions(variable)
        if "zi" in dimensions:
            raise ValueError(
                f"Variables based in the 'zi' dimension are not supported."
            )
        elif "z" in dimensions:
            return [self.TIME, self.DEPTH]
        elif len(dimensions) == 1 and (
            dimensions[0] == "lat" or dimensions[0] == "lon"
        ):
            return []
        else:
            return [self.TIME]

    def get_dimension_length(self, dimension: str) -> int:
        """Returns the length of the specified dimension.

        Parameters:
        - dimension (str): The dimension whose length is to be queried.

        Returns:
        - int: The length of the specified dimension.
        """
        if dimension not in self.nc.dimensions:
            raise ValueError(f"Dimension '{dimension}' does not exist in the file.")

        return len(self.nc.dimensions[dimension])

    ###################################
    # Variables
    ###################################

    def get_all_variables(self) -> list[str]:
        """Returns a list of all variables in the netCDF file.

        Returns:
        - list: A list of variable names in the netCDF file.
        """

        return list(self.nc.variables.keys())

    def get_all_queryable_variables(self) -> list[str]:
        """Returns a list of all queryable variables in the netCDF file.

        A Variable is queryable if it is not one of the following:
        - based on zi dimension

        Returns:
        - list: A list of queryable variable names in the netCDF file.
        """

        queryable_variables = []

        for variable in self.nc.variables:
            dim = self.nc.variables[variable].dimensions
            if "zi" in dim:
                continue
            queryable_variables.append(variable)

        return queryable_variables

    ###################################
    # Meta data
    ###################################

    def get_all_global_attributes(self) -> dict:
        """Returns a dictionary of all the global attributes in the netCDF file.

        Returns:
        - dict: A dictionary of all the global attributes in the netCDF file.
        """
        return self.nc.__dict__

    def get_variable_metadata(self, variable_name: str) -> dict:
        """Returns metadata for a variable in the netCDF file.

        Parameters:
        - variable_name (str): The name of the variable to get metadata for.

        Returns a dict with following keys:
        - variable_name (str): The name of the variable.
        - dimensions (list): A list of dimensions associated with the variable.
        - simple_dtype (str): The simple data type of the variable.
        - nc_dtype (str): The data type of the variable.
        - attributes (dict): A dictionary of attributes associated with the variable.

        Raises:
        - ValueError: If the variable does not exist in the file.
        """

        if variable_name not in self.nc.variables:
            raise ValueError(f"Variable '{variable_name}' does not exist in the file.")

        return {
            "variable_name": variable_name,
            "dimensions": list(self.nc.variables[variable_name].dimensions),
            "simple_dtype": convert_nc_dtype_to_simple_dtype(
                self.nc.variables[variable_name].dtype
            ),
            "nc_dtype": self.nc.variables[variable_name].dtype,
            "attributes": self.nc.variables[variable_name].__dict__,
        }

    ###################################
    # Set masks
    ###################################

    def set_default_masks(self):
        """Sets defaults masks for time and depth dimensions

        Will overwrite any existing masks

        Default masks allows all data to be returned.

        Must be executed before reading data!
        """

        self.masks = {
            self.TIME: slice(None),
            self.DEPTH: np.full(
                (self.get_dimension_length(self.TIME), self.get_dimension_length("z")),
                True,
                dtype=bool,
            ),
        }

    def set_time_mask_from(self, time: datetime):
        """Sets a mask to filter data from a given time.

        If default masks have not been set, they will be set.

        Parameters:
        - time (datetime): The time to filter data from given in UTC.
        """
        if self.masks is None:
            self.set_default_masks()

        time_num = self.datetime_to_num(time)
        variable = self.nc.variables[self.TIME]
        mask = np.ma.masked_greater_equal(variable[:], time_num).mask

        self.masks[self.TIME] = fit_mask_to_shape(mask, variable.shape)

    def set_time_mask_to(self, time: datetime):
        """Sets a mask to filter data to a given time (including specified time).

        If default masks have not been set, they will be set.

        Parameters:
        - time (datetime): The time to filter data to given in UTC.
        """
        if self.masks is None:
            self.set_default_masks()

        time_num = self.datetime_to_num(time)
        variable = self.nc.variables[self.TIME]
        mask = np.ma.masked_less_equal(variable[:], time_num).mask

        self.masks[self.TIME] = fit_mask_to_shape(mask, variable.shape)

    def set_time_mask_between(self, start_time: datetime, end_time: datetime):
        """Sets a mask to filter data between two times.

        If default masks have not been set, they will be set.

        Parameters:
        - start_time (datetime): The start time to filter data from given in UTC.
        - end_time (datetime): The end time to filter data to given in UTC.
        """
        if self.masks is None:
            self.set_default_masks()

        start_time_num = self.datetime_to_num(start_time)
        end_time_num = self.datetime_to_num(end_time)
        variable = self.nc.variables[self.TIME]
        mask = np.ma.masked_inside(variable[:], start_time_num, end_time_num).mask

        self.masks[self.TIME] = fit_mask_to_shape(mask, variable.shape)

    def set_depth_mask_from(self, depth_from: float):
        """Sets a mask to filter data from a given depth.

        If default masks have not been set, they will be set.

        Parameters:
        - depth (float): The depth to filter data from. The depth is in meters and is negative (positive is surface).
        """
        if self.masks is None:
            self.set_default_masks()

        depth_dim = self.get_pseudo_depth_dimension()
        mask = np.ma.masked_less_equal(depth_dim, depth_from).mask

        self.masks[self.DEPTH] = fit_mask_to_shape(mask, depth_dim.shape)

    def set_depth_mask_to(self, depth_to: float):
        """Sets a mask to filter data to a given depth (including specified depth).

        If default masks have not been set, they will be set.

        Parameters:
        - depth (float): The depth to filter data to. The depth is in meters and is negative (positive is surface).
        """
        if self.masks is None:
            self.set_default_masks()

        depth_dim = self.get_pseudo_depth_dimension()
        mask = np.ma.masked_greater_equal(depth_dim, depth_to).mask

        self.masks[self.DEPTH] = fit_mask_to_shape(mask, depth_dim.shape)

    def set_depth_mask_between(self, depth_from: float, depth_to: float):
        """Sets a mask to filter data between two depths.

        If default masks have not been set, they will be set.

        Parameters:
        - depth_from (float): The start depth to filter data from. The depth is in meters and is negative (positive is surface).
        - depth_to (float): The end depth to filter data to. The depth is in meters and is negative (positive is surface).
        """
        if self.masks is None:
            self.set_default_masks()

        depth_dim = self.get_pseudo_depth_dimension()
        mask = np.ma.masked_inside(depth_dim, depth_from, depth_to).mask

        self.masks[self.DEPTH] = fit_mask_to_shape(mask, depth_dim.shape)

    def set_depth_mask_exact(self, depth: float):
        """Sets a mask to filter data to an exact depth.

        If default masks have not been set, they will be set.

        Parameters:
        - depth (float): The depth to filter data to. The depth is in meters and is negative (positive is surface).
        """
        if self.masks is None:
            self.set_default_masks()

        depth_dim = self.get_pseudo_depth_dimension()
        mask = np.ma.masked_equal(depth_dim, depth).mask

        self.masks[self.DEPTH] = fit_mask_to_shape(mask, depth_dim.shape)

    def set_depth_mask_nearest(self, depth: float):
        """Sets a mask to filter data to the nearest depth.

        If default masks have not been set, they will be set.

        Parameters:
        - depth (float): The depth to filter data to. The depth is in meters and is negative (positive is surface).
        """
        if self.masks is None:
            self.set_default_masks()

        depth_dim = self.get_pseudo_depth_dimension()

        # Compute the absolute differences
        differences = np.abs(depth_dim - depth)
        # Find the index of the minimum difference
        min_indexs = np.argmin(differences, axis=1)
        # Create a mask where only the nearest depth value is True
        mask = np.zeros_like(depth_dim, dtype=bool)
        mask[:, min_indexs] = True

        self.masks[self.DEPTH] = fit_mask_to_shape(mask, depth_dim.shape)

    ###################################
    # Read data
    ###################################

    def read_time_data(self) -> list[datetime]:
        """Reads the 'time' variable from the netCDF file.

        All masks will be applied to the data and only data that is not masked will be returned.

        Returns:
        - list[datetime]: A list of datetime objects in UTC (no timezone offset).

        Raises:
        - ValueError: If no masks have been set - run set_default_masks() first.

        """
        if self.masks is None:
            raise ValueError(
                "No masks have been set. Please set masks before reading data."
            )

        mask = self.masks[self.TIME]

        time_nums = self.nc.variables[self.TIME][mask]

        attributes = self.get_variable_metadata("time")["attributes"]
        calender = attributes.get("calendar")
        units = attributes.get("units")

        real_datetimes = num2date(
            time_nums,
            units,
            calender,
            only_use_cftime_datetimes=False,
            only_use_python_datetimes=True,
        )

        return real_datetimes_to_datetimes(real_datetimes)

    def read_depth_data(self) -> list[list[float]]:
        """Constructs the pseudo depth dimension based on the layer thickness (variable h) and returns it.

        All masks will be applied to the data and only data that is not masked will be returned.

        The depth is reversed so the first item is the surface.

        Returns:
        - list[list[float]]: The pseudo depth dimension.

        Raises:
        - ValueError: If no masks have been set - run set_default_masks() first.
        """

        if self.masks is None:
            raise ValueError(
                "No masks have been set. Please set masks before reading data."
            )
        depth_data = self.get_pseudo_depth_dimension()

        time_mask = self.masks[self.TIME]
        depth_mask = self.masks[self.DEPTH][self.masks[self.TIME], :]

        depth_time_cutted = depth_data[time_mask, :]

        data_matrix = []
        for item_data, item_depth_mask in zip(depth_time_cutted, depth_mask):
            data_matrix.append(np.flip(item_data[item_depth_mask]).tolist())

        return data_matrix

    def read_variable_data(
        self, variable_name: str
    ) -> float | list[float] | list[list[float]]:
        """Reads the specified variable from the netCDF file.

        All masks will be applied to the data and only data that is not masked will be returned.

        Depending on the dimensionality of the variable, the return type will vary:
        - 0D: float (lat or lon)
        - 1D: list[float] (all variables only based on the queryable 'time' dimensions)
        - 2D: list[list[float]] (all variables based on the queryable 'time' and 'depth' dimensions)

        All 2D data are reversed in the depth dimension to have the surface as the first element - this is the opposite of the netCDF file.


        Parameters:
        - variable_name (str): The name of the variable to be read.

        Returns:
        - float | list[float] | list[list[float]]: The data of the specified variable.

        """
        if self.masks is None:
            raise ValueError(
                "No masks have been set. Please set masks before reading data."
            )

        dimensions = self.get_variable_queryable_dimensions(variable_name)

        if len(dimensions) == 0:
            return self._read_0D_variable_data(variable_name)
        elif len(dimensions) == 1 and dimensions[0] == self.TIME:
            return self._read_1D_variable_data(variable_name)
        elif (
            len(dimensions) == 2
            and dimensions[0] == self.TIME
            and dimensions[1] == self.DEPTH
        ):
            return self._read_2D_variable_data(variable_name)
        else:
            raise ValueError(
                f"Variable '{variable_name}' has unsupported dimensions: {dimensions}"
            )

    def _read_0D_variable_data(self, variable_name: str) -> float:
        """Reads a 0D variable from the netCDF file.

        Parameters:
        - variable_name (str): The name of the variable to be read.

        Returns:
        - float: The data of the specified variable.
        """
        data = self.nc.variables[variable_name][:]

        if len(data) == 1:
            return data.tolist()[0]

        raise ValueError(f"Variable '{variable_name}' is not 0D.")

    def _read_1D_variable_data(self, variable_name: str) -> list[float]:
        """Reads a 1D variable from the netCDF file.

        Parameters:
        - variable_name (str): The name of the variable to be read.

        Returns:
        - list[float]: The data of the specified variable.
        """

        data = self.nc.variables[variable_name][self.masks[self.TIME], 0, 0]

        if len(data.shape) == 1:
            return data.tolist()

        raise ValueError(f"Variable '{variable_name}' is not 1D.")

    def _read_2D_variable_data(self, variable_name: str) -> list[list[float]]:
        """Reads a 2D variable from the netCDF file.

        Reverses the depth dimension to have the surface as the first element.

        Parameters:
        - variable_name (str): The name of the variable to be read.

        Returns:
        - list[list[float]]: The data of the specified variable.
        """
        time_mask = self.masks[self.TIME]
        depth_mask = self.masks[self.DEPTH][self.masks[self.TIME], :]

        data_time_cutted = self.nc.variables[variable_name][time_mask, :, 0, 0]

        if depth_mask.shape != data_time_cutted.shape:
            raise ValueError(f"Variable '{variable_name}' is not 2D.")

        # Loop over the depth dimension, apply the mask and flip the data
        data_matrix = []
        for item_data, item_depth_mask in zip(data_time_cutted, depth_mask):
            data_matrix.append(np.flip(item_data[item_depth_mask]).tolist())

        return data_matrix

    ###################################
    # Pseudo dimensions
    ###################################

    def get_pseudo_depth_dimension(self) -> np.ndarray:
        """Calculates the pseudo depth based on layer thickness (variable h) and with the surface as reference 0.
        The depth is is the layer thickness cumulatively added.

        Returns a 2 level matrix, where the first level is the time dimensions and the second level is the depth dimension.

        The matrix is fixed in size.

        Returns:
        - list[list[float]]: A 2 level matrix with the pseudo depth dimension.

        """
        # Take the layer thickness, accumulate it and find the center of each layer.
        # Then reverse the order and invert the sign
        # With a constant thickness of 1, 2 time units and 2 depths: (the lat/lon dimensions are not shown)
        # h=[[1,1],[1,1]] -> [[1, 2], [1,2]] -> [[-2, -1], [-2, -1]] -> [[-1.5, -0.5], [-1.5, -0.5]]

        layer_thickness = self.nc.variables[self.LAYER_THICKNESS][:, :, 0, 0]

        layer_thickness_acc = np.cumsum(layer_thickness, axis=1)

        depth_crop = (
            layer_thickness_acc
            - layer_thickness_acc[:, -1, np.newaxis]
            - layer_thickness / 2
        )

        return depth_crop

    ###################################
    # Auxiliary functions
    ###################################

    def datetime_to_num(self, time: datetime) -> float:
        attributes = self.get_variable_metadata("time")["attributes"]
        calender = attributes.get("calendar")
        units = attributes.get("units")

        return date2num(time, units, calender)

    ###################################
    # Write data
    ###################################

    def write_1D_variable_data(self, variable_name: str, data: list[float]):
        """Writes data to a 1D variable in the netCDF file.

        Will write to variables that has following real dimensions:
        - time
        - lat
        - lon

        The dimensions lat and lon will automatically be applied.

        The provided data must match the existing shape (after applying lat and lon dimensions).

        Only float data types are supported.

        Parameters:
        - variable_name (str): The name of the variable to write data to.
        - data (list[float]): The data to write to the variable (without the lat and lon dimensions).

        Raises:
        - ValueError: If the variable does not exist in the file.
        - ValueError: If the variable is not 1D.
        - ValueError: If the data does not match the shape of the variable.
        """
        if variable_name not in self.nc.variables:
            raise ValueError(f"Variable '{variable_name}' does not exist in the file.")

        dimensions = self.get_variable_queryable_dimensions(variable_name)

        if dimensions != [self.TIME]:
            raise ValueError(f"Variable '{variable_name}' is not 1D.")

        data_extended = np.array(self.apply_lat_lon_dimensions(data))

        if data_extended.shape != self.nc.variables[variable_name].shape:
            raise ValueError(
                f"Data shape does not match the shape of variable '{variable_name}'. Expected shape {self.nc.variables[variable_name].shape}, got {data_extended.shape}."
            )

        self.nc.variables[variable_name][:] = data_extended

    def write_2D_variable_data(self, variable_name: str, data: list[list[float]]):
        """Writes data to a 2D variable in the netCDF file.

        Written data is not reversed in the depth dimension, hence the first element is the bottom.

        Will write to variables that has following real dimensions:
        - time
        - z
        - lat
        - lon

        The dimensions lat and lon will automatically be applied.

        The provided data must match the existing shape (after applying lat and lon dimensions).

        Only float data types are supported.

        Parameters:
        - variable_name (str): The name of the variable to write data to.
        - data (list[list[float]]): The data to write to the variable (without the lat and lon dimensions).

        Raises:
        - ValueError: If the variable does not exist in the file.
        - ValueError: If the variable is not 2D.
        - ValueError: If the data does not match the shape of the variable.
        """
        if variable_name not in self.nc.variables:
            raise ValueError(f"Variable '{variable_name}' does not exist in the file.")

        dimensions = self.get_variable_queryable_dimensions(variable_name)

        if dimensions != [self.TIME, self.DEPTH]:
            raise ValueError(f"Variable '{variable_name}' is not 2D.")

        data_extended = np.array(self.apply_lat_lon_dimensions(data))

        if data_extended.shape != self.nc.variables[variable_name].shape:
            raise ValueError(
                f"Data shape does not match the shape of variable '{variable_name}'. Expected shape {self.nc.variables[variable_name].shape}, got {data_extended.shape}."
            )

        self.nc.variables[variable_name][:] = data_extended

    def apply_lat_lon_dimensions(
        self, data: list[float] | list[list[float]]
    ) -> list[list[list[float]]] | list[list[list[list[float]]]]:
        """Applies the lat and lon dimensions to the data.

        Takes either a list of floats or a list of lists of floats and applies the lat and lon dimensions.

        Parameters:
        - data (list[float]|list[list[float]]): The data to apply the lat and lon dimensions to.

        Returns:
        - list[list[list[float]]]|list[list[list[list[float]]]: The data with the lat and lon dimensions applied.

        """
        data_array = np.array(data)

        if data_array.ndim == 1:
            return data_array[:, np.newaxis, np.newaxis].tolist()

        elif data_array.ndim == 2:
            return data_array[:, :, np.newaxis, np.newaxis].tolist()
        else:
            raise ValueError(f"Data has unsupported dimensions: {data_array.ndim}")
