import asyncio
import base64
import io
import json
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from urllib.parse import quote

import aiohttp
from aiohttp.payload import AsyncIterablePayload

from mineru_flow.internal.storage.base import StorageFile, StorageOperator


@dataclass
class MineruParsedResult:
    """MinerU 解析结果"""

    filename: str
    content: str
    content_list: List[Dict[str, Any]]
    images: Dict[str, str]
    html: Optional[str] = None
    latex: Optional[str] = None


TASK_FILE_API_PREFIX = "/api/v1/tasks/file?path="


def build_task_file_url(file_path: str) -> str:
    """构建用于访问任务文件的 API 路径"""
    return f"{TASK_FILE_API_PREFIX}{quote(str(file_path), safe='')}"


class MineruClientError(Exception):
    """Raised when the MinerU client fails to complete an operation."""


class MineruClientHTTPError(MineruClientError):
    """Raised when the MinerU service responds with an HTTP error."""

    def __init__(self, status_code: int, message: str):
        self.status_code = status_code
        super().__init__(message)


def prepare_result_media_dir(
    media_output_dir: Optional[Path], filename: str
) -> Path:
    """为解析结果准备媒体输出目录"""
    if media_output_dir is None:
        raise ValueError("media_output_dir must be provided to store images")

    safe_name = Path(filename).stem or Path(filename).name or "result"
    target_dir = media_output_dir / safe_name
    target_dir.mkdir(parents=True, exist_ok=True)
    return target_dir


def extract_relative_image_path(image_name: str) -> Path:
    """获取图片相对路径，去除 images/ 前缀"""
    image_path = Path(image_name)
    if image_path.parts and image_path.parts[0] == "images":
        relative_parts = image_path.parts[1:] or (image_path.name,)
    else:
        relative_parts = image_path.parts or (image_path.name,)
    return Path(*relative_parts)


def store_image_file(
    image_name: str, content: bytes, media_base_dir: Path
) -> Tuple[str, str]:
    """将图片写入磁盘并返回键和保存路径"""
    relative_path = extract_relative_image_path(image_name)
    key = relative_path.as_posix()

    output_path = media_base_dir / "images" / relative_path
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_bytes(content)
    return key, output_path.as_posix()


def replace_markdown_image_urls(
    content: str, images: Dict[str, str]
) -> str:
    """替换 Markdown 内容中的图片 URL"""
    if not content:
        return content

    for key, value in images.items():
        content = content.replace(f"images/{key}", value)
    return content


def decode_image_content(image_value: Any) -> bytes:
    """解析图片内容，支持 data URL 或纯 base64 字符串"""
    if isinstance(image_value, bytes):
        return image_value

    if isinstance(image_value, str):
        if image_value.startswith("data:"):
            try:
                _, encoded = image_value.split(",", 1)
            except ValueError as exc:
                raise ValueError("Invalid data URL format") from exc
        else:
            encoded = image_value
        return base64.b64decode(encoded)

    if isinstance(image_value, dict):
        for key in ("content", "data", "base64", "b64", "value"):
            if key in image_value and image_value[key] is not None:
                return decode_image_content(image_value[key])

    raise ValueError(f"Unsupported image payload type: {type(image_value)}")


def iter_image_payload(images_payload: Any) -> List[Tuple[str, Any]]:
    """标准化图片数据结构，返回 (name, payload) 列表"""
    if not images_payload:
        return []

    if isinstance(images_payload, dict):
        return [(str(key), value) for key, value in images_payload.items()]

    if isinstance(images_payload, list):
        entries: List[Tuple[str, Any]] = []
        for index, item in enumerate(images_payload):
            if isinstance(item, dict):
                name = (
                    item.get("name")
                    or item.get("filename")
                    or item.get("path")
                    or item.get("key")
                    or f"{index}"
                )
                data = (
                    item.get("content")
                    or item.get("data")
                    or item.get("base64")
                    or item.get("b64")
                    or item.get("value")
                )
                if data is None:
                    continue
                entries.append((str(name), data))
            else:
                entries.append((f"{index}", item))
        return entries

    return []


