import sys
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from limen import (
    Constant,
    Atom,
    ChainTemplate,
    InductionConfig,
    KnowledgeBase,
    LabelSet,
    Predicate,
    TruthAssignment,
    run_induction_and_update_kb,
)


def _relation_data() -> Dict[str, List[Tuple[str, str, float]]]:
    return {
        "failedLogin": [
            ("alice", "srv3", 0.95),
            ("bob", "srv4", 0.35),
        ],
        "lateralMove": [
            ("srv3", "srv6", 0.92),
            ("srv4", "srv6", 0.25),
        ],
    }


def _build_assignment(kb: KnowledgeBase, relations: Dict[str, List[Tuple[str, str, float]]]) -> TruthAssignment:
    assignment = TruthAssignment()
    for pred, facts in relations.items():
        predicate = kb.get_predicate(pred)
        for left, right, value in facts:
            atom = Atom(predicate, (kb.get_constant(left), kb.get_constant(right)))
            assignment.set(atom, value)
    return assignment


def test_chain_template_induces_clause():
    kb = KnowledgeBase()
    relations = _relation_data()
    consts = {value for facts in relations.values() for pair in facts for value in pair[:2]}
    consts.update({"srv6"})
    for name in sorted(consts):
        kb.add_constant(Constant(name))
    failed = Predicate("failedLogin", 2)
    lateral = Predicate("lateralMove", 2)
    alert = Predicate("escalates", 2)
    for predicate in (failed, lateral, alert):
        kb.add_predicate(predicate)

    assignment = _build_assignment(kb, relations)

    labels = {
        "escalates": LabelSet(
            positives=[("alice", "srv6")],
            negatives=[("bob", "srv6")],
        )
    }

    config = InductionConfig(train_steps=50, learning_rate=0.1, min_strength=0.5, min_positive_margin=0.1)
    run_induction_and_update_kb(
        kb,
        labels,
        assignment=assignment,
        templates=[ChainTemplate()],
        config={
            "train_steps": config.train_steps,
            "learning_rate": config.learning_rate,
            "min_strength": config.min_strength,
            "min_positive_margin": config.min_positive_margin,
        },
    )

    assert kb.induced_clauses, "Expected at least one induced clause"
    induced_bodies = {clause.body for clause in kb.induced_clauses if clause.head == "escalates"}
    assert ("failedLogin", "lateralMove") in induced_bodies


def test_update_labels_auto_induce_uses_assignment():
    kb = KnowledgeBase()
    relations = _relation_data()
    consts = {value for facts in relations.values() for pair in facts for value in pair[:2]}
    consts.update({"srv6"})
    for name in sorted(consts):
        kb.add_constant(Constant(name))
    failed = Predicate("failedLogin", 2)
    lateral = Predicate("lateralMove", 2)
    alert = Predicate("escalates", 2)
    for predicate in (failed, lateral, alert):
        kb.add_predicate(predicate)

    assignment = _build_assignment(kb, relations)

    kb.update_labels(
        {"escalates": {"pos": [("alice", "srv6")], "neg": [("bob", "srv6")]}},
        assignment=assignment,
        auto_induce=True,
        induction_config={
            "train_steps": 40,
            "learning_rate": 0.1,
            "min_strength": 0.5,
            "min_positive_margin": 0.1,
        },
    )

    induced = [clause for clause in kb.induced_clauses if clause.head == "escalates"]
    assert induced, "Expected induction to run via update_labels"

