From 297edb3884d5342cde4f5758928deea3364bc1be Mon Sep 17 00:00:00 2001 From: maxDorninger <97409287+maxDorninger@users.noreply.github.com> Date: Wed, 26 Feb 2025 19:22:09 +0100 Subject: [PATCH] working on getting the config --- MediaManager/src/auth/__init__.py | 12 ++++++------ MediaManager/src/config/__init__.py | 20 +++++++++++--------- MediaManager/src/database/__init__.py | 15 ++++++++------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/MediaManager/src/auth/__init__.py b/MediaManager/src/auth/__init__.py index fff8e0c..d684c51 100644 --- a/MediaManager/src/auth/__init__.py +++ b/MediaManager/src/auth/__init__.py @@ -7,9 +7,9 @@ from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pydantic import BaseModel -import config import database import database.users +from config import AuthConfig from database.users import UserInternal @@ -29,7 +29,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/token") router = APIRouter() -async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: +async def get_current_user(token: str = Depends(oauth2_scheme), config = Depends(AuthConfig)) -> UserInternal: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -37,7 +37,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: ) log.debug("token: " + token) try: - payload = jwt.decode(token, config.auth.jwt_signing_key, algorithms=[config.auth.jwt_signing_algorithm]) + payload = jwt.decode(token, config.jwt_signing_key, algorithms=[config.jwt_signing_algorithm]) log.debug("jwt payload: " + payload.__str__()) user_uid: str = payload.get("sub") log.debug("jwt payload sub (user uid): " + user_uid) @@ -55,12 +55,12 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: return user -def create_access_token(data: dict, expires_delta: timedelta | None = None): +def create_access_token(data: dict, expires_delta: timedelta | None = None, config = Depends(AuthConfig)): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=config.auth.jwt_access_token_lifetime) + expire = datetime.now(timezone.utc) + timedelta(minutes=config.jwt_access_token_lifetime) to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, config.auth.jwt_signing_key, algorithm=config.auth.jwt_signing_algorithm) + encoded_jwt = jwt.encode(to_encode, config.jwt_signing_key, algorithm=config.jwt_signing_algorithm) return encoded_jwt diff --git a/MediaManager/src/config/__init__.py b/MediaManager/src/config/__init__.py index 82c28be..31507d0 100644 --- a/MediaManager/src/config/__init__.py +++ b/MediaManager/src/config/__init__.py @@ -33,22 +33,24 @@ class AuthConfig(BaseModel): def jwt_signing_key(self): return self._jwt_signing_key -db: DbConfig = DbConfig() -indexer: IndexerConfig = IndexerConfig() -auth: AuthConfig = AuthConfig() +def get_db_config() -> DbConfig: + return DbConfig() + + log = logging.getLogger(__name__) def load_config(): - db = DbConfig() - log.info(f"loaded config: DbConfig: {db.__str__()}") - indexer = IndexerConfig() - log.info(f"loaded config: IndexerConfig: {indexer.__str__()}") - auth = AuthConfig() - log.info(f"loaded config: AuthConfig: {auth.__str__()}") + log.info(f"loaded config: DbConfig: {DbConfig().__str__()}") + log.info(f"loaded config: IndexerConfig: {IndexerConfig().__str__()}") + log.info(f"loaded config: AuthConfig: {AuthConfig().__str__()}") if __name__ == "__main__": + db: DbConfig = DbConfig() + indexer: IndexerConfig = IndexerConfig() + auth: AuthConfig = AuthConfig() + print(db.__str__()) print(indexer.__str__()) print(auth.__str__()) diff --git a/MediaManager/src/database/__init__.py b/MediaManager/src/database/__init__.py index 0dcf40b..799e653 100644 --- a/MediaManager/src/database/__init__.py +++ b/MediaManager/src/database/__init__.py @@ -1,9 +1,10 @@ import logging import psycopg +from fastapi import Depends from psycopg.rows import dict_row -import config +from config import DbConfig, get_db_config log = logging.getLogger(__name__) @@ -14,14 +15,14 @@ class PgDatabase: def __init__(self) -> None: self.driver = psycopg - def connect_to_database(self): + def connect_to_database(self, config: DbConfig = DbConfig()): return self.driver.connect( autocommit=True, - host=config.db.host, - port=config.db.port, - user=config.db.user, - password=config.db.password, - dbname=config.db.dbname, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + dbname=config.dbname, row_factory=dict_row )