"""Data presentation components (tables, charts, comparisons)."""
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from ..core import render_call
from ..highlight import highlight_text
from ..style.colors import Color
from ..symbols import Symbols
from ..utils import get_visible_length

if TYPE_CHECKING:
    from collections.abc import Sequence

__all__ = [
    "Table",
    "Chart",
    "Comparison",
    "KeyValue",
    "CodeBlock",
    "table",
    "chart",
    "comparison",
    "key_value",
    "code_block",
]


@dataclass(slots=True)
class TableCell:
    """Normalized representation of a table cell."""

    text: str
    colspan: int = 1
    rowspan: int = 1
    align: Literal["left", "center", "right"] | None = None
    style: str | None = None

    def normalized_align(self, fallback: Literal["left", "center", "right"]) -> Literal["left", "center", "right"]:
        if self.align in ("left", "center", "right"):
            return self.align
        return fallback


@dataclass(slots=True)
class _RenderSlot:
    cell: TableCell
    master: bool
    start_col: int
    span_offset: int


@dataclass(slots=True)
class _SpanTracker:
    cell: TableCell
    rows_left: int
    start_col: int


def table(*args: object, **kwargs: object):
    """Return a renderable table component."""
    return render_call(Table.render, *args, **kwargs)


def chart(*args: object, **kwargs: object):
    """Return a renderable chart component."""
    return render_call(Chart.bar_chart, *args, **kwargs)


def comparison(*args: object, **kwargs: object):
    """Return a renderable comparison component."""
    return render_call(Comparison.render, *args, **kwargs)


def key_value(*args: object, **kwargs: object):
    """Return a renderable key-value listing."""
    return render_call(KeyValue.render, *args, **kwargs)


def code_block(*args: object, **kwargs: object):
    """Return a renderable code block with lightweight syntax highlighting."""
    return render_call(CodeBlock.render, *args, **kwargs)


