from typing import Any, Dict, List, Optional, Tuple

from sqlalchemy import asc, desc, func
from sqlalchemy.orm import Session, joinedload

from mineru_flow.internal.schema.state import State
from mineru_flow.internal.api.v1.schema.request.config import (
    ConfigCreateRequest,
    ConfigUpdateRequest,
)
from mineru_flow.internal.api.v1.schema.request.job import (
    JobCreateRequest,
    JobUpdateRequest,
)
from mineru_flow.internal.api.v1.schema.request.task import (
    TaskCreateRequest,
    TaskUpdateRequest,
)
from mineru_flow.internal.crud.base import BaseCrud
from mineru_flow.internal.models.task import Config, Job, Task


class CRUDTask(BaseCrud[Task, TaskCreateRequest, TaskUpdateRequest]):
    def get_with_job(self, db: Session, *, id: int) -> Optional[Task]:
        """获取任务及其关联的作业信息"""
        return (
            db.query(self.model)
            .options(joinedload(Task.job))
            .filter(Task.id == id)
            .first()
        )

    def get_multi_with_jobs(
        self,
        db: Session,
        *,
        skip: int = 0,
        limit: int = 100,
        name: Optional[str] = None,
        phase: Optional[str] = None,
        status: Optional[str] = None,
        sort_by: str = "created_at",
        sort_order: str = "desc",
    ) -> List[Task]:
        """分页查询任务列表，包含作业信息"""
        query = (
            db.query(self.model)
            .options(joinedload(Task.job))
            .filter(self.model.state == State.ENABLED)
        )

        # 添加过滤条件
        if name:
            query = query.filter(Task.name.ilike(f"%{name}%"))

        if phase:
            query = query.join(Job).filter(Job.phase == phase)

        if status:
            query = query.join(Job).filter(Job.status == status)

        # 添加排序
        if hasattr(Task, sort_by):
            order_by = (
                desc(getattr(Task, sort_by))
                if sort_order == "desc"
                else asc(getattr(Task, sort_by))
            )
        else:
            order_by = desc(Task.created_at)

        return query.order_by(order_by).offset(skip).limit(limit).all()

    def count_with_filters(
        self,
        db: Session,
        *,
        name: Optional[str] = None,
        phase: Optional[str] = None,
        status: Optional[str] = None,
    ) -> int:
        """计算符合条件的任务总数"""
        query = db.query(self.model).filter(self.model.state == State.ENABLED)

        if name:
            query = query.filter(Task.name.ilike(f"%{name}%"))

        if phase:
            query = query.join(Job).filter(Job.phase == phase)

        if status:
            query = query.join(Job).filter(Job.status == status)

        return query.count()


