import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from limen.core import Atom, Constant, FormulaNode, Predicate, TruthAssignment, Operator
from limen.semantics import (
    evaluate_formula,
    evaluate_formula_batch,
    evaluate_formula_torch,
    lukasiewicz_and,
    lukasiewicz_implication,
    lukasiewicz_not,
    lukasiewicz_or,
)


def test_basic_operators():
    assert lukasiewicz_and(0.9, 0.7) == max(0.0, min(1.0, 0.9 + 0.7 - 1.0))
    assert lukasiewicz_or(0.2, 0.5) == min(1.0, 0.2 + 0.5)
    assert lukasiewicz_not(0.3) == 0.7
    assert lukasiewicz_implication(0.8, 0.4) == min(1.0, 1.0 - 0.8 + 0.4)


def test_evaluate_formula_atom_and_const():
    user = Constant("alice")
    suspicious = Predicate(name="suspicious", arity=1)
    atom = Atom(predicate=suspicious, arguments=(user,))
    node = FormulaNode.atom_node(atom)

    assignment = TruthAssignment({atom: 0.6})
    assert evaluate_formula(node, assignment) == 0.6

    const_node = FormulaNode.constant_node(0.8)
    assert evaluate_formula(const_node, assignment) == 0.8


def test_evaluate_formula_composed():
    user = Constant("alice")
    failed = Predicate(name="failedLogin", arity=1)
    alert = Predicate(name="raiseAlert", arity=1)

    atom_failed = Atom(predicate=failed, arguments=(user,))
    atom_alert = Atom(predicate=alert, arguments=(user,))

    conjunction = FormulaNode(operator=Operator.AND, children=(
        FormulaNode.atom_node(atom_failed),
        FormulaNode.constant_node(0.7),
    ))
    implication = FormulaNode(operator=Operator.IMPLIES, children=(
        conjunction,
        FormulaNode.atom_node(atom_alert),
    ))

    assignment = TruthAssignment({atom_failed: 0.9, atom_alert: 0.3})
    value = evaluate_formula(implication, assignment)
    # manual calculation: conjunction = max(0, min(1, 0.9 + 0.7 - 1)) = 0.6
    # implication = min(1, 1 - 0.6 + 0.3) = 0.7
    assert abs(value - 0.7) < 1e-9


def test_evaluate_formula_batch_matches_scalar():
    predicate = Predicate(name="flag", arity=0)
    atom = Atom(predicate=predicate, arguments=tuple())
    node = FormulaNode.atom_node(atom)
    assignment = TruthAssignment({atom: 0.4})
    batch = evaluate_formula_batch(node, [assignment, assignment])
    assert batch == [0.4, 0.4]


def test_evaluate_formula_torch_supports_gradients():
    predicate = Predicate(name="sensor", arity=0)
    atom = Atom(predicate=predicate, arguments=tuple())
    node = FormulaNode.atom_node(atom)
    value = torch.tensor(0.7, requires_grad=True)

    result = evaluate_formula_torch(node, lambda _: value)
    result.backward()

    assert torch.allclose(result, value)
    assert torch.allclose(value.grad, torch.ones_like(value))
