"""Layout-oriented components (columns, grids, trees)."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

from .. import text_engine
from ..core import Console, render_call
from ..layout import LayoutNode, LayoutStyle, compute_layout
from ..layout.utils import align_text
from ..style.colors import Color
from ..symbols import Symbols
from ..utils import wrap_text
from ..utils.input import read_key

if TYPE_CHECKING:
    from collections.abc import Callable

__all__ = [
    "Columns",
    "Flex",
    "Tree",
    "Grid",
    "SelectableList",
    "ScrollableList",
    "columns",
    "flex",
    "tree",
    "grid",
    "selectable_list",
    "scrollable_list",
]


def _append_segment(buffer: list[tuple[int, str]], x: int, text: str) -> None:
    buffer.append((x, text))


def _render_line(segments: list[tuple[int, str]]) -> str:
    if not segments:
        return ""
    segments.sort(key=lambda item: item[0])
    parts: list[str] = []
    current = 0
    for x_pos, text in segments:
        if x_pos > current:
            parts.append(" " * (x_pos - current))
            current = x_pos
        parts.append(text)
        current += text_engine.visible_width(text)
    return "".join(parts)


class Columns:
    """Build multi-column layouts."""

    @staticmethod
    def render(
        columns: list[list[str]],
        widths: list[int] | None = None,
        alignments: list[str] | None = None,
        gap: int = 2,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0
    ) -> str:
        """
        Render text in multiple columns.

        Args:
            columns: List of columns (each column is a list of lines)
            widths: Width for each column (auto if None)
            alignments: Alignment for each column
            gap: Space between columns
            blank_lines_before: Blank lines before columns
            blank_lines_after: Blank lines after columns
        """
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        num_cols = len(columns)
        if num_cols == 0:
            return ""

        console = Console()

        # Auto-calculate widths (respect inline markup)
        measured_widths: list[int] = []
        processed_columns: list[list[str]] = []
        for col in columns:
            processed_lines = [console.markup_engine.render(line) for line in col]
            processed_columns.append(processed_lines)
            measured_widths.append(
                max((text_engine.visible_width(line) for line in processed_lines), default=0)
            )

        if widths is None:
            widths = measured_widths
        else:
            widths = [max(0, w) for w in widths]

        if alignments is None:
            alignments_seq: list[str] = ["left"] * num_cols
        else:
            alignments_seq = alignments

        root = LayoutNode(
            style=LayoutStyle(
                direction="row",
                gap=gap,
                justify="start",
                align="start",
            )
        )

        column_nodes: list[LayoutNode] = []
        for col_idx, lines in enumerate(processed_columns):
            column_height = len(lines)
            column_width = widths[col_idx] if col_idx < len(widths) else measured_widths[col_idx]

            def _make_measure(text_lines: list[str], width_hint: int) -> Callable[[int | None, int | None], tuple[int, int]]:
                def _measure(_: int | None, __: int | None) -> tuple[int, int]:
                    return width_hint, max(1, len(text_lines))

                return _measure

            node = LayoutNode(
                style=LayoutStyle(
                    width=column_width,
                    height=column_height,
                    direction="column",
                    gap=0,
                ),
                measure=_make_measure(lines, column_width),
            )
            column_nodes.append(node)
            root.add(node)

        container = compute_layout(root)
        total_height = max(container.height, max((len(col) for col in processed_columns), default=0))
        line_buffers: list[list[tuple[int, str]]] = [[] for _ in range(total_height)]

        for idx, (node, lines, alignment, width_value) in enumerate(
            zip(column_nodes, processed_columns, alignments_seq, widths, strict=True)
        ):
            x = node.layout.x
            y = node.layout.y
            width_value = max(width_value, measured_widths[idx])

            for row_offset in range(node.layout.height):
                target_row = y + row_offset
                if target_row >= len(line_buffers):
                    break
                if row_offset < len(lines):
                    line_text = lines[row_offset]
                else:
                    line_text = ""
                line_segment = align_text(line_text, width_value, alignment)
                _append_segment(line_buffers[target_row], x, line_segment)

        for line in line_buffers:
            out.append(_render_line(line))

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)


class Flex:
    """Flexbox-inspired layout renderer for arbitrary text blocks."""

    @staticmethod
    def render(
        items: list[Any],
        *,
        direction: Literal["row", "column"] = "row",
        gap: int = 1,
        justify: Literal["start", "center", "end", "space-between", "space-around", "space-evenly"] = "start",
        align: Literal["start", "center", "end", "stretch"] = "start",
        width: int | None = None,
        height: int | None = None,
        wrap: bool = False,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
    ) -> str:
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        if not items:
            return ""

        root_style = LayoutStyle(
            direction=direction,
            gap=gap,
            justify=justify,
            align=align,
            wrap=wrap,
        )
        if width is not None:
            root_style.width = width
        if height is not None:
            root_style.height = height

        root = LayoutNode(style=root_style)
        console = Console()

        child_data: list[tuple[LayoutNode, list[str], int, Literal["left", "center", "right", "start", "end"] | None]] = []

        for item in items:
            node_style = LayoutStyle()
            child_align: Literal["left", "center", "right", "start", "end"] | None = None

            if isinstance(item, dict):
                content = item.get("text", "")
                child_align = item.get("align")
                node_style.flex_grow = float(item.get("flex_grow", 0.0))
                node_style.flex_shrink = float(item.get("flex_shrink", 1.0))
                node_style.flex_basis = item.get("flex_basis")
                node_style.width = item.get("width")
                node_style.height = item.get("height")
                node_style.margin = item.get("margin", 0)
                node_style.padding = item.get("padding", 0)
            else:
                content = item

            if isinstance(content, str):
                raw_text = content
            else:
                raw_text = console._coerce(content, markup=True)

            width_hint = node_style.width
            lines_raw = raw_text.split("\n")
            if width_hint:
                wrapped: list[str] = []
                for raw_line in lines_raw:
                    wrapped.extend(wrap_text(raw_line, width_hint))
                lines_raw = wrapped or [""]

            lines = [console.markup_engine.render(line) for line in lines_raw]
            intrinsic_width = max((text_engine.visible_width(line) for line in lines), default=0)
            intrinsic_height = max(1, len(lines))

            width_est = width_hint if width_hint is not None else intrinsic_width
            height_est = node_style.height if node_style.height is not None else intrinsic_height

            def _measure_factory(
                lines_ref: list[str],
                width_value: int,
                height_value: int,
            ) -> Callable[[int | None, int | None], tuple[int, int]]:
                def _measure(_: int | None, __: int | None) -> tuple[int, int]:
                    return width_value, height_value

                return _measure

            node = LayoutNode(
                style=node_style,
                measure=_measure_factory(lines, width_est, height_est),
            )
            root.add(node)
            child_data.append((node, lines, width_est, child_align))

        container = compute_layout(root)
        total_height = max(1, container.height)
        line_buffers: list[list[tuple[int, str]]] = [[] for _ in range(total_height)]

        for node, lines, width_est, child_align in child_data:
            width_actual = max(width_est, node.layout.width)
            height_actual = max(1, node.layout.height)

            for row_offset in range(height_actual):
                target_row = node.layout.y + row_offset
                while target_row >= len(line_buffers):
                    line_buffers.append([])

                if row_offset < len(lines):
                    line_text = lines[row_offset]
                else:
                    line_text = ""

                align_mode = child_align or "start"
                segment = align_text(line_text, width_actual, align_mode)

                _append_segment(line_buffers[target_row], node.layout.x, segment)

        for line in line_buffers:
            out.append(_render_line(line))

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)


def columns(*args: object, **kwargs: object):
    """Return a renderable column layout."""
    return render_call(Columns.render, *args, **kwargs)


def flex(*args: object, **kwargs: object):
    """Return a renderable flex layout."""
    return render_call(Flex.render, *args, **kwargs)


def tree(*args: object, **kwargs: object):
    """Return a renderable tree component."""
    return render_call(Tree.render, *args, **kwargs)


def grid(*args: object, **kwargs: object):
    """Return a renderable grid component."""
    return render_call(Grid.render, *args, **kwargs)


def selectable_list(*args: object, **kwargs: object):
    return render_call(SelectableList.render, *args, **kwargs)


def scrollable_list(*args: object, **kwargs: object):
    return render_call(ScrollableList.render, *args, **kwargs)


class Tree:
    """Build tree structures for hierarchical data."""

    @staticmethod
    def render(
        data: dict[str, Any],
        indent: int = 0,
        is_last: bool = True,
        prefix: str = "",
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
        *,
        expanded_paths: set[tuple[str, ...]] | None = None,
        highlight_path: tuple[str, ...] | None = None,
        _path: tuple[str, ...] = (),
    ) -> None:
        """Render a tree structure with optional expansion and highlighting."""

        if indent == 0:
            for _ in range(blank_lines_before):
                print()

        items = list(data.items())
        for index, (key, value) in enumerate(items):
            is_last_item = index == len(items) - 1
            branch = ""
            if indent > 0:
                branch = "└── " if is_last_item else "├── "

            current_path = _path + (str(key),)
            has_children = isinstance(value, dict) and bool(value)

            label = str(key)
            if has_children:
                marker = " [-]" if (expanded_paths is None or current_path in expanded_paths) else " [+]"
                label += Color.DIM + marker + Color.RESET

            if highlight_path is not None and current_path == highlight_path:
                label = Color.get_color("cyan") + Color.BOLD + label + Color.RESET

            print(prefix + branch + label)

            if has_children:
                should_expand = expanded_paths is None or current_path in expanded_paths
                if should_expand:
                    next_prefix = prefix + ("    " if is_last_item else "│   ")
                    Tree.render(
                        value,
                        indent + 1,
                        is_last_item,
                        next_prefix,
                        0,
                        0,
                        expanded_paths=expanded_paths,
                        highlight_path=highlight_path,
                        _path=current_path,
                    )

        if indent == 0:
            for _ in range(blank_lines_after):
                print()

    @staticmethod
    def interactive(
        data: dict[str, Any],
        *,
        title: str | None = None,
        expanded_paths: set[tuple[str, ...]] | None = None,
    ) -> tuple[str, ...] | None:
        """Run an interactive tree viewer returning the selected path."""

        def _has_children(node: Any) -> bool:
            return isinstance(node, dict) and bool(node)

        def _flatten(
            node: dict[str, Any],
            path: tuple[str, ...] = (),
            depth: int = 0,
            collector: list[dict[str, Any]] | None = None,
        ) -> list[dict[str, Any]]:
            if collector is None:
                collector = []
            for key, value in node.items():
                child_path = path + (str(key),)
                has_children = _has_children(value)
                collector.append(
                    {
                        "depth": depth,
                        "path": child_path,
                        "value": value,
                        "has_children": has_children,
                    }
                )
                if has_children and (exp_set is None or child_path in exp_set):
                    _flatten(value, child_path, depth + 1, collector)
            return collector

        exp_set = set(expanded_paths) if expanded_paths else set()
        if not exp_set:
            for key, value in data.items():
                if _has_children(value):
                    exp_set.add((str(key),))

        visible = _flatten(data)
        selected = 0

        try:
            while True:
                print("\x1b[2J\x1b[H", end="")
                if title:
                    print(Color.BOLD + title + Color.RESET)
                highlight = visible[selected]["path"] if visible else None
                Tree.render(
                    data,
                    blank_lines_before=0,
                    blank_lines_after=0,
                    expanded_paths=exp_set if exp_set else set(),
                    highlight_path=highlight,
                )
                if not visible:
                    print("(empty)")
                else:
                    print("Use ?/? (j/k) to move, ? to expand, ? to collapse, Enter to select, q to cancel.")

                key = read_key()
                if not key:
                    continue
                if key in {"up", "k"} and visible:
                    selected = (selected - 1) % len(visible)
                elif key in {"down", "j"} and visible:
                    selected = (selected + 1) % len(visible)
                elif key in {"right", "l"} and visible:
                    node = visible[selected]
                    if node["has_children"]:
                        exp_set.add(node["path"])
                elif key in {"left", "h"} and visible:
                    node = visible[selected]
                    if node["path"] in exp_set:
                        exp_set.remove(node["path"])
                    elif node["path"][:-1]:
                        parent = node["path"][:-1]
                        if parent in exp_set:
                            exp_set.remove(parent)
                        for idx, item in enumerate(visible):
                            if item["path"] == parent:
                                selected = idx
                                break
                elif key in {"enter", "space"} and visible:
                    return visible[selected]["path"]
                elif key in {"escape", "q"}:
                    return None

                visible = _flatten(data)
                if visible:
                    selected = max(0, min(selected, len(visible) - 1))
                else:
                    selected = 0
        finally:
            print("\x1b[2J\x1b[H", end="")


class Grid:
    """Build grid layouts."""

    @staticmethod
    def render(
        cells: list[list[str]],
        cell_width: int = 20,
        cell_height: int = 3,
        border_style: str = "box",
        show_grid_lines: bool = True,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
        ascii_mode: bool | None = None
    ) -> None:
        """
        Render a grid layout.

        Args:
            cells: 2D list of cell contents
            cell_width: Width of each cell
            cell_height: Height of each cell
            border_style: Border style
            show_grid_lines: Show internal grid lines
            blank_lines_before: Blank lines before grid
            blank_lines_after: Blank lines after grid
            ascii_mode: Force ASCII symbols (True), Unicode (False), or auto-detect (None)
        """
        for _ in range(blank_lines_before):
            print()

        symbols = Symbols.get_symbols(border_style, ascii_mode)
        rows = len(cells)
        cols = len(cells[0]) if cells else 0

        # Top border
        top = symbols["top_left"]
        for j in range(cols):
            top += symbols["horizontal"] * cell_width
            if j < cols - 1 and show_grid_lines:
                top += symbols["t_down"]
        top += symbols["top_right"]
        print(top)

        # Rows
        for i, row in enumerate(cells):
            # Cell content rows
            for h in range(cell_height):
                line = symbols["vertical"]
                for j, cell in enumerate(row):
                    cell_lines = wrap_text(str(cell), cell_width - 2)
                    if h < len(cell_lines):
                        cell_text = cell_lines[h]
                        colored_text = Color.apply_inline_markup(cell_text)
                        visible_len = text_engine.visible_width(colored_text)
                        padding = (cell_width - 2 - visible_len) // 2
                        line += " " + " " * padding + colored_text + " " * (cell_width - 2 - visible_len - padding) + " "
                    else:
                        line += " " * cell_width

                    if j < len(row) - 1 and show_grid_lines:
                        line += symbols["vertical"]
                line += symbols["vertical"]
                print(line)

            # Row separator
            if i < rows - 1 and show_grid_lines:
                separator = symbols["t_right"]
                for j in range(cols):
                    separator += symbols["horizontal"] * cell_width
                    if j < cols - 1:
                        separator += symbols["cross"]
                separator += symbols["t_left"]
                print(separator)

        # Bottom border
        bottom = symbols["bottom_left"]
        for j in range(cols):
            bottom += symbols["horizontal"] * cell_width
            if j < cols - 1 and show_grid_lines:
                bottom += symbols["t_up"]
        bottom += symbols["bottom_right"]
        print(bottom)

        for _ in range(blank_lines_after):
            print()


class SelectableList:
    @staticmethod
    def render(
        items: list[str],
        selected_index: int = 0,
        marker: str = "►",
        selected_color: str = "accent",
        width: int = 40,
    ) -> None:
        sel = Color.get_color(selected_color)
        for i, it in enumerate(items):
            prefix = marker if i == selected_index else " " * len(marker)
            text = f"{prefix} {it}"
            if i == selected_index:
                text = sel + Color.BOLD + text + Color.RESET
            print(text[:width])


class ScrollableList:
    @staticmethod
    def render(
        items: list[str],
        start: int = 0,
        height: int = 5,
        width: int = 40,
    ) -> None:
        window = items[start : start + height]
        for line in window:
            print(line[:width])
