"""Console abstraction and friendly printing helpers for Textforge."""

from __future__ import annotations

import sys
import threading
import time
from collections.abc import Callable, Iterable, Sequence
from contextlib import redirect_stderr, redirect_stdout
from contextvars import ContextVar
from dataclasses import dataclass, field
from io import StringIO
from typing import TYPE_CHECKING, Any, Protocol, TextIO, runtime_checkable

from ..markup import MarkupEngine
from ..style.colors import Color
from ..style.themes import ThemeManager
from ..text_engine import visible_width
from ..utils.logging import get_logger

if TYPE_CHECKING:
    from .events import EventBus
    from .rendering_cli.terminal_app import TerminalApp

__all__ = [
    "Console",
    "Renderable",
    "Measure",
    "RenderableLike",
    "CallableRenderable",
    "CompositeRenderable",
    "LazyRenderable",
    "render_call",
    "composite",
    "lazy",
    "tfprint",
]


@dataclass(slots=True)
class Measure:
    """Dimensions for rendered text (width/height in terminal cells)."""

    width: int
    height: int

    @staticmethod
    def from_text(text: str) -> Measure:
        lines = text.splitlines() or [""]
        width = 0
        for line in lines:
            width = max(width, visible_width(line, ignore_ansi=True, ignore_markup=True))
        return Measure(width=width, height=len(lines))


@runtime_checkable
class Renderable(Protocol):
    """Protocol implemented by objects that can render themselves via a Console."""

    def render(self, console: Console) -> str | Iterable[str] | None:
        ...

    def measure(self, console: Console) -> Measure | None:
        ...


# Type for streams that Console can write to
StreamLike = TextIO | Any


@runtime_checkable
class LiveSessionImpl(Protocol):
    """Protocol for live session implementations."""

    def update(self, text: str) -> None:
        """Update the display with new text."""
        ...

    def close(self) -> None:
        """Close the session."""
        ...


RenderableLike = str | Renderable | Callable[..., Any]


@dataclass(slots=True)
class CallableRenderable:
    """
    Lightweight adapter that turns legacy print-based render functions into renderables.
    """

    func: Callable[..., Any]
    args: tuple[Any, ...] = ()
    kwargs: dict[str, Any] | None = None
    _cache_text: str | None = field(default=None, init=False, repr=False)

    def _invoke(self, console: Console) -> str:
        text = console.capture(self.func, *self.args, **(self.kwargs or {}))
        self._cache_text = text
        return text

    def render(self, console: Console) -> str:
        if self._cache_text is not None:
            text = self._cache_text
            self._cache_text = None
            return text
        return self._invoke(console)

    def measure(self, console: Console) -> Measure:
        if self._cache_text is None:
            self._cache_text = self._invoke(console)
        text = self._cache_text
        self._cache_text = None
        return Measure.from_text(text)


@dataclass(slots=True)
class CompositeRenderable:
    """Renderable that composes child renderables or strings."""

    children: Sequence[RenderableLike]
    separator: str = ""
    markup: bool = True

    def render(self, console: Console) -> str:
        parts: list[str] = []
        for child in self.children:
            chunk = console._coerce(child, markup=self.markup)
            if self.separator:
                parts.append(chunk)
                continue
            # Default behavior: insert a single newline between adjacent chunks
            # when the previous chunk does not end with one and the next does
            # not start with one, to preserve visual separation without
            # accumulating blank lines.
            if parts and (not parts[-1].endswith("\n")) and chunk and (not chunk.startswith("\n")):
                parts.append("\n")
            parts.append(chunk)
        if self.separator:
            return self.separator.join(parts)
        return "".join(parts)

    def measure(self, console: Console) -> Measure:
        rendered = self.render(console)
        return Measure.from_text(rendered)


