"""Unified optimisation engine for FEBio parameter fitting."""

from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, cast

import numpy as np
from numpy.typing import NDArray

from ..Log import Log
from .cases import CaseEvaluator, CaseJob, SimulationCase
from .jacobian import JacobianComputer
from .optimizers import BoundsLike, OptimizerAdapter
from .options import EngineOptions
from .parameters import ParameterMapper, ParameterSpace
from .reporting import (
    CompositeReporter,
    ConsoleReporter,
    LoggerProtocol,
    MonitorReporter,
    NullReporter,
    Reporter,
)
from .runners import LocalParallelRunner, LocalSerialRunner, Runner
from .storage import StorageWorkspace

Array = NDArray[np.float64]


@dataclass
class OptimizeResult:
    """Final state of an optimisation run."""

    phi: Array
    theta: dict[str, float]
    metadata: dict[str, Any]


@dataclass
class IterationState:
    """Track iteration bookkeeping and cached evaluations."""

    progress_index: int = 0
    pending_initial_log: bool = True
    log_progress: bool = True
    last_phi: Array | None = None
    last_theta_vec: Array | None = None
    last_residual: Array | None = None
    last_iter_dir: Path | None = None
    last_metrics: dict[str, Any] = field(default_factory=dict)
    series_latest: dict[str, dict[str, Any]] = field(default_factory=dict)
    cached_jac_phi: Array | None = None
    cached_jacobian: Array | None = None

    def reset(self, *, log_progress: bool) -> None:
        """Clear state between optimisation runs."""
        self.progress_index = 0
        self.pending_initial_log = True
        self.log_progress = log_progress
        self.last_phi = None
        self.last_theta_vec = None
        self.last_residual = None
        self.last_iter_dir = None
        self.last_metrics = {}
        self.series_latest = {}
        self.cached_jac_phi = None
        self.cached_jacobian = None

    def cache_evaluation(
        self,
        phi_vec: Array,
        theta_vec: Array,
        residual: Array,
        iter_dir: Path,
        metrics: Mapping[str, Any],
        series: Mapping[str, dict[str, Any]],
    ) -> None:
        """Persist the latest evaluation payload."""
        self.last_phi = phi_vec.copy()
        self.last_theta_vec = theta_vec
        self.last_residual = residual.copy()
        self.last_iter_dir = iter_dir
        self.last_metrics = dict(metrics)
        self.series_latest = {k: dict(v) for k, v in series.items()}
        self.cached_jac_phi = None
        self.cached_jacobian = None

    def next_index(self) -> int:
        """Return and increment the iteration index."""
        idx = self.progress_index
        self.progress_index += 1
        return idx

    def cache_jacobian(self, phi_vec: Array, J: Array) -> None:
        """Store a Jacobian associated with a specific phi vector."""
        self.cached_jac_phi = phi_vec.copy()
        self.cached_jacobian = J.copy()

    def cached_jac(self, phi_vec: Array) -> Array | None:
        """Return a cached Jacobian matching ``phi_vec`` if available."""
        if (
            self.cached_jac_phi is not None
            and self.cached_jacobian is not None
            and np.array_equal(phi_vec, self.cached_jac_phi)
        ):
            return self.cached_jacobian.copy()
        return None


