"""Document ingestion pipeline that feeds facts into LIMEN-AI via an LLM."""

from __future__ import annotations

import json
import re
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple

from ..core import Atom, Constant, KnowledgeBase, TruthAssignment
from .llm_client import LLMClient
from .prompts import build_extraction_prompt
from .schema import SchemaRegistry


@dataclass
class ParsedFact:
    predicate: str
    args: Tuple[str, ...]
    confidence: float = 1.0
    provenance: Optional[str] = None


@dataclass
class IngestionResult:
    facts_added: List[ParsedFact] = field(default_factory=list)
    rejected_chunks: List[str] = field(default_factory=list)
    errors: List[str] = field(default_factory=list)


def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 100) -> List[str]:
    tokens = text.split()
    chunks: List[str] = []
    i = 0
    while i < len(tokens):
        chunk = tokens[i : i + chunk_size]
        chunks.append(" ".join(chunk))
        i += max(chunk_size - overlap, 1)
    return chunks


class DocumentIngestionPipeline:
    """Coordinates chunking, prompting, parsing, and KB insertion."""

    def __init__(
        self,
        registry: SchemaRegistry,
        llm_client: LLMClient,
        chunk_size: int = 800,
        overlap: int = 80,
        min_confidence: float = 0.4,
    ) -> None:
        self.registry = registry
        self.llm = llm_client
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.min_confidence = min_confidence

    def ingest(
        self,
        text: str,
        kb: KnowledgeBase,
        *,
        assignment: Optional[TruthAssignment] = None,
    ) -> IngestionResult:
        result = IngestionResult()
        for chunk in _chunk_text(text, self.chunk_size, self.overlap):
            prompt = build_extraction_prompt(chunk, self.registry)
            completion = self.llm.complete(prompt)
            parsed = self._parse_completion(completion)
            if not parsed:
                result.rejected_chunks.append(chunk)
                continue
            for fact in parsed:
                ok, message = self.registry.validate_fact(fact.predicate, fact.args)
                if not ok:
                    result.errors.append(message)
                    continue
                if fact.confidence < self.min_confidence:
                    continue
                predicate = self.registry.ensure_in_kb(fact.predicate, kb)
                constants = tuple(self._ensure_constant(kb, name) for name in fact.args)
                atom = Atom(predicate, constants)
                if assignment is not None:
                    assignment.set(atom, fact.confidence)
                result.facts_added.append(fact)
        return result

    def _ensure_constant(self, kb: KnowledgeBase, name: str) -> Constant:
        if name in kb.constants:
            return kb.constants[name]
        constant = Constant(name)
        kb.add_constant(constant)
        return constant

    def _parse_completion(self, completion: str) -> List[ParsedFact]:
        completion = completion.strip()
        if not completion:
            return []
        json_match = re.search(r"(\[.*\])", completion, re.DOTALL)
        payload = json_match.group(1) if json_match else completion
        try:
            data = json.loads(payload)
        except json.JSONDecodeError:
            return []
        facts: List[ParsedFact] = []
        if isinstance(data, dict) and "facts" in data:
            data = data["facts"]
        if not isinstance(data, list):
            return []
        for item in data:
            predicate = item.get("predicate")
            args: Sequence[str] = item.get("args", [])
            confidence = float(item.get("confidence", 1.0))
            provenance = item.get("provenance")
            if predicate and isinstance(args, list):
                facts.append(
                    ParsedFact(
                        predicate=predicate,
                        args=tuple(str(arg) for arg in args),
                        confidence=confidence,
                        provenance=provenance,
                    )
                )
        return facts

