"""Grapheme cluster iteration utilities.

This module provides Unicode-aware grapheme processing primitives used by
measurement, rendering, and bidi handling.
"""

from __future__ import annotations

import re
import unicodedata
from functools import lru_cache
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Iterator

__all__ = [
    "iter_graphemes",
    "split_graphemes",
    "next_grapheme",
    "grapheme_width",
]

_ANSI_RE = re.compile(r"\033\[[0-9;]*m")
_ZWJ = "\u200d"

_VARIATION_SELECTOR_RANGES: tuple[tuple[int, int], ...] = (
    (0xFE00, 0xFE0F),
    (0xE0100, 0xE01EF),
)
_SKIN_TONE_MODIFIERS = set(range(0x1F3FB, 0x1F3FF + 1))
_REGIONAL_INDICATOR_RANGE = (0x1F1E6, 0x1F1FF)
_EMOJI_RANGES: tuple[tuple[int, int], ...] = (
    (0x1F300, 0x1F5FF),
    (0x1F600, 0x1F64F),
    (0x1F680, 0x1F6FF),
    (0x1F700, 0x1F77F),
    (0x1F780, 0x1F7FF),
    (0x1F800, 0x1F8FF),
    (0x1F900, 0x1F9FF),
    (0x1FA00, 0x1FAFF),
    (0x1FB00, 0x1FBFF),
)


def _is_variation_selector(codepoint: int) -> bool:
    for start, end in _VARIATION_SELECTOR_RANGES:
        if start <= codepoint <= end:
            return True
    return False


def _is_regional_indicator(ch: str) -> bool:
    code = ord(ch)
    return _REGIONAL_INDICATOR_RANGE[0] <= code <= _REGIONAL_INDICATOR_RANGE[1]


def _is_skin_modifier(ch: str) -> bool:
    return ord(ch) in _SKIN_TONE_MODIFIERS


def _is_emoji(ch: str) -> bool:
    code = ord(ch)
    if any(start <= code <= end for start, end in _EMOJI_RANGES):
        return True
    return "EMOJI" in unicodedata.name(ch, "")


def _is_zero_width_mark(ch: str) -> bool:
    if ch == _ZWJ:
        return True
    if _is_skin_modifier(ch):
        return True
    if _is_variation_selector(ord(ch)):
        return True
    category = unicodedata.category(ch)
    if category in {"Mn", "Me", "Cf"}:
        return True
    return unicodedata.combining(ch) > 0


def _count_trailing_regional_indicators(cluster: list[str]) -> int:
    count = 0
    for ch in reversed(cluster):
        if _is_regional_indicator(ch):
            count += 1
        else:
            break
    return count


def next_grapheme(text: str, start: int = 0) -> tuple[str, int]:
    """Return the next grapheme cluster from `text` and the next index."""
    length = len(text)
    if start >= length:
        return "", length

    i = start
    cluster: list[str] = []

    ch = text[i]
    cluster.append(ch)
    i += 1

    if ch == "\r" and i < length and text[i] == "\n":
        cluster.append(text[i])
        i += 1
        return "".join(cluster), i

    join_with_next = False
    while i < length:
        c = text[i]
        code = ord(c)
        if join_with_next:
            cluster.append(c)
            i += 1
            join_with_next = False
            continue

        if c == _ZWJ:
            cluster.append(c)
            i += 1
            join_with_next = True
            continue

        if _is_zero_width_mark(c):
            cluster.append(c)
            i += 1
            continue

        if _is_regional_indicator(c):
            if _count_trailing_regional_indicators(cluster) % 2 == 1:
                cluster.append(c)
                i += 1
                continue

        if _is_skin_modifier(c):
            cluster.append(c)
            i += 1
            continue

        if _is_variation_selector(code):
            cluster.append(c)
            i += 1
            continue

        break

    return "".join(cluster), i


def iter_graphemes(text: str) -> Iterator[str]:
    """Iterate grapheme clusters in `text`."""
    index = 0
    length = len(text)
    while index < length:
        cluster, index = next_grapheme(text, index)
        if not cluster:
            break
        yield cluster


def split_graphemes(text: str) -> list[str]:
    """Return grapheme clusters as a list."""
    return list(iter_graphemes(text))


@lru_cache(maxsize=4096)
def grapheme_width(grapheme: str) -> int:
    """Compute terminal cell width for a grapheme cluster.

    Tabs and soft hyphens are handled at the line-level in width calculations;
    this function considers only the intrinsic width of the grapheme itself.
    """
    if not grapheme:
        return 0

    base_chars = [ch for ch in grapheme if not _is_zero_width_mark(ch)]
    if not base_chars:
        return 0

    width = 1
    for ch in base_chars:
        if ch in ("\n", "\r", "\t", "\u00AD"):
            return 0
        eaw = unicodedata.east_asian_width(ch)
        if eaw in {"W", "F"}:
            width = 2
            break
        if _is_emoji(ch):
            width = 2
            break
    return width
