import torch

from limen import (
    Atom,
    Constant,
    FormulaNode,
    KnowledgeBase,
    Operator,
    Predicate,
    WeightedFormula,
)
from limen.training import TorchFormulaEvaluator, TruthFunctionTrainer
from limen.truth_functions import LinearFeatureTruthFunction, differentiable_truth_function


def test_truth_function_trainer_reduces_energy():
    kb = KnowledgeBase()
    predicate = Predicate("Signal", 0)
    kb.add_predicate(predicate)
    atom = Atom(predicate, tuple())
    kb.add_formula(WeightedFormula(FormulaNode.atom_node(atom), 2.0, name="signal_rule"))

    feature_map = lambda args: [1.0]
    module = LinearFeatureTruthFunction(feature_map, feature_dim=1)
    kb.register_truth_function("Signal", differentiable_truth_function(module))

    optimizer = torch.optim.SGD(module.parameters(), lr=1.0)
    trainer = TruthFunctionTrainer(kb, {"Signal": module}, optimizer)

    initial_energy = trainer.current_energy()
    history = trainer.train(steps=8)

    assert history[-1] < initial_energy


def test_torch_formula_evaluator_matches_lukasiewicz_and():
    kb = KnowledgeBase()
    predicate = Predicate("Score", 1)
    kb.add_predicate(predicate)
    constant_a = Constant("a")
    constant_b = Constant("b")

    atom_a = Atom(predicate, (constant_a,))
    atom_b = Atom(predicate, (constant_b,))

    formula = FormulaNode(
        operator=Operator.AND,
        children=(FormulaNode.atom_node(atom_a), FormulaNode.atom_node(atom_b)),
    )
    weighted = WeightedFormula(formula, 1.0, name="and_rule")

    def feature_map(args):
        arg = args[0].name
        return [1.0] if arg == "a" else [-1.0]

    module = LinearFeatureTruthFunction(
        feature_map=feature_map,
        feature_dim=1,
        weights=torch.tensor([2.0]),
        bias=torch.tensor([0.0]),
    )

    evaluator = TorchFormulaEvaluator({"Score": module})

    value = evaluator(weighted)

    with torch.no_grad():
        val_a = module((constant_a,))
        val_b = module((constant_b,))
        expected = torch.clamp(val_a + val_b - 1.0, min=0.0, max=1.0)

    assert torch.allclose(value, expected)

