diff --git a/media_manager/auth/db.py b/media_manager/auth/db.py index f7b7d00..eedf6d5 100644 --- a/media_manager/auth/db.py +++ b/media_manager/auth/db.py @@ -11,8 +11,8 @@ from sqlalchemy import String from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import Mapped, relationship, mapped_column -from media_manager.database import Base -from media_manager.database import db_url +from media_manager.database import Base, build_db_url +from media_manager.config import AllEncompassingConfig class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): @@ -29,7 +29,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base): ) -engine = create_async_engine(db_url, echo=False) +engine = create_async_engine( + build_db_url(**AllEncompassingConfig().database.model_dump()), echo=False +) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/media_manager/database/__init__.py b/media_manager/database/__init__.py index d7ccfcd..a7c16d4 100644 --- a/media_manager/database/__init__.py +++ b/media_manager/database/__init__.py @@ -1,44 +1,82 @@ +# media_manager/database/__init__.py import logging +import os from contextvars import ContextVar -from typing import Annotated, Any, Generator +from typing import Annotated, Any, Generator, Optional from fastapi import Depends from sqlalchemy import create_engine +from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, declarative_base, sessionmaker -from media_manager.config import AllEncompassingConfig - log = logging.getLogger(__name__) -config = AllEncompassingConfig().database -db_url = ( - "postgresql+psycopg" - + "://" - + config.user - + ":" - + config.password - + "@" - + config.host - + ":" - + str(config.port) - + "/" - + config.dbname -) - -engine = create_engine( - db_url, - echo=False, - pool_size=10, - max_overflow=10, - pool_timeout=30, - pool_recycle=1800, -) -log.debug("initializing sqlalchemy declarative base") Base = declarative_base() -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +engine: Optional[Engine] = None +SessionLocal: Optional[sessionmaker] = None + + +def build_db_url( + user: str, + password: str, + host: str, + port: int | str, + dbname: str, +) -> str: + return f"postgresql+psycopg://{user}:{password}@{host}:{port}/{dbname}" + + +def init_engine( + db_config: Any | None = None, + url: str | None = None, +) -> Engine: + """ + Initialize the global SQLAlchemy engine and session factory. + Pass either a DbConfig-like object or a full URL. Only initializes once. + """ + global engine, SessionLocal + if engine is not None: + return engine + + if url is None: + if db_config is None: + url = os.getenv("DATABASE_URL") + if not url: + raise RuntimeError("DB config or `DATABASE_URL` must be provided") + else: + url = build_db_url( + db_config.user, + db_config.password, + db_config.host, + db_config.port, + db_config.dbname, + ) + + engine = create_engine( + url, + echo=False, + pool_size=10, + max_overflow=10, + pool_timeout=30, + pool_recycle=1800, + ) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + log.debug("SQLAlchemy engine initialized") + return engine + + +def get_engine() -> Engine: + if engine is None: + raise RuntimeError("Engine not initialized. Call init_engine(...) first.") + return engine def get_session() -> Generator[Session, Any, None]: + if SessionLocal is None: + raise RuntimeError( + "Session factory not initialized. Call init_engine(...) first." + ) db = SessionLocal() try: yield db @@ -46,12 +84,10 @@ def get_session() -> Generator[Session, Any, None]: except Exception as e: db.rollback() log.critical(f"error occurred: {e}") - raise e + raise finally: db.close() db_session: ContextVar[Session] = ContextVar("db_session") - - DbSessionDependency = Annotated[Session, Depends(get_session)] diff --git a/media_manager/main.py b/media_manager/main.py index 9675b09..07a07c3 100644 --- a/media_manager/main.py +++ b/media_manager/main.py @@ -106,6 +106,7 @@ from datetime import datetime # noqa: E402 from contextlib import asynccontextmanager # noqa: E402 from apscheduler.schedulers.background import BackgroundScheduler # noqa: E402 from apscheduler.triggers.cron import CronTrigger # noqa: E402 +from media_manager.database import init_engine # noqa: E402 config = AllEncompassingConfig() @@ -128,6 +129,8 @@ def weekly_tasks(): update_all_movies_metadata() +init_engine(config.database) + jobstores = {"default": SQLAlchemyJobStore(engine=media_manager.database.engine)} scheduler = BackgroundScheduler(jobstores=jobstores)