import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

from mineru_flow.internal.common.database import DatabaseManager
from mineru_flow.internal.common.exceptions import PhaseExecutionError
from mineru_flow.internal.common.logging import get_worker_logger
from mineru_flow.internal.config.worker import get_worker_config
from mineru_flow.internal.crud.task import job as job_crud
from mineru_flow.internal.processor.base import TaskContext
from mineru_flow.internal.processor.registry import ProcessorRegistry


@dataclass
class WorkerStats:
    """Worker 统计信息"""

    worker_id: int
    active_jobs: int = 0
    total_processed: int = 0
    total_errors: int = 0
    consecutive_errors: int = 0
    last_error: Optional[str] = None


class ImprovedJobWorker:
    def __init__(self, worker_id: int):
        self.worker_id = worker_id
        self.config = get_worker_config()
        self.logger = get_worker_logger()
        self.stats = WorkerStats(worker_id=worker_id)
        self.processor_registry = ProcessorRegistry()
        self._running = False

    @asynccontextmanager
    async def _database_session(self):
        """数据库会话上下文管理器"""
        with DatabaseManager.session_scope() as db:
            yield db

    async def _get_next_job(self) -> Optional[Dict[str, Any]]:
        """获取下一个待处理任务"""
        async with self._database_session() as db:
            job = job_crud.get_and_lock_waiting_job(db)
            if job:
                # 将Job对象转换为字典，避免会话绑定问题
                return {
                    "id": job.id,
                    "phase": job.phase,
                    "status": job.status,
                    "config": job.config,
                    "phases": job.phases,
                    "created_at": job.created_at,
                    "updated_at": job.updated_at,
                    "error_message": job.error_message,
                }
            return None

    async def _update_job_phase(self, job_id: int, phase: str) -> None:
        """更新任务阶段"""
        async with self._database_session() as db:
            job_crud.update_phase(db=db, job_id=job_id, phase=phase)

    async def _update_phase_result(
        self, job_id: int, phase: str, result: Dict[str, Any]
    ) -> None:
        """更新任务阶段结果"""
        async with self._database_session() as db:
            job_crud.update_phase_result(
                db=db, job_id=job_id, phase=phase, result=result
            )

    async def _mark_job_running(self, job_id: int) -> None:
        """标记任务为运行状态"""
        async with self._database_session() as db:
            job_crud.mark_as_running(db=db, job_id=job_id)

    async def _mark_job_success(self, job_id: int) -> None:
        """标记任务为成功状态"""
        async with self._database_session() as db:
            job_crud.mark_as_success(db=db, job_id=job_id)

    async def _mark_job_failed(self, job_id: int, error_message: str) -> None:
        """标记任务失败"""
        async with self._database_session() as db:
            job_crud.mark_as_failed(db=db, job_id=job_id, error_message=error_message)

    async def _process_job_phase(self, job_data: Dict[str, Any], context: TaskContext):
        """处理任务阶段"""
        processor = self.processor_registry.get_processor(context.current_phase)
        if not processor:
            raise PhaseExecutionError(
                f"No processor registered for phase '{context.current_phase}'",
                context.current_phase,
                context.job_id,
            )

        result = await processor.process(context)

        await self._update_phase_result(
            job_id=context.job_id, phase=context.current_phase, result=result
        )

    async def _process_job(self, job_data: Dict[str, Any]) -> None:
        """处理单个任务"""
        self.stats.active_jobs += 1
        job_id = job_data["id"]

        try:
            # 标记任务为运行状态
            await self._mark_job_running(job_id)

            # 创建任务上下文
            artifacts_dir = Path(self.config.artifacts_dir) / str(job_id)
            artifacts_dir.mkdir(parents=True, exist_ok=True)

            current_phase = job_data["phase"]
            phases = job_data["phases"]

            while current_phase:
                phase_index = phases.index(current_phase)

                if phase_index == -1:
                    raise PhaseExecutionError(
                        f"Unknown phase '{current_phase}'",
                        current_phase,
                        job_id,
                    )

                next_phase = (
                    phases[phase_index + 1] if phase_index + 1 < len(phases) else None
                )

                context = TaskContext(
                    job_id=job_id,
                    current_phase=current_phase,
                    next_phase=next_phase,
                    config=job_data["config"] or {},
                    artifacts_dir=artifacts_dir,
                    worker_id=self.worker_id,
                )

                self.logger.info(
                    "开始处理阶段",
                    job_id=job_id,
                    phase=current_phase,
                    worker_id=self.worker_id,
                )

                # 更新阶段状态
                await self._update_job_phase(job_id, current_phase)

                # 处理阶段
                await self._process_job_phase(job_data, context)

                self.logger.info(
                    "阶段处理完成",
                    job_id=job_id,
                    phase=current_phase,
                    worker_id=self.worker_id,
                )

                # 移到下一阶段
                current_phase = next_phase

            # 所有阶段完成，标记任务成功
            await self._mark_job_success(job_id)

            self.stats.total_processed += 1
            self.stats.consecutive_errors = 0
            self.logger.info("任务处理完成", job_id=job_id, worker_id=self.worker_id)

        except Exception as e:
            self.stats.total_errors += 1
            self.stats.consecutive_errors += 1
            self.stats.last_error = str(e)

            self.logger.error(
                "任务处理失败", job_id=job_id, worker_id=self.worker_id, error=str(e)
            )
            await self._mark_job_failed(job_id, str(e))

            # 检查是否需要停止 worker
            if self.stats.consecutive_errors >= self.config.max_consecutive_errors:
                self.logger.error("连续错误过多，停止 worker", worker_id=self.worker_id)
                self._running = False

        finally:
            self.stats.active_jobs -= 1

    async def start(self) -> None:
        """启动工作器"""
        self._running = True
        self.logger.info("Worker 启动", worker_id=self.worker_id)

        while self._running:
            try:
                job_data = await self._get_next_job()

                if not job_data:
                    # 没有任务，等待
                    await asyncio.sleep(self.config.polling_interval_ms / 1000)
                    continue

                await self._process_job(job_data)

            except Exception as e:
                self.logger.error(
                    "Worker 主循环异常", worker_id=self.worker_id, error=str(e)
                )
                await asyncio.sleep(self.config.error_backoff_seconds)

    def stop(self) -> None:
        """停止工作器"""
        self._running = False
        self.logger.info("Worker 停止", worker_id=self.worker_id)

    def get_stats(self) -> WorkerStats:
        """获取统计信息"""
        return self.stats


