"""Reporting helpers for optimisation runs."""

from __future__ import annotations

import logging
from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime
from pathlib import Path
from typing import Any, Protocol, cast

import numpy as np
from numpy.typing import NDArray
from prettytable import PrettyTable

from ..monitoring.client import MonitorConfig, OptimizationMonitorClient
from .options import EngineOptions, MonitorOptions
from .parameters import BoundsPayload, ParameterSpace
from .storage import StorageWorkspace

Array = NDArray[np.float64]


class Reporter(Protocol):
    """Lightweight interface for emitting run lifecycle events."""

    def run_started(
        self,
        phi0_vec: Array,
        theta0_vec: Array,
        bounds: BoundsPayload,
        optimizer_name: str,
        runner_jobs: int | None,
    ) -> None:
        """Called once before optimisation begins."""

    def record_iteration(
        self,
        index: int,
        phi_vec: Array,
        theta_vec: Array,
        cost: float,
        metrics: Mapping[str, Any],
        series: Mapping[str, dict[str, Any]],
        log_output: bool,
    ) -> None:
        """Called after each evaluation to log/monitor progress."""

    def completed(
        self,
        phi_vec: Array | None,
        theta_opt: Mapping[str, float],
        optimizer_meta: Mapping[str, Any],
        metrics: Mapping[str, Any],
    ) -> None:
        """Called when optimisation finishes successfully."""

    def failed(self, reason: str) -> None:
        """Called when optimisation terminates with an error."""

    def close(self) -> None:
        """Release any resources held by the reporter."""


class LoggerProtocol(Protocol):
    """Minimal logger interface used by reporters."""

    def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
    def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
    def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ...


class NullReporter:
    """No-op reporter used when monitoring/logging are disabled."""

    def run_started(
        self,
        _phi0_vec: Array,
        _theta0_vec: Array,
        _bounds: BoundsPayload,
        _optimizer_name: str,
        _runner_jobs: int | None,
    ) -> None:
        """Ignore run start event."""
        return None

    def record_iteration(
        self,
        _index: int,
        _phi_vec: Array,
        _theta_vec: Array,
        _cost: float,
        _metrics: Mapping[str, Any],
        _series: Mapping[str, dict[str, Any]],
        _log_output: bool,
    ) -> None:
        """Ignore iteration event."""
        return None

    def completed(
        self,
        _phi_vec: Array | None,
        _theta_opt: Mapping[str, float],
        _optimizer_meta: Mapping[str, Any],
        _metrics: Mapping[str, Any],
    ) -> None:
        """Ignore completion event."""
        return None

    def failed(self, _reason: str) -> None:
        """Ignore failure event."""
        return None

    def close(self) -> None:
        """No resources to release."""
        return None


class CompositeReporter:
    """Fan-out reporter to keep console and monitor sinks independent."""

    def __init__(
        self, reporters: Iterable[Reporter], *, logger: LoggerProtocol | None = None
    ) -> None:
        """Collect multiple reporters and optionally log failures."""
        self._reporters: list[Reporter] = list(reporters)
        self._logger = logger

    def _fanout(self, method: str, *args: object, **kwargs: object) -> None:
        for reporter in self._reporters:
            try:
                getattr(reporter, method)(*args, **kwargs)
            except Exception:
                if self._logger:
                    self._logger.exception("Reporter '%s' failed", type(reporter).__name__)

    def run_started(
        self,
        phi0_vec: Array,
        theta0_vec: Array,
        bounds: BoundsPayload,
        optimizer_name: str,
        runner_jobs: int | None,
    ) -> None:
        """Relay run start to all reporters."""
        self._fanout("run_started", phi0_vec, theta0_vec, bounds, optimizer_name, runner_jobs)

    def record_iteration(
        self,
        index: int,
        phi_vec: Array,
        theta_vec: Array,
        cost: float,
        metrics: Mapping[str, Any],
        series: Mapping[str, dict[str, Any]],
        log_output: bool,
    ) -> None:
        """Relay iteration data to all reporters."""
        self._fanout(
            "record_iteration", index, phi_vec, theta_vec, cost, metrics, series, log_output
        )

    def completed(
        self,
        phi_vec: Array | None,
        theta_opt: Mapping[str, float],
        optimizer_meta: Mapping[str, Any],
        metrics: Mapping[str, Any],
    ) -> None:
        """Relay completion event."""
        self._fanout("completed", phi_vec, theta_opt, optimizer_meta, metrics)

    def failed(self, reason: str) -> None:
        """Relay failure event."""
        self._fanout("failed", reason)

    def close(self) -> None:
        """Close all reporters."""
        self._fanout("close")


