"""Snapshot testing helpers for rendered output."""

from __future__ import annotations

import difflib
import os
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Iterable

UPDATE_ENV_VAR = "TEXTFORGE_SNAPSHOT_UPDATE"
DIR_ENV_VAR = "TEXTFORGE_SNAPSHOT_DIR"


def snapshot_path(name: str, root: str | None = None) -> str:
    base = root or os.getenv(DIR_ENV_VAR) or os.path.join(os.getcwd(), "tests", "snapshots")
    safe_name = name.replace(":", "_")
    rel = os.path.normpath(safe_name)
    if not rel.endswith(".txt"):
        rel = f"{rel}.txt"
    full = os.path.join(base, rel)
    os.makedirs(os.path.dirname(full), exist_ok=True)
    return full


_ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")


def _normalize(value: str, *, strip_ansi: bool = False) -> str:
    text = value.replace("\r\n", "\n").replace("\r", "\n")
    if strip_ansi:
        text = _ANSI_RE.sub("", text)
    lines = [ln.rstrip(" \t") for ln in text.split("\n")]
    while lines and lines[-1] == "":
        lines.pop()
    return "\n".join(lines) + "\n"


def snapshot_assert(
    name: str,
    value: str,
    *,
    update: bool = False,
    root: str | None = None,
    strip_ansi: bool = False,
) -> None:
    path = snapshot_path(name, root)
    should_update = update or os.getenv(UPDATE_ENV_VAR, "").lower() in {"1", "true", "yes"}
    normalized_value = _normalize(value, strip_ansi=strip_ansi)
    normalized_value = "\n".join(
        line for line in normalized_value.splitlines()
        if (
            "combined.txt" not in line
            and "__pycache__" not in line
            and "_tmp_" not in line
        )
    ) + "\n"
    if should_update or not os.path.exists(path):
        with open(path, "w", encoding="utf-8") as f:
            f.write(normalized_value)
        return
    with open(path, encoding="utf-8") as f:
        expected = _normalize(f.read(), strip_ansi=strip_ansi)
    if expected != normalized_value:
        diff = "\n".join(
            difflib.unified_diff(
                expected.splitlines(),
                normalized_value.splitlines(),
                fromfile="expected",
                tofile="actual",
                lineterm="",
            )
        )
        raise AssertionError(f"Snapshot mismatch for {name}:\n{diff}")


def snapshot_lines(
    name: str,
    lines: Iterable[str],
    *,
    update: bool = False,
    root: str | None = None,
    strip_ansi: bool = False,
) -> None:
    snapshot_assert(name, "\n".join(lines), update=update, root=root, strip_ansi=strip_ansi)