@dataclass(slots=True)
class LazyRenderable:
    """Renderable that defers construction until render/measure time."""

    factory: Callable[[], RenderableLike]
    cache: bool = True
    _resolved: RenderableLike | None = field(default=None, init=False, repr=False)
    _has_cache: bool = field(default=False, init=False, repr=False)

    def _resolve(self) -> RenderableLike:
        if self.cache and self._has_cache:
            assert self._resolved is not None
            return self._resolved
        value = self.factory()
        if self.cache:
            self._resolved = value
            self._has_cache = True
        return value

    def _coerce_value(self, console: Console, value: RenderableLike) -> str:
        if isinstance(value, Renderable):
            return console._coerce(value, markup=True)
        return console._coerce(value, markup=True)

    def render(self, console: Console) -> str:
        value = self._resolve()
        if isinstance(value, Renderable):
            rendered = value.render(console)
            if rendered is None:
                return ""
            if isinstance(rendered, str):
                return rendered
            if isinstance(rendered, Iterable):
                return "".join(rendered)
            return str(rendered)
        return self._coerce_value(console, value)

    def measure(self, console: Console) -> Measure:
        value = self._resolve()
        if isinstance(value, Renderable):
            measure = value.measure(console)
            if measure is not None:
                return measure
            rendered = value.render(console) or ""
            if isinstance(rendered, Iterable) and not isinstance(rendered, str):
                rendered = "".join(rendered)
            return Measure.from_text(str(rendered))
        rendered = self._coerce_value(console, value)
        return Measure.from_text(rendered)


