"""FEBio .xplt reader with regional FMT_NODE support and sliceable result views.

Overview
--------
This module reads FEBio ``.xplt`` files and exposes results through *views*.
Views let you slice by time, region, item (element/face), node, and component
without copying unless necessary.

Workflow
^^^^^^^^
1) Header and dictionary

   - Validate magic/version/compression.
   - Parse the dictionary (dtype + storage format FMT_NODE/FMT_ITEM/FMT_MULT/FMT_REGION).
   - Record which stream a variable belongs to (node, element, face).

2) Mesh

   - Parse nodes, domains (parts), surfaces, and nodesets.
   - Pad connectivity to fixed width for simple slicing.
   - Build :class:`Mesh` with node/element/surface tables and preserved names.

3) States

   - Read all states with ``readAllStates()`` or selected ones via ``readSteps([...])``.
   - Each state stores time plus node/element/face streams.
   - Format and region id are detected per stream; missing data is stored as ``None``.

4) Finalization

   - Per-variable buffers become *views* stored on :class:`Results` at ``xplt.results``.
   - Shapes are left as written; no interpolation is applied.

Views concept
-------------
Selectors share a consistent API:

- ``time(idx)`` picks states by index or slice.
- ``comp(idx)`` picks components by index, slice, or name (when available).
- Regions: ``region(name)``/``domain(name)`` or ``surface(name)``; ``regions()`` lists names.
- Items/nodes: ``items(ids)``/``elems(ids)``/``faces(ids)``, ``enodes(ids)``, ``nodes(ids)``.
- Shorthand ``__getitem__`` works on many views.

Selectors accept integers, slices, lists/NumPy index arrays, or the string token ``":"`` for
``all``. Returned array shapes follow the view's ``dims()`` description.

View types
^^^^^^^^^^
- :class:`NodeResultView` (FMT_NODE global): dims = (time, node, component).
- :class:`NodeRegionResultView` (FMT_NODE per-domain): dims = (time, node_in_region, component).
- :class:`ItemResultView` (FMT_ITEM domains/surfaces): dims = (time, item, component).
- :class:`MultResultView` (FMT_MULT per element-node): dims = (time, item, enode, component).
- :class:`RegionResultView` (FMT_REGION): dims = (time, component).

Component names
^^^^^^^^^^^^^^^
- VEC3F: ``"x" "y" "z"``
- MAT3FD: ``"xx" "yy" "zz"``
- MAT3FS: ``"xx" "yy" "zz" "xy" "yz" "xz"``
- MAT3F: full 3x3 row-major tokens
- TENS4FS: Voigt-6 style tokens like ``"xxyy"``

Examples:
^^^^^^^^^
Read all states and grab displacement:

.. code-block:: python

   from pyfebiopt.xplt import xplt

   xp = xplt("run.xplt")
   xp.readAllStates()
   r = xp.results
   U = r["displacement"]                  # NodeResultView or NodeRegionResultView
   U0 = U.time(0).comp("x")               # first time, x-component, all nodes
   Uset = U.nodes(xp.mesh.nodesets["base"]).comp("z")

Direct slicing with ``__getitem__``:

.. code-block:: python

   Uxz = r["displacement"][0, :, "x"]     # time 0, all nodes, x-comp

Per-domain nodal results (FMT_NODE per region):

.. code-block:: python

   Ur = r["displacement"].region("arteria")
   Ur_vals = Ur.time(":").nodes(slice(0, 10)).comp("y")

Element item results:

.. code-block:: python

   S = r["von Mises"].domain("arteria")   # ItemResultView
   S_last = S.time(-1).items([0, 5, 9]).comp(":")

Per-element-node results:

.. code-block:: python

   Q = r["strain"].domain("arteria")      # MultResultView
   Qpick = Q.time(0).items(0).enodes(":").comp("xx")

Region vector results:

.. code-block:: python

   R = r["domain volume"].domain("arteria")  # RegionResultView
   R_all_t = R.time(":").comp(":")

Missing data
^^^^^^^^^^^^
Missing variables for a time/region return arrays of NaN with preserved shapes.
All views convert outputs to float32 to keep memory usage low.

Implementation notes
^^^^^^^^^^^^^^^^^^^^
- Domain and surface names are taken from the file when available.
- Node ids in the file may be 1-based or 0-based; they are normalised.
- Connectivity and facet lists are right-padded with ``-1`` for uniform width.
"""

