from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union

from pydantic import BaseModel
from sqlalchemy.orm import Session

from mineru_flow.internal.schema.state import State

T = TypeVar("T", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)


class BaseCrud(Generic[T, CreateSchemaType, UpdateSchemaType]):
    def __init__(self, model: Type[T]):
        self.model = model

    def get(self, db: Session, id: int) -> T | None:
        """
        获取单个对象
        """
        return db.query(self.model).filter(self.model.id == id).first()

    def get_all(
        self, db: Session, filters: Optional[Dict[str, Any]] = None
    ) -> Tuple[List[T], int]:
        """
        获取所有对象
        """
        query = db.query(self.model).filter(self.model.state == State.ENABLED)
        if filters:
            # 去除空条件
            query = query.filter(
                *[getattr(self.model, k) == v for k, v in filters.items() if v]
            )
        if hasattr(self.model, "created_at"):
            query = query.order_by(self.model.created_at.desc())
        result = query.all()
        total = len(result)
        return result, total

    def get_multi(
        self, db: Session, *, skip: int = 0, limit: int = 100
    ) -> Tuple[List[T], int]:
        """
        获取多个对象
        """
        query = db.query(self.model).filter(self.model.state == State.ENABLED)

        # 先排序，再分页
        if hasattr(self.model, "created_at"):
            query = query.order_by(self.model.created_at.desc())

        query = query.offset(skip).limit(limit)
        result = query.all()
        total = db.query(self.model).filter(self.model.state == State.ENABLED).count()
        return result, total

    def create(
        self, db: Session, *, obj_in: CreateSchemaType, auto_commit: bool = True
    ) -> T:
        """
        创建对象

        参数:
            db: 数据库会话
            obj_in: 创建对象的输入模型
            auto_commit: 是否自动提交事务，默认为True。设置为False时可以在外部管理事务。
        """
        obj_dict = obj_in.model_dump() if hasattr(obj_in, "model_dump") else obj_in
        db_obj = self.model(**obj_dict)

        db.add(db_obj)
        if auto_commit:
            db.commit()
            db.refresh(db_obj)

        return db_obj

    def update(
        self,
        db: Session,
        *,
        id: int,
        obj_in: Union[UpdateSchemaType, Dict[str, Any]],
        auto_commit: bool | None = True,
    ) -> T | None:
        """
        更新对象

        参数:
            db: 数据库会话
            id: 对象ID
            obj_in: 更新对象的输入模型或字典
            auto_commit: 是否自动提交事务，默认为False时可以在外部管理事务。
        """
        db_obj = self.get(db=db, id=id)
        if db_obj is None:
            return None

        update_data = (
            obj_in
            if isinstance(obj_in, dict)
            else obj_in.model_dump(exclude_unset=True)
        )

        for field in update_data:
            if hasattr(db_obj, field) and update_data[field] is not None:
                setattr(db_obj, field, update_data[field])

        db.add(db_obj)
        if auto_commit:
            db.commit()
            db.refresh(db_obj)

        return db_obj

    def delete(self, db: Session, *, id: Any, auto_commit: bool = True) -> Optional[T]:
        """
        删除对象（软删除）

        将对象的state更新为disabled，而不是硬删除

        参数:
            db: 数据库会话
            id: 对象ID
            auto_commit: 是否自动提交事务，默认为True。设置为False时可以在外部管理事务。
        """
        db_obj = self.get(db=db, id=id)
        if db_obj is None:
            return None

        # 软删除：将state设置为disabled
        if hasattr(db_obj, "state"):
            setattr(db_obj, "state", State.DISABLED)
            db.add(db_obj)
            if auto_commit:
                db.commit()
                db.refresh(db_obj)

        return db_obj
