"""Parameter reparameterisation utilities for optimisation workflows."""

from __future__ import annotations

import logging
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import cast

import numpy as np
from numpy.typing import NDArray

logger = logging.getLogger(__name__)

Array = NDArray[np.float64]
BoolArray = NDArray[np.bool_]
BoundsPayload = tuple[Array, Array] | Sequence[tuple[float, float]] | None


@dataclass(frozen=True)
class Parameter:
    """Scalar optimisation parameter metadata."""

    name: str
    theta0: float
    vary: bool = True
    bounds: tuple[float | None, float | None] = (None, None)

    def __post_init__(self) -> None:
        if not self.name:
            raise ValueError("Parameter name may not be empty.")
        if not math.isfinite(self.theta0):
            raise ValueError("theta0 must be finite.")
        if self.bounds is None:
            object.__setattr__(self, "bounds", (None, None))


class ParameterSpace:
    """Mapping between φ (optimiser space) and θ (physical parameters).

    Each parameter θ_i is related to its optimisation counterpart φ_i via::

        θ_i = θ0_i * ξ**φ_i

    The exponential reparameterisation keeps θ positive while allowing unconstrained
    optimisation in φ-space. Parameters can be supplied either through the legacy
    constructor arguments or incrementally via :meth:`add_parameter`.
    """

    def __init__(
        self,
        names: Sequence[str] | None = None,
        theta0: Mapping[str, float] | None = None,
        *,
        xi: float = 10.0,
        vary: Mapping[str, bool] | None = None,
        theta_bounds: Mapping[str, tuple[float | None, float | None]] | None = None,
        parameters: Iterable[Parameter] | None = None,
    ):
        """Create a parameter space using legacy args or explicit Parameter objects."""
        if xi <= 0.0:
            raise ValueError("xi must be > 0.")
        if xi == 1.0:
            raise ValueError("xi == 1 makes dθ/dφ = 0.")

        self.xi = float(xi)
        self._ln_xi = math.log(self.xi)

        self._parameters: list[Parameter] = []
        self._theta0_map: dict[str, float] = {}
        self._vary_map: dict[str, bool] = {}
        self._bounds_map: dict[str, tuple[float | None, float | None]] = {}

        if parameters is not None:
            for spec in parameters:
                self._append_parameter(spec, rebuild=False)

        if names is not None or theta0 is not None:
            if names is None or theta0 is None:
                raise ValueError("names and theta0 must be provided together.")
            for name in names:
                if name not in theta0:
                    raise ValueError(f"Missing theta0 entry for parameter '{name}'.")
                spec = Parameter(
                    name=name,
                    theta0=float(theta0[name]),
                    vary=True if vary is None else bool(vary.get(name, True)),
                    bounds=(
                        theta_bounds.get(name, (None, None))
                        if theta_bounds is not None
                        else (None, None)
                    ),
                )
                self._append_parameter(spec, rebuild=False)

        self._rebuild_cache()

    # ---------- parameter management ----------
    def add_parameter(
        self,
        parameter: Parameter | None = None,
        *,
        name: str | None = None,
        theta0: float | None = None,
        vary: bool = True,
        bounds: tuple[float | None, float | None] | None = None,
    ) -> Parameter:
        """Register a new optimisation parameter.

        Parameters can be supplied either as a :class:`Parameter` instance or through
        the keyword arguments ``name``/``theta0``/``vary``/``bounds``.

        Returns:
            The registered Parameter instance.
        """
        if parameter is None:
            if name is None or theta0 is None:
                raise ValueError("Provide a Parameter instance or name/theta0 pair.")
            parameter = Parameter(
                name=name,
                theta0=float(theta0),
                vary=bool(vary),
                bounds=bounds if bounds is not None else (None, None),
            )
        elif not isinstance(parameter, Parameter):
            raise TypeError("parameter must be a Parameter instance.")

        self._append_parameter(parameter, rebuild=True)
        return parameter

    def parameters(self) -> list[Parameter]:
        """Return a copy of the registered parameter specifications.

        Returns:
            List of Parameter definitions.
        """
        return list(self._parameters)

    # ---------- mapping ----------
    # ---------- masks and packing ----------
    def active_mask(self) -> BoolArray:
        """Return a boolean mask describing which parameters vary.

        Returns:
            Boolean array aligned with ``names``.
        """
        return self._vary_vec.copy()

    def pack_dict(self, d: Mapping[str, float]) -> Array:
        """Pack a parameter dictionary into a vector ordered by ``names``.

        Returns:
            θ vector ordered to match ``names``.
        """
        return np.asarray([float(d[k]) for k in self._names], dtype=float)

    def unpack_vec(self, v: Sequence[float]) -> dict[str, float]:
        """Convert a vector into a parameter dictionary.

        Returns:
            Mapping from parameter name to θ value.
        """
        return {
            k: float(x) for k, x in zip(self._names, np.asarray(v, dtype=float), strict=True)
        }

    # ---------- compatibility helpers ----------
    @property
    def names(self) -> list[str]:
        """Names of all registered parameters."""
        return list(self._names)

    @property
    def theta0(self) -> dict[str, float]:
        """Initial θ values keyed by name."""
        return dict(self._theta0_map)

    @property
    def theta_bounds(self) -> dict[str, tuple[float | None, float | None]]:
        """Return θ-space bounds keyed by parameter name."""
        return dict(self._bounds_map)

    @property
    def vary(self) -> dict[str, bool]:
        """Flags indicating whether each parameter varies."""
        return dict(self._vary_map)

    # ---------- internal helpers ----------
    def _append_parameter(self, spec: Parameter, rebuild: bool) -> None:
        name = spec.name
        if name in self._theta0_map:
            raise ValueError(f"Parameter '{name}' already registered.")

        self._parameters.append(spec)
        self._theta0_map[name] = float(spec.theta0)
        self._vary_map[name] = bool(spec.vary)
        bounds = spec.bounds if spec.bounds is not None else (None, None)
        if len(bounds) != 2:
            raise ValueError("bounds must be a (low, high) tuple.")
        self._bounds_map[name] = tuple(bounds)  # type: ignore[assignment]

        if rebuild:
            self._rebuild_cache()

    def _rebuild_cache(self) -> None:
        if not self._parameters:
            self._names = []
            self._theta0_vec = cast(Array, np.asarray([], dtype=float))
            self._vary_vec = cast(BoolArray, np.asarray([], dtype=bool))
            self._th_lo = None
            self._th_hi = None
            return

        self._names = [p.name for p in self._parameters]
        self._theta0_vec = cast(
            Array,
            np.asarray([float(p.theta0) for p in self._parameters], dtype=float),
        )
        self._vary_vec = cast(
            BoolArray,
            np.asarray([bool(p.vary) for p in self._parameters], dtype=bool),
        )

        lo_vals = []
        hi_vals = []
        any_lo = False
        any_hi = False
        for p in self._parameters:
            lo, hi = p.bounds if p.bounds is not None else (None, None)
            lo_vals.append(-np.inf if lo is None else float(lo))
            hi_vals.append(+np.inf if hi is None else float(hi))
            any_lo = any_lo or (lo is not None)
            any_hi = any_hi or (hi is not None)

        self._th_lo = cast(Array, np.asarray(lo_vals, dtype=float)) if any_lo else None
        self._th_hi = cast(Array, np.asarray(hi_vals, dtype=float)) if any_hi else None