class MinerUSaasClient:
    """MinerU SaaS 客户端"""

    # 错误码映射
    ERROR_CODE_MAP = {
        "A0202": "Token error",
        "A0211": "Token expired",
        "-500": "Invalid parameters",
        "-10001": "Service exception",
        "-10002": "Invalid request parameters",
        "-60001": "Failed to generate upload URL",
        "-60002": "Failed to match a supported file format",
        "-60003": "Failed to read file",
        "-60004": "Empty file",
        "-60005": "File size exceeds limit",
        "-60006": "File page count exceeds limit",
        "-60007": "Model service temporarily unavailable",
        "-60008": "File read timed out",
        "-60009": "Task submission queue is full",
        "-60010": "Parsing failed",
        "-60011": "Failed to obtain valid file",
        "-60012": "Task not found",
        "-60013": "No permission to access task",
        "-60014": "Cannot delete a running task",
        "-60015": "File conversion failed",
        "-60016": "Failed to convert file to target format",
    }

    def __init__(
        self,
        base_url: str,
        api_key: str,
        storage_operator: StorageOperator,
        media_output_dir: Optional[str | Path] = None,
    ):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.storage_operator = storage_operator
        self.session: Optional[aiohttp.ClientSession] = None
        self.media_output_dir = (
            Path(media_output_dir).resolve() if media_output_dir is not None else None
        )

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()

    async def parse_documents(
        self,
        files: List[StorageFile],
        enable_formula: bool = True,
        enable_table: bool = True,
        language: str = "ch",
        model_version: str = "pipeline",
        extra_formats: Optional[List[str]] = None,
        enable_ocr: bool = False,
    ) -> List[MineruParsedResult]:
        """解析文档"""
        if not self.session:
            raise RuntimeError("Client not initialized. Use async with statement.")

        if not files:
            raise ValueError("No files to parse")
        if self.media_output_dir is None:
            raise ValueError("media_output_dir must be provided to store images")

        # 1. 获取批处理上传 URLs
        batch_data = await self._create_batch(
            files,
            enable_formula,
            enable_table,
            language,
            model_version,
            extra_formats,
            enable_ocr,
        )

        # 2. 上传文件
        uploaded_files = await self._upload_files(batch_data["file_urls"], files)

        if len(uploaded_files) == 0:
            raise MineruClientError("All files failed to upload")

        # 3. 轮询解析结果
        results = await self._poll_extract_results(
            batch_data["batch_id"], uploaded_files
        )

        return results

    async def _create_batch(
        self,
        files: List[StorageFile],
        enable_formula: bool,
        enable_table: bool,
        language: str,
        model_version: str,
        extra_formats: Optional[List[str]],
        enable_ocr: bool,
    ) -> Dict[str, Any]:
        """创建批处理任务"""
        url = f"{self.base_url}/api/v4/file-urls/batch"

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }

        payload = {
            "enable_formula": enable_formula,
            "enable_table": enable_table,
            "language": language,
            "model_version": model_version,
            "extra_formats": extra_formats or [],
            "files": [{"name": f.name, "is_ocr": enable_ocr} for f in files],
        }

        async with self.session.post(url, headers=headers, json=payload) as response:
            data = await response.json()

            if data.get("code") != 0:
                error_msg = data.get("msg") or self.ERROR_CODE_MAP.get(
                    str(data.get("code")), "Unknown error"
                )
                raise MineruClientError(f"Failed to create batch: {error_msg}")

            return data["data"]

    async def _upload_files(
        self, upload_urls: List[str], files: List[StorageFile]
    ) -> List[str]:
        """上传文件到指定 URLs"""
        uploaded_files = []
        failed_files = []

        for url, file_info in zip(upload_urls, files):
            try:
                stream = self.storage_operator.stream_read(file_info.path)
                payload = AsyncIterablePayload(stream, size=file_info.size)

                async with self.session.put(
                    url, data=payload, skip_auto_headers=["Content-Type"]
                ) as response:
                    if response.status == 200:
                        uploaded_files.append(url)
                        print(f"File {file_info.name} uploaded successfully")
                    else:
                        response_text = await response.text()
                        error_msg = (
                            f"File {file_info.name} upload failed: HTTP {response.status}"
                        )
                        if response_text:
                            error_msg += f" - {response_text}"
                        raise MineruClientHTTPError(response.status, error_msg)

            except Exception as e:
                print(f"Failed to upload {file_info.name}: {e}")
                failed_files.append(file_info.name)

        if failed_files:
            print(f"Failed to upload the following files: {', '.join(failed_files)}")

        return uploaded_files

    async def _poll_extract_results(
        self, batch_id: str, uploaded_files: List[str]
    ) -> List[MineruParsedResult]:
        """轮询解析结果"""
        url = f"{self.base_url}/api/v4/extract-results/batch/{batch_id}"
        headers = {"Authorization": f"Bearer {self.api_key}"}

        max_retries = 100
        completed_files = set()
        results = []

        for attempt in range(1, max_retries + 1):
            try:
                async with self.session.get(url, headers=headers) as response:
                    data = await response.json()

                    if data.get("code") != 0:
                        error_msg = data.get("msg") or self.ERROR_CODE_MAP.get(
                            str(data.get("code")), "Unknown error"
                        )
                        raise MineruClientError(
                            f"Failed to fetch extract results: {error_msg}"
                        )

                    extract_results = data["data"]["extract_result"]

                    # 检查失败的项目
                    failed_items = [
                        item for item in extract_results if item["state"] == "failed"
                    ]
                    if failed_items:
                        error_messages = [
                            item.get("err_msg", "Extract failed")
                            for item in failed_items
                        ]
                        raise MineruClientError(
                            f"Extraction failed: {'; '.join(error_messages)}"
                        )

                    # 处理新完成的项目
                    new_completed = [
                        item
                        for item in extract_results
                        if item["state"] == "done"
                        and item["file_name"] not in completed_files
                    ]

                    if new_completed:
                        for item in new_completed:
                            completed_files.add(item["file_name"])
                            result = await self._extract_from_zip(
                                item["full_zip_url"], item["file_name"]
                            )
                            results.append(result)

                    # 检查是否全部完成
                    if len(completed_files) > 0 and len(completed_files) >= len(
                        uploaded_files
                    ):
                        break

            except Exception as e:
                if attempt == max_retries:
                    raise e
                print(f"Retry {attempt} failed: {e}")

            # 等待后重试
            if len(completed_files) < len(uploaded_files):
                await asyncio.sleep(5)

        if len(completed_files) < len(uploaded_files):
            print(
                f"Maximum retries reached. Completed {len(completed_files)}/{len(uploaded_files)} files."
            )

        return results

    async def _extract_from_zip(
        self, zip_url: str, filename: str
    ) -> MineruParsedResult:
        """从 ZIP 文件中提取解析结果"""
        async with self.session.get(zip_url) as response:
            if response.status != 200:
                raise MineruClientHTTPError(
                    response.status, f"Failed to download ZIP file: HTTP {response.status}"
                )

            zip_data = await response.read()

        result = MineruParsedResult(
            filename=filename, content="", content_list=[], images={}
        )
        media_base_dir = prepare_result_media_dir(self.media_output_dir, filename)

        with zipfile.ZipFile(io.BytesIO(zip_data)) as zip_file:
            for file_info in zip_file.filelist:
                if file_info.is_dir():
                    continue

                file_content = zip_file.read(file_info.filename)
                if file_info.filename.startswith("images/"):
                    key, stored_path = store_image_file(
                        file_info.filename, file_content, media_base_dir
                    )
                    result.images[key] = build_task_file_url(stored_path)
                elif file_info.filename.endswith(".md"):
                    result.content = file_content.decode("utf-8")
                elif (
                    file_info.filename.endswith(".json")
                    and file_info.filename != "layout.json"
                ):
                    result.content_list = json.loads(file_content.decode("utf-8"))
                elif file_info.filename.endswith(".html"):
                    result.html = file_content.decode("utf-8")
                elif file_info.filename.endswith(".tex"):
                    result.latex = file_content.decode("utf-8")

        if result.images and result.content:
            result.content = replace_markdown_image_urls(result.content, result.images)

        return result


