import json
import mimetypes
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4

from fastapi import APIRouter, Body, Depends, HTTPException, Query, status

from mineru_flow.internal.api.v1.schema.request.task import (
    TaskCreateRequest,
    TaskRetryRequest,
    TaskUpdateRequest,
)
from mineru_flow.internal.api.v1.schema.response.task import (
    TaskCreateResponse,
    TaskDetailResponse,
    TaskDeleteResponse,
    TaskSummaryResponse,
)
from mineru_flow.internal.common.database import get_db_context
from mineru_flow.internal.common.exceptions import DatabaseError, ValidationError
from mineru_flow.internal.common.logging import get_api_logger
from mineru_flow.internal.config.worker import get_worker_config
from mineru_flow.internal.crud.task import job as job_crud
from mineru_flow.internal.crud.task import task as task_crud
from mineru_flow.internal.models import Task
from mineru_flow.internal.schema.state import JobStatus, Phase
from mineru_flow.internal.storage.s3 import S3StorageOperator

router = APIRouter(tags=["tasks"])


def get_config():
    return get_worker_config()


def get_logger():
    return get_api_logger()


class TaskService:
    """任务服务类 - 业务逻辑分离"""

    def __init__(self, config, logger):
        self.config = config
        self.logger = logger
        self.artifacts_dir = Path(config.artifacts_dir)

    def format_timestamp(self, dt) -> Optional[int]:
        """将 datetime 转换为毫秒时间戳"""
        if dt is None:
            return None
        return int(dt.timestamp() * 1000)

    async def format_task_response(self, task: Task) -> Dict[str, Any]:
        """格式化任务响应"""

        return {
            "id": task.id,
            "name": task.name,
            "created_at": self.format_timestamp(task.created_at),
            "updated_at": self.format_timestamp(task.updated_at),
            "job": {
                "id": task.job.id,
                "errorMessage": task.job.error_message,
                "phase": task.job.phase,
                "phases": task.job.phases,
                "status": task.job.status,
                "created_at": self.format_timestamp(task.job.created_at),
                "updated_at": self.format_timestamp(task.job.updated_at),
                "config": task.job.config,
                "results": task.job.results,
            }
            if task.job
            else None,
        }

    def generate_task_name(self, file_path: str) -> str:
        filename = Path(file_path).name if file_path else "untitled"
        unique_prefix = str(uuid4()).split("-")[0]
        return f"{unique_prefix}-{filename}"

    def extract_phases_from_config(self, config: Dict[str, Any]) -> Optional[list]:
        """从配置中提取处理阶段列表"""
        phases = []

        if not config:
            return phases

        if "mineru" in config:
            phases.append("parse")

        if "chunk" in config and config["chunk"]:
            phases.append("chunk")
        if "knowledgeBases" in config and config["knowledgeBases"]:
            phases.append("import")

        return phases

    def get_task_summary(self) -> Dict[str, Any]:
        with get_db_context() as db:
            status_counts = job_crud.get_status_summary(db)

        total = sum(status_counts.values())
        return {"total": total, "statusCounts": status_counts}


def get_task_service(
    config=Depends(get_config), logger=Depends(get_logger)
) -> TaskService:
    return TaskService(config, logger)


@router.post(
    "/tasks", response_model=TaskCreateResponse, status_code=status.HTTP_201_CREATED
)
async def create_task(
    request: TaskCreateRequest, task_service: TaskService = Depends(get_task_service)
):
    """创建新任务"""
    # 验证请求
    if not request.config:
        raise ValidationError("Configuration cannot be empty", "config")

    task_service.logger.info("创建新任务", request_data=request.dict())
    source_config = request.config.get("source")
    if not source_config or not source_config.get("storagePath"):
        raise ValidationError(
            "config.source.storagePath must be provided",
            "config.source.storagePath",
        )

    name = task_service.generate_task_name(source_config.get("storagePath"))

    try:
        with get_db_context() as db:
            job_data = {
                "phase": Phase.PARSE,
                "status": JobStatus.WAITING,
                "config": request.config,
                "phases": task_service.extract_phases_from_config(request.config),
            }
            job = job_crud.create(db=db, obj_in=job_data)

            # 创建任务
            task_data = {"name": name, "job_id": job.id}
            task = task_crud.create(db=db, obj_in=task_data)

            task_service.logger.info("任务创建成功", task_id=task.id, name=name)

            return TaskCreateResponse(
                data={"id": task.id}, message="任务创建成功，等待处理"
            )

    except Exception as e:
        task_service.logger.error("Failed to create task", error=str(e))
        raise DatabaseError("Failed to create task", "create_task") from e


@router.get("/tasks/{id}/s3")
async def get_task_s3_file(
    path: str,
    id: int,
    task_service: TaskService = Depends(get_task_service),
):
    if not path.startswith("s3"):
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Not s3 path",
        )

    try:
        with get_db_context() as db:
            task = task_crud.get_with_job(db=db, id=id)

            if not task:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                )

            formatted_task = await task_service.format_task_response(task)
            s3_config = formatted_task["job"]["config"]["source"]
            s3_storage = S3StorageOperator(
                aws_access_key_id=s3_config.get("accessKeyId"),
                aws_secret_access_key=s3_config.get("secretKeyId"),
                endpoint=s3_config.get("endpoint"),
            )

            return await s3_storage.presigned(path)

    except HTTPException:
        raise
    except Exception as e:
        task_service.logger.error("Failed to retrieve s3 file", task_id=id, error=str(e))
        raise DatabaseError("Failed to retrieve S3 file", "get_task") from e