class Table:
    """Build formatted tables with headers, merging, and scrolling."""

    @staticmethod
    def render(
        headers: Sequence[str],
        rows: Sequence[Sequence[Any]],
        width: int | None = None,
        column_alignments: Sequence[Literal["left", "center", "right"]] | None = None,
        column_widths: list[int] | None = None,
        header_color: str | None = None,
        border_style: str = "box",
        show_header: bool = True,
        zebra_stripes: bool = False,
        zebra_color: str = "gray",
        scroll_offset: int = 0,
        max_rows: int | None = None,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
        ascii_mode: bool | None = None,
    ) -> str:
        """Render a formatted table with support for merged cells and scrolling."""

        if width is not None:
            _ = width  # Reserved for future layout handling

        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        if not headers:
            out.append("(empty table)")
            for _ in range(blank_lines_after):
                out.append("")
            return "\n".join(out)

        symbols = Symbols.get_symbols(border_style, ascii_mode)
        num_cols = len(headers)
        rows = list(rows)

        if column_alignments is None:
            column_alignments = ("left",) * num_cols
        else:
            column_alignments = list(column_alignments)
            if len(column_alignments) < num_cols:
                column_alignments.extend(("left",) * (num_cols - len(column_alignments)))
            elif len(column_alignments) > num_cols:
                column_alignments = column_alignments[:num_cols]
            column_alignments = [
                align if align in ("left", "center", "right") else "left"
                for align in column_alignments
            ]

        def _coerce_cell(value: Any) -> TableCell:  # Any: table cells accept heterogeneous input types
            if isinstance(value, TableCell):
                return value
            if isinstance(value, dict):
                text = str(value.get("text", ""))
                colspan = int(value.get("colspan", 1) or 1)
                rowspan = int(value.get("rowspan", 1) or 1)
                align = value.get("align")
                style = value.get("style")
                return TableCell(text=text, colspan=max(1, colspan), rowspan=max(1, rowspan), align=align, style=style)
            if isinstance(value, tuple):
                text = str(value[0]) if value else ""
                colspan = 1
                rowspan = 1
                align: Literal["left", "center", "right"] | None = None
                style: str | None = None
                if len(value) > 1 and isinstance(value[1], int):
                    colspan = value[1]
                if len(value) > 2 and isinstance(value[2], int):
                    rowspan = value[2]
                if len(value) > 3 and isinstance(value[3], str):
                    if value[3] in ("left", "center", "right"):
                        align = value[3]
                    else:
                        style = value[3]
                if len(value) > 4 and isinstance(value[4], str):
                    style = value[4]
                return TableCell(text=text, colspan=max(1, colspan), rowspan=max(1, rowspan), align=align, style=style)
            return TableCell(text=str(value), colspan=1, rowspan=1)

        normalized_rows: list[list[TableCell]] = []
        for row in rows:
            cells = [_coerce_cell(cell) for cell in row]
            total_span = sum(max(1, cell.colspan) for cell in cells)
            if total_span < num_cols:
                cells.extend(TableCell("") for _ in range(num_cols - total_span))
            elif total_span > num_cols:
                raise ValueError(f"Row spans exceed column count ({total_span} > {num_cols})")
            normalized_rows.append(cells)

        if column_widths is None:
            column_widths = []
            for _idx, header in enumerate(headers):
                header_text = Color.apply_inline_markup(str(header))
                header_len = get_visible_length(header_text)
                column_widths.append(max(3, header_len + 2))
        else:
            column_widths = list(column_widths)
            if len(column_widths) < num_cols:
                column_widths.extend([3] * (num_cols - len(column_widths)))
            elif len(column_widths) > num_cols:
                column_widths = column_widths[:num_cols]

        def _grow_width(start_col: int, span: int, required: int) -> None:
            current = sum(column_widths[start_col : start_col + span])
            if current >= required:
                return
            diff = required - current
            base, remainder = divmod(diff, span)
            for offset in range(span):
                column_widths[start_col + offset] += base + (1 if offset < remainder else 0)

        for row in normalized_rows:
            col_idx = 0
            for cell in row:
                span = max(1, cell.colspan)
                content = Color.apply_inline_markup(cell.text)
                required = get_visible_length(content) + 2
                _grow_width(col_idx, span, required)
                col_idx += span

        expanded_rows: list[list[_RenderSlot]] = []
        active: list[_SpanTracker | None] = [None] * num_cols

        for row in normalized_rows:
            row_slots: list[_RenderSlot] = []
            used_trackers: list[_SpanTracker] = []
            col_idx = 0
            cell_iter = iter(row)

            while col_idx < num_cols:
                tracker = active[col_idx]
                if tracker is not None:
                    row_slots.append(
                        _RenderSlot(
                            cell=tracker.cell,
                            master=False,
                            start_col=tracker.start_col,
                            span_offset=col_idx - tracker.start_col,
                        )
                    )
                    if tracker not in used_trackers:
                        used_trackers.append(tracker)
                    col_idx += 1
                    continue

                try:
                    cell = next(cell_iter)
                except StopIteration:
                    cell = TableCell("")

                span = max(1, cell.colspan)
                row_slots.append(_RenderSlot(cell=cell, master=True, start_col=col_idx, span_offset=0))
                for offset in range(1, span):
                    row_slots.append(
                        _RenderSlot(cell=cell, master=False, start_col=col_idx, span_offset=offset)
                    )

                if cell.rowspan > 1:
                    tracker = _SpanTracker(cell=cell, rows_left=cell.rowspan - 1, start_col=col_idx)
                    for offset in range(span):
                        active[col_idx + offset] = tracker
                else:
                    for offset in range(span):
                        active[col_idx + offset] = None

                col_idx += span

            expanded_rows.append(row_slots)

            for tracker in used_trackers:
                tracker.rows_left -= 1
                if tracker.rows_left <= 0:
                    for offset in range(tracker.cell.colspan):
                        idx = tracker.start_col + offset
                        if 0 <= idx < len(active) and active[idx] is tracker:
                            active[idx] = None

        while any(active):
            row_slots = []
            used_trackers = []
            col_idx = 0
            while col_idx < num_cols:
                tracker = active[col_idx]
                if tracker is not None:
                    row_slots.append(
                        _RenderSlot(
                            cell=tracker.cell,
                            master=False,
                            start_col=tracker.start_col,
                            span_offset=col_idx - tracker.start_col,
                        )
                    )
                    if tracker not in used_trackers:
                        used_trackers.append(tracker)
                    col_idx += 1
                else:
                    blank_cell = TableCell("")
                    row_slots.append(_RenderSlot(cell=blank_cell, master=True, start_col=col_idx, span_offset=0))
                    col_idx += 1
            expanded_rows.append(row_slots)
            for tracker in used_trackers:
                tracker.rows_left -= 1
                if tracker.rows_left <= 0:
                    for offset in range(tracker.cell.colspan):
                        idx = tracker.start_col + offset
                        if 0 <= idx < len(active) and active[idx] is tracker:
                            active[idx] = None

        total_rows = len(expanded_rows)
        scroll_offset = max(0, min(scroll_offset, max(0, total_rows - 1))) if total_rows else 0
        if max_rows is not None and max_rows >= 0:
            end_index = min(total_rows, scroll_offset + max_rows)
            visible_rows = expanded_rows[scroll_offset:end_index]
        else:
            visible_rows = expanded_rows[scroll_offset:]

        top_border = symbols["top_left"]
        for idx, width_val in enumerate(column_widths):
            top_border += symbols["horizontal"] * width_val
            if idx < num_cols - 1:
                top_border += symbols["t_down"]
        top_border += symbols["top_right"]
        out.append(top_border)

        if show_header:
            header_line = symbols["vertical"]
            for col_idx, header in enumerate(headers):
                cell_width = column_widths[col_idx]
                cell_text = Color.apply_inline_markup(str(header))
                align = column_alignments[col_idx] if column_alignments else "left"
                visible_len = get_visible_length(cell_text)
                pad = max(0, cell_width - visible_len)
                if align == "center":
                    left_pad = pad // 2
                    right_pad = pad - left_pad
                elif align == "right":
                    left_pad = pad
                    right_pad = 0
                else:
                    left_pad = 0
                    right_pad = pad
                rendered_header = (" " * left_pad) + cell_text + (" " * right_pad)
                if header_color:
                    hcolor = Color.get_color(header_color)
                    if hcolor:
                        rendered_header = hcolor + rendered_header + Color.RESET
                header_line += rendered_header + symbols["vertical"]
            out.append(header_line)

            separator = symbols["t_right"]
            for idx, width_val in enumerate(column_widths):
                separator += symbols["horizontal"] * width_val
                if idx < num_cols - 1:
                    separator += symbols["cross"]
            separator += symbols["t_left"]
            out.append(separator)

        base_row_index = scroll_offset
        for relative_idx, slots in enumerate(visible_rows):
            row_index = base_row_index + relative_idx
            zebra_active = zebra_stripes and row_index % 2 == 1
            row_output = symbols["vertical"]
            col_idx = 0
            while col_idx < num_cols:
                slot = slots[col_idx]
                if slot.span_offset > 0:
                    col_idx += 1
                    continue
                cell = slot.cell
                span = max(1, cell.colspan)
                width_span = sum(column_widths[col_idx : col_idx + span])

                text = cell.text if slot.master else ""
                rendered = Color.apply_inline_markup(text)
                if cell.style:
                    style_code = Color.get_color(cell.style)
                    if style_code:
                        rendered = style_code + rendered + Color.RESET

                align = cell.normalized_align(column_alignments[col_idx] if column_alignments else "left")
                visible_len = get_visible_length(rendered)
                pad = max(0, width_span - visible_len)
                if align == "center":
                    left_pad = pad // 2
                    right_pad = pad - left_pad
                elif align == "right":
                    left_pad = pad
                    right_pad = 0
                else:
                    left_pad = 0
                    right_pad = pad

                cell_render = (" " * left_pad) + rendered + (" " * right_pad)
                if zebra_active:
                    zcolor = Color.get_color(zebra_color)
                    if zcolor:
                        cell_render = zcolor + cell_render + Color.RESET

                row_output += cell_render + symbols["vertical"]
                col_idx += span

            out.append(row_output)

        bottom_border = symbols["bottom_left"]
        for idx, width_val in enumerate(column_widths):
            bottom_border += symbols["horizontal"] * width_val
            if idx < num_cols - 1:
                bottom_border += symbols["t_up"]
        bottom_border += symbols["bottom_right"]
        out.append(bottom_border)

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)


