"""Helpers for translating FEBio meshes to PyVista geometries and fields."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import numpy as np
import pyvista as pv

from pyfebiopt.mesh.mesh import Mesh, SurfaceArray
from pyfebiopt.xplt.enums import FEDataType
from pyfebiopt.xplt.views import (
    ItemResultView,
    MultResultView,
    NodeResultView,
)

# -------------------------- internal helpers --------------------------

_VTK_MAP: dict[tuple[str, int], int] = {
    ("TET4", 4): pv.CellType.TETRA,
    ("TET10", 10): pv.CellType.QUADRATIC_TETRA,
    ("HEX8", 8): pv.CellType.HEXAHEDRON,
    ("HEX20", 20): pv.CellType.QUADRATIC_HEXAHEDRON,
    ("WEDGE", 6): pv.CellType.WEDGE,
    ("QUADRATIC_WEDGE", 15): pv.CellType.QUADRATIC_WEDGE,
    ("PYRA5", 5): pv.CellType.PYRAMID,
    ("QUAD4", 4): pv.CellType.QUAD,
    ("QUAD8", 8): pv.CellType.QUADRATIC_QUAD,
    ("TRI3", 3): pv.CellType.TRIANGLE,
    ("QUADRATIC_TRIANGLE", 6): pv.CellType.QUADRATIC_TRIANGLE,
    ("LINE2", 2): pv.CellType.LINE,
    ("LINE3", 3): pv.CellType.QUADRATIC_EDGE,
}


def _norm_label(raw: str, k: int) -> str:
    """Normalize FE labels when mapping to VTK cell types.

    Args:
        raw: Element label as read from the file.
        k: Node count for the element row.

    Returns:
        str: Canonical label used to pick a VTK cell type.
    """
    s = str(raw).upper().replace("ELEM_", "")
    if s.startswith("FE_"):
        s = s[3:]
    if s in {"HEX", "HEXA", "BRICK"}:
        return "HEX20" if k >= 20 else "HEX8"
    if s in {"TET", "TETRA"}:
        return "TET10" if k >= 10 else "TET4"
    if s in {"WEDGE", "PENTA"}:
        return "QUADRATIC_WEDGE" if k >= 15 else "WEDGE"
    if s in {"PYRA", "PYRAMID"}:
        return "PYRA5"
    if s in {"QUAD", "QUAD4"}:
        return "QUAD8" if k >= 8 else "QUAD4"
    if s in {"TRI", "TRIANGLE"}:
        return "QUADRATIC_TRIANGLE" if k >= 6 else "TRI3"
    if s in {"LINE"}:
        return "LINE3" if k >= 3 else "LINE2"
    return s


def _vtk_cell_type(etype: str, k: int) -> int | None:
    """Look up the PyVista cell id for the given FEBio element.

    Args:
        etype: FEBio element label (e.g., ``HEX8``).
        k: Number of nodes in the element.

    Returns:
        int | None: PyVista cell type id, or ``None`` if unsupported.
    """
    return _VTK_MAP.get((_norm_label(etype, k), k))


def _six_to_nine(voigt6: np.ndarray) -> np.ndarray:
    """Expand symmetric tensor components to full row-major 3x3 entries.

    Args:
        voigt6: Array shaped (N, 6) in Voigt order.

    Returns:
        np.ndarray: Array shaped (N, 9) with full 3x3 tensors.
    """
    v = np.asarray(voigt6, dtype=np.float32)
    out = np.empty((v.shape[0], 9), dtype=np.float32)
    xx, yy, zz, yz, xz, xy = v.T
    out[:, 0] = xx
    out[:, 1] = xy
    out[:, 2] = xz
    out[:, 3] = xy
    out[:, 4] = yy
    out[:, 5] = yz
    out[:, 6] = xz
    out[:, 7] = yz
    out[:, 8] = zz
    return out


def _attach_point(ds: pv.DataSet, name: str, arr: np.ndarray) -> None:
    """Attach point-wise data to the PyVista dataset."""
    a = np.asarray(arr, dtype=np.float32)
    ds.point_data[name] = a


def _attach_cell(ds: pv.DataSet, name: str, arr: np.ndarray) -> None:
    """Attach cell-wise data to the PyVista dataset."""
    a = np.asarray(arr, dtype=np.float32)
    ds.cell_data[name] = a


# -------------------------- public bridge --------------------------


@dataclass
class PVBridge:
    """Build per-domain grids and attach pre-sliced results to PyVista."""

    mesh: Mesh

    # ---------- grids ----------

    def domain_grid(self, domain: str) -> Any:
        """Build the unstructured grid for the requested domain.

        Args:
            domain: Domain/part name from the mesh.

        Returns:
            Any: PyVista ``UnstructuredGrid`` for the domain.

        Example:
            ``grid = PVBridge(mesh).domain_grid("artery")``
        """
        m = self.mesh
        if domain not in m.parts:
            raise KeyError(f"unknown domain '{domain}'")
        rows = np.asarray(m.parts[domain], dtype=np.int64)
        if rows.size == 0:
            raise ValueError(f"domain '{domain}' has no elements")

        xyz = np.asarray(m.nodes.xyz, dtype=np.float64, order="C")
        conn = np.asarray(m.elements.conn, dtype=np.int64, order="C")
        nper = np.asarray(m.elements.nper, dtype=np.int64, order="C")
        etype = np.asarray(m.elements.etype, dtype=object)

        cells_list: list[np.ndarray] = []
        ctype_list: list[int] = []

        for e in rows:
            k = int(nper[e])
            if k <= 0:
                continue
            vtk_id = _vtk_cell_type(str(etype[e]), k)
            if vtk_id is None:
                continue
            ids = conn[e, :k]
            if ids.min() < 0:
                continue
            cells_list.append(np.concatenate(([k], ids.astype(np.int64))))
            ctype_list.append(vtk_id)

        if not ctype_list:
            raise ValueError(f"domain '{domain}' has no supported cells")

        cells = np.ascontiguousarray(np.concatenate(cells_list).astype(np.int64))
        ctypes = np.ascontiguousarray(np.array(ctype_list, dtype=np.uint8))
        return pv.UnstructuredGrid(cells, ctypes, xyz).clean(tolerance=0.0)

    def surface_mesh(self, surface: str) -> Any:
        """Build the surface mesh for the requested surface.

        Args:
            surface: Surface name from the mesh.

        Returns:
            Any: PyVista ``PolyData`` for the surface.
        """
        m = self.mesh
        if surface not in m.surfaces:
            raise KeyError(f"unknown surface '{surface}'")
        sa: SurfaceArray = m.surfaces[surface]
        faces_list: list[np.ndarray] = []
        for f in range(sa.faces.shape[0]):
            k = int(sa.nper[f])
            if k <= 0:
                continue
            ids = sa.faces[f, :k]
            faces_list.append(np.concatenate(([k], ids.astype(np.int64))))
        if not faces_list:
            raise ValueError(f"surface '{surface}' has no faces")
        faces = np.ascontiguousarray(np.concatenate(faces_list).astype(np.int64))
        pts = np.asarray(m.nodes.xyz, dtype=np.float64, order="C")
        return pv.PolyData(pts, faces).clean()

    # ---------- node data (array provided) ----------

    def add_node_result_array(
        self,
        ds: pv.DataSet,
        *,
        view: NodeResultView,
        data: np.ndarray,
        name: str | None = None,
        set_active: bool = True,
    ) -> str:
        """Attach nodal results to the dataset with appropriate structure.

        Args:
            ds: Target PyVista dataset.
            view: Source view that describes the nodal result.
            data: Nodal array (N or N x C).
            name: Optional field name override.
            set_active: Whether to set scalars/vectors/tensors active.

        Returns:
            str: Name of the attached field.

        Example:
            ``bridge.add_node_result_array(grid, view=U_view, data=U, name="U")``
        """
        a = np.asarray(data, dtype=np.float32)
        if a.ndim == 1:
            if a.shape[0] != ds.n_points:
                raise ValueError(f"points {ds.n_points} != data {a.shape[0]}")
            _attach_point(ds, name or view.meta.name, a)
            if set_active:
                ds.set_active_scalars(name or view.meta.name)
            return name or view.meta.name

        if a.shape[0] != ds.n_points:
            raise ValueError(f"points {ds.n_points} != data {a.shape[0]}")
        c = a.shape[1]
        nm = name or view.meta.name

        if view.meta.dtype == FEDataType.VEC3F and c == 3:
            _attach_point(ds, nm, a)
            if set_active:
                ds.set_active_vectors(nm)
        elif view.meta.dtype in (FEDataType.MAT3FS, FEDataType.MAT3F):
            if c == 6 and view.meta.dtype == FEDataType.MAT3FS:
                _attach_point(ds, nm, _six_to_nine(a))
                if set_active:
                    ds.set_active_tensors(nm)
            elif c == 9:
                _attach_point(ds, nm, a)
                if set_active:
                    ds.set_active_tensors(nm)
            else:
                _attach_point(ds, nm, a[:, 0])
                if set_active:
                    ds.set_active_scalars(nm)
        else:
            _attach_point(ds, nm, a[:, 0] if c > 1 else a)
            if set_active:
                ds.set_active_scalars(nm)
        return nm

    # ---------- element FMT_ITEM (array provided) ----------

    def add_elem_item_array(
        self,
        ds: pv.UnstructuredGrid,
        *,
        view: ItemResultView,
        domain: str,
        data: np.ndarray,
        name: str | None = None,
        set_active: bool = True,
    ) -> str:
        """Attach the provided element array to the cell data.

        Args:
            ds: Target unstructured grid.
            view: Item view describing the result.
            domain: Domain label (unused; kept for symmetry with callers).
            data: Element result array.
            name: Optional field name override.
            set_active: Whether to set the field active.

        Returns:
            str: Name of the attached field.
        """
        _ = domain  # placeholder for future domain-aware hooks
        a = np.asarray(data, dtype=np.float32)
        if a.ndim == 1:
            if a.shape[0] != ds.n_cells:
                raise ValueError(f"cells {ds.n_cells} != data {a.shape[0]}")
            _attach_cell(ds, name or view.meta.name, a)
            if set_active:
                ds.set_active_scalars(name or view.meta.name)
            return name or view.meta.name

        if a.shape[0] != ds.n_cells:
            raise ValueError(f"cells {ds.n_cells} != data {a.shape[0]}")
        c = a.shape[1]
        nm = name or view.meta.name

        if view.meta.dtype == FEDataType.VEC3F and c == 3:
            _attach_cell(ds, nm, a)
            if set_active:
                ds.set_active_vectors(nm)
        elif view.meta.dtype in (FEDataType.MAT3FS, FEDataType.MAT3F):
            if c == 6 and view.meta.dtype == FEDataType.MAT3FS:
                _attach_cell(ds, nm, _six_to_nine(a))
                if set_active:
                    ds.set_active_tensors(nm)
            elif c == 9:
                _attach_cell(ds, nm, a)
                if set_active:
                    ds.set_active_tensors(nm)
            else:
                _attach_cell(ds, nm, a[:, 0])
                if set_active:
                    ds.set_active_scalars(nm)
        else:
            _attach_cell(ds, nm, a[:, 0] if c > 1 else a)
            if set_active:
                ds.set_active_scalars(nm)
        return nm

    # ---------- surface FMT_ITEM (array provided) ----------

    def add_face_item_array(
        self,
        ds: pv.PolyData | pv.UnstructuredGrid,
        *,
        view: ItemResultView,
        surface: str,
        data: np.ndarray,
        name: str | None = None,
        set_active: bool = True,
    ) -> str:
        """Attach the provided face array to the surface or cell data.

        Args:
            ds: Target polydata or unstructured grid.
            view: Item view describing the result.
            surface: Surface name (unused; for API symmetry).
            data: Face result array.
            name: Optional field name override.
            set_active: Whether to set the field active.

        Returns:
            str: Name of the attached field.
        """
        _ = surface  # placeholder for future surface-aware hooks
        a = np.asarray(data, dtype=np.float32)
        if a.ndim == 1:
            if a.shape[0] != ds.n_cells:
                raise ValueError(f"cells {ds.n_cells} != data {a.shape[0]}")
            _attach_cell(ds, name or view.meta.name, a)
            if set_active:
                ds.set_active_scalars(name or view.meta.name)
            return name or view.meta.name

        if a.shape[0] != ds.n_cells:
            raise ValueError(f"cells {ds.n_cells} != data {a.shape[0]}")
        c = a.shape[1]
        nm = name or view.meta.name

        if view.meta.dtype == FEDataType.VEC3F and c == 3:
            _attach_cell(ds, nm, a)
            if set_active:
                ds.set_active_vectors(nm)
        elif view.meta.dtype in (FEDataType.MAT3FS, FEDataType.MAT3F):
            if c == 6 and view.meta.dtype == FEDataType.MAT3FS:
                _attach_cell(ds, nm, _six_to_nine(a))
                if set_active:
                    ds.set_active_tensors(nm)
            elif c == 9:
                _attach_cell(ds, nm, a)
                if set_active:
                    ds.set_active_tensors(nm)
            else:
                _attach_cell(ds, nm, a[:, 0])
                if set_active:
                    ds.set_active_scalars(nm)
        else:
            _attach_cell(ds, nm, a[:, 0] if c > 1 else a)
            if set_active:
                ds.set_active_scalars(nm)
        return nm

    # ---------- FMT_MULT reduction (array provided) ----------

    def add_elem_mult_reduced_array(
        self,
        ds: pv.UnstructuredGrid,
        *,
        view: MultResultView,
        domain: str,
        data: np.ndarray,  # shape (R, Kmax) or (R, Kmax, C)
        reducer: str = "mean",
        name: str | None = None,
        set_active: bool = True,
    ) -> str:
        """Reduce node-wise element arrays to one value per cell then attach.

        Args:
            ds: Target unstructured grid.
            view: Mult result view.
            domain: Domain label (unused; for API symmetry).
            data: Block array shaped (R, Kmax[, C]).
            reducer: Reduction method (mean, max, min, first).
            name: Optional field name override.
            set_active: Whether to set the field active.

        Returns:
            str: Name of the attached field.
        """
        _ = domain  # placeholder for future domain-aware hooks
        blk = np.asarray(data, dtype=np.float32)
        if blk.ndim == 2:
            blk = blk[:, :, None]
        R, _Kmax, C = blk.shape
        if ds.n_cells != R:
            raise ValueError(f"cells {ds.n_cells} != data rows {R}")

        elem_rows = np.asarray(self.mesh.parts[domain], dtype=np.int64)
        nper = np.asarray(self.mesh.elements.nper[elem_rows], dtype=np.int64)

        if reducer == "mean":
            out = np.vstack([
                np.nanmean(blk[r, : nper[r], :], axis=0)
                if nper[r] > 0
                else np.full((C,), np.nan, np.float32)
                for r in range(R)
            ])
        elif reducer == "max":
            out = np.vstack([
                np.nanmax(blk[r, : nper[r], :], axis=0)
                if nper[r] > 0
                else np.full((C,), np.nan, np.float32)
                for r in range(R)
            ])
        elif reducer == "min":
            out = np.vstack([
                np.nanmin(blk[r, : nper[r], :], axis=0)
                if nper[r] > 0
                else np.full((C,), np.nan, np.float32)
                for r in range(R)
            ])
        elif reducer == "first":
            out = np.vstack([
                blk[r, 0, :] if nper[r] > 0 else np.full((C,), np.nan, np.float32)
                for r in range(R)
            ])
        else:
            raise ValueError("reducer must be one of {'mean','max','min','first'}")

        nm = name or f"{view.meta.name}_{reducer}"
        if C == 1:
            _attach_cell(ds, nm, out.reshape(R))
            if set_active:
                ds.set_active_scalars(nm)
        elif view.meta.dtype == FEDataType.VEC3F and C == 3:
            _attach_cell(ds, nm, out.reshape(R, 3))
            if set_active:
                ds.set_active_vectors(nm)
        elif view.meta.dtype in (FEDataType.MAT3FS, FEDataType.MAT3F):
            if C == 6 and view.meta.dtype == FEDataType.MAT3FS:
                _attach_cell(ds, nm, _six_to_nine(out.reshape(R, 6)))
                if set_active:
                    ds.set_active_tensors(nm)
            elif C == 9:
                _attach_cell(ds, nm, out.reshape(R, 9))
                if set_active:
                    ds.set_active_tensors(nm)
            else:
                _attach_cell(ds, nm, out[:, 0])
                if set_active:
                    ds.set_active_scalars(nm)
        else:
            _attach_cell(ds, nm, out[:, 0])
            if set_active:
                ds.set_active_scalars(nm)
        return nm

    # ---------- region vectors (array provided) ----------

    def region_series_array(self, *, data: np.ndarray) -> np.ndarray:
        """Normalize region time series data to float32 arrays.

        Args:
            data: Array shaped (T, C...) to be flattened on the last axes.

        Returns:
            np.ndarray: Array shaped (T, -1) in float32.

        Usage:
            ``flattened = bridge.region_series_array(data=series_block)``
        """
        a = np.asarray(data, dtype=np.float32)
        return a.reshape(a.shape[0], -1)