@router.get("/tasks/{id}/file")
async def get_task_file(
    path: str,
    id: int,
    task_service: TaskService = Depends(get_task_service),
):
    """获取任务文件，除了json外，其他文件可通过浏览器打开"""
    try:
        if path.startswith("s3"):
            try:
                with get_db_context() as db:
                    task = task_crud.get_with_job(db=db, id=id)

                    if not task:
                        raise HTTPException(
                            status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                        )

                    formatted_task = await task_service.format_task_response(task)
                    s3_config = formatted_task["job"]["config"]["source"]
                    s3_storage = S3StorageOperator(
                        aws_access_key_id=s3_config.get("accessKeyId"),
                        aws_secret_access_key=s3_config.get("secretKeyId"),
                        endpoint=s3_config.get("endpoint"),
                    )
                    return await s3_storage.presigned(path)

            except HTTPException:
                raise
            except Exception as e:
                task_service.logger.error("Failed to retrieve task", task_id=id, error=str(e))
                raise DatabaseError("Failed to retrieve task", "get_task") from e

        if not Path(path).exists():
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Result file not found",
            )

        if not path.endswith(".json"):
            from fastapi.responses import FileResponse

            media_type, _ = mimetypes.guess_type(path)
            return FileResponse(
                path,
                media_type=media_type or "application/octet-stream",
                filename=Path(path).name,
                content_disposition_type="inline",
            )

        try:
            with open(path, "r", encoding="utf-8") as f:
                size = Path(path).stat().st_size
                modified = int(Path(path).stat().st_mtime * 1000)
                content = json.load(f)
        except Exception as e:
            task_service.logger.error(
                "Failed to read artifacts file",
                artifacts_path=str(path),
                error=str(e),
            )
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Failed to read result file",
            )

        return {"data": content, "size": size, "modified": modified}

    except HTTPException:
        raise


@router.get("/tasks/summary", response_model=TaskSummaryResponse)
async def get_task_summary(
    task_service: TaskService = Depends(get_task_service),
):
    """获取任务状态汇总"""
    summary = task_service.get_task_summary()
    return TaskSummaryResponse(data=summary)


@router.get("/tasks/{task_id}")
async def get_task(task_id: int, task_service: TaskService = Depends(get_task_service)):
    """获取任务详情"""
    try:
        with get_db_context() as db:
            task = task_crud.get_with_job(db=db, id=task_id)

            if not task:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                )

            formatted_task = await task_service.format_task_response(task)
            return {"data": formatted_task}

    except HTTPException:
        raise
    except Exception as e:
        task_service.logger.error("Failed to retrieve task", task_id=task_id, error=str(e))
        raise DatabaseError("Failed to retrieve task", "get_task") from e


@router.get("/tasks")
async def list_tasks(
    page: int = Query(1, ge=1, description="页码"),
    page_size: int = Query(20, ge=1, le=100, alias="pageSize", description="每页大小"),
    name: Optional[str] = Query(None, description="任务名称过滤"),
    phase: Optional[str] = Query(None, description="阶段过滤"),
    status: Optional[str] = Query(None, description="状态过滤"),
    sort_by: str = Query("created_at", alias="sortBy", description="排序字段"),
    sort_order: str = Query("desc", alias="sortOrder", description="排序方向"),
    task_service: TaskService = Depends(get_task_service),
):
    """查询任务列表（分页/过滤/排序）"""

    # 验证排序参数
    if sort_order not in ["asc", "desc"]:
        raise ValidationError("sortOrder must be 'asc' or 'desc'", "sortOrder")

    skip = (page - 1) * page_size

    try:
        with get_db_context() as db:
            # 获取任务列表
            tasks = task_crud.get_multi_with_jobs(
                db=db,
                skip=skip,
                limit=page_size,
                name=name,
                phase=phase,
                status=status,
                sort_by=sort_by,
                sort_order=sort_order,
            )

            # 获取总数
            total = task_crud.count_with_filters(
                db=db, name=name, phase=phase, status=status
            )

            # 格式化响应
            items = []
            for task in tasks:
                formatted_task = await task_service.format_task_response(task)
                items.append(
                    {
                        "id": formatted_task["id"],
                        "name": formatted_task["name"],
                        "created_at": formatted_task["created_at"],
                        "updated_at": formatted_task["updated_at"],
                        "job": formatted_task["job"],
                    }
                )

            task_service.logger.debug(
                "Task list retrieved", count=len(items), total=total, page=page
            )

            return {
                "data": {
                    "items": items,
                    "total": total,
                    "page": page,
                    "pageSize": page_size,
                }
            }

    except Exception as e:
        task_service.logger.error("Failed to list tasks", error=str(e))
        raise DatabaseError("Failed to list tasks", "list_tasks") from e


