import math
import random

import pytest

from limen import (
    Atom,
    Constant,
    ConstantTruthFunction,
    FormulaNode,
    KnowledgeBase,
    Predicate,
    TruthAssignment,
    WeightedFormula,
)
from limen.inference import ImportanceSampler, PowerSampler
from limen.sampling import (
    make_score_guided_proposal,
    make_tempered_proposal,
    make_uniform_proposal,
)


def _build_kb(value: float = 0.6, weight: float = 1.5):
    kb = KnowledgeBase()
    predicate = Predicate("Alert", 0)
    kb.add_predicate(predicate)
    atom = Atom(predicate, tuple())
    formula = FormulaNode.atom_node(atom)
    kb.add_formula(WeightedFormula(formula, weight, name="alert_rule"))
    assignment = TruthAssignment()
    assignment.set(atom, value)
    return kb, atom, assignment


def test_importance_sampler_traces_capture_weights_and_activations():
    kb, atom, assignment = _build_kb()
    sampler = ImportanceSampler(kb, lambda: assignment)

    estimate, traces = sampler.estimate(lambda a: a.get(atom), num_samples=4)

    assert pytest.approx(estimate, rel=1e-5) == assignment.get(atom)
    assert len(traces) == 4
    trace = traces[0]
    assert pytest.approx(trace.weight, rel=1e-5) == 0.9  # 1.5 * 0.6
    assert pytest.approx(trace.activations["alert_rule"], rel=1e-5) == 0.6


def test_importance_sampler_log_proposal_adjusts_weights():
    kb, _, assignment = _build_kb(value=0.4, weight=2.0)
    raw_energy = 0.8
    sampler = ImportanceSampler(kb, lambda: assignment, log_proposal_prob=lambda _: math.log(0.5))

    traces = sampler.draw(2)

    assert pytest.approx(traces[0].weight, rel=1e-6) == raw_energy * 2.0


def test_power_sampler_applies_exponent_to_weights():
    kb, atom, assignment = _build_kb(value=0.7, weight=1.0)
    sampler = PowerSampler(kb, lambda: assignment, exponent=0.25)

    traces = sampler.draw(3)

    expected_weight = (0.7) ** 0.25
    assert pytest.approx(traces[0].weight, rel=1e-6) == expected_weight
    assert pytest.approx(traces[0].activations["alert_rule"], rel=1e-6) == assignment.get(atom)


def test_importance_sampler_defaults_to_uniform_proposal():
    kb, atom, _ = _build_kb()
    sampler = ImportanceSampler(kb)

    traces = sampler.draw(1)
    sample_value = traces[0].assignment.get(atom)
    assert 0.0 <= sample_value <= 1.0


def test_uniform_proposal_stays_within_bounds():
    kb, atom, _ = _build_kb()
    proposal = make_uniform_proposal(kb, rng=random.Random(0))
    assignment = proposal()
    assert 0.0 <= assignment.get(atom) <= 1.0


def test_tempered_proposal_tracks_reference():
    kb, atom, assignment = _build_kb(value=0.8)
    proposal = make_tempered_proposal(kb, reference=assignment, temperature=0.0, rng=random.Random(1))
    sample = proposal()
    assert pytest.approx(sample.get(atom), rel=1e-9) == 0.8


def test_score_guided_proposal_uses_truth_functions():
    kb, atom, _ = _build_kb()
    kb.register_truth_function("Alert", ConstantTruthFunction(0.95))
    proposal = make_score_guided_proposal(kb, temperature=0.0, rng=random.Random(2))
    sample = proposal()
    assert pytest.approx(sample.get(atom), rel=1e-9) == 0.95

