import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from limen import Atom, FormulaNode, KnowledgeBase, Predicate, WeightedFormula
from limen.energy import compute_energy_torch


def test_compute_energy_torch_gradient_matches_weight():
    kb = KnowledgeBase()
    predicate = Predicate("risk", 0)
    kb.add_predicate(predicate)
    atom = Atom(predicate, tuple())
    kb.add_formula(WeightedFormula(FormulaNode.atom_node(atom), 1.7, name="risk_rule"))

    value = torch.tensor(0.6, requires_grad=True)

    def evaluator(wf):
        if wf.formula.operator == "atom":
            return value
        return torch.tensor(0.0)

    energy = compute_energy_torch(kb, evaluator)
    energy.backward()
    assert torch.allclose(value.grad, torch.tensor(1.7))