@router.put("/tasks/{task_id}", response_model=TaskDetailResponse)
async def update_task(
    task_id: int,
    request: TaskUpdateRequest,
    task_service: TaskService = Depends(get_task_service),
):
    """更新任务配置"""

    if request.config is None and request.file_path is None:
        raise ValidationError("Provide at least one field to update", "body")

    try:
        with get_db_context() as db:
            task = task_crud.get_with_job(db=db, id=task_id)

            if not task:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                )

            if not task.job:
                raise ValidationError("Task has no associated job; cannot update", "job")

            current_config = task.job.config or {}
            updated_config = request.config or current_config
            if isinstance(updated_config, dict):
                updated_config = deepcopy(updated_config)

            # file_path 会同步更新 source.storagePath
            if request.file_path:
                source_config = dict(updated_config.get("source") or {})
                source_config["storagePath"] = request.file_path
                updated_config = dict(updated_config)
                updated_config["source"] = source_config

            if not updated_config:
                raise ValidationError("Task configuration cannot be empty", "config")

            phases = task_service.extract_phases_from_config(updated_config)

            job_crud.update(
                db=db,
                id=task.job.id,
                obj_in={
                    "config": updated_config,
                    "phases": phases,
                },
                auto_commit=False,
            )

            db.commit()
            db.refresh(task)
            db.refresh(task.job)

            formatted_task = await task_service.format_task_response(task)

            return TaskDetailResponse(
                data=formatted_task,
                message="任务更新成功",
            )

    except HTTPException:
        raise
    except ValidationError:
        raise
    except Exception as e:
        task_service.logger.error("Failed to update task", task_id=task_id, error=str(e))
        raise DatabaseError("Failed to update task", "update_task") from e


@router.delete("/tasks/{task_id}", response_model=TaskDeleteResponse)
async def delete_task(
    task_id: int, task_service: TaskService = Depends(get_task_service)
):
    """删除任务"""
    try:
        with get_db_context() as db:
            task = task_crud.get_with_job(db=db, id=task_id)

            if not task:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                )

            task_crud.delete(db=db, id=task_id, auto_commit=False)
            if task.job:
                job_crud.delete(db=db, id=task.job.id, auto_commit=False)

            # 清理 artifacts 目录
            try:
                job_artifacts_dir = task_service.artifacts_dir / str(task.job_id)
                if job_artifacts_dir.exists():
                    shutil.rmtree(job_artifacts_dir, ignore_errors=True)
            except Exception as cleanup_error:
                task_service.logger.warning(
                    "Failed to clean artifacts directory during task deletion",
                    task_id=task_id,
                    error=str(cleanup_error),
                )

            task_service.logger.info("Task deleted", task_id=task_id)
            return TaskDeleteResponse(
                message="任务删除成功",
                data={"id": task_id},
            )

    except HTTPException:
        raise
    except Exception as e:
        task_service.logger.error("Failed to delete task", task_id=task_id, error=str(e))
        raise DatabaseError("Failed to delete task", "delete_task") from e


@router.post("/tasks/{task_id}/retry", status_code=status.HTTP_200_OK)
async def retry_task(
    task_id: int,
    request: TaskRetryRequest = Body(default=None),
    task_service: TaskService = Depends(get_task_service),
):
    """重试任务"""
    try:
        with get_db_context() as db:
            task = task_crud.get_with_job(db=db, id=task_id)

            if not task:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
                )

            job_status = task.job.status

            if job_status == JobStatus.RUNNING:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail="Task is currently running and cannot be retried",
                )

            retry_phase = request.phase if (request and request.phase) else None

            job_phases = task.job.phases or task_service.extract_phases_from_config(
                task.job.config or {}
            )
            job_phases = job_phases or []

            if retry_phase:
                if retry_phase not in job_phases:
                    raise ValidationError(
                        "Specified phase is not in the task phase list", "phase"
                    )
            else:
                if job_status == JobStatus.FAILED and task.job.phase:
                    retry_phase = task.job.phase
                elif job_phases:
                    retry_phase = job_phases[0]
                else:
                    retry_phase = Phase.PARSE

            normalized_phases: List[str] = []
            for phase_key in job_phases:
                if phase_key and phase_key not in normalized_phases:
                    normalized_phases.append(phase_key)
            if retry_phase and retry_phase not in normalized_phases:
                normalized_phases.append(retry_phase)
            if not normalized_phases:
                normalized_phases = [retry_phase] if retry_phase else []

            # 重试作业
            job_crud.retry_job(
                db=db,
                job_id=task.job.id,
                restart_phase=retry_phase,
                phases=normalized_phases,
            )

            task_service.logger.info("Task retry succeeded", task_id=task_id)
            return {"message": "任务已重新进入等待队列"}

    except HTTPException:
        raise
    except ValidationError:
        raise
    except Exception as e:
        task_service.logger.error("Failed to retry task", task_id=task_id, error=str(e))
        raise DatabaseError("Failed to retry task", "retry_task") from e