class Reparameterizer:
    """Encapsulate θ↔φ transformations and bounds handling."""

    def __init__(self, space: ParameterSpace, enabled: bool) -> None:
        """Store references and cache θ-space bounds."""
        self._space = space
        self._enabled = bool(enabled)
        self._th_lo = space._th_lo
        self._th_hi = space._th_hi
        self._theta0_vec = space._theta0_vec
        self._ln_xi = space._ln_xi
        self._names = space._names

    @property
    def enabled(self) -> bool:
        """Whether reparameterisation is active."""
        return self._enabled

    @property
    def names(self) -> Sequence[str]:
        """Parameter names."""
        return list(self._space.names)

    def initial_phi(self) -> Array:
        """Return starting φ vector respecting reparameterisation."""
        if self._enabled:
            return cast(Array, np.zeros(len(self.names), dtype=float))
        theta0_vec = self._space.pack_dict(self._space.theta0)
        return cast(Array, np.asarray(theta0_vec, dtype=float))

    def phi_to_theta(self, phi_vec: Array) -> Array:
        """Convert φ vector to θ, clamping to bounds.

        Returns:
            θ vector after applying bounds.
        """
        phi_vec = cast(Array, np.asarray(phi_vec, dtype=float))
        if self._enabled:
            theta_vec = cast(Array, self._theta0_vec * np.power(self._space.xi, phi_vec))
        else:
            theta_vec = phi_vec
        theta_vec = self.clamp_theta(theta_vec)
        return cast(Array, np.asarray(theta_vec, dtype=float))

    def bounds(self) -> BoundsPayload:
        """Return bounds appropriate for the current parameterisation."""
        return self.phi_bounds() if self._enabled else self.theta_bounds_array()

    def theta_bounds_array(self) -> tuple[Array, Array] | None:
        """Return θ-space bounds as dense arrays."""
        if self._th_lo is None and self._th_hi is None:
            return None
        size = len(self._names)
        if self._th_lo is not None:
            lo = cast(Array, np.asarray(self._th_lo, dtype=float))
        else:
            lo = cast(Array, np.full(size, -np.inf, dtype=float))
        if self._th_hi is not None:
            hi = cast(Array, np.asarray(self._th_hi, dtype=float))
        else:
            hi = cast(Array, np.full(size, +np.inf, dtype=float))
        return lo.copy(), hi.copy()

    def phi_bounds(self) -> tuple[Array, Array] | None:
        """Transform θ-space bounds into φ-space bounds.

        Returns:
            Tuple of lower/upper φ bounds or ``None`` when unbounded.
        """
        if self._th_lo is None and self._th_hi is None:
            return None
        lo = cast(Array, np.full(len(self._names), -np.inf, dtype=float))
        hi = cast(Array, np.full(len(self._names), +np.inf, dtype=float))
        if self._th_lo is not None:
            mask = np.isfinite(self._th_lo)
            lo_vals = self.phi_from_theta(self._th_lo)
            lo[mask] = lo_vals[mask]
        if self._th_hi is not None:
            mask = np.isfinite(self._th_hi)
            hi_vals = self.phi_from_theta(self._th_hi)
            hi[mask] = hi_vals[mask]
        return lo, hi

    def phi_from_theta(self, theta_vec: Array) -> Array:
        """Map θ values back into φ-space, tolerating zero bounds via eps nudging.

        Returns:
            φ vector corresponding to supplied θ.
        """
        theta_vec = self._ensure_positive_theta(theta_vec)
        ratio = theta_vec / self._theta0_vec
        if np.any(ratio <= 0.0):
            raise ValueError("theta must be > 0 elementwise to invert mapping.")
        return np.log(ratio) / self._ln_xi

    def dtheta_dphi(self, phi_vec: Array) -> Array:
        """Return ∂θ/∂φ for the provided φ vector.

        Returns:
            θ-space gradient for each φ component.
        """
        theta = self.theta_from_phi(phi_vec)
        return theta * self._ln_xi

    def theta_from_phi(self, phi_vec: Array) -> Array:
        """Map φ values to θ-space.

        Returns:
            θ vector produced from φ.
        """
        phi_vec = np.asarray(phi_vec, dtype=float)
        result: Array = self._theta0_vec * np.power(self._space.xi, phi_vec)
        return result

    def _ensure_positive_theta(self, theta_vec: Array) -> Array:
        theta_arr = np.asarray(theta_vec, dtype=float).copy()
        mask = theta_arr <= 0.0
        if mask.any():
            eps = np.finfo(theta_arr.dtype).tiny
            logger.warning(
                "theta values %s contain non-positive entries; clamping to %.3e",
                theta_arr[mask],
                eps,
            )
            theta_arr[mask] = eps
        return theta_arr

    def clamp_theta(self, theta_vec: Array) -> Array:
        """Clamp θ values according to stored bounds.

        Returns:
            Bounded θ vector.
        """
        if self._th_lo is None and self._th_hi is None:
            return theta_vec
        out = np.asarray(theta_vec, dtype=float).copy()
        if self._th_lo is not None:
            out = np.maximum(out, self._th_lo)
        if self._th_hi is not None:
            out = np.minimum(out, self._th_hi)
        return out


