"""Mesh parsing helpers for XPLT reader."""

from __future__ import annotations

import struct
from dataclasses import dataclass, field
from os import SEEK_SET

import numpy as np

from pyfebiopt.mesh.mesh import ElementArray, Mesh, NodeArray, SurfaceArray

from .binary_reader import BinaryReader
from .enums import Elem_Type, nodesPerElementClass


def _norm_node_ids(ids: np.ndarray, N: int) -> np.ndarray:
    """Normalize node ids to 0-based contiguous indices.

    Returns:
        np.ndarray: Normalized ids.
    """
    a = np.asarray(ids, dtype=np.int64, order="C")
    if a.size == 0:
        return a
    valid = a >= 0
    if not np.any(valid):
        return a
    vmax = int(a[valid].max())
    if vmax == N:
        a[valid] -= 1
    elif vmax > N:
        raise ValueError(f"node id {vmax} exceeds node count {N}")
    if np.any(a[valid] < 0) or np.any(a[valid] >= N):
        raise ValueError("normalized node ids out of range")
    return a


@dataclass
class MeshParseResult:
    """Container for parsed mesh pieces and lookup maps."""

    mesh: Mesh
    domain_nodes: dict[str, np.ndarray]
    parts_map: dict[str, list[int]] = field(default_factory=dict)
    surfaces_map: dict[str, list[np.ndarray]] = field(default_factory=dict)
    nodesets_map: dict[str, np.ndarray] = field(default_factory=dict)
    part_id2name: dict[int, str] = field(default_factory=dict)
    surf_id2name: dict[int, str] = field(default_factory=dict)
    dom_idx2name: dict[int, str] = field(default_factory=dict)
    surf_idx2name: dict[int, str] = field(default_factory=dict)


