import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List

import aiofiles

from mineru_flow.internal.common.exceptions import PhaseExecutionError
from mineru_flow.internal.processor.base import BasePhaseProcessor, TaskContext
from mineru_flow.internal.schema.state import Phase

from .ragflow_client import ImportDocument, RagflowClient, RagflowConfig


class KnowledgeBaseImportProcessor(BasePhaseProcessor):
    """知识库导入处理器"""

    def __init__(self):
        super().__init__(Phase.IMPORT)

        self._input: List[Dict[str, Any]] = []
        self._output: List[str] = []

    async def validate_input(self, context: TaskContext) -> bool:
        """验证输入参数"""
        config = context.config

        # 检查知识库配置
        if "knowledgeBases" not in config or not config["knowledgeBases"]:
            self.logger.error("缺少知识库配置", job_id=context.job_id)
            raise PhaseExecutionError(
                "缺少知识库配置",
                context.current_phase,
                context.job_id,
            )

        # 检查每个知识库配置
        for i, kb_config in enumerate(config["knowledgeBases"]):
            if "type" not in kb_config:
                raise PhaseExecutionError(
                    f"知识库配置 {i} 缺少 type 字段",
                    context.current_phase,
                    context.job_id,
                )

            if kb_config["type"] != "ragflow":
                continue

            required_fields = ["baseUrl", "apiKey", "datasetId"]
            for field in required_fields:
                if not kb_config.get(field):
                    raise PhaseExecutionError(
                        f"Ragflow 配置缺少必要字段: {field}",
                        context.current_phase,
                        context.job_id,
                    )

            parse_dir = context.artifacts_dir / Phase.PARSE
            parse_result_files: List[Path] = []
            if parse_dir.exists():
                parse_result_files = [
                    child / "result.json"
                    for child in parse_dir.iterdir()
                    if child.is_dir() and (child / "result.json").is_file()
                ]

            import_as_chunk = kb_config.get("importAsChunk", False)

            if import_as_chunk:
                if not parse_result_files:
                    raise PhaseExecutionError(
                        "分块导入需要解析阶段的结果文件",
                        context.current_phase,
                        context.job_id,
                    )

                chunk_dir = context.artifacts_dir / Phase.CHUNK
                if not chunk_dir.exists():
                    raise PhaseExecutionError(
                        "分块导入缺少分块阶段的结果目录",
                        context.current_phase,
                        context.job_id,
                    )

                for parse_file in parse_result_files:
                    chunk_file = chunk_dir / parse_file.parent.name / "result.json"
                    if not chunk_file.is_file():
                        raise PhaseExecutionError(
                            f"解析结果 {parse_file.parent.name} 缺少对应的分块结果",
                            context.current_phase,
                            context.job_id,
                        )
        return True

    async def process(self, context: TaskContext) -> Dict[str, Any]:
        """导入知识库"""
        self.logger.info(
            "开始知识库导入", job_id=context.job_id, phase=context.current_phase
        )

        # 验证输入
        if not await self.validate_input(context):
            raise PhaseExecutionError(
                "输入验证失败", context.current_phase, context.job_id
            )

        try:
            # 执行知识库导入
            await self._import_to_knowledge_bases(context)

            self.logger.info(
                "知识库导入完成",
                job_id=context.job_id,
                total_imported=len(self._output),
            )

            return {
                "input": self._input,
                "timestamp": asyncio.get_running_loop().time(),
                "output": self._output,
            }

        except Exception as e:
            self.logger.error("知识库导入失败", job_id=context.job_id, error=str(e))
            raise PhaseExecutionError(
                f"知识库导入失败: {e}", context.current_phase, context.job_id
            )

    async def _import_to_knowledge_bases(
        self,
        context: TaskContext,
    ):
        """导入到各个知识库"""
        kb_configs = context.config["knowledgeBases"]
        results = []

        for kb_config in kb_configs:
            kb_type = kb_config["type"]

            if kb_type == "ragflow":
                result = await self._import_to_ragflow(context, kb_config)

                results.append((kb_type, result))
            else:
                self.logger.warning(
                    f"不支持的知识库类型: {kb_type}", job_id=context.job_id
                )

                continue

        return results

    async def _import_to_ragflow(
        self,
        context: TaskContext,
        kb_config: Dict[str, Any],
    ) -> None:
        """导入到 Ragflow 知识库"""
        self.logger.info("开始导入到 Ragflow", job_id=context.job_id)

        # 构建 Ragflow 配置
        ragflow_config = RagflowConfig(
            base_url=kb_config["baseUrl"],
            api_key=kb_config["apiKey"],
            dataset_id=kb_config["datasetId"],
            import_as_chunk=kb_config.get("importAsChunk", False),
        )

        # 执行导入
        parse_dir = context.artifacts_dir / Phase.PARSE
        chunk_dir = context.artifacts_dir / Phase.CHUNK
        result_filename = "result.json"

        if not parse_dir.exists():
            raise PhaseExecutionError(
                "未找到解析阶段产物目录",
                context.current_phase,
                context.job_id,
            )

        parse_result_files = sorted(
            [
                child / result_filename
                for child in parse_dir.iterdir()
                if child.is_dir() and (child / result_filename).is_file()
            ],
            key=lambda file_path: file_path.parent.name,
        )

        if not parse_result_files:
            raise PhaseExecutionError(
                "未找到解析结果文件",
                context.current_phase,
                context.job_id,
            )

        chunk_result_files: List[Path] = []
        if ragflow_config.import_as_chunk:
            if not chunk_dir.exists():
                raise PhaseExecutionError(
                    "未找到分块阶段产物目录",
                    context.current_phase,
                    context.job_id,
                )

            for parse_file in parse_result_files:
                chunk_file = chunk_dir / parse_file.relative_to(parse_dir)
                if not chunk_file.is_file():
                    raise PhaseExecutionError(
                        f"解析结果 {parse_file.parent.name} 缺少对应的分块结果",
                        context.current_phase,
                        context.job_id,
                    )
                chunk_result_files.append(chunk_file)

            if len(chunk_result_files) != len(parse_result_files):
                raise PhaseExecutionError(
                    "分块结果数量与解析结果数量不匹配",
                    context.current_phase,
                    context.job_id,
                )

        final_input_files: List[Path] = parse_result_files + chunk_result_files

        input_file_stats = await asyncio.gather(
            *[asyncio.to_thread(f.stat) for f in final_input_files]
        )

        self._input = [
            {
                "filename": file.name,
                "path": str(file),
                "size": stat.st_size,
            }
            for file, stat in zip(final_input_files, input_file_stats)
        ]

        documents: List[ImportDocument] = []

        # 导入文档
        for file_path in parse_result_files:
            async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
                data = json.loads(await f.read())

            if not isinstance(data, dict):
                raise PhaseExecutionError(
                    f"解析文件 {file_path} 的内容格式不是对象（dict）",
                    context.current_phase,
                    context.job_id,
                )

            content = data.get("content")
            if not content:
                raise PhaseExecutionError(
                    f"解析文件 {file_path} 缺少 content 字段或内容为空",
                    context.current_phase,
                    context.job_id,
                )

            filename = data.get("filename") or file_path.parent.name
            documents.append(
                ImportDocument(
                    content=content,
                    filename=filename,
                )
            )

        async with RagflowClient(ragflow_config) as client:
            document_import_result = await client.import_documents(documents)

            target_dir = self.prepare_media_dir(context, "ragflow")
            document_import_output = target_dir / "document_import_result.json"
            await self.save_artifacts(
                context=context,
                result=document_import_result,
                target_path=document_import_output,
            )
            self._output.append(str(document_import_output))

            if ragflow_config.import_as_chunk:
                if len(document_import_result) != len(chunk_result_files):
                    raise PhaseExecutionError(
                        "文档导入结果数量与分块文件数量不一致",
                        context.current_phase,
                        context.job_id,
                    )

                for document_result, chunk_result_file in zip(
                    document_import_result, chunk_result_files
                ):
                    if not isinstance(document_result, dict):
                        raise PhaseExecutionError(
                            "Ragflow 文档导入返回值格式异常",
                            context.current_phase,
                            context.job_id,
                        )

                    document_id = document_result.get("id")
                    if not isinstance(document_id, str):
                        raise PhaseExecutionError(
                            "Ragflow 文档导入结果缺少 document_id",
                            context.current_phase,
                            context.job_id,
                        )

                    async with aiofiles.open(
                        chunk_result_file, "r", encoding="utf-8"
                    ) as f:
                        chunk_data = json.loads(await f.read())

                    if not isinstance(chunk_data, list):
                        raise PhaseExecutionError(
                            f"分块结果文件 {chunk_result_file} 的内容必须是列表（list）",
                            context.current_phase,
                            context.job_id,
                        )

                    chunks: List[str] = []
                    for chunk_index, item in enumerate(chunk_data):
                        if not isinstance(item, dict):
                            self.logger.warning(
                                "跳过非字典结构的分块结果",
                                job_id=context.job_id,
                                file=str(chunk_result_file),
                                chunk_index=chunk_index,
                            )
                            continue

                        content = item.get("content")
                        if not isinstance(content, str) or not content.strip():
                            self.logger.warning(
                                "跳过缺少内容的分块",
                                job_id=context.job_id,
                                file=str(chunk_result_file),
                                chunk_index=chunk_index,
                            )
                            continue

                        chunks.append(content)

                    if not chunks:
                        self.logger.warning(
                            "未找到可导入的分块内容",
                            job_id=context.job_id,
                            file=str(chunk_result_file),
                            document_id=document_id,
                        )
                        continue

                    chunk_import_result = await client.import_chunks(
                        chunks, document_id
                    )
                    chunk_import_output = (
                        target_dir / f"chunk_import_result_{document_id}.json"
                    )
                    await self.save_artifacts(
                        context=context,
                        result=chunk_import_result,
                        target_path=chunk_import_output,
                    )
                    self._output.append(str(chunk_import_output))
