from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Generator

from sqlalchemy.orm import Session

from .db import SessionLocal
from .logging import get_db_logger

db_logger = get_db_logger()


@contextmanager
def get_db_context() -> Generator[Session, None, None]:
    """同步数据库会话上下文管理器"""
    db = SessionLocal()
    try:
        yield db
        db.commit()
    except Exception as e:
        db.rollback()
        db_logger.error("数据库操作失败", error=str(e))
        raise
    finally:
        db.close()


@asynccontextmanager
async def get_async_db_context() -> AsyncGenerator[Session, None]:
    """异步数据库会话上下文管理器"""
    db = SessionLocal()
    try:
        yield db
        db.commit()
    except Exception as e:
        db.rollback()
        db_logger.error("数据库操作失败", error=str(e))
        raise
    finally:
        db.close()


class DatabaseManager:
    """数据库管理器"""

    @staticmethod
    def get_session() -> Session:
        """获取新的数据库会话（需要手动管理）"""
        return SessionLocal()

    @staticmethod
    @contextmanager
    def session_scope() -> Generator[Session, None, None]:
        """数据库会话作用域管理器"""
        session = SessionLocal()
        try:
            yield session
            session.commit()
        except Exception as exc:
            session.rollback()
            db_logger.error("数据库会话作用域异常", error=str(exc))
            raise
        finally:
            session.close()