def parse_mesh(reader: BinaryReader) -> MeshParseResult:
    """Parse mesh blocks and build Mesh plus lookup maps.

    Returns:
        MeshParseResult: Parsed mesh and helper maps.
    """
    mesh_size = reader.search_block("PLT_MESH")
    mesh_start = reader.tell()
    mesh_end = mesh_start + mesh_size if mesh_size > 0 else reader.filesize

    nodes_xyz: np.ndarray | None = None
    conn_list: list[np.ndarray] = []
    etype_list: list[str] = []
    parts_map: dict[str, list[int]] = {}
    surfaces_map: dict[str, list[np.ndarray]] = {}
    nodesets_map: dict[str, np.ndarray] = {}
    part_id2name: dict[int, str] = {}
    surf_id2name: dict[int, str] = {}
    dom_idx2name: dict[int, str] = {}
    surf_idx2name: dict[int, str] = {}

    # nodes
    reader.search_block("PLT_NODE_SECTION")
    reader.search_block("PLT_NODE_HEADER")
    reader.search_block("PLT_NODE_SIZE")
    nodeSize = int(struct.unpack("I", reader.read(4))[0])
    reader.search_block("PLT_NODE_DIM")
    nodeDim = int(struct.unpack("I", reader.read(4))[0])
    reader.search_block("PLT_NODE_COORDS")
    xyz = np.zeros((nodeSize, max(3, nodeDim)), dtype=float)
    for _i in range(nodeSize):
        _ = struct.unpack("I", reader.read(4))[0]
        for j in range(nodeDim):
            xyz[_i, j] = struct.unpack("f", reader.read(4))[0]
    nodes_xyz = xyz

    # domains
    reader.search_block("PLT_DOMAIN_SECTION")
    dom_ord = 0
    while reader.check_block("PLT_DOMAIN"):
        reader.search_block("PLT_DOMAIN")
        reader.search_block("PLT_DOMAIN_HDR")
        reader.search_block("PLT_DOM_ELEM_TYPE")
        et_num = int(struct.unpack("I", reader.read(4))[0])
        etype = Elem_Type(et_num).name
        nne = nodesPerElementClass[etype]

        reader.search_block("PLT_DOM_PART_ID")
        part_id_1b = int(struct.unpack("I", reader.read(4))[0])
        part_id = part_id_1b - 1

        reader.search_block("PLT_DOM_ELEMS")
        _ = int(struct.unpack("I", reader.read(4))[0])  # count

        nlen = reader.search_block("PLT_DOM_NAME")
        part_name = reader.read(nlen).split(b"\x00")[-1].decode("utf-8", errors="ignore")
        part_id2name[part_id] = part_name
        dom_idx2name[dom_ord] = part_name
        dom_ord += 1

        reader.search_block("PLT_DOM_ELEM_LIST")
        while reader.check_block("PLT_ELEMENT"):
            sec_sz = reader.search_block("PLT_ELEMENT", print_tag=0)
            payload = reader.read(sec_sz)
            vals = np.frombuffer(payload, dtype=np.int32)
            if vals.size < nne:
                raise ValueError(
                    f"PLT_ELEMENT too short for {etype} (need {nne}, got {vals.size})"
                )
            nodes_raw = vals[-nne:]
            N = nodes_xyz.shape[0] if nodes_xyz is not None else 0
            nodes0 = _norm_node_ids(nodes_raw, N)
            conn_list.append(nodes0)
            etype_list.append(etype)
            parts_map.setdefault(part_name, []).append(len(etype_list) - 1)

    # surfaces
    if reader.search_block("PLT_SURFACE_SECTION") > 0:
        surf_ord = 0
        while reader.check_block("PLT_SURFACE"):
            reader.search_block("PLT_SURFACE")
            reader.search_block("PLT_SURFACE_HDR")
            reader.search_block("PLT_SURFACE_ID")
            sid_1b = int(struct.unpack("I", reader.read(4))[0]) - 1
            reader.search_block("PLT_SURFACE_FACES")
            _ = int(struct.unpack("I", reader.read(4))[0])
            nlen = reader.seek_block("PLT_SURFACE_NAME")
            sname = reader.read(nlen).decode("utf-8", errors="ignore").split("\x00")[-1]
            surf_id2name[sid_1b] = sname
            surf_idx2name[surf_ord] = sname
            surf_ord += 1

            reader.search_block("PLT_SURFACE_MAX_FACET_NODES")
            maxn = int(struct.unpack("I", reader.read(4))[0])

            if reader.check_block("PLT_FACE_LIST"):
                reader.search_block("PLT_FACE_LIST")
                lst: list[np.ndarray] = []
                while reader.check_block("PLT_FACE"):
                    sec_size = reader.search_block("PLT_FACE")
                    cur = reader.tell()
                    _ = int(struct.unpack("I", reader.read(4))[0])
                    reader.skip(4)
                    face = np.zeros(maxn, dtype=np.int64)
                    for j in range(maxn):
                        face[j] = int(struct.unpack("I", reader.read(4))[0])
                    N = nodes_xyz.shape[0] if nodes_xyz is not None else 0
                    lst.append(_norm_node_ids(face, N))
                    reader.seek(cur + sec_size, SEEK_SET)
                surfaces_map[sname] = lst

    # nodesets
    if reader.search_block("PLT_NODESET_SECTION") > 0:
        while reader.check_block("PLT_NODESET"):
            reader.search_block("PLT_NODESET")
            reader.search_block("PLT_NODESET_HDR")
            reader.search_block("PLT_NODESET_ID")
            _ = int(struct.unpack("I", reader.read(4))[0])
            reader.search_block("PLT_NODESET_SIZE")
            nsize = int(struct.unpack("I", reader.read(4))[0])
            nlen = reader.search_block("PLT_NODESET_NAME")
            nname = reader.read(nlen).decode("utf-8", errors="ignore").split("\x00")[-1]
            ids: list[int] = []
            if reader.check_block("PLT_NODESET_LIST"):
                reader.search_block("PLT_NODESET_LIST")
                for _ in range(nsize):
                    ids.append(int(struct.unpack("I", reader.read(4))[0]))
            N = nodes_xyz.shape[0] if nodes_xyz is not None else 0
            nodesets_map[nname] = _norm_node_ids(np.asarray(ids, dtype=np.int64), N)

    # optional part renames
    if reader.search_block("PLT_PARTS_SECTION") > 0:
        while reader.check_block("PLT_PART"):
            reader.search_block("PLT_PART")
            reader.search_block("PLT_PART_ID")
            pid = int(struct.unpack("I", reader.read(4))[0]) - 1
            nlen = reader.search_block("PLT_PART_NAME")
            pname = reader.read(nlen).decode("utf-8", errors="ignore").split("\x00")[0]
            part_id2name[pid] = pname

    # build Mesh
    nodes = NodeArray(nodes_xyz if nodes_xyz is not None else np.zeros((0, 3), float))
    kmax = max((len(c) for c in conn_list), default=0)
    E = len(conn_list)
    conn = -np.ones((E, kmax), dtype=np.int64)
    nper = np.zeros((E,), dtype=np.int64)
    for i, c in enumerate(conn_list):
        conn[i, : c.size] = c
        nper[i] = c.size
    elements = ElementArray(conn=conn, nper=nper, etype=np.asarray(etype_list, dtype=object))
    parts = {name: np.asarray(idx, dtype=np.int64) for name, idx in parts_map.items()}
    surfaces: dict[str, SurfaceArray] = {}
    for sname, lst in surfaces_map.items():
        if not lst:
            continue
        kk = max(len(a) for a in lst)
        F = len(lst)
        faces = -np.ones((F, kk), dtype=np.int64)
        nps = np.zeros((F,), dtype=np.int64)
        for i, a in enumerate(lst):
            faces[i, : a.size] = a
            nps[i] = a.size
        surfaces[sname] = SurfaceArray(faces=faces, nper=nps)
    nodesets = {name: ids for name, ids in nodesets_map.items()}
    mesh = Mesh(
        nodes=nodes, elements=elements, parts=parts, surfaces=surfaces, nodesets=nodesets
    )

    # domain nodes (unique per part)
    domain_nodes: dict[str, np.ndarray] = {}
    conn = mesh.elements.conn
    nper_e = mesh.elements.nper
    for dname, eids in mesh.parts.items():
        ids_flat: list[int] = []
        for e in np.asarray(eids, dtype=np.int64):
            k = int(nper_e[e])
            if k > 0:
                ids_flat.extend(conn[e, :k].tolist())
        seen: set[int] = set()
        order: list[int] = []
        for i in ids_flat:
            if i not in seen:
                seen.add(i)
                order.append(i)
        domain_nodes[dname] = np.asarray(order, dtype=np.int64)

    result = MeshParseResult(
        mesh=mesh,
        domain_nodes=domain_nodes,
        parts_map=parts_map,
        surfaces_map=surfaces_map,
        nodesets_map=nodesets_map,
        part_id2name=part_id2name,
        surf_id2name=surf_id2name,
        dom_idx2name=dom_idx2name,
        surf_idx2name=surf_idx2name,
    )

    # advance to end of mesh block to align for subsequent sections
    reader.seek(mesh_end)
    return result


__all__ = ["MeshParseResult", "parse_mesh"]