class Chart:
    """Build simple bar charts and visualizations."""

    @staticmethod
    def bar_chart(
        data: dict[str, int | float],
        width: int = 60,
        max_bar_width: int = 40,
        color: str = "cyan",
        show_values: bool = True,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0
    ) -> str:
        """
        Render a horizontal bar chart.

        Args:
            data: Dictionary of labels to values
            width: Total chart width
            max_bar_width: Maximum width of bars
            color: Bar color
            show_values: Show numeric values
            blank_lines_before: Blank lines before chart
            blank_lines_after: Blank lines after chart
        """
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        if not data:
            out.append("No data to display")
            return "\n".join(out)

        max_value = max(data.values())
        max_label_len = max(len(str(k)) for k in data.keys())
        bar_color = Color.get_color(color)

        for label, value in data.items():
            # Calculate bar length
            if max_value > 0:
                bar_length = int((value / max_value) * max_bar_width)
            else:
                bar_length = 0

            # Build bar
            bar = Symbols.BAR_FILLED() * bar_length
            colored_bar = bar_color + bar + Color.RESET

            # Format label
            padded_label = str(label).ljust(max_label_len)

            # Build line
            if show_values:
                line = f"{padded_label} │ {colored_bar} {value}"
            else:
                line = f"{padded_label} │ {colored_bar}"

            out.append(line)

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)

    @staticmethod
    def line_chart(
        series: dict[str, list[float]],
        width: int = 60,
        height: int = 10,
        color: str = "cyan",
    ) -> str:
        if not series:
            return "No data"
        # Normalize values into grid
        all_vals = [v for vals in series.values() for v in vals]
        if not all_vals:
            return "No data"
        vmin, vmax = min(all_vals), max(all_vals)
        span = (vmax - vmin) or 1
        cols = max(len(next(iter(series.values()))), 1)
        grid = [[" " for _ in range(cols)] for _ in range(height)]
        for _si, (_name, vals) in enumerate(series.items()):
            for x, v in enumerate(vals):
                y = height - 1 - int((v - vmin) / span * (height - 1))
                grid[y][x % cols] = "●"
        return "\n".join(Color.get_color(color) + "".join(row) + Color.RESET for row in grid)

    @staticmethod
    def area_chart(
        values: list[float],
        width: int = 60,
        height: int = 10,
        color: str = "accent",
    ) -> str:
        if not values:
            return "No data"
        vmin, vmax = min(values), max(values)
        span = (vmax - vmin) or 1
        cols = max(len(values), 1)
        grid = [[" " for _ in range(cols)] for _ in range(height)]
        for x, v in enumerate(values):
            top_y = height - 1 - int((v - vmin) / span * (height - 1))
            for y in range(top_y, height):
                grid[y][x % cols] = "█"
        col = Color.get_color(color)
        return "\n".join(col + "".join(row) + Color.RESET for row in grid)

    @staticmethod
    def pie_chart(data: dict[str, float]) -> str:
        total = sum(data.values()) or 1
        lines: list[str] = []
        for label, value in data.items():
            pct = value / total * 100
            bar = Symbols.BAR_FILLED() * int(pct / 4)
            lines.append(f"{label:<10} {bar} {pct:.1f}%")
        return "\n".join(lines)

    @staticmethod
    def heatmap(matrix: list[list[float]], shades: str = " .:-=+*#%@") -> str:
        if not matrix:
            return ""
        flat = [v for row in matrix for v in row]
        vmin, vmax = min(flat), max(flat)
        span = (vmax - vmin) or 1
        out_lines: list[str] = []
        for row in matrix:
            line = []
            for v in row:
                idx = int((v - vmin) / span * (len(shades) - 1))
                line.append(shades[idx])
            out_lines.append("".join(line))
        return "\n".join(out_lines)

    @staticmethod
    def radar_chart(axes: list[str], values: list[float]) -> str:
        # Simple textual representation
        max_label = max(len(a) for a in axes) if axes else 0
        max_val = max(values) if values else 1
        lines: list[str] = []
        for a, v in zip(axes, values, strict=False):
            bar = Symbols.BAR_FILLED() * int((v / max_val) * 20)
            lines.append(f"{a:<{max_label}} | {bar}")
        return "\n".join(lines)

    @staticmethod
    def sparkline(
        values: list[int | float],
        color: str = "accent",
        height: int = 1,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
    ) -> str:
        """Render a sparkline using Unicode block characters."""
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        if not values:
            out.append("")
            return "\n".join(out)

        ticks = "▁▂▃▄▅▆▇█"
        vmin = min(values)
        vmax = max(values)
        span = (vmax - vmin) or 1
        scaled = [int((v - vmin) / span * (len(ticks) - 1)) for v in values]
        col = Color.get_color(color)
        line = col + "".join(ticks[s] for s in scaled) + Color.RESET
        for _ in range(height):
            out.append(line)

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)

    @staticmethod
    def histogram(
        values: list[int | float],
        bins: int = 10,
        width: int = 40,
        color: str = "cyan",
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
    ) -> str:
        """Render a simple horizontal histogram."""
        lines: list[str] = []
        for _ in range(blank_lines_before):
            lines.append("")

        if not values or bins <= 0:
            lines.append("No data")
            return "\n".join(lines)

        vmin = min(values)
        vmax = max(values)
        span = (vmax - vmin) or 1
        step = span / bins
        # Count bins
        counts = [0 for _ in range(bins)]
        for v in values:
            idx = int((v - vmin) / span * bins)
            if idx == bins:
                idx -= 1
            counts[idx] += 1

        max_count = max(counts) or 1
        col = Color.get_color(color)
        for i, c in enumerate(counts):
            bar_len = int(c / max_count * width)
            low = vmin + i * step
            high = low + step
            label = f"[{low:.2f},{high:.2f})"
            lines.append(f"{label} {col}{Symbols.BAR_FILLED() * bar_len}{Color.RESET} {c}")

        for _ in range(blank_lines_after):
            lines.append("")
        return "\n".join(lines)