class Console:
    """Coordinates rendering, markup processing, and output."""

    def __init__(
        self,
        stream: StreamLike | None = None,
        *,
        markup: MarkupEngine | None = None,
        theme_overrides: dict[str, str] | None = None,
    ) -> None:
        self.stream = stream or sys.stdout
        self.markup_engine = markup or MarkupEngine()
        self._should_close = False
        self._lock = threading.Lock()
        self._theme_overrides = theme_overrides or None
        # Register a simple [time:FORMAT] custom tag for convenience
        try:
            import datetime as _dt

            def _time_tag(fmt: str) -> str:
                try:
                    return _dt.datetime.now().strftime(fmt)
                except Exception:
                    return _dt.datetime.now().isoformat(timespec="seconds")

            self.markup_engine.register_tag("time", _time_tag)
            # i18n tag: [t:key]
            from ..utils.i18n import t as _t

            def _t_tag(payload: str) -> str:
                return _t(payload)

            self.markup_engine.register_tag("t", _t_tag)
        except Exception as e:
            get_logger().debug(f"Failed to register optional markup tags: {e}")

    class LiveSession:
        """A minimal live rendering session with per-line diff updates.

        Delegates to the CLI renderer implementation in
        ``textforge.core.rendering_cli`` so renderer concerns stay separate
        from Console.
        """

        def __init__(self, stream: StreamLike, *, renderer: str | None = None) -> None:
            self._impl: Any
            # Import lazily to avoid circular imports during module init
            selected = renderer
            if selected is None:
                try:
                    from ..utils.config import get_default_renderer as _get_renderer
                    selected = _get_renderer()
                except Exception as e:
                    get_logger().debug(f"Failed to get default renderer from config: {e}")
                    selected = None
            selected = (selected or "tty").lower()

            if selected in ("gui",):
                try:
                    from .rendering_gui import LiveSession as _GuiLiveSession
                    self._impl = _GuiLiveSession(stream)
                    return
                except Exception as e:
                    get_logger().debug(f"Failed to initialize GUI renderer, falling back: {e}")
            if selected in ("terminal", "term", "custom_terminal"):
                try:
                    from .rendering_cli import TerminalSession as _TerminalSession
                    self._impl = _TerminalSession(stream)
                    return
                except Exception as e:
                    get_logger().debug(f"Failed to initialize terminal session renderer, falling back: {e}")
            from .rendering_cli import LiveSession as _CliLiveSession
            self._impl = _CliLiveSession(stream)

        def update(self, text: str) -> None:
            self._impl.update(text)

        def close(self) -> None:
            self._impl.close()

        # Optional APIs for richer sessions
        def run(self, app: TerminalApp, *, fps: float = 30.0) -> None:
            """Run an application with input handling (terminal sessions only)."""
            if hasattr(self._impl, "run"):
                return self._impl.run(app, fps=fps)
            raise NotImplementedError("run() is only available for terminal sessions")

        @property
        def events(self) -> EventBus | None:
            """Access the event bus for subscribing to input and timer events (terminal sessions only)."""
            if hasattr(self._impl, "events"):
                return self._impl.events
            return None

        def __enter__(self) -> Console.LiveSession:
            return self

        def __exit__(self, exc_type, exc, tb) -> None:
            self.close()

    def live(self, *, renderer: str | None = None) -> Console.LiveSession:
        return Console.LiveSession(self.stream, renderer=renderer)

    class Scheduler:
        """A simple frame scheduler running a callback at target FPS in a thread.

        Exposes a cancellation/join handle and captures exceptions from the
        callback, making them available to callers.
        """

        @dataclass(slots=True)
        class Handle:
            _scheduler: Console.Scheduler

            def cancel(self) -> None:
                self._scheduler.stop()

            def join(self, timeout: float | None = None) -> None:
                self._scheduler._join(timeout)

            @property
            def exception(self) -> BaseException | None:
                return self._scheduler._exception

        def __init__(self, fps: float = 30.0) -> None:
            self.fps = fps
            self._thread: threading.Thread | None = None
            self._running = False
            self._exception: BaseException | None = None

        def start(self, callback: Callable[[float], None]) -> Handle:
            if self._running:
                return Console.Scheduler.Handle(self)
            self._running = True
            self._exception = None

            def _runner() -> None:
                prev = time.perf_counter()
                interval = 1.0 / max(1e-6, self.fps)
                while self._running:
                    now = time.perf_counter()
                    dt = now - prev
                    prev = now
                    try:
                        callback(dt)
                    except BaseException as exc:  # capture, stop, and store
                        self._exception = exc
                        self._running = False
                        break
                    sleep_for = interval - (time.perf_counter() - now)
                    if sleep_for > 0:
                        time.sleep(sleep_for)

            self._thread = threading.Thread(target=_runner, daemon=True)
            self._thread.start()
            return Console.Scheduler.Handle(self)

        def _join(self, timeout: float | None = None) -> None:
            if self._thread and self._thread.is_alive():
                self._thread.join(timeout=timeout)

        def stop(self) -> None:
            self._running = False
            self._join(timeout=0.2)

    @dataclass(slots=True)
    class ThreadHandle:
        thread: threading.Thread
        _exception: BaseException | None = None

        def join(self, timeout: float | None = None) -> None:
            if self.thread.is_alive():
                self.thread.join(timeout=timeout)

        @property
        def exception(self) -> BaseException | None:
            return self._exception

        def cancel(self) -> None:
            # Cooperative cancellation is not supported for print_async
            pass

    def print_async(self, *objects: Any, sep: str = " ", end: str = "\n", style: str | None = None, markup: bool = True) -> ThreadHandle:  # Any: accepts any renderable-like objects
        """Print objects asynchronously on a background thread.

        Args:
            objects: Objects to render and print.
            sep: String to join objects with.
            end: String to append after printing.
            style: Optional style name to apply.
            markup: Whether to parse markup in string objects.

        Returns:
            A ThreadHandle for monitoring the async operation.
        """
        def _target() -> None:
            try:
                self.print(*objects, sep=sep, end=end, style=style, markup=markup)
            except BaseException as exc:  # capture, don't re-raise across thread boundary
                handle._exception = exc

        t = threading.Thread(target=_target, daemon=True)
        handle = Console.ThreadHandle(thread=t)
        t.start()
        return handle

    def capture(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
        """Execute `func` while redirecting stdout and stderr; return captured text."""
        buffer = StringIO()
        with redirect_stdout(buffer), redirect_stderr(buffer):
            func(*args, **kwargs)
        return buffer.getvalue()

    def measure(self, obj: Any, *, markup: bool = True) -> Measure:  # Any: accepts any renderable-like object
        """Compute rendered dimensions for any renderable-compatible object."""
        if isinstance(obj, Renderable):
            measure = obj.measure(self)
            if measure is not None:
                return measure
            rendered = obj.render(self) or ""
            if isinstance(rendered, Iterable) and not isinstance(rendered, str):
                rendered = "".join(rendered)
            return Measure.from_text(str(rendered))
        text = self._coerce(obj, markup=markup)
        return Measure.from_text(text)

    def _coerce(self, obj: Any, *, markup: bool) -> str:  # Any: accepts any renderable-like object
        if obj is None:
            return ""
        if isinstance(obj, str):
            return self.markup_engine.render(obj) if markup else obj

        if isinstance(obj, Renderable):
            rendered = obj.render(self)
            if rendered is None:
                return ""
            if isinstance(rendered, str):
                return rendered
            if isinstance(rendered, Iterable):
                return "".join(rendered)
            return str(rendered)

        if callable(obj):
            return self.capture(obj)

        text = str(obj)
        return self.markup_engine.render(text) if markup else text

    def print(
        self,
        *objects: Any,  # Any: accepts any renderable-like objects
        sep: str = " ",
        end: str = "\n",
        style: str | None = None,
        markup: bool = True,
    ) -> None:
        """Render objects and write them to the configured stream."""
        # Apply per-thread theme overrides while rendering this call
        ctx = ThemeManager.override(self._theme_overrides) if self._theme_overrides else None
        if ctx is None:
            # Fast path when no overrides are provided
            pieces = [self._coerce(obj, markup=markup) for obj in objects]
            text = sep.join(pieces)
            if style:
                style_code = Color.get_color(style)
                if style_code:
                    text = f"{style_code}{text}{Color.RESET}"
            with self._lock:
                try:
                    self.stream.write(text + end)
                    self.stream.flush()
                except UnicodeEncodeError:
                    encoding = getattr(self.stream, "encoding", None) or "utf-8"
                    data = (text + end).encode(encoding, errors="replace")
                    buffer = getattr(self.stream, "buffer", None)
                    if buffer is not None:
                        buffer.write(data)
                        buffer.flush()
                    else:
                        self.stream.write(data.decode(encoding, errors="ignore"))
            return
        # Slow path: with overrides context manager
        with ctx:
            pieces = [self._coerce(obj, markup=markup) for obj in objects]
            text = sep.join(pieces)
            if style:
                style_code = Color.get_color(style)
                if style_code:
                    text = f"{style_code}{text}{Color.RESET}"
            with self._lock:
                try:
                    self.stream.write(text + end)
                    self.stream.flush()
                except UnicodeEncodeError:
                    encoding = getattr(self.stream, "encoding", None) or "utf-8"
                    data = (text + end).encode(encoding, errors="replace")
                    buffer = getattr(self.stream, "buffer", None)
                    if buffer is not None:
                        buffer.write(data)
                        buffer.flush()
                    else:
                        self.stream.write(data.decode(encoding, errors="ignore"))

    def close(self) -> None:
        """Close the underlying stream if this console owns it."""
        if self._should_close and self.stream:
            try:
                if hasattr(self.stream, "close"):
                    self.stream.close()
            finally:
                self._should_close = False

    @classmethod
    def for_stdout(cls) -> Console:
        """Create a console bound to the process stdout stream."""
        return cls(stream=sys.stdout)

    @classmethod
    def for_file(
        cls,
        path: str,
        *,
        mode: str = "w",
        encoding: str = "utf-8",
    ) -> Console:
        """Create a console that writes to a file."""
        stream = open(path, mode, encoding=encoding)
        console = cls(stream=stream)
        console._should_close = True
        return console

    @classmethod
    def for_notebook(cls) -> Console:
        """Create a console that renders output inside Jupyter notebooks."""

        class _NotebookStream:
            encoding = "utf-8"

            def __init__(self) -> None:
                self._buffer: list[str] = []

            def write(self, text: str) -> int:
                self._buffer.append(text)
                return len(text)

            def flush(self) -> None:
                if not self._buffer:
                    return
                from ..utils.jupyter import display_ansi_html, strip_ansi

                combined = "".join(self._buffer)
                try:
                    display_ansi_html(combined)
                except Exception:
                    print(strip_ansi(combined))
                self._buffer.clear()

        return cls(stream=_NotebookStream())

    @classmethod
    def for_backend(cls, name: str = "tty") -> Console:
        """Create a console bound to a named runtime backend (tty/gui)."""
        try:
            from importlib import import_module
            mod = import_module("textforge.core.backends")
            Backends = mod.Backends
            # Support auto-selection via config when name is "auto" or empty
            if not name or name.lower() == "auto":
                def _resolve_default() -> str:
                    try:
                        from ..utils.config import get_default_renderer as _get_renderer
                    except Exception:
                        return "tty"
                    try:
                        value = (_get_renderer() or "").strip().lower()
                    except Exception:
                        value = ""
                    aliases = {
                        "": "tty",
                        "auto": "tty",
                        "tty": "tty",
                        "cli": "tty",
                        "console": "tty",
                        "gui": "gui",
                        "window": "gui",
                        "desktop": "gui",
                        "terminal": "terminal",
                        "term": "terminal",
                        "custom_terminal": "terminal",
                    }
                    return aliases.get(value, value or "tty")

                selected = _resolve_default()
                stream = Backends.create_stream(selected)
            else:
                stream = Backends.create_stream(name.lower())
        except Exception:
            stream = sys.stdout
        return cls(stream=stream)


# Context-based default console for dependency injection
_CURRENT_CONSOLE: ContextVar[Console | None] = ContextVar("textforge_current_console", default=None)


def tfprint(
    *objects: Any,
    sep: str = " ",
    end: str = "\n",
    style: str | None = None,
    markup: bool = True,
    console: Console | None = None,
) -> Any | None:
    """
    User-facing helper that mirrors `print` while understanding Textforge renderables.
    """
    active_console = console or _CURRENT_CONSOLE.get() or Console.for_stdout()
    active_console.print(
        *objects,
        sep=sep,
        end=end,
        style=style,
        markup=markup,
    )


class _ConsoleUse:
    def __init__(self, console: Console) -> None:
        self.console = console
        self._token = None

    def __enter__(self) -> Console:
        self._token = _CURRENT_CONSOLE.set(self.console)
        return self.console

    def __exit__(self, exc_type, exc, tb) -> None:
        if self._token is not None:
            _CURRENT_CONSOLE.reset(self._token)


def use_console(console: Console) -> _ConsoleUse:
    """Context manager to set the process-local default console."""
    return _ConsoleUse(console)


def render_call(func: Callable[..., Any], *args: Any, **kwargs: Any) -> CallableRenderable:  # Any: generic wrapper for legacy print functions
    """Wrap a renderer for use with `tfprint` and exporters.

    If the wrapped function prints to stdout, its output is captured.
    If it returns a string or iterable of strings, the value is written
    to stdout by this shim so capture still works. This enables migrating
    component renderers to return strings without breaking callers.
    """

    def _shim(*a: Any, **kw: Any) -> None:
        result = func(*a, **kw)
        if result is None:
            return
        if isinstance(result, str):
            print(result, end="")
            return
        if isinstance(result, Iterable) and not isinstance(result, str):
            print("".join(result), end="")
            return
        print(str(result), end="")

    return CallableRenderable(_shim, args=args, kwargs=kwargs)


def composite(children: Sequence[RenderableLike], *, separator: str = "", markup: bool = True) -> CompositeRenderable:
    """Create a composite renderable from child renderable-like objects."""
    return CompositeRenderable(children=children, separator=separator, markup=markup)


def lazy(factory: Callable[[], RenderableLike], *, cache: bool = True) -> LazyRenderable:
    """Create a lazily-evaluated renderable from a factory callable."""
    return LazyRenderable(factory=factory, cache=cache)