class ConsoleReporter:
    """Console-only reporter that mirrors previous logging behaviour."""

    def __init__(
        self,
        logger: LoggerProtocol,
        parameter_space: ParameterSpace,
        case_descriptions: list[Mapping[str, Any]],
        workspace: StorageWorkspace,
        *,
        reparam_enabled: bool,
    ) -> None:
        """Prepare console reporter with configuration context."""
        self.logger = logger
        self.parameter_space = parameter_space
        self.case_descriptions = case_descriptions
        self.workspace = workspace
        self.reparam_enabled = reparam_enabled

    # Convenience hooks (not part of Reporter protocol)
    def log_banner(self) -> None:
        """Print an ASCII banner on startup."""
        banner_lines = [
            "",
            "=================================================",
            "             _____ _____ ____  _  ___        _   ",
            " _ __  _   _|  ___| ____| __ )(_)/ _ \ _ __ | |_ ",  # noqa:W605
            "| '_ \| | | | |_  |  _| |  _ \| | | | | '_ \| __|",  # noqa:W605
            "| |_) | |_| |  _| | |___| |_) | | |_| | |_) | |_ ",
            "| .__/ \__, |_|   |_____|____/|_|\___/| .__/ \__|",  # noqa:W605
            "|_|    |___/                          |_|        ",
            "                                                 ",
            "               Optimization Engine               ",
            "=================================================",
            "",
        ]
        self.logger.info("\n" + "\n".join(banner_lines))

    def log_configuration(
        self,
        *,
        options: EngineOptions,
        runner_command: tuple[str, ...],
        runner_env: Mapping[str, str] | None,
        optimizer_adapter: str,
    ) -> None:
        """Log a structured summary of the optimisation setup."""
        optimizer_opts = options.optimizer
        jacobian_opts = options.jacobian
        monitor_opts = options.monitor
        runner_opts = options.runner

        params = self.parameter_space.parameters()
        param_lines: list[str] = []
        if params:
            for param in params:
                lo, hi = param.bounds if param.bounds is not None else (None, None)
                bounds_str = f"({lo},{hi})"
                param_lines.append(
                    "• "
                    f"{param.name}: θ₀={param.theta0:.6g}, vary={param.vary}, "
                    f"bounds={bounds_str}"
                )
        else:
            param_lines.append("• (no parameters)")

        def _section(title: str, rows: Sequence[str]) -> list[str]:
            lines = [f"│ {title}"]
            if rows:
                lines.extend(f"│   {row}" for row in rows)
            else:
                lines.append("│   • (none)")
            lines.append("│")
            return lines

        lines: list[str] = ["┌────────────── ENGINE CONFIGURATION ────────────────"]
        lines += _section(
            "Optimizer",
            [
                f"• name: {optimizer_opts.name}",
                f"• adapter: {optimizer_adapter}",
                f"• reparametrize: {optimizer_opts.reparametrize}",
                f"• settings: {optimizer_opts.settings or {}}",
            ],
        )
        lines += _section(
            "Jacobian",
            [
                f"• enabled: {jacobian_opts.enabled}",
                f"• perturbation: {float(jacobian_opts.perturbation)}",
                f"• parallel: {bool(jacobian_opts.parallel)}",
            ],
        )
        lines += _section("Storage", self.workspace.describe())
        lines += _section(
            "Runner",
            [
                f"• jobs: {runner_opts.jobs}",
                f"• command: {' '.join(runner_command)}",
                f"• env keys: {sorted((runner_env or {}).keys())}",
            ],
        )
        lines += _section(
            "Monitor",
            [
                f"• enabled: {monitor_opts.enabled}",
                f"• socket: {monitor_opts.socket}",
                f"• label: {monitor_opts.label}",
            ],
        )
        case_lines: list[str] = []
        for entry in self.case_descriptions:
            subfolder = entry.get("subfolder") or "."
            experiments = entry.get("experiments") or []
            case_lines.append(f"• subfolder='{subfolder}'")
            case_lines.append(f"    experiments: {experiments or '(none)'}")
        lines += _section("Cases", case_lines)
        lines += _section("Parameters", param_lines)
        lines[-1] = "└─────────────────────────────────────────────────────"
        self.logger.info("\n" + "\n".join(lines))

    # Reporter protocol
    def run_started(
        self,
        phi0_vec: Array,
        theta0_vec: Array,
        bounds: BoundsPayload,
        optimizer_name: str,
        runner_jobs: int | None,
    ) -> None:
        """Report initial configuration once optimisation begins."""
        _ = (phi0_vec, theta0_vec, bounds, optimizer_name, runner_jobs)
        return None

    def record_iteration(
        self,
        index: int,
        phi_vec: Array,
        theta_vec: Array,
        cost: float,
        metrics: Mapping[str, Any],
        series: Mapping[str, dict[str, Any]],
        log_output: bool,
    ) -> None:
        """Log iteration summary to the console."""
        if not log_output:
            return
        _ = series
        r_squared = metrics.get("r_squared", {})
        nrmse_val = metrics.get("nrmse")  # metrics assembled as fixed NRMSE/R²
        table = PrettyTable()
        if self.reparam_enabled:
            table.field_names = ["parameter", "phi", "theta"]
        else:
            table.field_names = ["parameter", "theta"]
        for name, phi_value, theta_value in zip(
            self.parameter_space.names, phi_vec, theta_vec, strict=True
        ):
            theta_float = float(theta_value)
            if self.reparam_enabled:
                table.add_row([name, f"{float(phi_value):.6e}", f"{theta_float:.6e}"])
            else:
                table.add_row([name, f"{theta_float:.6e}"])
        nrmse_display = (
            "nan"
            if nrmse_val is None or not np.isfinite(nrmse_val)
            else f"{float(nrmse_val):.6e}"
        )
        block_lines = [
            "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
            f"┃ Iteration #{index:03d}",
            "┣━━━━━━━━ Parameters ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
        ]
        block_lines.extend(f"┃   {line}" for line in table.get_string().splitlines())
        block_lines.append("┣━━━━━━━━ Results ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
        block_lines.append(f"┃   cost  : {cost:.6e}")
        block_lines.append(f"┃   nrmse : {nrmse_display}")
        if r_squared:
            block_lines.append("┃   R² map:")
            for key in sorted(r_squared):
                value = r_squared[key]
                if value is None or not np.isfinite(value):
                    display = "nan"
                else:
                    display = f"{float(value):.6f}"
                block_lines.append(f"┃     • {key}: {display}")
        else:
            block_lines.append("┃   R² map: (no data)")
        block_lines.append("┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
        block = "\n".join(block_lines)
        self.logger.info(f"\n{block}")

    def completed(
        self,
        phi_vec: Array | None,
        theta_opt: Mapping[str, float],
        optimizer_meta: Mapping[str, Any],
        metrics: Mapping[str, Any],
    ) -> None:
        """Log final parameter table and metric summary."""
        if phi_vec is None:
            return
        _ = optimizer_meta
        table = PrettyTable()
        table.field_names = ["parameter", "phi", "theta"]
        phi_values = np.asarray(phi_vec, dtype=float).reshape(-1)
        theta_values = {k: float(v) for k, v in theta_opt.items()}
        for name, phi_value in zip(self.parameter_space.names, phi_values, strict=True):
            theta_value = theta_values.get(name, float("nan"))
            theta_display = "nan" if not np.isfinite(theta_value) else f"{theta_value:+.6e}"
            phi_display = f"{float(phi_value):+.6e}"
            table.add_row([name, phi_display, theta_display])
        nrmse_val = metrics.get("nrmse")  # fixed metric from MetricsAssembler
        nrmse_display = (
            "nan"
            if nrmse_val is None or not np.isfinite(nrmse_val)
            else f"{float(nrmse_val):.6e}"
        )
        self.logger.info(
            "\nOptimization complete nrmse=%s\n%s",
            nrmse_display,
            table.get_string(),
        )

    def failed(self, reason: str) -> None:
        """Log a failure message."""
        self.logger.error("Optimization failed: %s", reason)

    def close(self) -> None:
        """No resources to release for console logging."""
        return None


class MonitorReporter:
    """Forward lifecycle events to the monitoring service when available."""

    def __init__(
        self,
        monitor_opts: MonitorOptions,
        parameter_space: ParameterSpace,
        workspace: StorageWorkspace,
        case_descriptions: list[Mapping[str, Any]],
        logger: LoggerProtocol | None = None,
    ) -> None:
        """Initialise the monitoring client if enabled."""
        self.parameter_space = parameter_space
        self.workspace = workspace
        self.case_descriptions = case_descriptions
        self.logger = logger or logging.getLogger(__name__)
        self._monitor_opts = monitor_opts
        self._client: OptimizationMonitorClient | None = None
        if monitor_opts.enabled:
            self._client = self._build_client(monitor_opts, workspace)

    @property
    def enabled(self) -> bool:
        """True when the monitoring client is active."""
        return self._client is not None

    def _build_client(
        self, monitor_opts: MonitorOptions, workspace: StorageWorkspace
    ) -> OptimizationMonitorClient | None:
        socket_path = Path(monitor_opts.socket).expanduser() if monitor_opts.socket else None
        try:
            config = MonitorConfig(
                socket_path=socket_path,
                label=monitor_opts.label or _default_label(workspace),
            )
            return OptimizationMonitorClient(config)
        except Exception:
            self.logger.exception("Monitor initialisation failed; disabling monitoring.")
            return None

    def run_started(
        self,
        phi0_vec: Array,
        theta0_vec: Array,
        bounds: BoundsPayload,
        optimizer_name: str,
        runner_jobs: int | None,
    ) -> None:
        """Emit a run start event to the monitoring service."""
        if self._client is None:
            return
        parameters = {
            "names": list(self.parameter_space.names),
            "phi0": [float(x) for x in np.asarray(phi0_vec, dtype=float)],
            "theta0": [float(x) for x in np.asarray(theta0_vec, dtype=float)],
            "bounds": _serialise_bounds(bounds),
        }
        meta = {
            "storage_root": str(self.workspace.persist_root),
            "runtime_root": str(self.workspace.workdir),
            "runner_jobs": runner_jobs,
        }
        try:
            self._client.run_started(
                parameters=parameters,
                cases=self.case_descriptions,
                optimizer={"adapter": optimizer_name},
                meta={k: v for k, v in meta.items() if v is not None},
            )
        except Exception:
            self.logger.exception("Failed to emit monitor start event.")

    def record_iteration(
        self,
        index: int,
        phi_vec: Array,
        theta_vec: Array,
        cost: float,
        metrics: Mapping[str, Any],
        series: Mapping[str, dict[str, Any]],
        log_output: bool,
    ) -> None:
        """Send iteration payload to the monitoring service."""
        if self._client is None:
            return
        _ = (phi_vec, log_output)
        try:
            self._client.record_iteration(
                index=index,
                cost=float(cost),
                theta={
                    name: float(val)
                    for name, val in zip(self.parameter_space.names, theta_vec, strict=True)
                },
                metrics=_sanitize_metrics(metrics),  # fixed metrics (NRMSE, R²)
                series=series,
            )
        except Exception:
            self.logger.exception("Failed to emit monitor iteration event.")

    def completed(
        self,
        phi_vec: Array | None,
        theta_opt: Mapping[str, float],
        optimizer_meta: Mapping[str, Any],
        metrics: Mapping[str, Any],
    ) -> None:
        """Send completion summary to the monitoring service."""
        if self._client is None:
            return
        _ = phi_vec
        try:
            summary = {
                "theta_opt": {name: float(val) for name, val in theta_opt.items()},
                **_sanitize_metrics(metrics),
                "optimizer": _simplify_meta(dict(optimizer_meta or {})),
            }
            self._client.run_completed(summary=summary)
        except Exception:
            self.logger.exception("Failed to emit monitor completion event.")

    def failed(self, reason: str) -> None:
        """Notify the monitoring service of a failure."""
        if self._client is None:
            return
        try:
            self._client.run_failed(reason=reason)
        except Exception:
            self.logger.exception("Failed to emit monitor failure event.")

    def close(self) -> None:
        """Detach the monitoring client."""
        self._client = None


# ---- shared helpers ---------------------------------------------------------


def _default_label(workspace: StorageWorkspace) -> str:
    for candidate in (
        workspace.persist_root.name,
        workspace.workdir.name,
    ):
        if candidate:
            return candidate
    return datetime.now().strftime("run-%Y%m%d%H%M%S")


def _sanitize_metrics(metrics: Mapping[str, Any]) -> dict[str, Any]:
    """Retain only fixed metrics (NRMSE, R²) in a JSON-safe shape.

    Returns:
        Dict with ``nrmse`` and ``r_squared`` entries suitable for JSON.
    """
    clean: dict[str, Any] = {}
    nrmse = metrics.get("nrmse")
    if isinstance(nrmse, (int, float)) and np.isfinite(nrmse):
        clean["nrmse"] = float(nrmse)
    r_squared = metrics.get("r_squared")
    if isinstance(r_squared, Mapping):
        clean["r_squared"] = {
            key: (
                float(value) if isinstance(value, (int, float)) and np.isfinite(value) else None
            )
            for key, value in r_squared.items()
        }
    return clean


def _simplify_meta(data: Mapping[str, Any]) -> dict[str, Any]:
    return {str(key): _simplify_value(value) for key, value in data.items()}


def _simplify_value(value: Any) -> Any:
    if isinstance(value, (int, float, str, bool)) or value is None:
        return value
    if isinstance(value, Mapping):
        limited_items = list(value.items())[:16]
        return {str(k): _simplify_value(v) for k, v in limited_items}
    if isinstance(value, (list, tuple, set)):
        limited = list(value)[:16]
        return [_simplify_value(item) for item in limited]
    if hasattr(value, "tolist"):
        arr = np.asarray(value)
        if arr.ndim == 0:
            scalar = arr.item()
            if isinstance(scalar, (int, float)):
                return float(scalar)
            return _simplify_value(scalar)
        flat = arr.reshape(-1)
        limited = flat[:16].tolist()
        return [float(x) for x in limited]
    return str(value)


def _serialise_bounds(bounds: BoundsPayload) -> list[tuple[float, float]]:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        lo_arr, hi_arr = bounds
        lo_vec = np.asarray(lo_arr, dtype=float).reshape(-1)
        hi_vec = np.asarray(hi_arr, dtype=float).reshape(-1)
        return list(zip(lo_vec.tolist(), hi_vec.tolist(), strict=True))
    serialised: list[tuple[float, float]] = []
    seq_bounds = cast(Sequence[tuple[float, float]], bounds)
    for pair in seq_bounds:
        try:
            lo, hi = pair
        except Exception:
            continue
        serialised.append((float(lo), float(hi)))
    return serialised


# Backwards-compatible wrapper for existing callers (deprecated).
class RunReporter(CompositeReporter):
    """Deprecated wrapper preserving legacy RunReporter API."""

    def __init__(
        self,
        logger: LoggerProtocol,
        parameter_space: ParameterSpace,
        case_descriptions: list[Mapping[str, Any]],
        monitor_opts: MonitorOptions,
        workspace: StorageWorkspace,
        *,
        reparam_enabled: bool,
    ) -> None:
        """Compose console and monitor reporters."""
        console = ConsoleReporter(
            logger,
            parameter_space,
            case_descriptions,
            workspace,
            reparam_enabled=reparam_enabled,
        )
        monitor = MonitorReporter(
            monitor_opts,
            parameter_space,
            workspace,
            case_descriptions,
            logger=logger,
        )
        reporters: list[Reporter] = [console]
        if monitor.enabled:
            reporters.append(monitor)
        super().__init__(reporters, logger=logger)
        self.console = console
        self.monitor = monitor

    def log_banner(self) -> None:
        """Delegate banner logging to the console reporter."""
        self.console.log_banner()

    def log_configuration(
        self,
        *,
        options: EngineOptions,
        runner_command: tuple[str, ...],
        runner_env: Mapping[str, str] | None,
        optimizer_adapter: str,
    ) -> None:
        """Delegate configuration logging to the console reporter."""
        self.console.log_configuration(
            options=options,
            runner_command=runner_command,
            runner_env=runner_env,
            optimizer_adapter=optimizer_adapter,
        )


__all__ = [
    "CompositeReporter",
    "ConsoleReporter",
    "MonitorReporter",
    "NullReporter",
    "Reporter",
    "RunReporter",
]
