"""Container for XPLT result views."""

from __future__ import annotations

from collections.abc import Iterable

import numpy as np

from pyfebiopt.mesh.mesh import Mesh

from .types import MultLike
from .views import (
    ItemResultView,
    MultResultView,
    NodeRegionResultView,
    NodeResultView,
    RegionResultView,
    _FieldMeta,
)


class Results:
    """Holds all result views keyed by variable name."""

    def __init__(self, times: Iterable[float]):
        """Initialize empty result collections for each storage kind."""
        self._times = np.asarray(list(times), float)
        self.node: dict[str, NodeResultView] = {}  # global nodal
        self.node_region: dict[str, NodeRegionResultView] = {}  # nodal per-domain
        self.elem_item: dict[str, ItemResultView] = {}
        self.face_item: dict[str, ItemResultView] = {}
        self.elem_mult: dict[str, MultResultView] = {}
        self.face_mult: dict[str, MultResultView] = {}
        self.elem_region: dict[str, RegionResultView] = {}
        self.face_region: dict[str, RegionResultView] = {}
        self._meta: dict[str, _FieldMeta] = {}

    def __len__(self) -> int:
        """Return number of time steps."""
        return int(self._times.shape[0])

    def times(self) -> np.ndarray:
        """Return the time vector."""
        return self._times

    def register_node_global(
        self, name: str, meta: _FieldMeta, per_t: list[np.ndarray | None], mesh: Mesh
    ) -> None:
        """Register a global nodal variable."""
        self._meta[name] = meta
        self.node[name] = NodeResultView(meta, self._times, per_t, mesh)

    def register_node_region(
        self,
        name: str,
        meta: _FieldMeta,
        per_name: dict[str, list[np.ndarray | None]],
        region_nodes: dict[str, np.ndarray],
    ) -> None:
        """Register a per-domain nodal variable."""
        self._meta[name] = meta
        self.node_region[name] = NodeRegionResultView(meta, self._times, per_name, region_nodes)

    def register_item(
        self,
        where: str,
        name: str,
        meta: _FieldMeta,
        per_name: dict[str, list[np.ndarray | None]],
    ) -> None:
        """Register an item variable for elements or faces."""
        self._meta[name] = meta
        v = ItemResultView(meta, self._times, per_name)
        if where == "elem":
            self.elem_item[name] = v
        else:
            self.face_item[name] = v

    def register_mult(
        self,
        where: str,
        name: str,
        meta: _FieldMeta,
        per_name: dict[str, list[MultLike | None]],
    ) -> None:
        """Register a per-element-node variable for elements or faces."""
        self._meta[name] = meta
        v = MultResultView(meta, self._times, per_name)
        if where == "elem":
            self.elem_mult[name] = v
        else:
            self.face_mult[name] = v

    def register_region(
        self,
        where: str,
        name: str,
        meta: _FieldMeta,
        per_name: dict[str, list[np.ndarray | None]],
    ) -> None:
        """Register a per-region vector variable for elements or faces."""
        self._meta[name] = meta
        v = RegionResultView(meta, self._times, per_name)
        if where == "elem":
            self.elem_region[name] = v
        else:
            self.face_region[name] = v

    def __getitem__(
        self, key: str
    ) -> (
        ItemResultView
        | MultResultView
        | NodeResultView
        | NodeRegionResultView
        | RegionResultView
    ):
        """Look up a variable by name across all locations.

        Returns:
            Result view matching the variable name.
        """
        if key in self.node:
            return self.node[key]
        if key in self.node_region:
            return self.node_region[key]
        if key in self.elem_item:
            return self.elem_item[key]
        if key in self.face_item:
            return self.face_item[key]
        if key in self.elem_mult:
            return self.elem_mult[key]
        if key in self.face_mult:
            return self.face_mult[key]
        if key in self.elem_region:
            return self.elem_region[key]
        if key in self.face_region:
            return self.face_region[key]
        raise KeyError(key)

    def __repr__(self) -> str:
        return (
            f"Results(ntimes={len(self)}, "
            f"node={len(self.node)}, node_region={len(self.node_region)}, "
            f"elem_item={len(self.elem_item)}, elem_mult={len(self.elem_mult)}, "
            f"elem_region={len(self.elem_region)}, "
            f"face_item={len(self.face_item)}, face_mult={len(self.face_mult)}, "
            f"face_region={len(self.face_region)})"
        )

    __str__ = __repr__


__all__ = ["Results"]
