import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from limen.core import Atom, Constant, FormulaNode, KnowledgeBase, Predicate, WeightedFormula
from limen.truth_functions import ConstantTruthFunction, LinearFeatureTruthFunction


def test_constant_truth_function_integration():
    kb = KnowledgeBase()
    pred = Predicate(name="failedLogin", arity=1)
    user = Constant("alice")
    kb.add_predicate(pred)
    kb.add_constant(user)

    atom = Atom(predicate=pred, arguments=(user,))
    kb.add_formula(WeightedFormula(formula=FormulaNode.atom_node(atom), weight=1.0))

    kb.register_truth_function("failedLogin", ConstantTruthFunction(0.8))
    assignment = kb.build_assignment_from_truth_functions()
    assert assignment.get(atom) == 0.8


def test_linear_truth_function_sigmoid():
    def feature_map(args):
        return [float(len(args[0].name))]

    tf = LinearFeatureTruthFunction(
        feature_map=feature_map,
        weights=torch.tensor([0.5]),
        bias=torch.tensor(0.1),
    )
    value = tf((Constant("alice"),)).item()
    assert 0.0 <= value <= 1.0
