"""Sliceable result views for XPLT data."""

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Self

import numpy as np

from pyfebiopt.mesh.mesh import Mesh

from .enums import FEDataDim, FEDataType, Storage_Fmt
from .types import Index, MultBlock, MultLike

_VEC3_ORDER = ("x", "y", "z")
_MAT3FD_ORDER = ("xx", "yy", "zz")
_MAT3FS_ORDER = ("xx", "yy", "zz", "xy", "yz", "xz")
_MAT3F_ORDER = ("xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz")
_VOIGT6 = {"xx": 0, "yy": 1, "zz": 2, "yz": 3, "xz": 4, "xy": 5}


class _FieldMeta:
    """Lightweight metadata for a result variable."""

    __slots__ = ("dtype", "fmt", "name", "ncomp")

    def __init__(self, name: str, fmt: Storage_Fmt, dtype: FEDataType):
        self.name = name
        self.fmt = fmt
        self.dtype = dtype
        self.ncomp = int(FEDataDim[dtype.name].value)

    def __repr__(self) -> str:
        return (
            f"_FieldMeta(name={self.name!r}, fmt={self.fmt.name}, "
            f"dtype={self.dtype.name}, ncomp={self.ncomp})"
        )

    __str__ = __repr__


def _normalize_comp_token(s: str) -> str:
    return s.strip().lower().replace(" ", "")


def _comp_names_for_dtype(dtype: FEDataType) -> tuple[str, ...] | None:
    if dtype == FEDataType.VEC3F:
        return _VEC3_ORDER
    if dtype == FEDataType.MAT3FD:
        return _MAT3FD_ORDER
    if dtype == FEDataType.MAT3FS:
        return _MAT3FS_ORDER
    if dtype == FEDataType.MAT3F:
        return _MAT3F_ORDER
    return None


def _tens4fs_pair_index(p: str) -> int:
    p = _normalize_comp_token(p)
    a, b = p[:2], p[2:]
    if a not in _VOIGT6 or b not in _VOIGT6:
        raise KeyError(f"invalid pair '{p}'")
    i, j = _VOIGT6[a], _VOIGT6[b]
    if i > j:
        i, j = j, i
    return i + (j * (j + 1)) // 2


def _comp_spec_to_index(meta: _FieldMeta, spec: Index | str | Iterable[str]) -> Index:
    if spec is None:
        return None
    if isinstance(spec, (int, slice, np.ndarray)):
        return spec
    if isinstance(spec, str):
        s = _normalize_comp_token(spec)
        if s == ":":
            return slice(None)
        names = _comp_names_for_dtype(meta.dtype)
        if names is not None:
            try:
                return names.index(s)
            except ValueError as e:
                raise KeyError(
                    f"component '{spec}' not valid for {meta.dtype.name}. Valid: {names}"
                ) from e
        if meta.dtype == FEDataType.TENS4FS:
            return _tens4fs_pair_index(s)
        raise KeyError(f"component strings not supported for {meta.dtype.name}")
    lst = list(spec)
    out: list[int] = []
    for item in lst:
        if not isinstance(item, str):
            raise TypeError("component list must be strings")
        idx_item = _comp_spec_to_index(meta, item)
        if isinstance(idx_item, int):
            out.append(idx_item)
        else:
            raise TypeError("component list does not accept slices")
    return np.asarray(out, dtype=np.int64)


def _as_index_list(sel: Index, T: int) -> list[int]:
    if sel is None:
        return list(range(T))
    if isinstance(sel, int):
        return [sel]
    if isinstance(sel, slice):
        return list(range(T))[sel]
    if isinstance(sel, np.ndarray) and sel.dtype == bool:
        return list(np.nonzero(sel)[0].tolist())
    return list(sel)


