from __future__ import annotations

import threading
from typing import TextIO

from ..render_tree import diff_changes


class LiveSession:
    """A minimal live rendering session with per-line diff updates for TTY.

    This implementation uses ANSI escape sequences to move the cursor and
    rewrite only changed lines, clearing trailing lines when frames shrink.
    """

    def __init__(self, stream: TextIO) -> None:
        self.stream = stream
        self._last_lines: list[str] = []
        self._lock = threading.Lock()

    def _write(self, s: str) -> None:
        try:
            self.stream.write(s)
            self.stream.flush()
        except UnicodeEncodeError:
            encoding = getattr(self.stream, "encoding", None) or "utf-8"
            data = s.encode(encoding, errors="replace")
            buffer = getattr(self.stream, "buffer", None)
            if buffer is not None:
                buffer.write(data)
                buffer.flush()
            else:
                self.stream.write(data.decode(encoding, errors="ignore"))

    def update(self, text: str) -> None:
        with self._lock:
            new_lines = text.split("\n")
            old_len = len(self._last_lines)
            new_len = len(new_lines)
            max_lines = max(old_len, new_len)

            if old_len:
                self._write(f"\x1b[{old_len}F")

            changed = {c.index for c in diff_changes(self._last_lines, new_lines)}

            for i in range(max_lines):
                if i < new_len:
                    if i in changed:
                        self._write("\x1b[2K")
                        self._write(new_lines[i] + "\n")
                    else:
                        self._write("\x1b[1E")
                else:
                    self._write("\x1b[2K")
                    self._write("\n")

            if max_lines:
                move_up = max_lines - new_len + 1
                if move_up > 0:
                    self._write(f"\x1b[{move_up}F")
                self._write("\x1b[999C")

            self._last_lines = new_lines

    def close(self) -> None:
        if self._last_lines:
            self._write("\x1b[0m\n")
        self._last_lines = []