class JacobianHelper:
    """Handle Jacobian scheduling/finalisation to keep Engine slim."""

    def __init__(
        self,
        jacobian: JacobianComputer,
        case_evaluator: CaseEvaluator,
        mapper: ParameterMapper,
        workspace: StorageWorkspace,
        param_names: Sequence[str],
    ):
        """Coordinate Jacobian evaluations, optionally in parallel."""
        self.jacobian = jacobian
        self.case_evaluator = case_evaluator
        self.mapper = mapper
        self.workspace = workspace
        self.param_names = list(param_names)

    def compute(self, phi_vec: Array, state: IterationState) -> Array:
        """Return a forward-difference Jacobian for ``phi_vec``."""
        phi_vec = cast(Array, np.asarray(phi_vec, dtype=float))
        base_residual = state.last_residual
        iter_dir = state.last_iter_dir or self.workspace.next_iter_dir()
        if base_residual is None:
            raise RuntimeError("Residuals must be evaluated before computing the Jacobian.")

        if self.jacobian.parallel:
            cached = state.cached_jac(phi_vec)
            if cached is not None:
                return cached
            jobs = self._schedule_jobs(phi_vec, iter_dir)
            J = self._finalize_jobs(jobs, base_residual)
            state.cache_jacobian(phi_vec, J)
            return J

        def residual_with_label(theta_vec: Array, lbl: str | None) -> Array:
            theta_vec = cast(Array, np.asarray(theta_vec, dtype=float))
            theta_vec = self.mapper.clamp_theta(theta_vec)
            theta_dict = self.mapper.theta_dict(theta_vec)
            result = self.case_evaluator.evaluate(
                theta_dict, iter_dir, label=lbl, track_series=False
            )
            return result.residual

        label_fn = self._label_fn()
        _, J = self.jacobian.compute(
            phi_vec,
            self.mapper.phi_to_theta,
            residual_with_label,
            label_fn=label_fn,
            base_residual=base_residual,
        )
        return J

    def _schedule_jobs(
        self,
        phi_vec: Array,
        iter_dir: Path,
    ) -> dict[int, list[CaseJob]]:
        """Kick off perturbed simulations for each parameter column.

        Returns:
            Mapping of column index to launched case jobs.
        """
        jobs_by_index: dict[int, list[CaseJob]] = {}
        for idx in range(len(self.param_names)):
            phi = phi_vec.copy()
            phi[idx] += float(self.jacobian.perturbation)
            theta_vec = self.mapper.phi_to_theta(phi)
            theta_dict = self.mapper.theta_dict(theta_vec)
            label = self._label_fn()(idx)
            jobs_by_index[idx] = self.case_evaluator.case_runner.launch_cases(
                theta_dict, iter_dir, label
            )
        return jobs_by_index

    def _finalize_jobs(
        self,
        jobs_by_index: dict[int, list[CaseJob]],
        base_residual: Array,
    ) -> Array:
        """Collect results from scheduled jobs and form the Jacobian matrix.

        Returns:
            Dense Jacobian matrix aligned with ``param_names`` order.
        """
        J = cast(Array, np.zeros((base_residual.size, len(self.param_names)), dtype=float))
        for idx, jobs in jobs_by_index.items():
            result = self.case_evaluator.case_runner.finalize_cases(
                jobs,
                self.case_evaluator.residual_assembler,
                self.case_evaluator.preparer,
            )
            residual_arrays, _ = result
            residual = (
                cast(Array, np.concatenate(residual_arrays))
                if residual_arrays
                else cast(Array, np.array([], dtype=float))
            )
            if residual.shape != base_residual.shape:
                raise RuntimeError("Residual size mismatch while forming the Jacobian.")
            J[:, idx] = (residual - base_residual) / float(self.jacobian.perturbation)
        return J

    def _label_fn(self) -> Callable[[int], str | None]:
        """Provide a label formatter for perturbed runs.

        Returns:
            Callable that maps perturbation column to suffix strings.
        """
        names = self.param_names

        def label(idx: int) -> str | None:
            if idx < 0:
                return "_base"
            if idx < len(names):
                return f"_{names[idx]}"
            return f"_col_{idx}"

        return label