def _shape_from_slice(n: int, s: Index) -> int:
    if s is None:
        return n
    if isinstance(s, int):
        return 1
    if isinstance(s, slice):
        start, stop, step = s.indices(n)
        return max(0, (stop - start + (step - 1)) // step)
    return len(list(s))


class _BaseView:
    """Common behavior for result views."""

    __slots__ = ("_comp_idx", "_t_idx", "_times", "meta")

    def __init__(self, meta: _FieldMeta, times: np.ndarray):
        self.meta = meta
        self._times = np.asarray(times, float)
        self._comp_idx: Index = None
        self._t_idx: Index = None

    def time(self, idx: Index | str = None) -> Self:
        """Select time indices to apply on next evaluation.

        Args:
            idx: Integer, slice, list/array, boolean mask, or ``":"``.

        Returns:
            Self: View for chaining.

        Example:
            ``view.time(0).comp(":")`` selects the first state.
        """
        if idx == ":":
            idx = slice(None)
        self._t_idx = idx  # type: ignore[assignment]
        return self

    def comp(self, idx: Index | str | Iterable[str]) -> np.ndarray:
        """Select components and immediately evaluate the array.

        Args:
            idx: Integer/slice, iterable, or component name(s) like ``"x"``.

        Returns:
            np.ndarray: Result array with selections applied.
        """
        self._comp_idx = _comp_spec_to_index(self.meta, idx)
        return self.eval()

    def dims(self) -> tuple[str, ...]:
        raise NotImplementedError

    def eval(self) -> np.ndarray:
        raise NotImplementedError

    def __len__(self) -> int:
        return int(self._times.shape[0])


class NodeResultView(_BaseView):
    """Global nodal results (FMT_NODE)."""

    __slots__ = ("_mesh", "_node_idx", "_per_t")

    def __init__(
        self,
        meta: _FieldMeta,
        times: np.ndarray,
        per_t: list[np.ndarray | None],
        mesh: Mesh,
    ):
        """Create a nodal view backed by per-time arrays."""
        super().__init__(meta, times)
        self._per_t = per_t
        self._mesh = mesh
        self._node_idx: Index = None

    def nodes(self, ids: Index) -> Self:
        """Select node rows by indices, slices, masks, or lists.

        Returns:
            Self: View for chaining.
        """
        self._node_idx = ids
        return self

    def nodeset(self, name: str) -> Self:
        """Select by nodeset name from the mesh.

        Returns:
            Self: View for chaining.

        Example:
            ``view.nodeset("base").comp("z")`` selects nodeset ``base``.
        """
        self._node_idx = self._mesh.nodesets[name]
        return self

    def dims(self) -> tuple[str, ...]:
        """Describe array axes.

        Returns:
            tuple[str, ...]: Axis labels.
        """
        return ("time", "node", "component")

    def eval(self) -> np.ndarray:
        """Return an array with the current selections applied."""
        T_sel = _as_index_list(self._t_idx, len(self))
        N0 = next((a.shape[0] for a in self._per_t if a is not None), 0)
        C0 = self.meta.ncomp
        N_sel = _shape_from_slice(N0, self._node_idx)
        C_sel = _shape_from_slice(C0, self._comp_idx)

        def _one(k: int) -> np.ndarray:
            a = self._per_t[k]
            if a is None:
                return np.full((N_sel, C_sel), np.nan, dtype=np.float32)
            out = a
            if self._node_idx is not None:
                out = out[self._node_idx, :]  # type: ignore[index]
            if self._comp_idx is not None:
                out = out[..., self._comp_idx]  # type: ignore[index]
            out = np.asarray(out, dtype=np.float32)
            if out.ndim == 1:
                out = out.reshape(1, -1) if N_sel == 1 else out.reshape(-1, 1)
            return out

        if len(T_sel) == 1 and isinstance(self._t_idx, int):
            return _one(T_sel[0])
        return np.stack([_one(k) for k in T_sel], axis=0)

    def __getitem__(self, key: Any) -> np.ndarray:
        """Shorthand selection: ``[time, nodes, comp]``.

        Args:
            key: Tuple of selectors or single selector for time.

        Returns:
            np.ndarray: Array with selections applied.
        """
        if not isinstance(key, tuple):
            t = key
            if isinstance(t, str) and t == ":":
                t = slice(None)
            self._t_idx = t
            self._node_idx = None
            self._comp_idx = None
            return self.eval()

        t, n, c = (key + (slice(None),) * 3)[:3]
        if isinstance(t, str) and t == ":":
            t = slice(None)
        if isinstance(n, str) and n == ":":
            n = slice(None)
        if isinstance(c, str):
            c = slice(None) if c == ":" else _comp_spec_to_index(self.meta, c)

        self._t_idx = t
        self._node_idx = n
        self._comp_idx = c
        return self.eval()

    def __repr__(self) -> str:
        N = next((a.shape[0] for a in self._per_t if a is not None), 0)
        C = next((a.shape[1] for a in self._per_t if a is not None), self.meta.ncomp)
        missing = sum(1 for a in self._per_t if a is None)
        return (
            "NodeResultView("
            f"name={self.meta.name!r}, T={len(self)}, "
            f"N={N}, C={C}, missing={missing})"
        )

    __str__ = __repr__


class NodeRegionResultView(_BaseView):
    """Per-domain nodal results (FMT_NODE split by region)."""

    __slots__ = ("_node_idx", "_per_name", "_region_nodes")

    def __init__(
        self,
        meta: _FieldMeta,
        times: np.ndarray,
        per_name: dict[str, list[np.ndarray | None]],
        region_nodes: dict[str, np.ndarray],
    ):
        """Create a per-region nodal view."""
        super().__init__(meta, times)
        self._per_name = per_name
        self._region_nodes = region_nodes
        self._node_idx: Index = None

    def regions(self) -> list[str]:
        """Return available region names."""
        return sorted(self._per_name.keys())

    domains = regions

    def region(self, name: str) -> NodeRegionResultView:
        """Restrict the view to a single region.

        Returns:
            NodeRegionResultView: View for the selected region.
        """
        if name not in self._per_name:
            raise KeyError(name)
        return NodeRegionResultView(
            self.meta,
            self._times,
            {name: self._per_name[name]},
            {name: self._region_nodes[name]},
        )

    domain = region

    def region_nodes(self) -> np.ndarray:
        """Return node ids for the selected region."""
        if len(self._region_nodes) != 1:
            raise ValueError("multiple regions present; select region() first")
        return next(iter(self._region_nodes.values()))

    def nodes(self, ids: Index) -> Self:
        """Select nodes by indices relative to the region list.

        Returns:
            Self: View for chaining.
        """
        self._node_idx = ids
        return self

    def dims(self) -> tuple[str, ...]:
        """Describe array axes.

        Returns:
            tuple[str, ...]: Axis labels.
        """
        return ("time", "node_in_region", "component")

    def _pick_per_t(self) -> list[np.ndarray | None]:
        if len(self._per_name) != 1:
            raise ValueError("multiple regions present; select region() first")
        return next(iter(self._per_name.values()))

    def eval(self, *, region: str | None = None) -> np.ndarray:
        """Return array with current selections."""
        per_t = self._pick_per_t() if region is None else self._per_name[region]
        T_sel = _as_index_list(self._t_idx, len(self))
        N0 = next((a.shape[0] for a in per_t if a is not None), 0)
        C0 = self.meta.ncomp
        N_sel = _shape_from_slice(N0, self._node_idx)
        C_sel = _shape_from_slice(C0, self._comp_idx)

        def _one(k: int) -> np.ndarray:
            a = per_t[k]
            if a is None:
                return np.full((N_sel, C_sel), np.nan, dtype=np.float32)
            out = a
            if self._node_idx is not None:
                out = out[self._node_idx, :]  # type: ignore[index]
            if self._comp_idx is not None:
                out = out[..., self._comp_idx]  # type: ignore[index]
            out = np.asarray(out, dtype=np.float32)
            if out.ndim == 1:
                out = out.reshape(1, -1) if N_sel == 1 else out.reshape(-1, 1)
            return out

        if len(T_sel) == 1 and isinstance(self._t_idx, int):
            return _one(T_sel[0])
        return np.stack([_one(k) for k in T_sel], axis=0)

    def __repr__(self) -> str:
        names = list(self._per_name.keys())
        T = len(self)
        if len(names) == 1:
            name = names[0]
            per_t = self._per_name[name]
            N = next((a.shape[0] for a in per_t if a is not None), 0)
            C = next((a.shape[1] for a in per_t if a is not None), self.meta.ncomp)
            missing = sum(1 for a in per_t if a is None)
            return (
                "NodeRegionResultView("
                f"name={self.meta.name!r}, region={name!r}, "
                f"T={T}, N={N}, C={C}, missing={missing})"
            )
        parts = []
        for d in sorted(names):
            per_t = self._per_name[d]
            N = next((a.shape[0] for a in per_t if a is not None), 0)
            C = next((a.shape[1] for a in per_t if a is not None), self.meta.ncomp)
            missing = sum(1 for a in per_t if a is None)
            parts.append(f"{d}:N={N},C={C},missing={missing}")
        return (
            f"NodeRegionResultView(name={self.meta.name!r}, T={T}, regions={len(names)} | "
            + "; ".join(parts)
            + ")"
        )

    __str__ = __repr__


class ItemResultView(_BaseView):
    """Per-item results (FMT_ITEM)."""

    __slots__ = ("_item_idx", "_per_name")

    def __init__(
        self,
        meta: _FieldMeta,
        times: np.ndarray,
        per_name: dict[str, list[np.ndarray | None]],
    ):
        """Create an item view backed by per-region arrays."""
        super().__init__(meta, times)
        self._per_name = per_name
        self._item_idx: Index = None

    def regions(self) -> list[str]:
        """Return available regions."""
        return sorted(self._per_name.keys())

    domains = regions
    surfaces = regions

    def region(self, name: str) -> ItemResultView:
        """Restrict the view to a single region.

        Returns:
            ItemResultView: View limited to ``name``.
        """
        if name not in self._per_name:
            raise KeyError(name)
        return ItemResultView(self.meta, self._times, {name: self._per_name[name]})

    domain = region

    def items(self, idx: Index) -> ItemResultView:
        """Select items by indices, slices, masks, or lists.

        Returns:
            ItemResultView: View for chaining.
        """
        self._item_idx = idx
        return self

    elems = items
    faces = items

    def dims(self) -> tuple[str, ...]:
        """Describe array axes.

        Returns:
            tuple[str, ...]: Axis labels.
        """
        return ("time", "item", "component")

    def _pick_per_t(self) -> list[np.ndarray | None]:
        if len(self._per_name) != 1:
            raise ValueError("multiple regions present; select region() first")
        return next(iter(self._per_name.values()))

    def eval(self, *, region: str | None = None) -> np.ndarray:
        """Return array with current selections."""
        per_t = self._pick_per_t() if region is None else self._per_name[region]
        T_sel = _as_index_list(self._t_idx, len(self))
        I0 = next((a.shape[0] for a in per_t if a is not None), 0)
        C0 = self.meta.ncomp
        I_sel = _shape_from_slice(I0, self._item_idx)
        C_sel = _shape_from_slice(C0, self._comp_idx)

        def _one(k: int) -> np.ndarray:
            a = per_t[k]
            if a is None:
                return np.full((I_sel, C_sel), np.nan, dtype=np.float32)
            out = a
            if self._item_idx is not None:
                out = out[self._item_idx, :]  # type: ignore[index]
            if self._comp_idx is not None:
                out = out[..., self._comp_idx]  # type: ignore[index]
            out = np.asarray(out, dtype=np.float32)
            if out.ndim == 1:
                out = out.reshape(1, -1) if I_sel == 1 else out.reshape(-1, 1)
            return out

        if len(T_sel) == 1 and isinstance(self._t_idx, int):
            return _one(T_sel[0])
        return np.stack([_one(k) for k in T_sel], axis=0)

    def __repr__(self) -> str:
        names = list(self._per_name.keys())
        T = len(self)
        parts = []
        for d in sorted(names):
            per_t = self._per_name[d]
            n_items = next((a.shape[0] for a in per_t if a is not None), 0)
            C = next((a.shape[1] for a in per_t if a is not None), self.meta.ncomp)
            missing = sum(1 for a in per_t if a is None)
            parts.append(f"{d}:I={n_items},C={C},missing={missing}")
        return (
            f"ItemResultView(name={self.meta.name!r}, T={T}, regions={len(names)} | "
            + "; ".join(parts)
            + ")"
        )

    __str__ = __repr__


class MultResultView(_BaseView):
    """Per-item per-element-node results (FMT_MULT)."""

    __slots__ = ("_enode_idx", "_item_idx", "_per_name")

    def __init__(
        self,
        meta: _FieldMeta,
        times: np.ndarray,
        per_name: dict[str, list[MultLike | None]],
    ):
        """Create a per-item per-enode view backed by region blocks."""
        super().__init__(meta, times)
        self._per_name = per_name
        self._item_idx: Index = None
        self._enode_idx: Index = None

    def regions(self) -> list[str]:
        """Return available regions."""
        return sorted(self._per_name.keys())

    domains = regions
    surfaces = regions

    def region(self, name: str) -> MultResultView:
        """Restrict the view to a single region.

        Returns:
            MultResultView: View limited to ``name``.
        """
        if name not in self._per_name:
            raise KeyError(name)
        return MultResultView(self.meta, self._times, {name: self._per_name[name]})

    domain = region

    def items(self, idx: Index) -> MultResultView:
        """Select items by indices, slices, masks, or lists.

        Returns:
            MultResultView: View for chaining.
        """
        self._item_idx = idx
        return self

    elems = items
    faces = items

    def enodes(self, idx: Index) -> MultResultView:
        """Select element-node positions by indices or slices.

        Returns:
            MultResultView: View for chaining.
        """
        self._enode_idx = idx
        return self

    nodes = enodes

    def dims(self) -> tuple[str, ...]:
        """Describe array axes.

        Returns:
            tuple[str, ...]: Axis labels.
        """
        return ("time", "item", "enode", "component")

    def _pick_per_t(self) -> list[MultLike | None]:
        if len(self._per_name) != 1:
            raise ValueError("multiple regions present; select region() first")
        return next(iter(self._per_name.values()))

    def _shape(self, per_t: list[MultLike | None]) -> tuple[int, int, int]:
        """Infer item/enode/component counts from a region block list.

        Returns:
            tuple[int, int, int]: (items, enodes, components).
        """
        n_items = next(
            (
                a.data.shape[0] if isinstance(a, MultBlock) else a.shape[0]
                for a in per_t
                if a is not None
            ),
            0,
        )
        enodes = next(
            (
                a.data.shape[1] if isinstance(a, MultBlock) else a.shape[1]
                for a in per_t
                if a is not None
            ),
            0,
        )
        comps = next(
            (
                a.data.shape[2] if isinstance(a, MultBlock) else a.shape[2]
                for a in per_t
                if a is not None
            ),
            self.meta.ncomp,
        )
        return n_items, enodes, comps

    def eval(self, *, region: str | None = None) -> np.ndarray:
        """Return array with current selections."""
        per_t = self._pick_per_t() if region is None else self._per_name[region]
        T_sel = _as_index_list(self._t_idx, len(self))
        n_items, k_count, comp_count = self._shape(per_t)
        I_sel = _shape_from_slice(n_items, self._item_idx)
        K_sel = _shape_from_slice(k_count, self._enode_idx)
        C_sel = _shape_from_slice(comp_count, self._comp_idx)

        def _block_data(block: MultLike) -> np.ndarray:
            data = block.data if isinstance(block, MultBlock) else block
            out = data
            if self._item_idx is not None:
                out = out[self._item_idx, :, :]  # type: ignore[index]
            if self._enode_idx is not None:
                out = out[:, self._enode_idx, :]  # type: ignore[index]
            if self._comp_idx is not None:
                out = out[..., self._comp_idx]  # type: ignore[index]
            out = np.asarray(out, dtype=np.float32)
            if out.ndim == 1:
                out = out.reshape(1, -1) if I_sel == 1 else out.reshape(-1, 1)
            return out

        def _one(k: int) -> np.ndarray:
            a = per_t[k]
            if a is None:
                return np.full((I_sel, K_sel, C_sel), np.nan, dtype=np.float32)
            return _block_data(a)

        if len(T_sel) == 1 and isinstance(self._t_idx, int):
            return _one(T_sel[0])
        return np.stack([_one(k) for k in T_sel], axis=0)

    def __repr__(self) -> str:
        names = list(self._per_name.keys())
        T = len(self)
        parts = []
        for d in sorted(names):
            per_t = self._per_name[d]
            n_items, K, C = self._shape(per_t)
            missing = sum(1 for a in per_t if a is None)
            parts.append(f"{d}:I={n_items},K={K},C={C},missing={missing}")
        return (
            f"MultResultView(name={self.meta.name!r}, T={T}, regions={len(names)} | "
            + "; ".join(parts)
            + ")"
        )

    __str__ = __repr__


class RegionResultView(_BaseView):
    """Per-region vector results (FMT_REGION)."""

    __slots__ = ("_per_name", "_region_idx")

    def __init__(
        self,
        meta: _FieldMeta,
        times: np.ndarray,
        per_name: dict[str, list[np.ndarray | None]],
    ):
        """Create a per-region vector view."""
        super().__init__(meta, times)
        self._per_name = per_name
        self._region_idx: Index = None

    def regions(self) -> list[str]:
        """Return available regions."""
        return sorted(self._per_name.keys())

    domains = regions
    surfaces = regions

    def region(self, name: str) -> RegionResultView:
        """Restrict the view to a single region.

        Returns:
            RegionResultView: View limited to ``name``.
        """
        if name not in self._per_name:
            raise KeyError(name)
        return RegionResultView(self.meta, self._times, {name: self._per_name[name]})

    domain = region

    def dims(self) -> tuple[str, ...]:
        """Describe array axes.

        Returns:
            tuple[str, ...]: Axis labels.
        """
        return ("time", "component")

    def _pick_per_t(self) -> list[np.ndarray | None]:
        if len(self._per_name) != 1:
            raise ValueError("multiple regions present; select region() first")
        return next(iter(self._per_name.values()))

    def eval(self) -> np.ndarray:
        """Return array with current selections."""
        per_t = self._pick_per_t()
        T_sel = _as_index_list(self._t_idx, len(self))
        C0 = self.meta.ncomp
        C_sel = _shape_from_slice(C0, self._comp_idx)

        def _one(k: int) -> np.ndarray:
            a = per_t[k]
            if a is None:
                return np.full((C_sel,), np.nan, dtype=np.float32)
            out = a
            if self._comp_idx is not None:
                out = out[self._comp_idx]  # type: ignore[index]
            out = np.asarray(out, dtype=np.float32)
            if out.ndim == 0:
                out = out.reshape(1)
            return out

        if len(T_sel) == 1 and isinstance(self._t_idx, int):
            return _one(T_sel[0])
        return np.stack([_one(k) for k in T_sel], axis=0)

    def __repr__(self) -> str:
        T = len(self)
        names = list(self._per_name.keys())
        if len(names) == 1:
            name = names[0]
            per_t = self._per_name[name]
            C = next((a.shape[0] for a in per_t if a is not None), self.meta.ncomp)
            missing = sum(1 for a in per_t if a is None)
            return (
                "RegionResultView("
                f"name={self.meta.name!r}, region={name!r}, "
                f"T={T}, C={C}, missing={missing})"
            )
        parts = []
        for n in sorted(names):
            per_t = self._per_name[n]
            C = next((a.shape[0] for a in per_t if a is not None), self.meta.ncomp)
            missing = sum(1 for a in per_t if a is None)
            parts.append(f"{n}:C={C},missing={missing}")
        return (
            f"RegionResultView(name={self.meta.name!r}, T={T}, regions={len(names)} | "
            + "; ".join(parts)
            + ")"
        )

    __str__ = __repr__


__all__ = [
    "ItemResultView",
    "MultResultView",
    "NodeRegionResultView",
    "NodeResultView",
    "RegionResultView",
    "_FieldMeta",
    "_comp_spec_to_index",
]