class ParameterMapper:
    """Helper object that exposes φ/θ conversions and bounds for the engine."""

    def __init__(self, space: ParameterSpace, use_reparam: bool) -> None:
        """Bridge the engine to parameter-space transformations."""
        self._space = space
        self._reparam = Reparameterizer(space, use_reparam)

    @property
    def names(self) -> Sequence[str]:
        """Parameter names in optimiser order."""
        return self._reparam.names

    @property
    def reparam_enabled(self) -> bool:
        """Whether reparameterisation is enabled."""
        return self._reparam.enabled

    def initial_phi(self, phi0: Sequence[float] | None) -> Array:
        """Return initial φ vector, using overrides when provided."""
        if phi0 is not None:
            return cast(Array, np.asarray(phi0, dtype=float))
        return self._reparam.initial_phi()

    def bounds(self, bounds: Sequence[tuple[float, float]] | None) -> BoundsPayload:
        """Return bounds payload, preferring caller-provided bounds when set."""
        if bounds is not None:
            return bounds
        return self._reparam.bounds()

    def phi_to_theta(self, phi_vec: Array) -> Array:
        """Map φ vector to θ vector.

        Returns:
            θ vector corresponding to the provided φ values.
        """
        return self._reparam.phi_to_theta(phi_vec)

    def theta_dict(self, theta_vec: Array) -> dict[str, float]:
        """Convert θ vector to dict keyed by parameter name.

        Returns:
            Mapping of parameter names to θ values.
        """
        theta_vec = cast(Array, np.asarray(theta_vec, dtype=float))
        return self._space.unpack_vec(theta_vec.tolist())

    def theta_from_phi(self, phi_vec: Array) -> Array:
        """Alias to phi_to_theta for clarity in callers.

        Returns:
            θ vector corresponding to φ.
        """
        return self._reparam.phi_to_theta(phi_vec)

    def clamp_theta(self, theta_vec: Array) -> Array:
        """Clamp θ vector to bounds.

        Returns:
            Bounded θ vector.
        """
        return self._reparam.clamp_theta(theta_vec)


__all__ = ["BoundsPayload", "Parameter", "ParameterMapper", "ParameterSpace", "Reparameterizer"]