class CRUDJob(BaseCrud[Job, JobCreateRequest, JobUpdateRequest]):
    def get_and_lock_waiting_job(self, db: Session) -> Optional[Job]:
        """获取并锁定一个等待中的作业"""
        from mineru_flow.internal.schema.state import JobStatus, State

        result = (
            db.query(self.model)
            .filter(Job.status == JobStatus.WAITING, Job.state == State.ENABLED)
            .order_by(Job.created_at)
            .with_for_update(skip_locked=True)
            .first()
        )

        return result

    def update_phase(self, db: Session, *, job_id: int, phase: str) -> Optional[Job]:
        """更新作业阶段"""
        from mineru_flow.internal.schema.state import State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.phase = phase
            db.commit()
            db.refresh(job)
        return job

    def update_phase_result(
        self, db: Session, *, job_id: int, phase: str, result: Dict[str, Any]
    ) -> Optional[Job]:
        """更新作业阶段结果"""
        from mineru_flow.internal.schema.state import State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )

        if not job:
            return None

        updated_results = dict(job.results or {})
        updated_results[str(phase)] = result

        job.phase = phase
        job.results = updated_results
        db.commit()
        db.refresh(job)
        return job

    def update_status(self, db: Session, *, job_id: int, status: str) -> Optional[Job]:
        """更新作业状态"""
        from mineru_flow.internal.schema.state import State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.status = status
            db.commit()
            db.refresh(job)
        return job

    def update_phase_and_status(
        self, db: Session, *, job_id: int, phase: str, status: str
    ) -> Optional[Job]:
        """同时更新作业阶段和状态"""
        from mineru_flow.internal.schema.state import State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.phase = phase
            job.status = status
            db.commit()
            db.refresh(job)
        return job

    def mark_as_failed(
        self, db: Session, *, job_id: int, error_message: str
    ) -> Optional[Job]:
        """标记作业为失败状态"""
        from mineru_flow.internal.schema.state import JobStatus, State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.status = JobStatus.FAILED
            job.error_message = error_message
            db.commit()
            db.refresh(job)
        return job

    def mark_as_success(self, db: Session, *, job_id: int) -> Optional[Job]:
        """标记作业为成功状态"""
        from mineru_flow.internal.schema.state import JobStatus, State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.status = JobStatus.SUCCESS
            job.error_message = None
            db.commit()
            db.refresh(job)
        return job

    def get_status_summary(self, db: Session) -> Dict[str, int]:
        """按状态统计作业数量"""
        from mineru_flow.internal.schema.state import JobStatus, State

        rows = (
            db.query(Job.status, func.count(Job.id))
            .filter(Job.state == State.ENABLED)
            .group_by(Job.status)
            .all()
        )

        summary: Dict[str, int] = {status.value: 0 for status in JobStatus}
        for status_value, count in rows:
            key = (
                status_value.value
                if hasattr(status_value, "value")
                else str(status_value)
            )
            summary[key] = count

        return summary

    def mark_as_running(self, db: Session, *, job_id: int) -> Optional[Job]:
        """标记作业为运行状态"""
        from mineru_flow.internal.schema.state import JobStatus, State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.status = JobStatus.RUNNING
            db.commit()
            db.refresh(job)
        return job

    def retry_job(
        self,
        db: Session,
        *,
        job_id: int,
        restart_phase: Optional[str] = None,
        phases: Optional[List[str]] = None,
    ) -> Optional[Job]:
        """重置作业状态并重新排队，默认从指定阶段开始继续"""
        from mineru_flow.internal.schema.state import JobStatus, State

        job = (
            db.query(self.model)
            .filter(Job.id == job_id, Job.state == State.ENABLED)
            .first()
        )
        if job:
            job.status = JobStatus.WAITING
            job.error_message = None

            normalized_phases: List[str] = []

            for phase_key in phases or list(job.phases or []):
                if phase_key and phase_key not in normalized_phases:
                    normalized_phases.append(phase_key)

            target_phase = restart_phase or job.phase
            if not target_phase and normalized_phases:
                target_phase = normalized_phases[0]

            if target_phase and target_phase not in normalized_phases:
                normalized_phases.append(target_phase)

            if normalized_phases:
                job.phases = normalized_phases

            if job.results is None:
                job.results = {}
            elif isinstance(job.results, dict):
                phase_sequence = normalized_phases or list(job.phases or [])
                if target_phase and phase_sequence and target_phase in phase_sequence:
                    restart_index = phase_sequence.index(target_phase)
                    for phase_key in phase_sequence[restart_index:]:
                        job.results.pop(phase_key, None)
                else:
                    for phase_key in list(job.results.keys()):
                        if not normalized_phases or phase_key in normalized_phases:
                            job.results.pop(phase_key, None)

            if target_phase:
                job.phase = target_phase

            db.commit()
            db.refresh(job)

        return job


class CRUDConfig(BaseCrud[Config, ConfigCreateRequest, ConfigUpdateRequest]):
    def get_by_type_and_name(
        self, db: Session, *, type: str, name: str
    ) -> Optional[Config]:
        """根据类型和名称获取配置"""
        from mineru_flow.internal.schema.state import State

        return (
            db.query(self.model)
            .filter(
                Config.type == type, Config.name == name, Config.state == State.ENABLED
            )
            .first()
        )

    def get_by_type(self, db: Session, *, type: str) -> Tuple[List[Config], int]:
        """根据类型获取所有配置"""
        from mineru_flow.internal.schema.state import State

        query_filter = (Config.type == type, Config.state == State.ENABLED)
        total = db.query(self.model).filter(*query_filter).count()
        result = db.query(self.model).filter(*query_filter).all()
        return result, total


# 创建实例
task = CRUDTask(Task)
job = CRUDJob(Job)
config = CRUDConfig(Config)
