diff --git a/MediaManager/src/auth/__init__.py b/MediaManager/src/auth/__init__.py index c9267c6..30d6808 100644 --- a/MediaManager/src/auth/__init__.py +++ b/MediaManager/src/auth/__init__.py @@ -1,6 +1,5 @@ import logging from datetime import datetime, timedelta, timezone -from typing import Annotated import jwt from fastapi import Depends, HTTPException, status, APIRouter @@ -8,12 +7,12 @@ from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pydantic import BaseModel -import database -import database.users from config import AuthConfig +from database import SessionDependency from database.users import UserInternal + class Token(BaseModel): access_token: str token_type: str @@ -30,7 +29,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/token") router = APIRouter() -async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: +async def get_current_user(db: SessionDependency, token: str = Depends(oauth2_scheme)) -> UserInternal: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -38,6 +37,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: ) config = AuthConfig() log.debug("token: " + token) + try: payload = jwt.decode(token, config.jwt_signing_key, algorithms=[config.jwt_signing_algorithm]) log.debug("jwt payload: " + payload.__str__()) @@ -49,10 +49,13 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal: except InvalidTokenError: log.warning("received invalid token: " + token) raise credentials_exception - user = database.users.get_user(uid=token_data.uid) + + user: UserInternal | None = db.get(UserInternal, token_data.uid) + if user is None: log.debug("user not found") raise credentials_exception + log.debug("received user: " + user.__str__()) return user diff --git a/MediaManager/src/auth/password.py b/MediaManager/src/auth/password.py index 9f4240f..85aafe7 100644 --- a/MediaManager/src/auth/password.py +++ b/MediaManager/src/auth/password.py @@ -1,12 +1,15 @@ from typing import Annotated +import hashlib + import bcrypt from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm +from sqlmodel import Session, select import database from auth import create_access_token, Token, router -from database import users +from database import users, SessionDependency from database.users import UserInternal @@ -17,21 +20,18 @@ def verify_password(plain_password, hashed_password): ) -def get_password_hash(password): - return bcrypt.hashpw( - bytes(password, encoding="utf-8"), - bcrypt.gensalt(), - ) +def get_password_hash(password: str) -> str: + return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") -def authenticate_user(email: str, password: str) -> bool | UserInternal: +def authenticate_user(db: SessionDependency, email: str, password: str) -> bool | UserInternal: """ :param email: email of the user :param password: password of the user :return: if authentication succeeds, returns the user object with added name and lastname, otherwise or if the user doesn't exist returns False """ - user = database.users.get_user(email=email) + user: UserInternal | None = db.exec(select(UserInternal).where(UserInternal.email == email)).first() if not user: return False if not verify_password(password, user.hashed_password): @@ -42,13 +42,15 @@ def authenticate_user(email: str, password: str) -> bool | UserInternal: @router.post("/token") async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db: SessionDependency, ) -> Token: - user = authenticate_user(form_data.username, form_data.password) + user = authenticate_user(db,form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token = create_access_token(data={"sub": user.id}) + # id needs to be converted because a UUID object isn't json serializable + access_token = create_access_token(data={"sub": user.id.__str__()}) return Token(access_token=access_token, token_type="bearer") diff --git a/MediaManager/src/database/__init__.py b/MediaManager/src/database/__init__.py index c837dd4..2fb1dd9 100644 --- a/MediaManager/src/database/__init__.py +++ b/MediaManager/src/database/__init__.py @@ -1,43 +1,28 @@ import logging +from typing import Any, Generator, Annotated -import psycopg -from psycopg.rows import dict_row +from fastapi import Depends +from sqlmodel import create_engine, SQLModel, Session +import config from config import DbConfig log = logging.getLogger(__name__) +config: DbConfig = config.get_db_config() + +db_url = "postgresql+psycopg" + "://" + config.user + ":" + config.password + "@" + config.host + ":" + str( + config.port) + "/" + config.dbname + +engine = create_engine(db_url, echo=True) -class PgDatabase: - """PostgreSQL Database context manager using psycopg""" - - def __init__(self) -> None: - self.driver = psycopg - - def connect_to_database(self, config: DbConfig = DbConfig()): - return self.driver.connect( - autocommit=True, - host=config.host, - port=config.port, - user=config.user, - password=config.password, - dbname=config.dbname, - row_factory=dict_row - ) - - def __enter__(self): - self.connection = self.connect_to_database() - return self - - def __exit__(self, exception_type, exc_val, traceback): - self.connection.close() +def init_db() -> None: + SQLModel.metadata.create_all(engine) -def init_db(): - log.info("Initializing database") +def get_session() -> Generator[Session, Any, None]: + with Session(engine) as session: + yield session - from database import tv, users - users.init_table() - tv.init_table() +SessionDependency = Annotated[Session, Depends(get_session)] - log.info("Tables initialized successfully") diff --git a/MediaManager/src/database/tv.py b/MediaManager/src/database/tv.py deleted file mode 100644 index 41a28b4..0000000 --- a/MediaManager/src/database/tv.py +++ /dev/null @@ -1,37 +0,0 @@ -from database import PgDatabase, log - -# TODO: add NOT NULL and default values to DB - -def init_table(): - with PgDatabase() as db: - db.connection.execute(""" - CREATE TABLE IF NOT EXISTS tv_show ( - id UUID PRIMARY KEY, - external_id TEXT, - metadata_provider TEXT, - name TEXT, - episode_count INTEGER, - season_count INTEGER, - UNIQUE (external_id, metadata_provider) - - );""") - log.info("tv_show Table initialized successfully") - db.connection.execute(""" - CREATE TABLE IF NOT EXISTS tv_season ( - show_id UUID REFERENCES tv_show(id), - season_number INTEGER, - episode_count INTEGER, - CONSTRAINT PK_season PRIMARY KEY (show_id,season_number) - );""") - log.info("tv_seasonTable initialized successfully") - db.connection.execute(""" - CREATE TABLE IF NOT EXISTS tv_episode ( - season_number INTEGER, - show_id uuid, - episode_number INTEGER, - title TEXT, - CONSTRAINT PK_episode PRIMARY KEY (season_number,show_id,episode_number), - FOREIGN KEY (season_number, show_id) REFERENCES tv_season(season_number,show_id) - - );""") - log.info("tv_episode Table initialized successfully") diff --git a/MediaManager/src/database/users.py b/MediaManager/src/database/users.py index d3216b5..5d3a2d0 100644 --- a/MediaManager/src/database/users.py +++ b/MediaManager/src/database/users.py @@ -1,91 +1,19 @@ -from uuid import uuid4 +import uuid +from uuid import UUID -import psycopg -from pydantic import BaseModel - -from database import PgDatabase, log +from sqlmodel import Field, SQLModel -class User(BaseModel): - """ - User model - """ - name: str +class UserBase(SQLModel): + name: str = Field() lastname: str - email: str + email: str = Field(unique=True) +class UserPublic(UserBase): + id: UUID = Field(primary_key=True, default_factory=uuid.uuid4) -class UserInternal(User): - """" - Internal user model, assumes the password is already hashed, when a new instance is created - """ - id: str = str(uuid4()) +class UserInternal(UserPublic, table=True): hashed_password: str - -def create_user(user: UserInternal) -> bool: - """ - - :param user: user to create, password must already be hashed - :return: True if user was created, False otherwise - """ - with PgDatabase() as db: - try: - db.connection.execute( - """ - INSERT INTO users (id, name, lastname, email, hashed_password) - VALUES (%s, %s, %s, %s, %s) - """, - (user.id, user.name, user.lastname, user.email, user.hashed_password) - ) - except psycopg.errors.UniqueViolation as e: - log.error(e) - return False - - log.info("User inserted successfully") - log.debug(f"Inserted User: " + user.model_dump().__str__()) - return True - - -def get_user(email: str = None, uid: str = None) -> UserInternal | None: - """ - either specify the email or uid to search for the user, if both parameters are provided the uid will be used - - - :param email: the users email address - :param uid: the users id - :return: if user was found it a UserInternal object is returned, otherwise None - """ - with PgDatabase() as db: - if email is not None and uid is None: - result = db.connection.execute( - "SELECT id, name, lastname, email, hashed_password FROM users WHERE email=%s", - (email,) - ).fetchone() - if uid is not None: - result = db.connection.execute( - "SELECT id, name, lastname, email, hashed_password FROM users WHERE id=%s", - (uid,) - ).fetchone() - - if result is None: - return None - user = UserInternal(id=result["id"].__str__(), name=result["name"], lastname=result["lastname"], - email=result["email"], - hashed_password=result["hashed_password"]) - log.debug(f"Retrieved User successfully: {user.model_dump()} ") - return user - - -def init_table(): - with PgDatabase() as db: - db.connection.execute(""" - CREATE TABLE IF NOT EXISTS users ( - id UUID NOT NULL PRIMARY KEY, - lastname TEXT, - name TEXT NOT NULL, - email TEXT NOT NULL UNIQUE, - hashed_password TEXT NOT NULL - ); - """) - log.info("users Table initialized successfully") +class UserCreate(UserBase): + password: str \ No newline at end of file diff --git a/MediaManager/src/main.py b/MediaManager/src/main.py index b0d1086..ed9b0ad 100644 --- a/MediaManager/src/main.py +++ b/MediaManager/src/main.py @@ -6,7 +6,8 @@ from fastapi import FastAPI import config import database -import tv.router +import database.users +#import tv.router from auth import password from routers import users @@ -19,7 +20,7 @@ database.init_db() app = FastAPI(root_path="/api/v1") app.include_router(users.router, tags=["users"]) app.include_router(password.router, tags=["authentication"]) -app.include_router(tv.router.router, tags=["tv"]) +#app.include_router(tv.router.router, tags=["tv"]) diff --git a/MediaManager/src/routers/users.py b/MediaManager/src/routers/users.py index c30fa98..fe5ac33 100644 --- a/MediaManager/src/routers/users.py +++ b/MediaManager/src/routers/users.py @@ -1,12 +1,13 @@ from fastapi import APIRouter from fastapi import Depends +from sqlalchemy.exc import IntegrityError from pydantic import BaseModel from starlette.responses import JSONResponse -import database from auth import get_current_user from auth.password import get_password_hash -from database.users import UserInternal, User +from database import SessionDependency, get_session +from database.users import UserInternal, UserCreate, UserPublic from routers import log router = APIRouter( @@ -18,28 +19,26 @@ class Message(BaseModel): message: str -class CreateUser(User): - """" - The Usermodel, but with an additional non-hashed password. attribute - """ - password: str - @router.post("/", status_code=201, responses={ 409: {"model": Message, "description": "User with provided email already exists"}, - 201: {"model": UserInternal, "description": "User created successfully"} + 201: {"model": UserPublic, "description": "User created successfully"} }) async def create_user( - user: CreateUser = Depends(CreateUser), + db: SessionDependency, + user: UserCreate = Depends(UserCreate), ): internal_user = UserInternal(name=user.name, lastname=user.lastname, email=user.email, hashed_password=get_password_hash(user.password)) - if database.users.create_user(internal_user): - log.info("Created new user", internal_user.model_dump()) - return internal_user - else: - log.warning("Failed to create new user, User with this email already exists,", internal_user.model_dump()) + db.add(internal_user) + try: + db.commit() + except IntegrityError as e: + log.debug(e) + log.warning("Failed to create new user, User with this email already exists "+internal_user.model_dump().__str__()) return JSONResponse(status_code=409, content={"message": "User with this email already exists"}) + log.info("Created new user "+internal_user.email) + return UserPublic(**internal_user.model_dump()) @router.get("/me") @@ -51,6 +50,6 @@ async def read_users_me( @router.get("/me/items") async def read_own_items( - current_user: User = Depends(get_current_user), + current_user: UserInternal = Depends(get_current_user), ): return [{"item_id": "Foo", "owner": current_user.name}] diff --git a/MediaManager/src/tv/__init__.py b/MediaManager/src/tv/__init__.py index d81a7a5..bf8cca3 100644 --- a/MediaManager/src/tv/__init__.py +++ b/MediaManager/src/tv/__init__.py @@ -1,90 +1,92 @@ import logging import pprint -from typing import Literal, List, Any +from typing import List from uuid import UUID, uuid4 -import requests +import tmdbsimple as tmdb from pydantic import BaseModel from config import TvConfig from database import PgDatabase -import tmdbsimple as tmdb + + + +from sqlmodel import Field, Session, SQLModel, create_engine, select # NOTE: use tmdbsimple for api calls -class Episode(BaseModel): +class Episode(SQLModel): + show_id: int = Field(foreign_key="show.id") + season_number: int = Field(foreign_key="season.number") number: int title: str -class Season(BaseModel): +class Season(SQLModel, table=True): + show_id: UUID = Field(foreign_key="show.id") number: int - episodes: List[Episode] - - def get_episode_count(self) -> int: - return self.episodes.__len__() -class Show(BaseModel): - id: UUID = uuid4() +class Show(SQLModel, table=True): + id: UUID = Field(primary_key=True) external_id: int metadata_provider: str name: str - seasons: List[Season] = [] - def get_season_count(self) -> int: - return self.seasons.__len__() +# def get_season_count(self) -> int: +# return self.seasons.__len__() - def get_episode_count(self) -> int: - episode_count = 0 - for season in self.seasons: - episode_count += season.get_episode_count() - return episode_count +# def get_episode_count(self) -> int: +# episode_count = 0 +# for season in self.seasons: +# episode_count += season.get_episode_count() +# return episode_count - def save_show(self) -> None: - with PgDatabase() as db: - db.connection.execute(""" - INSERT INTO tv_show ( - id, - external_id, - metadata_provider, - name, - episode_count, - season_count - )VALUES(%s,%s,%s,%s,%s,%s); - """, - (self.id, - self.external_id, - self.metadata_provider, - self.name, - self.get_episode_count(), - self.get_season_count(), - ) - ) - log.info("added show: " + self.__str__()) +# def save_show(self) -> None: +# with PgDatabase() as db: +# db.connection.execute(""" +# INSERT INTO tv_show ( +# id, +# external_id, +# metadata_provider, +# name, +# episode_count, +# season_count +# )VALUES(%s,%s,%s,%s,%s,%s); +# """, +# (self.id, +# self.external_id, +# self.metadata_provider, +# self.name, +# self.get_episode_count(), +# self.get_season_count(), +# ) +# ) +# log.info("added show: " + self.__str__()) - def get_data_from_tmdb(self) -> None: - data = tmdb.TV(self.external_id).info() - log.debug("data from tmdb: " + pprint.pformat(data)) - self.name = data["original_name"] - self.metadata_provider = "tmdb" +# def get_data_from_tmdb(self) -> None: +# data = tmdb.TV(self.external_id).info() +# log.debug("data from tmdb: " + pprint.pformat(data)) +# self.name = data["original_name"] +# self.metadata_provider = "tmdb" - def add_season(self, season_number: int) -> None: - data = tmdb.TV_Seasons(self.external_id, season_number).info() - log.debug("data from tmdb: " + pprint.pformat(data)) +# def add_season(self, season_number: int) -> None: +# data = tmdb.TV_Seasons(self.external_id, season_number).info() +# log.debug("data from tmdb: " + pprint.pformat(data)) - episodes: List[Episode] = [] - for episode in data["episodes"]: - episodes.append(Episode(title=episode["name"],number=episode["episode_number"])) +# episodes: List[Episode] = [] +# for episode in data["episodes"]: +# episodes.append(Episode(title=episode["name"], number=episode["episode_number"])) - season = Season(number=season_number, episodes=episodes) +# season = Season(number=season_number, episodes=episodes) - self.seasons.append(season) +# self.seasons.append(season) + +# def add_seasons(self, season_numbers: List[int]) -> None: +# for season_number in season_numbers: +# self.add_season(season_number) - def add_seasons(self, season_numbers: List[int]) -> None: - for season_number in season_numbers: - self.add_season(season_number) def get_all_shows() -> List[Show]: with PgDatabase() as db: @@ -93,12 +95,14 @@ def get_all_shows() -> List[Show]: """).fetchall() return result + def get_show(id: UUID) -> Show: with PgDatabase() as db: result = db.connection.execute(""" SELECT * FROM tv_show WHERE id = %s """, (id,)).fetchone() - return Show(**result) + return Show(**result) + config = TvConfig() log = logging.getLogger(__name__)