class Comparison:
    """Build comparison displays."""

    @staticmethod
    def render(
        left_title: str,
        left_data: dict[str, Any],
        right_title: str,
        right_data: dict[str, Any],
        width: int = 80,
        vs_text: str = "VS",
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
        ascii_mode: bool | None = None
    ) -> str:
        """
        Render a side-by-side comparison.

        Args:
            left_title: Title for left side
            left_data: Data for left side
            right_title: Title for right side
            right_data: Data for right side
            width: Total width
            vs_text: Text between sides
            blank_lines_before: Blank lines before comparison
            blank_lines_after: Blank lines after comparison
        """
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        # Calculate column widths
        vs_width = len(vs_text) + 4
        col_width = (width - vs_width) // 2

        # Titles
        left_title_formatted = Color.BOLD + left_title.center(col_width) + Color.RESET
        right_title_formatted = Color.BOLD + right_title.center(col_width) + Color.RESET
        vs_formatted = Color.get_color("gold") + Color.BOLD + vs_text.center(vs_width) + Color.RESET

        out.append(left_title_formatted + vs_formatted + right_title_formatted)

        # Divider
        out.append("─" * col_width + " " * vs_width + "─" * col_width)

        # Data rows
        all_keys = set(left_data.keys()) | set(right_data.keys())
        for key in all_keys:
            left_val = str(left_data.get(key, "-"))
            right_val = str(right_data.get(key, "-"))

            key_str = f"{key}:"
            left_str = f"{key_str:<15} {left_val}".ljust(col_width)
            right_str = f"{key_str:<15} {right_val}".ljust(col_width)

            out.append(left_str + " " * vs_width + right_str)

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)


