import asyncio
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Tuple

import aiofiles
from bs4 import BeautifulSoup
from langchain_text_splitters import RecursiveCharacterTextSplitter

from mineru_flow.internal.common.exceptions import PhaseExecutionError
from mineru_flow.internal.processor.base import BasePhaseProcessor, TaskContext
from mineru_flow.internal.schema.state import Phase


@dataclass
class SemanticBlock:
    """代表一个从文档解析出的、具有语义完整性的内容块。"""

    content: str
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class Chunk:
    """代表最终进入向量数据库的一个分块。"""

    content: str
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """将Chunk对象序列化为字典。"""
        return asdict(self)


ChunkingStrategy = Literal["HYBRID", "MERGE", "RECURSIVE"]


class DefaultDocumentChunker(BasePhaseProcessor):
    """
    文档分块处理器
    """

    # 标题和内容之间的分隔符
    HEADING_SEPARATOR = "\n\n"

    def __init__(self):
        super().__init__(Phase.CHUNK)

    async def _prepare_blocks(
        self, context: TaskContext
    ) -> Tuple[List[Path], List[List[SemanticBlock]]]:
        parse_dir = context.artifacts_dir / Phase.PARSE
        if not parse_dir.exists():
            raise PhaseExecutionError(
                "解析阶段结果目录不存在。", Phase.PARSE, context.job_id
            )

        result_files = [
            child / "result.json"
            for child in parse_dir.iterdir()
            if child.is_dir() and (child / "result.json").is_file()
        ]
        if not result_files:
            raise PhaseExecutionError(
                "未找到任何解析结果 result.json。", Phase.PARSE, context.job_id
            )

        tasks = [self._read_and_parse_result(path, context) for path in result_files]
        block_lists = await asyncio.gather(*tasks)

        if not any(block_lists):
            raise PhaseExecutionError(
                "所有解析文件的 content_list 均为空，无法生成语义块。",
                Phase.PARSE,
                context.job_id,
            )

        return result_files, block_lists

    async def _read_and_parse_result(
        self, result_path: Path, context: TaskContext
    ) -> List[SemanticBlock]:
        """异步读取并解析单个 result.json 文件。"""
        try:
            async with aiofiles.open(result_path, "r", encoding="utf-8") as f:
                content_str = await f.read()
                parsed = json.loads(content_str)
        except Exception as exc:
            self.logger.warning(
                "读取或解析结果文件失败，已跳过。",
                job_id=context.job_id,
                file=str(result_path),
                error=str(exc),
            )
            return []

        content_list = parsed.get("content_list")
        if not isinstance(content_list, list):
            return []

        filename = parsed.get("filename") or result_path.parent.name
        base_metadata = {"filename": filename, "source": str(result_path.parent)}
        # 例如: {1: "2. Methods", 2: "2.1. Characterisation..."}
        heading_context_stack: Dict[int, str] = {}
        file_blocks: List[SemanticBlock] = []
        for index, item in enumerate(content_list):
            content, item_metadata = "", {}
            if isinstance(item, dict):
                item_metadata = {
                    k: v
                    for k, v in item.items()
                    if k not in {"text", "image_caption", "table_body", "content"}
                }

            metadata = {**base_metadata, "item_index": index, **item_metadata}

            if isinstance(item, dict):
                item_metadata_dict = {
                    k: v
                    for k, v in item.items()
                    if k not in {"text", "image_caption", "table_body", "content"}
                }
                metadata.update(item_metadata_dict)

                content_type = item.get("type")
                text_level = item.get("text_level", 0)
                item_text = item.get("text", "").strip()

                if content_type == "text" and text_level > 0 and item_text:
                    heading_context_stack[text_level] = item_text
                    keys_to_remove = [
                        k for k in heading_context_stack if k > text_level
                    ]
                    for k in keys_to_remove:
                        del heading_context_stack[k]

                    continue

                sorted_headings = [
                    heading_context_stack[k]
                    for k in sorted(heading_context_stack.keys())
                ]
                current_heading_prefix = self.HEADING_SEPARATOR.join(sorted_headings)
                if current_heading_prefix:
                    current_heading_prefix += self.HEADING_SEPARATOR

                if content_type == "text":
                    content = item.get("text", "")
                if content_type == "image":
                    captions = item.get("image_caption")
                    content = (
                        "\n".join(filter(None, map(str, captions)))
                        if isinstance(captions, list)
                        else str(captions or "")
                    )
                elif content_type == "equation":
                    content = item.get("text", "").strip()
                elif content_type == "table":
                    table_blocks = self._parse_table_as_rows(
                        item, metadata, context, current_heading_prefix
                    )
                    file_blocks.extend(table_blocks)
                    continue

            else:
                content = str(item)
                item_metadata["type"] = "text"

                sorted_headings = [
                    heading_context_stack[k]
                    for k in sorted(heading_context_stack.keys())
                ]
                current_heading_prefix = self.HEADING_SEPARATOR.join(sorted_headings)
                if current_heading_prefix:
                    current_heading_prefix += self.HEADING_SEPARATOR

            content = content.strip()
            if not content:
                continue

            final_content = current_heading_prefix + content
            file_blocks.append(SemanticBlock(content=final_content, metadata=metadata))

        return file_blocks

    def _parse_table_as_rows(
        self,
        item: Dict[str, Any],
        base_metadata: Dict[str, Any],
        context: TaskContext,
        heading_prefix: str,
    ) -> List[SemanticBlock]:
        """
        将一个表格 HTML 转换为多个“逐行线性化”的 SemanticBlock。
        """
        html_content = item.get("table_body", "")
        if not html_content:
            return []

        try:
            soup = BeautifulSoup(html_content, "html.parser")
            rows = soup.find_all("tr")
            if len(rows) < 2:  # 至少需要一个表头和一行数据
                return []
        except Exception as e:
            self.logger.warning(
                f"BeautifulSoup 解析表格失败: {e}",
                job_id=context.job_id,
                **base_metadata,
            )
            return []

        caption_list = item.get("table_caption", [])
        # 优先使用 MinerU 提取的标题，如果为空则尝试使用元数据中的 text 字段（如果存在）
        caption = ""
        if caption_list:
            caption = caption_list[0].strip()
        if not caption:
            caption = item.get("text", "").strip()  # 有时标题在 'text' 字段
        if not caption:
            caption = base_metadata.get("filename", "table")  # 最后的回退

        headers = [
            th.get_text(strip=True).replace("\n", " ")
            for th in rows[0].find_all(["th", "td"])
        ]

        blocks: List[SemanticBlock] = []
        for row_index, row in enumerate(rows[1:]):
            cells = [
                td.get_text(strip=True).replace("\n", " ") for td in row.find_all("td")
            ]

            if len(cells) != len(headers):
                self.logger.warning(
                    f"跳过不匹配的表格行: 表头 {len(headers)} 列, 此行 {len(cells)} 列",
                    job_id=context.job_id,
                    **base_metadata,
                )
                continue

            linearized_parts = []
            for header, cell in zip(headers, cells):
                if header and cell:
                    linearized_parts.append(f"{header}: {cell}")

            if not linearized_parts:
                continue

            linearized_content = f"Table '{caption}', Row {row_index + 1}: {', '.join(linearized_parts)}."

            # 【已修改】将标题上下文、表格行内容组合
            final_content = heading_prefix + linearized_content

            row_metadata = base_metadata.copy()
            row_metadata["type"] = "table_row"
            row_metadata["table_caption"] = caption
            row_metadata["row_index"] = row_index + 1
            row_metadata.pop("table_body", None)
            row_metadata.pop("bbox", None)

            blocks.append(SemanticBlock(content=final_content, metadata=row_metadata))

        return blocks

    async def process(self, context: TaskContext) -> Dict[str, Any]:
        """主处理函数，增加返回值包装。"""
        config = context.config["chunk"]

        self.strategy = config["strategy"]
        self.chunk_overlap = config["chunk_overlap"]
        self.chunk_size = config["chunk_size"]

        self._fallback_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=len,
            add_start_index=False,
        )

        files, block_lists = await self._prepare_blocks(context)

        self.logger.info(
            f"开始对 {len(block_lists)} 个文件进行分块...",
            job_id=context.job_id,
            strategy=self.strategy,
        )

        tasks = [
            self._chunk_and_save_single_file(context, file, blocks)
            for file, blocks in zip(files, block_lists)
            if blocks
        ]
        output_paths = await asyncio.gather(*tasks)

        input_file_stats = await asyncio.gather(
            *[asyncio.to_thread(f.stat) for f in files]
        )

        loop = asyncio.get_running_loop()

        return {
            "input": [
                {
                    "filename": file.name,
                    "path": str(file),
                    "size": stat.st_size,
                }
                for file, stat in zip(files, input_file_stats)
            ],
            "timestamp": loop.time(),
            "output": output_paths,
        }

    async def _chunk_and_save_single_file(
        self, context: TaskContext, file: Path, blocks: List[SemanticBlock]
    ) -> str:
        filename = blocks[0].metadata.get("filename", file.parent.name)
        self.logger.info(
            f"正在处理文件: {filename}，包含 {len(blocks)} 个语义块。",
            job_id=context.job_id,
        )

        chunks: List[Chunk] = []
        if self.strategy == "HYBRID":
            chunks = self._chunk_with_hybrid_strategy(blocks)
        elif self.strategy == "MERGE":
            chunks = self._chunk_with_merge_strategy(blocks)
        elif self.strategy == "RECURSIVE":
            chunks = self._chunk_with_recursive_strategy(blocks)
        else:
            raise PhaseExecutionError(
                f"未知的策略: {self.strategy}", context.current_phase, context.job_id
            )

        chunks_as_dicts = [chunk.to_dict() for chunk in chunks]

        basename = file.parent.name
        target_dir = self.prepare_media_dir(context, basename)
        file_path = target_dir / "result.json"

        await self.save_artifacts(
            context=context, result=chunks_as_dicts, target_path=file_path
        )

        self.logger.info(
            f"文件 '{filename}' 分块完成，生成 {len(chunks)} 个分块，已保存至 {file_path}",
            job_id=context.job_id,
        )
        return str(file_path)

    def _finalize_buffer(
        self, buffer_content_parts: list, buffer_metadata_list: list
    ) -> Chunk:
        final_content = "\n\n".join(buffer_content_parts)
        if not buffer_metadata_list:
            return Chunk(content=final_content, metadata={})

        final_metadata = buffer_metadata_list[0].copy()
        final_metadata["source_item_indexes"] = [
            m.get("item_index") for m in buffer_metadata_list
        ]

        page_indices = sorted(
            list(
                set(
                    m.get("page_idx")
                    for m in buffer_metadata_list
                    if m.get("page_idx") is not None
                )
            )
        )
        if page_indices:
            final_metadata["source_page_indices"] = page_indices

        # 移除不再适用于“合并块”的特定元数据
        final_metadata.pop("item_index", None)
        final_metadata.pop("page_idx", None)
        final_metadata.pop("bbox", None)
        final_metadata.pop("text_level", None)
        final_metadata.pop("type", None)
        final_metadata.pop("table_caption", None)
        final_metadata.pop("row_index", None)

        return Chunk(content=final_content, metadata=final_metadata)

    def _chunk_with_hybrid_strategy(self, blocks: List[SemanticBlock]) -> List[Chunk]:
        """策略1: 推荐的混合策略。合并小块，分割超大块。"""
        self.logger.info("执行 HYBRID 策略...")
        chunks: List[Chunk] = []
        buffer_content_parts: List[str] = []
        buffer_metadata_list: List[Dict] = []
        buffer_len = 0

        for block in blocks:
            block_len = len(block.content)
            if block_len > self.chunk_size:
                if buffer_content_parts:
                    chunks.append(
                        self._finalize_buffer(
                            buffer_content_parts, buffer_metadata_list
                        )
                    )
                    buffer_content_parts, buffer_metadata_list, buffer_len = [], [], 0
                self.logger.info(
                    f"检测到超长语义块 (长度 {block_len})，进行内部分割..."
                )
                sub_splits = self._fallback_splitter.split_text(block.content)
                for split in sub_splits:
                    chunks.append(Chunk(content=split, metadata=block.metadata))
                continue

            if buffer_len > 0 and buffer_len + len(block.content) + 2 > self.chunk_size:
                chunks.append(
                    self._finalize_buffer(buffer_content_parts, buffer_metadata_list)
                )
                buffer_content_parts, buffer_metadata_list, buffer_len = [], [], 0

            separator_len = 2 if buffer_content_parts else 0
            buffer_content_parts.append(block.content)
            buffer_metadata_list.append(block.metadata)
            buffer_len += block_len + separator_len

        if buffer_content_parts:
            chunks.append(
                self._finalize_buffer(buffer_content_parts, buffer_metadata_list)
            )
        return chunks

    def _chunk_with_merge_strategy(self, blocks: List[SemanticBlock]) -> List[Chunk]:
        """策略2: 只合并，不分割超长块。"""
        self.logger.info("执行 MERGE 策略...")
        chunks: List[Chunk] = []
        buffer_content_parts: List[str] = []
        buffer_metadata_list: List[Dict] = []
        buffer_len = 0

        for block in blocks:
            block_len = len(block.content)
            if block_len > self.chunk_size:
                if buffer_content_parts:
                    chunks.append(
                        self._finalize_buffer(
                            buffer_content_parts, buffer_metadata_list
                        )
                    )
                    buffer_content_parts, buffer_metadata_list, buffer_len = [], [], 0

                chunks.append(Chunk(content=block.content, metadata=block.metadata))
                continue

            if buffer_len > 0 and buffer_len + block_len + 2 > self.chunk_size:
                chunks.append(
                    self._finalize_buffer(buffer_content_parts, buffer_metadata_list)
                )
                buffer_content_parts, buffer_metadata_list, buffer_len = [], [], 0

            separator_len = 2 if buffer_content_parts else 0
            buffer_content_parts.append(block.content)
            buffer_metadata_list.append(block.metadata)
            buffer_len += block_len + separator_len

        if buffer_content_parts:
            chunks.append(
                self._finalize_buffer(buffer_content_parts, buffer_metadata_list)
            )
        return chunks

    def _chunk_with_recursive_strategy(
        self, blocks: List[SemanticBlock]
    ) -> List[Chunk]:
        """策略3: 忽略语义结构，直接拼接后用RecursiveCharacterTextSplitter分割。"""
        self.logger.info("执行 RECURSIVE 策略...")

        all_content = []
        for block in blocks:
            content = str(block.content).strip()
            if content:
                all_content.append(content)

        full_text = "\n\n".join(all_content)
        if not full_text:
            return []

        split_texts = self._fallback_splitter.split_text(full_text)

        doc_metadata = blocks[0].metadata.copy() if blocks else {}
        doc_metadata.pop("item_index", None)
        doc_metadata.pop("page_idx", None)
        doc_metadata.pop("bbox", None)
        doc_metadata.pop("text_level", None)
        doc_metadata.pop("type", None)
        doc_metadata.pop("table_caption", None)
        doc_metadata.pop("row_index", None)

        chunks = [Chunk(content=text, metadata=doc_metadata) for text in split_texts]
        return chunks
