"""Sampling-based inference with explanation traces for LIMEN-AI."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple

import torch

from .core import KnowledgeBase, TruthAssignment, WeightedFormula
from .energy import compute_energy, compute_energy_torch
from .sampling import make_uniform_proposal, rule_activation_trace


@dataclass
class SampleTrace:
    assignment: TruthAssignment
    weight: float
    activations: Dict[str, float]


class ImportanceSampler:
    """Generic importance sampler that records explanation traces."""

    def __init__(
        self,
        kb: KnowledgeBase,
        proposal: Optional[Callable[[], TruthAssignment]] = None,
        log_proposal_prob: Optional[Callable[[TruthAssignment], float]] = None,
    ) -> None:
        self.kb = kb
        self.proposal = proposal or make_uniform_proposal(kb)
        self.log_proposal_prob = log_proposal_prob

    def draw(self, num_samples: int) -> List[SampleTrace]:
        traces: List[SampleTrace] = []
        for _ in range(num_samples):
            assignment = self.proposal()
            raw_weight = compute_energy(self.kb, assignment)
            weight = raw_weight
            if self.log_proposal_prob is not None:
                log_q = self.log_proposal_prob(assignment)
                weight = raw_weight * torch.exp(-torch.tensor(log_q)).item()
            activations = self._rule_activations(assignment)
            traces.append(SampleTrace(assignment, weight, activations))
        return traces

    def estimate(
        self,
        evaluator: Callable[[TruthAssignment], float],
        num_samples: int,
    ) -> Tuple[float, List[SampleTrace]]:
        traces = self.draw(num_samples)
        numerator = sum(t.weight * evaluator(t.assignment) for t in traces)
        denominator = sum(t.weight for t in traces)
        estimate = numerator / denominator if denominator else 0.0
        return estimate, traces

    def _rule_activations(self, assignment: TruthAssignment) -> Dict[str, float]:
        return rule_activation_trace(self.kb, assignment)


class PowerSampler(ImportanceSampler):
    """Importance sampler that raises weights to a configurable exponent."""

    def __init__(
        self,
        kb: KnowledgeBase,
        proposal: Callable[[], TruthAssignment],
        exponent: float = 0.5,
    ) -> None:
        super().__init__(kb, proposal)
        self.exponent = exponent

    def draw(self, num_samples: int) -> List[SampleTrace]:
        traces: List[SampleTrace] = []
        for _ in range(num_samples):
            assignment = self.proposal()
            weight = compute_energy(self.kb, assignment) ** self.exponent
            activations = self._rule_activations(assignment)
            traces.append(SampleTrace(assignment, weight, activations))
        return traces


class TorchEnergyWrapper(torch.nn.Module):
    """Module wrapper so optimisers can backpropagate through the energy."""

    def __init__(
        self,
        kb: KnowledgeBase,
        evaluate_formula_torch: Callable[[WeightedFormula], torch.Tensor],
    ) -> None:
        super().__init__()
        self.kb = kb
        self.evaluate_formula_torch = evaluate_formula_torch

    def forward(self) -> torch.Tensor:
        return compute_energy_torch(self.kb, self.evaluate_formula_torch)