class KeyValue:
    """Build key-value displays."""

    @staticmethod
    def render(
        data: dict[str, Any],
        key_color: str = "cyan",
        value_color: str = "white",
        separator: str = ": ",
        indent: int = 0,
        blank_lines_before: int = 0,
        blank_lines_after: int = 0
    ) -> str:
        """
        Render key-value pairs.

        Args:
            data: Dictionary of key-value pairs
            key_color: Color for keys
            value_color: Color for values
            separator: Separator between key and value
            indent: Left indentation
            blank_lines_before: Blank lines before display
            blank_lines_after: Blank lines after display
        """
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        kcolor = Color.get_color(key_color)
        vcolor = Color.get_color(value_color)
        indent_str = " " * indent

        max_key_len = max(len(str(k)) for k in data.keys()) if data else 0

        for key, value in data.items():
            key_str = str(key).ljust(max_key_len)
            value_str = Color.apply_inline_markup(str(value))

            out.append(f"{indent_str}{kcolor}{key_str}{Color.RESET}{separator}{vcolor}{value_str}{Color.RESET}")

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)


class CodeBlock:
    """Render code with simple syntax highlighting and optional line numbers."""

    @staticmethod
    def render(
        code: str,
        language: str | None = None,
        width: int = 80,
        show_line_numbers: bool = True,
        number_color: str = "muted",
        code_color: str = "fg",
        border_style: str = "rounded",
        blank_lines_before: int = 0,
        blank_lines_after: int = 0,
    ) -> str:
        out: list[str] = []
        for _ in range(blank_lines_before):
            out.append("")

        # Highlighting delegated to textforge.highlight

        num_col = Color.get_color(number_color)
        code_col = Color.get_color(code_color)
        symbols = Symbols.get_symbols(border_style)

        # Prepare lines
        raw_lines = code.split("\n")

        # Top border
        out.append(symbols["top_left"] + symbols["horizontal"] * (width - 2) + symbols["top_right"])

        for idx, line in enumerate(raw_lines, start=1):
            # Basic escaping for braces so our markup does not parse code braces
            safe_line = line.replace("{", "{{").replace("}", "}}")
            safe_line = highlight_text(safe_line, language)

            # Build content line with optional numbers
            prefix = f"{num_col}{idx:>4}{Color.RESET}  " if show_line_numbers else ""
            content = prefix + code_col + safe_line + Color.RESET
            visible = get_visible_length(content)
            pad = max(0, width - 2 - visible)
            out.append(symbols["vertical"] + content + " " * pad + symbols["vertical"])

        # Bottom border
        out.append(symbols["bottom_left"] + symbols["horizontal"] * (width - 2) + symbols["bottom_right"])

        for _ in range(blank_lines_after):
            out.append("")
        return "\n".join(out)

    @staticmethod
    def scroll(
        code: str,
        start_line: int,
        num_lines: int,
        width: int = 80,
    ) -> None:
        lines = code.split("\n")
        window = lines[start_line - 1 : start_line - 1 + num_lines]
        CodeBlock.render("\n".join(window), width=width, show_line_numbers=True)
