from abc import ABC, abstractmethod

import numpy as np
import skfmm
from numpy.typing import NDArray

from machwave.models.propulsion.grain import GrainGeometryError, GrainSegment
from machwave.services.decorators import validate_assertions


class FMMGrainSegment(GrainSegment, ABC):
    """
    Fast Marching Method (FMM) implementation of a grain segment.

    This class was inspired by the Andrew Reilley's software openMotor, in
    particular the fmm module.
    openMotor's repository can be accessed at:
    https://github.com/reilleya/openMotor
    """

    def __init__(
        self,
        map_dim: int,
        length: float,
        outer_diameter: float,
        spacing: float,
        inhibited_ends: int = 0,
    ) -> None:
        self.map_dim = map_dim

        # "Cache" variables:
        self.maps = None
        self.mask = None
        self.masked_face = None
        self.regression_map = None

        super().__init__(
            length=length,
            outer_diameter=outer_diameter,
            spacing=spacing,
            inhibited_ends=inhibited_ends,
        )

    @abstractmethod
    def get_initial_face_map(self) -> np.typing.NDArray[np.int_]:
        """
        Method needs to be implemented for each and every geometry.
        """
        pass

    @abstractmethod
    def get_maps(self) -> tuple:
        """
        Returns:
            - 2D: (map_x, map_y)
            - 3D: (map_x, map_y, map_z)
        """
        pass

    @abstractmethod
    def get_mask(self) -> np.ndarray:
        """
        Implementation varies depending if the geometry is 2D or 3D.
        """
        pass

    @validate_assertions(exception=GrainGeometryError)
    def validate(self) -> None:
        """
        Validates the internal geometry of the grain.

        This method ensures the grain map dimension meets the minimum
        required size. If validation fails, a GrainGeometryError is raised.

        Raises:
            GrainGeometryError: If the grain map dimension is below the valid threshold.
        """
        super().validate()
        assert self.map_dim >= 100

    def normalize(self, value: int | float) -> float:
        """
        Converts a raw dimensional value into a normalized scale based on the
        object's outer diameter.

        Args:
            value: The dimensional value (e.g., length) to normalize.

        Returns:
            A float representing the dimension as a fraction of the object's
            half-diameter.
        """
        return value / (0.5 * self.outer_diameter)

    def denormalize(self, value: int | float) -> float:
        """
        Converts a normalized input value into an actual dimension based on the
        object's outer diameter.

        Args:
            value: A numeric value representing a normalized quantity.

        Returns:
            The denormalized value as a float, calculated by scaling `value` with
            the object's outer diameter.
        """
        return (value / 2) * (self.outer_diameter)

    def map_to_area(self, value: float):
        """
        Convert a pixel-area value into square meters.

        The conversion is based on the ratio of this object's outer diameter
        (squared) to the total pixel map dimension (squared).

        Args:
            value: The area in pixel units.

        Returns:
            The corresponding area in square meters.
        """
        return (self.outer_diameter**2) * (value / (self.map_dim**2))

    def map_to_length(self, value: float) -> float:
        """
        Convert a pixel-distance value into meters.

        The conversion is based on the ratio of this object's outer diameter
        to its total map dimension.

        Args:
            value: The distance in pixel units.

        Returns:
            The corresponding distance in meters.
        """
        return self.outer_diameter * (value / self.map_dim)

    def get_empty_face_map(self) -> np.ndarray:
        """
        Return a new face map consisting entirely of ones.

        The shape of the array matches the first element in the object's stored maps.
        """
        return np.ones_like(self.get_maps()[0])

    def get_masked_face(self) -> np.ndarray:
        """
        Return a masked representation of the face map.

        The mask is circular and normalized to the map dimensions. If a mask
        has not been created yet, it is generated by combining the initial face
        map with the circular mask.
        """
        if self.masked_face is None:
            self.masked_face = np.ma.MaskedArray(
                self.get_initial_face_map(), self.get_mask()
            )
        return self.masked_face

    def get_cell_size(self) -> float:
        """
        Return the size of each grid cell in normalized coordinates.

        The value is derived by taking 1 divided by the map dimension.
        """
        return 1 / self.map_dim

    def get_regression_map(self):
        """
        Calculate and return the distance map for grain regression.

        This uses the fast marching method (scikit-fmm) on the masked face.
        Each value represents the distance from the initial face along the
        cross-section of the grain.
        """
        if self.regression_map is None:
            self.regression_map = (
                skfmm.distance(self.get_masked_face(), dx=self.get_cell_size()) * 2
            )
        return self.regression_map

    def get_web_thickness(self) -> float:
        """
        Return the maximum thickness of the grain web in real units.

        The web thickness is the largest distance from the center of the
        grain segment, derived from the distance map and converted to a
        real-world measurement.
        """
        return self.denormalize(np.amax(self.get_regression_map()))

    @abstractmethod
    def get_contours(
        self, web_distance: float, *args, **kwargs
    ) -> list[NDArray[np.float64]]:
        """
        Return the contours of the regression map after a specified web distance.

        This method must be implemented by a subclass to compute the contour
        data based on the given web distance and any additional parameters.

        Args:
            web_distance: The depth of regression into the grain web.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            An array representing the computed contours of the grain regression.
        """
        pass

    def get_face_map(self, web_distance: float) -> np.typing.NDArray[np.int64]:
        """
        Returns a matrix representing the grain face based on the given web distance.

        The returned array can contain:
        -1 for masked or invalid points,
        0 for points below the threshold,
        1 for points above the threshold.

        Args:
            web_distance: The distance traveled into the grain web.

        Returns:
            A NumPy array with -1, 0, or 1 indicating the grain face at the specified web distance.
        """
        web_distance_normalized = self.normalize(web_distance)
        regression_map = self.get_regression_map()
        valid = np.logical_not(self.get_mask())

        # Create a masked array, where ~valid cells are masked out
        maskarr = np.ma.MaskedArray(
            (regression_map > web_distance_normalized).astype(np.int64), mask=~valid
        )

        # Fill masked entries with -1, valid/true entries remain 1 or 0
        return maskarr.filled(-1)
