switch to the SQLModel ORM

This commit is contained in:
maxDorninger
2025-03-02 21:14:07 +01:00
parent b890b9e8dc
commit b88cb1b042
8 changed files with 126 additions and 241 deletions

View File

@@ -1,6 +1,5 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import Annotated
import jwt
from fastapi import Depends, HTTPException, status, APIRouter
@@ -8,12 +7,12 @@ from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from pydantic import BaseModel
import database
import database.users
from config import AuthConfig
from database import SessionDependency
from database.users import UserInternal
class Token(BaseModel):
access_token: str
token_type: str
@@ -30,7 +29,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/token")
router = APIRouter()
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
async def get_current_user(db: SessionDependency, token: str = Depends(oauth2_scheme)) -> UserInternal:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -38,6 +37,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
)
config = AuthConfig()
log.debug("token: " + token)
try:
payload = jwt.decode(token, config.jwt_signing_key, algorithms=[config.jwt_signing_algorithm])
log.debug("jwt payload: " + payload.__str__())
@@ -49,10 +49,13 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
except InvalidTokenError:
log.warning("received invalid token: " + token)
raise credentials_exception
user = database.users.get_user(uid=token_data.uid)
user: UserInternal | None = db.get(UserInternal, token_data.uid)
if user is None:
log.debug("user not found")
raise credentials_exception
log.debug("received user: " + user.__str__())
return user

View File

@@ -1,12 +1,15 @@
from typing import Annotated
import hashlib
import bcrypt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session, select
import database
from auth import create_access_token, Token, router
from database import users
from database import users, SessionDependency
from database.users import UserInternal
@@ -17,21 +20,18 @@ def verify_password(plain_password, hashed_password):
)
def get_password_hash(password):
return bcrypt.hashpw(
bytes(password, encoding="utf-8"),
bcrypt.gensalt(),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def authenticate_user(email: str, password: str) -> bool | UserInternal:
def authenticate_user(db: SessionDependency, email: str, password: str) -> bool | UserInternal:
"""
:param email: email of the user
:param password: password of the user
:return: if authentication succeeds, returns the user object with added name and lastname, otherwise or if the user doesn't exist returns False
"""
user = database.users.get_user(email=email)
user: UserInternal | None = db.exec(select(UserInternal).where(UserInternal.email == email)).first()
if not user:
return False
if not verify_password(password, user.hashed_password):
@@ -42,13 +42,15 @@ def authenticate_user(email: str, password: str) -> bool | UserInternal:
@router.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: SessionDependency,
) -> Token:
user = authenticate_user(form_data.username, form_data.password)
user = authenticate_user(db,form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.id})
# id needs to be converted because a UUID object isn't json serializable
access_token = create_access_token(data={"sub": user.id.__str__()})
return Token(access_token=access_token, token_type="bearer")

View File

@@ -1,43 +1,28 @@
import logging
from typing import Any, Generator, Annotated
import psycopg
from psycopg.rows import dict_row
from fastapi import Depends
from sqlmodel import create_engine, SQLModel, Session
import config
from config import DbConfig
log = logging.getLogger(__name__)
config: DbConfig = config.get_db_config()
db_url = "postgresql+psycopg" + "://" + config.user + ":" + config.password + "@" + config.host + ":" + str(
config.port) + "/" + config.dbname
engine = create_engine(db_url, echo=True)
class PgDatabase:
"""PostgreSQL Database context manager using psycopg"""
def __init__(self) -> None:
self.driver = psycopg
def connect_to_database(self, config: DbConfig = DbConfig()):
return self.driver.connect(
autocommit=True,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
dbname=config.dbname,
row_factory=dict_row
)
def __enter__(self):
self.connection = self.connect_to_database()
return self
def __exit__(self, exception_type, exc_val, traceback):
self.connection.close()
def init_db() -> None:
SQLModel.metadata.create_all(engine)
def init_db():
log.info("Initializing database")
def get_session() -> Generator[Session, Any, None]:
with Session(engine) as session:
yield session
from database import tv, users
users.init_table()
tv.init_table()
SessionDependency = Annotated[Session, Depends(get_session)]
log.info("Tables initialized successfully")

View File

@@ -1,37 +0,0 @@
from database import PgDatabase, log
# TODO: add NOT NULL and default values to DB
def init_table():
with PgDatabase() as db:
db.connection.execute("""
CREATE TABLE IF NOT EXISTS tv_show (
id UUID PRIMARY KEY,
external_id TEXT,
metadata_provider TEXT,
name TEXT,
episode_count INTEGER,
season_count INTEGER,
UNIQUE (external_id, metadata_provider)
);""")
log.info("tv_show Table initialized successfully")
db.connection.execute("""
CREATE TABLE IF NOT EXISTS tv_season (
show_id UUID REFERENCES tv_show(id),
season_number INTEGER,
episode_count INTEGER,
CONSTRAINT PK_season PRIMARY KEY (show_id,season_number)
);""")
log.info("tv_seasonTable initialized successfully")
db.connection.execute("""
CREATE TABLE IF NOT EXISTS tv_episode (
season_number INTEGER,
show_id uuid,
episode_number INTEGER,
title TEXT,
CONSTRAINT PK_episode PRIMARY KEY (season_number,show_id,episode_number),
FOREIGN KEY (season_number, show_id) REFERENCES tv_season(season_number,show_id)
);""")
log.info("tv_episode Table initialized successfully")

