from typing import Generator

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from mineru_flow.internal.config.config import settings
from mineru_flow.internal.common.logging import get_db_logger

engine = None
database_url = settings.DATABASE_URL


# connect_args is needed only for SQLite. It's not needed for other databases
engine = create_engine(
    database_url,
    connect_args={"check_same_thread": False},
    echo=False,
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

db_logger = get_db_logger().bind(module="session")


# create database tables
def init_tables() -> None:
    Base.metadata.create_all(bind=engine)


def get_db() -> Generator:
    db = None
    try:
        db = SessionLocal()
        yield db
        db.commit()
    except Exception as e:
        db_logger.error("数据库会话异常", error=str(e))
        if db:
            db.rollback()

        raise e
    finally:
        if db:
            db.close()


def get_db_session():
    """Get a new database session (for worker use)."""
    return SessionLocal()
