import asyncio
from dataclasses import dataclass
from typing import Any, List, Optional

from ragflow_sdk import RAGFlow


@dataclass
class ImportDocument:
    """导入文档数据"""

    filename: str
    content: str


@dataclass
class RagflowConfig:
    """Ragflow 配置"""

    base_url: str
    api_key: str
    dataset_id: str
    import_as_chunk: bool = False


class RagflowClientError(Exception):
    """Base error for Ragflow client failures."""


class RagflowDatasetNotFoundError(RagflowClientError):
    """Raised when the configured dataset cannot be found."""


class RagflowDocumentNotFoundError(RagflowClientError):
    """Raised when a target document does not exist."""


class RagflowClient:
    """Ragflow 知识库客户端"""

    def __init__(self, config: RagflowConfig):
        self.config = config
        self.base_url = config.base_url.rstrip("/")
        self.ragflow = RAGFlow(api_key=config.api_key, base_url=config.base_url)
        self._dataset: Optional[Any] = None
        self._dataset_lock = asyncio.Lock()

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        return False

    async def _get_dataset(self) -> Any:
        if self._dataset:
            return self._dataset

        async with self._dataset_lock:
            if self._dataset:
                return self._dataset

            datasets = await asyncio.to_thread(
                self.ragflow.list_datasets, id=self.config.dataset_id
            )

            if not datasets:
                raise RagflowDatasetNotFoundError(
                    f"Ragflow: datasetId -> {self.config.dataset_id} 未找到对应数据集"
                )

            self._dataset = datasets[0]
            return self._dataset

    async def import_chunks(self, chunks: List[str], document_id: str):
        dataset = await self._get_dataset()

        documents = await asyncio.to_thread(dataset.list_documents, id=document_id)
        if not documents:
            raise RagflowDocumentNotFoundError(
                f"Ragflow: documentId -> {document_id} 未找到对应文档"
            )

        doc = documents[0]

        result = []
        for chunk in chunks:
            res_chunk = await asyncio.to_thread(doc.add_chunk, content=chunk)
            json = res_chunk.to_json()
            json.pop('content')
            result.append(json)

        return result

    async def import_documents(self, documents: List[ImportDocument]) -> List[Any]:
        dataset = await self._get_dataset()

        payload = [
            {"display_name": f"{doc.filename}.md", "blob": doc.content.encode("utf-8")}
            for doc in documents
        ]

        res_documents = await asyncio.to_thread(
            dataset.upload_documents,
            payload,
        )

        return [item.to_json() for item in res_documents]