class MinerUSelfhostClient:
    """MinerU 自托管客户端"""

    def __init__(
        self,
        base_url: str,
        storage_operator: StorageOperator,
        media_output_dir: Optional[str | Path] = None,
    ):
        self.base_url = base_url.rstrip("/")
        self.storage_operator = storage_operator
        self.session: Optional[aiohttp.ClientSession] = None
        self.media_output_dir = (
            Path(media_output_dir).resolve() if media_output_dir is not None else None
        )

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()

    async def parse_documents(
        self,
        files: List[StorageFile],
        server_url: Optional[str] = None,
        backend_type: str = "pipeline",
        parse_method: str = "auto",
        return_images: bool = True,
        language: str = "ch",
        enable_formula: bool = True,
        enable_table: bool = True,
        out_dir: Optional[str] = None,
        return_content_list: bool = True,
        return_md: bool = True,
        start_page: Optional[int] = None,
        end_page: Optional[int] = None,
    ) -> List[MineruParsedResult]:
        """解析文档"""
        if not self.session:
            raise RuntimeError("Client not initialized. Use async with statement.")

        if not files:
            raise ValueError("No files to parse")
        if self.media_output_dir is None:
            raise ValueError("media_output_dir must be provided to store images")

        url = f"{self.base_url}/file_parse"

        # 构建表单数据
        form_data = aiohttp.FormData()

        # 添加文件
        for file_info in files:
            file_data = await self.storage_operator.read(file_info.path)

            form_data.add_field("files", file_data, filename=file_info.name)

        # 添加参数
        if server_url:
            form_data.add_field("server_url", server_url)

        # 处理语言列表
        lang_list = [lang.strip() for lang in language.split(",")]
        for lang in lang_list:
            form_data.add_field("lang_list", lang)

        form_data.add_field("backend", backend_type)
        form_data.add_field("parse_method", parse_method)
        form_data.add_field("formula_enable", enable_formula)
        form_data.add_field("table_enable", enable_table)
        form_data.add_field("return_images", return_images)
        form_data.add_field("return_content_list", return_content_list)
        form_data.add_field("return_md", return_md)
        form_data.add_field("out_dir", out_dir)

        if start_page is not None:
            form_data.add_field("start_page", str(start_page))
        if end_page is not None:
            form_data.add_field("end_page", str(end_page))

        # 发送请求（30分钟超时）
        # TODO： 本地部署或许可以监控out_dir目录来获取结果
        timeout = aiohttp.ClientTimeout(total=1800)
        async with self.session.post(url, data=form_data, timeout=timeout) as response:
            if response.status != 200:
                error_text = await response.text()
                raise MineruClientHTTPError(
                    response.status,
                    f"Parsing failed: HTTP {response.status} - {error_text}",
                )

            data = await response.json()

            if not data.get("results"):
                raise MineruClientError("Parsing result is empty")

            return self._process_results(data["results"])

    def _process_results(
        self, results_data: Dict[str, Any]
    ) -> List[MineruParsedResult]:
        """处理解析结果"""
        results = []

        for filename, result_item in results_data.items():
            result = MineruParsedResult(
                filename=filename, content="", content_list=[], images={}
            )

            image_entries = iter_image_payload(result_item.get("images"))
            if image_entries:
                media_base_dir = prepare_result_media_dir(
                    self.media_output_dir, filename
                )
                processed_images: Dict[str, str] = {}
                for image_name, payload in image_entries:
                    try:
                        image_bytes = decode_image_content(payload)
                    except Exception as exc:
                        raise MineruClientError(
                            f"Failed to decode image payload: {image_name}"
                        ) from exc

                    stored_key, stored_path = store_image_file(
                        image_name, image_bytes, media_base_dir
                    )
                    processed_images[stored_key] = build_task_file_url(stored_path)

                result.images = processed_images

            # 处理内容列表
            if result_item.get("content_list"):
                try:
                    if isinstance(result_item["content_list"], str):
                        result.content_list = json.loads(result_item["content_list"])
                    else:
                        result.content_list = result_item["content_list"]
                except json.JSONDecodeError as exc:
                    raise MineruClientError(
                        "content_list is not a valid JSON string"
                    ) from exc

            # 处理 Markdown 内容
            if result_item.get("md_content"):
                content = result_item["md_content"]
                if result.images:
                    content = replace_markdown_image_urls(content, result.images)
                result.content = content

            results.append(result)

        return results