class Engine:
    """Coordinate FEBio simulations and optimisation loops."""

    def __init__(
        self,
        parameter_space: ParameterSpace,
        cases: Sequence[SimulationCase],
        *,
        options: EngineOptions | None = None,
    ) -> None:
        """Initialise the engine, wiring runner, reporter, mapper, and helpers."""
        if not cases:
            raise ValueError("At least one SimulationCase is required.")

        self.options = options = options or EngineOptions()
        self.parameter_space = parameter_space
        self._param_names = list(self.parameter_space.names)
        optimizer_opts = options.optimizer
        jacobian_opts = options.jacobian
        runner_opts = options.runner
        storage_opts = options.storage
        cleanup_opts = options.cleanup
        monitor_opts = options.monitor
        self._reparam_enabled = bool(optimizer_opts.reparametrize)

        self.workspace = StorageWorkspace(storage_opts, cleanup_opts)
        self.workdir = self.workspace.workdir
        self.persist_root = self.workspace.persist_root
        log_instance = Log(log_file=self.workspace.log_file)
        self._logger: LoggerProtocol = log_instance.logger

        self.jacobian: JacobianComputer | None = (
            JacobianComputer(
                perturbation=float(jacobian_opts.perturbation),
                parallel=bool(jacobian_opts.parallel),
            )
            if jacobian_opts.enabled
            else None
        )

        runner_jobs = int(max(1, runner_opts.jobs))
        runner_command = tuple(runner_opts.command or ("febio4", "-i"))
        runner_env = dict(runner_opts.env) if runner_opts.env is not None else None
        self._runner_jobs = runner_jobs
        if runner_jobs <= 1:
            self.runner: Runner = LocalSerialRunner(command=runner_command, env=runner_env)
        else:
            self.runner = LocalParallelRunner(
                n_jobs=self._runner_jobs,
                command=runner_command,
                env=runner_env,
            )

        self.case_evaluator = CaseEvaluator(cases, self.runner, self._logger)
        self.parameter_mapper = ParameterMapper(parameter_space, self._reparam_enabled)
        self._reparam_enabled = self.parameter_mapper.reparam_enabled
        case_descriptions = self.case_evaluator.describe_cases()
        console_reporter = ConsoleReporter(
            logger=self._logger,
            parameter_space=self.parameter_space,
            case_descriptions=case_descriptions,
            workspace=self.workspace,
            reparam_enabled=self._reparam_enabled,
        )
        monitor_reporter = MonitorReporter(
            monitor_opts,
            self.parameter_space,
            self.workspace,
            case_descriptions,
            logger=self._logger,
        )
        reporters: list[Reporter] = [console_reporter]
        if monitor_reporter.enabled:
            reporters.append(monitor_reporter)
        self.reporter: Reporter = (
            CompositeReporter(reporters, logger=self._logger) if reporters else NullReporter()
        )
        console_reporter.log_banner()
        self.optimizer_adapter = OptimizerAdapter.build(
            optimizer_opts.name, optimizer_opts.settings
        )

        console_reporter.log_configuration(
            options=options,
            runner_command=runner_command,
            runner_env=runner_env,
            optimizer_adapter=type(self.optimizer_adapter).__name__,
        )

        self.state = IterationState()
        self.jac_helper = (
            JacobianHelper(
                self.jacobian,
                self.case_evaluator,
                self.parameter_mapper,
                self.workspace,
                self._param_names,
            )
            if self.jacobian is not None
            else None
        )

    def run(
        self,
        *,
        phi0: Sequence[float] | None = None,
        bounds: Sequence[tuple[float, float]] | None = None,
        verbose: bool = True,
        callbacks: Iterable[Callable[[Array, float], None]] | None = None,
    ) -> OptimizeResult:
        """Execute the optimisation loop and return the best solution.

        Returns:
            OptimizeResult with optimal φ/θ and optimiser metadata.
        """
        try:
            self.state.reset(log_progress=bool(verbose))

            phi0_vec = self.parameter_mapper.initial_phi(phi0)
            bounds_input: BoundsLike = self.parameter_mapper.bounds(bounds)
            theta0_vec = self.parameter_mapper.phi_to_theta(phi0_vec)
            self.reporter.run_started(
                phi0_vec,
                theta0_vec,
                bounds_input,
                type(self.optimizer_adapter).__name__,
                getattr(self, "_runner_jobs", None),
            )

            def residual_phi(phi_vec: Array) -> Array:
                residual = self._evaluate_residual(phi_vec)
                if self.state.pending_initial_log:
                    initial_cost = 0.5 * float(np.dot(residual, residual))
                    self._record_iteration(
                        phi_vec, initial_cost, log_output=self.state.log_progress
                    )
                    self.state.pending_initial_log = False
                return residual

            callback_list: list[Callable[[Array, float], None]] = list(callbacks or [])
            if verbose:

                def _callback(phi_vec: Array, cost: float) -> None:
                    self._record_iteration(phi_vec, cost, log_output=True)

                callback_list.append(_callback)

            jacobian_fn = None
            if self.jac_helper is not None:
                jac_helper = self.jac_helper

                def _jacobian(phi: Array) -> Array:
                    return jac_helper.compute(phi, self.state)

                jacobian_fn = _jacobian

            phi_opt, meta = self.optimizer_adapter.minimize(
                residual_phi,
                jacobian_fn,
                phi0_vec,
                bounds_input,
                callback_list,
            )

            phi_opt_array: Array = np.asarray(phi_opt, dtype=float)
            theta_opt_vec = self.parameter_mapper.phi_to_theta(phi_opt_array)
            theta_opt = self.parameter_mapper.theta_dict(theta_opt_vec)
            self.reporter.completed(
                phi_opt_array, theta_opt, meta, self.state.last_metrics or {}
            )
            try:
                self.workspace.write_series(self.state.series_latest)
            except Exception:
                self._logger.exception("Failed to write series outputs.")
            return OptimizeResult(phi=phi_opt_array, theta=theta_opt, metadata=meta)
        except KeyboardInterrupt:
            self.reporter.failed("interrupted")
            raise
        except Exception as exc:
            self.reporter.failed(f"{exc.__class__.__name__}: {exc}")
            raise
        finally:
            try:
                self.workspace.final_cleanup(self.state.last_iter_dir)
            finally:
                self.reporter.close()
                self.close()

    def _record_iteration(self, phi_vec: Array, cost: float, *, log_output: bool) -> None:
        """Forward iteration data to the reporter and advance counters."""
        theta_vec = self.state.last_theta_vec
        if theta_vec is None:
            theta_vec = self.parameter_mapper.phi_to_theta(phi_vec)
        metrics = self.state.last_metrics or {}
        series = self.state.series_latest
        self.reporter.record_iteration(
            index=self.state.progress_index,
            phi_vec=phi_vec,
            theta_vec=theta_vec,
            cost=cost,
            metrics=metrics,
            series=series,
            log_output=log_output,
        )
        self.state.next_index()
        self.workspace.prune_old_iterations(self.state.last_iter_dir)

    def _evaluate_residual(self, phi_vec: Array) -> Array:
        """Compute residuals for ``phi_vec``, reusing cached values when possible.

        Returns:
            Residual vector for the supplied φ parameters.
        """
        phi_vec = np.asarray(phi_vec, dtype=float)
        if (
            self.state.last_phi is not None
            and self.state.last_residual is not None
            and np.array_equal(phi_vec, self.state.last_phi)
        ):
            return self.state.last_residual.copy()
        theta_vec = self.parameter_mapper.phi_to_theta(phi_vec)
        theta_dict = self.parameter_mapper.theta_dict(theta_vec)
        iter_dir = self.workspace.next_iter_dir()
        result = self.case_evaluator.evaluate(
            theta_dict, iter_dir, label="_base", track_series=True
        )
        self.state.cache_evaluation(
            phi_vec,
            theta_vec,
            result.residual,
            iter_dir,
            result.metrics,
            result.series,
        )
        return result.residual

    def close(self) -> None:
        """Cleanly stop the runner."""
        try:
            self.runner.shutdown()
        except Exception:
            pass