View File

@@ -1,91 +1,19 @@
from uuid import uuid4
import uuid
from uuid import UUID
import psycopg
from pydantic import BaseModel
from database import PgDatabase, log
from sqlmodel import Field, SQLModel
class User(BaseModel):
"""
User model
"""
name: str
class UserBase(SQLModel):
name: str = Field()
lastname: str
email: str
email: str = Field(unique=True)
class UserPublic(UserBase):
id: UUID = Field(primary_key=True, default_factory=uuid.uuid4)
class UserInternal(User):
""""
Internal user model, assumes the password is already hashed, when a new instance is created
"""
id: str = str(uuid4())
class UserInternal(UserPublic, table=True):
hashed_password: str
def create_user(user: UserInternal) -> bool:
"""
:param user: user to create, password must already be hashed
:return: True if user was created, False otherwise
"""
with PgDatabase() as db:
try:
db.connection.execute(
"""
INSERT INTO users (id, name, lastname, email, hashed_password)
VALUES (%s, %s, %s, %s, %s)
""",
(user.id, user.name, user.lastname, user.email, user.hashed_password)
)
except psycopg.errors.UniqueViolation as e:
log.error(e)
return False
log.info("User inserted successfully")
log.debug(f"Inserted User: " + user.model_dump().__str__())
return True
def get_user(email: str = None, uid: str = None) -> UserInternal | None:
"""
either specify the email or uid to search for the user, if both parameters are provided the uid will be used
:param email: the users email address
:param uid: the users id
:return: if user was found it a UserInternal object is returned, otherwise None
"""
with PgDatabase() as db:
if email is not None and uid is None:
result = db.connection.execute(
"SELECT id, name, lastname, email, hashed_password FROM users WHERE email=%s",
(email,)
).fetchone()
if uid is not None:
result = db.connection.execute(
"SELECT id, name, lastname, email, hashed_password FROM users WHERE id=%s",
(uid,)
).fetchone()
if result is None:
return None
user = UserInternal(id=result["id"].__str__(), name=result["name"], lastname=result["lastname"],
email=result["email"],
hashed_password=result["hashed_password"])
log.debug(f"Retrieved User successfully: {user.model_dump()} ")
return user
def init_table():
with PgDatabase() as db:
db.connection.execute("""
CREATE TABLE IF NOT EXISTS users (
id UUID NOT NULL PRIMARY KEY,
lastname TEXT,
name TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
hashed_password TEXT NOT NULL
);
""")
log.info("users Table initialized successfully")
class UserCreate(UserBase):
password: str

View File

@@ -6,7 +6,8 @@ from fastapi import FastAPI
import config
import database
import tv.router
import database.users
#import tv.router
from auth import password
from routers import users
@@ -19,7 +20,7 @@ database.init_db()
app = FastAPI(root_path="/api/v1")
app.include_router(users.router, tags=["users"])
app.include_router(password.router, tags=["authentication"])
app.include_router(tv.router.router, tags=["tv"])
#app.include_router(tv.router.router, tags=["tv"])

View File