class WorkerManager:
    """Worker 管理器"""

    def __init__(self):
        self.config = get_worker_config()
        self.logger = get_worker_logger()
        self.workers: List[ImprovedJobWorker] = []
        self._tasks: List[asyncio.Task] = []

    async def start_workers(self) -> None:
        """启动所有 Worker"""
        self.logger.info("启动 Worker 管理器", concurrency=self.config.concurrency)

        # 创建 Worker 实例
        for i in range(self.config.concurrency):
            worker = ImprovedJobWorker(worker_id=i + 1)
            self.workers.append(worker)

            # 启动 Worker 任务
            task = asyncio.create_task(worker.start())
            self._tasks.append(task)

        try:
            await asyncio.gather(*self._tasks)
        except Exception as e:
            self.logger.error("Worker 管理器异常", error=str(e))
            await self.stop_workers()
            raise

    async def stop_workers(self) -> None:
        """停止所有 Worker"""
        self.logger.info("停止所有 Worker")

        # 停止所有 Worker
        for worker in self.workers:
            worker.stop()

        # 取消所有任务
        for task in self._tasks:
            task.cancel()

        # 等待任务完成
        await asyncio.gather(*self._tasks, return_exceptions=True)

    def get_all_stats(self) -> List[WorkerStats]:
        """获取所有 Worker 统计信息"""
        return [worker.get_stats() for worker in self.workers]


# 全局 Worker 管理器实例
_worker_manager: Optional[WorkerManager] = None


async def start_improved_workers():
    global _worker_manager
    if _worker_manager is None:
        _worker_manager = WorkerManager()

    await _worker_manager.start_workers()


def get_worker_stats() -> Dict[str, Any]:
    """获取 Worker 统计信息"""
    if _worker_manager is None:
        return {"active_jobs": 0, "total_processed": 0}

    all_stats = _worker_manager.get_all_stats()

    return {
        "active_jobs": sum(stats.active_jobs for stats in all_stats),
        "total_processed": sum(stats.total_processed for stats in all_stats),
        "total_errors": sum(stats.total_errors for stats in all_stats),
        "workers": len(all_stats),
        "worker_details": [
            {
                "worker_id": stats.worker_id,
                "active_jobs": stats.active_jobs,
                "total_processed": stats.total_processed,
                "total_errors": stats.total_errors,
                "consecutive_errors": stats.consecutive_errors,
                "last_error": stats.last_error,
            }
            for stats in all_stats
        ],
    }