from __future__ import annotations

import struct
from os import SEEK_SET

import numpy as np

from pyfebiopt.mesh.mesh import Mesh

from .binary_reader import BinaryReader
from .dictionary import BufferStore, prepare_field_meta_and_buffers, read_dictionary
from .enums import Storage_Fmt
from .mesh_parser import MeshParseResult, parse_mesh
from .results import Results
from .types import (
    MultBlock,
    MultLike,
)
from .views import _FieldMeta

# Views are implemented in pyfebiopt.XPLT.views and imported above.


# Results container lives in pyfebiopt.XPLT.results


# --------------------------- reader ---------------------------


class xplt:
    """FEBio ``.xplt`` reader.

    Features
    --------
    - Global and per-domain FMT_NODE.
    - FMT_ITEM, FMT_MULT, and FMT_REGION for domains and surfaces.
    - Builds a :class:`Mesh` with parts, surfaces, and nodesets.

    Typical use
    -----------
    >>> xp = xplt("run.xplt")
    >>> xp.readAllStates()
    >>> r = xp.results
    >>> disp = r["displacement"][0, ":", "x"]
    """

    def __init__(self, filename: str):
        """Open and parse the xplt file header, dictionary, and mesh."""
        self._reader = BinaryReader(filename)
        self._time: list[float] = []
        self._readMode = ""

        # dictionary
        self.dictionary: dict[str, dict[str, str]] = {}
        self._field_meta: dict[str, _FieldMeta] = {}
        self._dict_order: dict[str, list[str]] = {"node": [], "elem": [], "face": []}
        self._where: dict[str, str] = {}  # var -> "node" | "elem" | "face"

        # ids and ordinals
        self._part_id2name: dict[int, str] = {}
        self._surf_id2name: dict[int, str] = {}
        self._dom_idx2name: dict[int, str] = {}  # PLT_DOMAIN ordinal -> domain name
        self._surf_idx2name: dict[int, str] = {}  # PLT_SURFACE ordinal -> surface name

        # mesh maps
        self._parts_map: dict[str, list[int]] = {}
        self._surfaces_map: dict[str, list[np.ndarray]] = {}
        self._nodesets_map: dict[str, np.ndarray] = {}
        self._domain_nodes: dict[str, np.ndarray] = {}
        self._mesh_result: MeshParseResult | None = None

        # result buffers
        self._buf_node_global: dict[str, list[np.ndarray | None]] = {}
        self._buf_node_region: dict[str, dict[str, list[np.ndarray | None]]] = {}
        self._buf_elem_item: dict[str, dict[str, list[np.ndarray | None]]] = {}
        self._buf_elem_mult: dict[str, dict[str, list[MultLike | None]]] = {}
        self._buf_elem_region: dict[str, dict[str, list[np.ndarray | None]]] = {}
        self._buf_face_item: dict[str, dict[str, list[np.ndarray | None]]] = {}
        self._buf_face_mult: dict[str, dict[str, list[MultLike | None]]] = {}
        self._buf_face_region: dict[str, dict[str, list[np.ndarray | None]]] = {}

        # per-state scratch
        self._state_node_global_seen: dict[str, np.ndarray] = {}
        self._state_node_region_seen: dict[str, dict[str, np.ndarray]] = {}
        self._state_elem_item_seen: dict[str, dict[str, np.ndarray]] = {}
        self._state_elem_mult_seen: dict[str, dict[str, MultBlock]] = {}
        self._state_elem_region_seen: dict[str, dict[str, np.ndarray]] = {}
        self._state_face_item_seen: dict[str, dict[str, np.ndarray]] = {}
        self._state_face_mult_seen: dict[str, dict[str, MultBlock]] = {}
        self._state_face_region_seen: dict[str, dict[str, np.ndarray]] = {}

        # parse header+dict+mesh
        self._read_xplt()
        if self._mesh_result is None:
            raise RuntimeError("Failed to parse mesh")
        self.mesh: Mesh = self._mesh_result.mesh
        self._domain_nodes = self._mesh_result.domain_nodes

        self.version: int
        self.compression: int
        self.results: Results = Results([])

    def __repr__(self) -> str:
        return (
            "xplt("
            f"v={getattr(self, 'version', 'NA')}, "
            f"comp={getattr(self, 'compression', 'NA')}, "
            f"ntimes={len(self._time)}, vars={len(self.dictionary)})"
        )

    __str__ = __repr__

    def _assign_buffers(self, buffers: BufferStore) -> None:
        """Attach freshly prepared buffers to this reader."""
        self._buf_node_global = buffers.node_global
        self._buf_node_region = buffers.node_region
        self._buf_elem_item = buffers.elem_item
        self._buf_elem_mult = buffers.elem_mult
        self._buf_elem_region = buffers.elem_region
        self._buf_face_item = buffers.face_item
        self._buf_face_mult = buffers.face_mult
        self._buf_face_region = buffers.face_region

    # states

    @staticmethod
    def _pack_mult(block_flat: np.ndarray, nper: np.ndarray, C: int) -> np.ndarray:
        """Pack a flat MULT block into a padded 3D array.

        Returns:
            Mult timeseries shaped (R, Kmax, C) with zero padding for missing items.
        """
        R = int(nper.size)
        Kmax = int(nper.max()) if R > 0 else 0
        out = np.full((R, Kmax, C), np.nan, dtype=np.float32)
        off = 0
        for r in range(R):
            k = int(nper[r])
            if k > 0:
                out[r, :k, :] = block_flat[off : off + k, :]
                off += k
        return out

    def _flush_state_missing(self) -> None:
        """Append per-state buffers and mark missing entries as None."""
        # node global
        for v in self._buf_node_global.keys():
            self._buf_node_global[v].append(self._state_node_global_seen.get(v, None))
        # node per-region
        dom_names = list(self._dom_idx2name.values())
        for v in list(self._buf_node_region.keys()):
            if not self._buf_node_region[v]:
                for d in dom_names:
                    self._buf_node_region[v][d] = []
            seen = self._state_node_region_seen.get(v, {})
            for d in self._buf_node_region[v].keys():
                self._buf_node_region[v][d].append(seen.get(d, None))

        # elem item/mult/region
        for v in list(self._buf_elem_item.keys()):
            if not self._buf_elem_item[v]:
                for d in dom_names:
                    self._buf_elem_item[v][d] = []
            seen = self._state_elem_item_seen.get(v, {})
            for d in self._buf_elem_item[v].keys():
                self._buf_elem_item[v][d].append(seen.get(d, None))

        for v in list(self._buf_elem_mult.keys()):
            if not self._buf_elem_mult[v]:
                for d in dom_names:
                    self._buf_elem_mult[v][d] = []
            seen_elem_mult: dict[str, MultBlock] = self._state_elem_mult_seen.get(v, {})
            for d in self._buf_elem_mult[v].keys():
                self._buf_elem_mult[v][d].append(seen_elem_mult.get(d, None))

        for v in list(self._buf_elem_region.keys()):
            if not self._buf_elem_region[v]:
                for d in dom_names:
                    self._buf_elem_region[v][d] = []
            seen = self._state_elem_region_seen.get(v, {})
            for d in self._buf_elem_region[v].keys():
                self._buf_elem_region[v][d].append(seen.get(d, None))

        # face item/mult/region
        surf_names = list(self._surf_idx2name.values())
        for v in list(self._buf_face_item.keys()):
            if not self._buf_face_item[v]:
                for s in surf_names:
                    self._buf_face_item[v][s] = []
            seen = self._state_face_item_seen.get(v, {})
            for s in self._buf_face_item[v].keys():
                self._buf_face_item[v][s].append(seen.get(s, None))

        for v in list(self._buf_face_mult.keys()):
            if not self._buf_face_mult[v]:
                for s in surf_names:
                    self._buf_face_mult[v][s] = []
            seen_face_mult: dict[str, MultBlock] = self._state_face_mult_seen.get(v, {})
            for s in self._buf_face_mult[v].keys():
                self._buf_face_mult[v][s].append(seen_face_mult.get(s, None))

        for v in list(self._buf_face_region.keys()):
            if not self._buf_face_region[v]:
                for s in surf_names:
                    self._buf_face_region[v][s] = []
            seen = self._state_face_region_seen.get(v, {})
            for s in self._buf_face_region[v].keys():
                self._buf_face_region[v][s].append(seen.get(s, None))

        # clear scratch
        self._state_node_global_seen.clear()
        self._state_node_region_seen.clear()
        self._state_elem_item_seen.clear()
        self._state_elem_mult_seen.clear()
        self._state_elem_region_seen.clear()
        self._state_face_item_seen.clear()
        self._state_face_mult_seen.clear()
        self._state_face_region_seen.clear()

    def _readResultStream(self, tag: str, kind: str) -> None:
        """Read a result stream for node/elem/face in the current state."""
        order = self._dict_order[kind]
        _ = self._reader.search_block(tag)
        while self._reader.check_block("PLT_STATE_VARIABLE"):
            self._reader.search_block("PLT_STATE_VARIABLE")
            self._reader.search_block("PLT_STATE_VAR_ID")
            var_id_raw = int(struct.unpack("I", self._reader.read(4))[0])
            dlen = self._reader.search_block("PLT_STATE_VAR_DATA")
            endp = self._reader.tell() + dlen

            # FEBio historically writes 1-based ids; be tolerant if zero-based
            idx = var_id_raw - 1
            if idx < 0 or idx >= len(order):
                idx = var_id_raw  # fallback to zero-based

            if idx < 0 or idx >= len(order):
                self._reader.seek(endp, SEEK_SET)
                continue

            dictKey = order[idx]
            meta = self._field_meta.get(dictKey)
            if meta is None:
                self._reader.seek(endp, SEEK_SET)
                continue
            C = meta.ncomp

            while self._reader.tell() < endp:
                reg_raw = int(struct.unpack("I", self._reader.read(4))[0])
                size_b = int(struct.unpack("I", self._reader.read(4))[0])
                nrows = int(size_b // (C * 4))
                if nrows <= 0:
                    continue

                # NODAL (global or per-domain)
                if kind == "node" and meta.fmt == Storage_Fmt.FMT_NODE:
                    block = np.frombuffer(
                        self._reader.read(4 * C * nrows), dtype=np.float32
                    ).reshape(nrows, C)
                    rid = reg_raw - 1
                    rname = self._dom_idx2name.get(rid)
                    if rname is not None:
                        self._state_node_region_seen.setdefault(dictKey, {})[rname] = block
                    else:
                        self._state_node_global_seen[dictKey] = block
                    continue

                # resolve region name + nper for elem/face
                if kind == "elem":
                    rid = reg_raw - 1
                    rname = self._dom_idx2name.get(rid, f"domain_{rid}")
                    elem_ids = np.asarray(self._parts_map.get(rname, []), dtype=np.int64)
                    nper = (
                        self.mesh.elements.nper[elem_ids]
                        if elem_ids.size > 0
                        else np.zeros((0,), dtype=np.int64)
                    )
                else:
                    rid = reg_raw - 1
                    rname = self._surf_idx2name.get(
                        rid, self._surf_id2name.get(rid, f"surface_{rid}")
                    )
                    nper = (
                        self.mesh.surfaces[rname].nper
                        if rname in self.mesh.surfaces
                        else np.zeros((0,), dtype=np.int64)
                    )

                # ITEM
                if meta.fmt == Storage_Fmt.FMT_ITEM:
                    block = np.frombuffer(
                        self._reader.read(4 * C * nrows), dtype=np.float32
                    ).reshape(nrows, C)
                    if kind == "elem":
                        self._state_elem_item_seen.setdefault(dictKey, {})[rname] = block
                    else:
                        self._state_face_item_seen.setdefault(dictKey, {})[rname] = block
                    continue

                # MULT
                if meta.fmt == Storage_Fmt.FMT_MULT:
                    flat = np.frombuffer(
                        self._reader.read(4 * C * nrows), dtype=np.float32
                    ).reshape(nrows, C)
                    packed = self._pack_mult(flat, nper.astype(np.int64), C)
                    mb = MultBlock(
                        packed.astype(np.float32, copy=False),
                        nper.astype(np.int64, copy=False),
                    )
                    if kind == "elem":
                        self._state_elem_mult_seen.setdefault(dictKey, {})[rname] = mb
                    else:
                        self._state_face_mult_seen.setdefault(dictKey, {})[rname] = mb
                    continue

                # REGION
                if meta.fmt == Storage_Fmt.FMT_REGION:
                    vec = np.frombuffer(
                        self._reader.read(4 * C * nrows), dtype=np.float32
                    ).reshape(nrows, C)
                    row = vec[0, :] if vec.ndim == 2 else vec.reshape(C)
                    if kind == "elem":
                        self._state_elem_region_seen.setdefault(dictKey, {})[rname] = row
                    else:
                        self._state_face_region_seen.setdefault(dictKey, {})[rname] = row
                    continue

                # skip mismatched
                self._reader.seek(self._reader.tell() + 4 * C * nrows, SEEK_SET)

    def _readState(self) -> int:
        """Read one PLT_STATE.

        Returns:
            0 on success; non-zero when the reader should stop.
        """
        size = self._reader.search_block("PLT_STATE")
        if size < 0:
            raise RuntimeError("No further PLT_STATE found")
        self._reader.search_block("PLT_STATE_HEADER")
        self._reader.search_block("PLT_STATE_HDR_TIME")
        t = float(struct.unpack("f", self._reader.read(4))[0])
        self._reader.search_block("PLT_STATE_STATUS")
        _ = int(struct.unpack("I", self._reader.read(4))[0])  # status flag (ignored)
        self._time.append(t)
        self._reader.search_block("PLT_STATE_DATA")
        try:
            self._readResultStream("PLT_NODE_DATA", "node")
            self._readResultStream("PLT_ELEMENT_DATA", "elem")
            self._readResultStream("PLT_FACE_DATA", "face")
        finally:
            self._flush_state_missing()
        return 0

    def _skipState(self) -> None:
        """Skip one PLT_STATE block."""
        size = self._reader.search_block("PLT_STATE")
        if size < 0:
            raise RuntimeError("No further PLT_STATE found while skipping")
        self._reader.skip(size)

    def readAllStates(self) -> None:
        """Read all states in order and finalize results."""
        if self._readMode == "readSteps":
            raise RuntimeError("readAllStates incompatible with readSteps")
        while True:
            try:
                self._readState()
            except RuntimeError:
                break  # likely no further PLT_STATE blocks
        self._readMode = "readAllStates"
        self._finalize_results()

    def readSteps(self, stepList: list[int]) -> None:
        """Read a selected list of step indices and finalize results.

        Steps are 0-based. The method advances by skipping blocks between
        requested steps.
        """
        if self._readMode == "readAllStates":
            raise RuntimeError("readSteps incompatible with readAllStates")
        for i, s in enumerate(stepList):
            stepDiff = s - (stepList[i - 1] if i > 0 else 0) - (1 if i > 0 else 0)
            for _ in range(stepDiff):
                self._skipState()
            self._readState()
        self._readMode = "readSteps"
        self._finalize_results()
        self._reader.close()

    # finalize

    def _finalize_results(self) -> None:
        """Convert buffers into views and attach them to the Results object."""
        times = np.asarray(self._time, float)
        self.results = Results(times)

        # node global
        for name, per_t in self._buf_node_global.items():
            if any(a is not None for a in per_t):
                sizes = {a.shape[0] for a in per_t if a is not None}
                if len(sizes) > 1:
                    raise ValueError(f"{name}: inconsistent nodal row count over time")
                self.results.register_node_global(name, self._field_meta[name], per_t, self.mesh)

        # node region
        for name, per_name in self._buf_node_region.items():
            has_any = any(any(a is not None for a in lst) for lst in per_name.values())
            if has_any:
                self.results.register_node_region(
                    name, self._field_meta[name], per_name, self._domain_nodes
                )

        # elem
        for name, per_item in self._buf_elem_item.items():
            self.results.register_item("elem", name, self._field_meta[name], per_item)
        for name, per_mult in self._buf_elem_mult.items():
            self.results.register_mult("elem", name, self._field_meta[name], per_mult)
        for name, per_region in self._buf_elem_region.items():
            self.results.register_region("elem", name, self._field_meta[name], per_region)

        # face
        for name, per_item in self._buf_face_item.items():
            self.results.register_item("face", name, self._field_meta[name], per_item)
        for name, per_mult in self._buf_face_mult.items():
            self.results.register_mult("face", name, self._field_meta[name], per_mult)
        for name, per_region in self._buf_face_region.items():
            self.results.register_region("face", name, self._field_meta[name], per_region)

        # clear
        self._buf_node_global.clear()
        self._buf_node_region.clear()
        self._buf_elem_item.clear()
        self._buf_elem_mult.clear()
        self._buf_elem_region.clear()
        self._buf_face_item.clear()
        self._buf_face_mult.clear()
        self._buf_face_region.clear()

    # header

    def _read_xplt(self) -> None:
        """Read file header then dictionary and mesh sections."""
        magic = int(struct.unpack("I", self._reader.read(4))[0])
        if magic != 4605250:
            raise RuntimeError("Not a valid xplt")
        root_size = self._reader.search_block("PLT_ROOT")
        root_start = self._reader.tell()
        root_end = root_start + root_size if root_size > 0 else self._reader.filesize

        self._reader.search_block("PLT_HEADER")
        self._reader.search_block("PLT_HDR_VERSION")
        self.version = int(struct.unpack("I", self._reader.read(4))[0])
        self._reader.search_block("PLT_HDR_COMPRESSION")
        self.compression = int(struct.unpack("I", self._reader.read(4))[0])
        self.dictionary, self._dict_order, self._where = read_dictionary(self._reader)
        self._field_meta, buffers = prepare_field_meta_and_buffers(self.dictionary, self._where)
        self._assign_buffers(buffers)
        # ensure we exit the ROOT payload before parsing mesh/state siblings
        self._reader.seek(root_end)

        self._mesh_result = parse_mesh(self._reader)
        self._part_id2name = self._mesh_result.part_id2name
        self._surf_id2name = self._mesh_result.surf_id2name
        self._dom_idx2name = self._mesh_result.dom_idx2name
        self._surf_idx2name = self._mesh_result.surf_idx2name
        self._parts_map = self._mesh_result.parts_map
        self._surfaces_map = self._mesh_result.surfaces_map
        self._nodesets_map = self._mesh_result.nodesets_map