@@ -1,12 +1,13 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.exc import IntegrityError
from pydantic import BaseModel
from starlette.responses import JSONResponse
import database
from auth import get_current_user
from auth.password import get_password_hash
from database.users import UserInternal, User
from database import SessionDependency, get_session
from database.users import UserInternal, UserCreate, UserPublic
from routers import log
router = APIRouter(
@@ -18,28 +19,26 @@ class Message(BaseModel):
message: str
class CreateUser(User):
""""
The Usermodel, but with an additional non-hashed password. attribute
"""
password: str
@router.post("/", status_code=201, responses={
409: {"model": Message, "description": "User with provided email already exists"},
201: {"model": UserInternal, "description": "User created successfully"}
201: {"model": UserPublic, "description": "User created successfully"}
})
async def create_user(
user: CreateUser = Depends(CreateUser),
db: SessionDependency,
user: UserCreate = Depends(UserCreate),
):
internal_user = UserInternal(name=user.name, lastname=user.lastname, email=user.email,
hashed_password=get_password_hash(user.password))
if database.users.create_user(internal_user):
log.info("Created new user", internal_user.model_dump())
return internal_user
else:
log.warning("Failed to create new user, User with this email already exists,", internal_user.model_dump())
db.add(internal_user)
try:
db.commit()
except IntegrityError as e:
log.debug(e)
log.warning("Failed to create new user, User with this email already exists "+internal_user.model_dump().__str__())
return JSONResponse(status_code=409, content={"message": "User with this email already exists"})
log.info("Created new user "+internal_user.email)
return UserPublic(**internal_user.model_dump())
@router.get("/me")
@@ -51,6 +50,6 @@ async def read_users_me(
@router.get("/me/items")
async def read_own_items(
current_user: User = Depends(get_current_user),
current_user: UserInternal = Depends(get_current_user),
):
return [{"item_id": "Foo", "owner": current_user.name}]

View File

@@ -1,90 +1,92 @@
import logging
import pprint
from typing import Literal, List, Any
from typing import List
from uuid import UUID, uuid4
import requests
import tmdbsimple as tmdb
from pydantic import BaseModel
from config import TvConfig
from database import PgDatabase
import tmdbsimple as tmdb
from sqlmodel import Field, Session, SQLModel, create_engine, select
# NOTE: use tmdbsimple for api calls
class Episode(BaseModel):
class Episode(SQLModel):
show_id: int = Field(foreign_key="show.id")
season_number: int = Field(foreign_key="season.number")
number: int
title: str
class Season(BaseModel):
class Season(SQLModel, table=True):
show_id: UUID = Field(foreign_key="show.id")
number: int
episodes: List[Episode]
def get_episode_count(self) -> int:
return self.episodes.__len__()
class Show(BaseModel):
id: UUID = uuid4()
class Show(SQLModel, table=True):
id: UUID = Field(primary_key=True)
external_id: int
metadata_provider: str
name: str
seasons: List[Season] = []
def get_season_count(self) -> int:
return self.seasons.__len__()
# def get_season_count(self) -> int:
# return self.seasons.__len__()
def get_episode_count(self) -> int:
episode_count = 0
for season in self.seasons:
episode_count += season.get_episode_count()
return episode_count
# def get_episode_count(self) -> int:
# episode_count = 0
# for season in self.seasons:
# episode_count += season.get_episode_count()
# return episode_count
def save_show(self) -> None:
with PgDatabase() as db:
db.connection.execute("""
INSERT INTO tv_show (
id,
external_id,
metadata_provider,
name,
episode_count,
season_count
)VALUES(%s,%s,%s,%s,%s,%s);
""",
(self.id,
self.external_id,
self.metadata_provider,
self.name,
self.get_episode_count(),
self.get_season_count(),
)
)
log.info("added show: " + self.__str__())
# def save_show(self) -> None:
# with PgDatabase() as db:
# db.connection.execute("""
# INSERT INTO tv_show (
# id,
# external_id,
# metadata_provider,
# name,
# episode_count,
# season_count
# )VALUES(%s,%s,%s,%s,%s,%s);
# """,
# (self.id,
# self.external_id,
# self.metadata_provider,
# self.name,
# self.get_episode_count(),
# self.get_season_count(),
# )
# )
# log.info("added show: " + self.__str__())
def get_data_from_tmdb(self) -> None:
data = tmdb.TV(self.external_id).info()
log.debug("data from tmdb: " + pprint.pformat(data))
self.name = data["original_name"]
self.metadata_provider = "tmdb"
# def get_data_from_tmdb(self) -> None:
# data = tmdb.TV(self.external_id).info()
# log.debug("data from tmdb: " + pprint.pformat(data))
# self.name = data["original_name"]
# self.metadata_provider = "tmdb"
def add_season(self, season_number: int) -> None:
data = tmdb.TV_Seasons(self.external_id, season_number).info()
log.debug("data from tmdb: " + pprint.pformat(data))
# def add_season(self, season_number: int) -> None:
# data = tmdb.TV_Seasons(self.external_id, season_number).info()
# log.debug("data from tmdb: " + pprint.pformat(data))
episodes: List[Episode] = []
for episode in data["episodes"]:
episodes.append(Episode(title=episode["name"],number=episode["episode_number"]))
# episodes: List[Episode] = []
# for episode in data["episodes"]:
# episodes.append(Episode(title=episode["name"], number=episode["episode_number"]))
season = Season(number=season_number, episodes=episodes)
# season = Season(number=season_number, episodes=episodes)
self.seasons.append(season)
# self.seasons.append(season)
# def add_seasons(self, season_numbers: List[int]) -> None:
# for season_number in season_numbers:
# self.add_season(season_number)
def add_seasons(self, season_numbers: List[int]) -> None:
for season_number in season_numbers:
self.add_season(season_number)
def get_all_shows() -> List[Show]:
with PgDatabase() as db:
@@ -93,12 +95,14 @@ def get_all_shows() -> List[Show]:
""").fetchall()
return result
def get_show(id: UUID) -> Show:
with PgDatabase() as db:
result = db.connection.execute("""
SELECT * FROM tv_show WHERE id = %s
""", (id,)).fetchone()
return Show(**result)
return Show(**result)
config = TvConfig()
log = logging.getLogger(__name__)