import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol

from mineru_flow.internal.common.logging import get_worker_logger


@dataclass
class TaskContext:
    """任务执行上下文"""

    job_id: int
    current_phase: str
    next_phase: Optional[str]
    config: Dict[str, Any]
    artifacts_dir: Path
    worker_id: int


class PhaseProcessor(Protocol):
    """阶段处理器协议"""

    async def process(self, context: TaskContext) -> Dict[str, Any]:
        """处理特定阶段的任务"""
        ...


class BasePhaseProcessor(ABC):
    """阶段处理器基类"""

    def __init__(self, phase_name: str):
        self.phase_name = phase_name
        self.logger = get_worker_logger()

    @abstractmethod
    async def process(self, context: TaskContext) -> Dict[str, Any]:
        """处理特定阶段的任务"""
        pass

    async def validate_input(self, context: TaskContext) -> bool:
        """验证输入参数"""
        return True

    async def cleanup(self, context: TaskContext) -> None:
        """清理资源"""
        pass

    def prepare_media_dir(self, context: TaskContext, filename: str) -> Path:
        """为当前解析结果准备媒体输出目录"""
        safe_name = Path(filename).stem or Path(filename).name or "result"
        job_media_dir = context.artifacts_dir / context.current_phase
        target_dir = job_media_dir / safe_name
        target_dir.mkdir(parents=True, exist_ok=True)

        return target_dir

    async def save_artifacts(
        self,
        context: TaskContext,
        result: Dict[str, Any] | List[Any],
        target_path: Path | None = None,
    ) -> None:
        """保存处理结果到 artifacts - 通用方法"""
        if target_path is None:
            target_path = (
                context.artifacts_dir / f"{context.current_phase}-artifacts.json"
            )

        # 确保目录存在
        context.artifacts_dir.mkdir(parents=True, exist_ok=True)

        with open(target_path, "w", encoding="utf-8") as f:
            json.dump(result, f, ensure_ascii=False, indent=2)

        self.logger.info(
            "处理结果已保存到 artifacts",
            job_id=context.job_id,
            phase=context.current_phase,
            path=str(target_path),
        )
