"""Utilities for selecting evaluation grids and aligning simulation data."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, cast

import numpy as np
from numpy.typing import NDArray
from scipy.interpolate import interp1d

Array = NDArray[np.float64]

GridPolicy = Literal["exp_to_sim", "fixed_user"]


@dataclass
class EvaluationGrid:
    """Policy-driven evaluation grid selection.

    ``exp_to_sim`` preserves the experimental sampling and ensures simulations
    are compared on that grid. ``fixed_user`` enforces a custom user-provided
    grid, enabling shared abscissa among multiple experiments.
    """

    policy: GridPolicy
    common_grid: Array | None = None

    def select_grid(self, x_exp: Array, _x_sim: Array) -> Array:
        """Return the abscissa where residuals should be evaluated.

        Args:
            x_exp: Experimental sampling points.

        Returns:
            Evaluation grid matching the chosen policy.
        """
        if self.policy == "exp_to_sim":
            return cast(Array, np.asarray(x_exp, dtype=float))
        if self.policy == "fixed_user":
            if self.common_grid is None:
                raise ValueError("common_grid must be provided for fixed_user policy")
            return cast(Array, np.asarray(self.common_grid, dtype=float))
        raise ValueError(f"Unknown grid policy: {self.policy}")


class Aligner:
    """Thin wrapper above :func:`scipy.interpolate.interp1d` style API."""

    def map(self, x_src: Array, y_src: Array, x_tgt: Array) -> Array:
        """Project the ``(x_src, y_src)`` samples onto ``x_tgt``.

        Returns:
            Aligned values with the same shape as ``x_tgt``.
        """
        x_src_arr: Array = np.asarray(x_src, dtype=float).reshape(-1)
        y_src_arr: Array = np.asarray(y_src, dtype=float).reshape(-1)
        x_tgt_arr: Array = np.asarray(x_tgt, dtype=float).reshape(-1)
        if x_src_arr.size == 0:
            return np.zeros_like(x_tgt_arr)
        if x_src_arr.shape != y_src_arr.shape:
            raise ValueError(
                f"x and y must have the same length; got {x_src_arr.shape} and {y_src_arr.shape}"
            )
        if x_src_arr.size == 1:
            return np.full_like(x_tgt_arr, y_src_arr[0])
        order = np.argsort(x_src_arr)
        x_src_arr = x_src_arr[order]
        y_src_arr = y_src_arr[order]
        interpolator = interp1d(
            x_src_arr,
            y_src_arr,
            kind="linear",
            fill_value="extrapolate",
            bounds_error=False,
        )
        return np.asarray(interpolator(x_tgt_arr), dtype=float)
