mirror of
https://github.com/maxdorninger/MediaManager.git
synced 2026-04-17 21:54:00 +02:00
switch to the SQLModel ORM
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status, APIRouter
|
||||
@@ -8,12 +7,12 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from pydantic import BaseModel
|
||||
|
||||
import database
|
||||
import database.users
|
||||
from config import AuthConfig
|
||||
from database import SessionDependency
|
||||
from database.users import UserInternal
|
||||
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
@@ -30,7 +29,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/token")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
|
||||
async def get_current_user(db: SessionDependency, token: str = Depends(oauth2_scheme)) -> UserInternal:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -38,6 +37,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
|
||||
)
|
||||
config = AuthConfig()
|
||||
log.debug("token: " + token)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_signing_key, algorithms=[config.jwt_signing_algorithm])
|
||||
log.debug("jwt payload: " + payload.__str__())
|
||||
@@ -49,10 +49,13 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInternal:
|
||||
except InvalidTokenError:
|
||||
log.warning("received invalid token: " + token)
|
||||
raise credentials_exception
|
||||
user = database.users.get_user(uid=token_data.uid)
|
||||
|
||||
user: UserInternal | None = db.get(UserInternal, token_data.uid)
|
||||
|
||||
if user is None:
|
||||
log.debug("user not found")
|
||||
raise credentials_exception
|
||||
|
||||
log.debug("received user: " + user.__str__())
|
||||
return user
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from typing import Annotated
|
||||
|
||||
import hashlib
|
||||
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session, select
|
||||
|
||||
import database
|
||||
from auth import create_access_token, Token, router
|
||||
from database import users
|
||||
from database import users, SessionDependency
|
||||
from database.users import UserInternal
|
||||
|
||||
|
||||
@@ -17,21 +20,18 @@ def verify_password(plain_password, hashed_password):
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return bcrypt.hashpw(
|
||||
bytes(password, encoding="utf-8"),
|
||||
bcrypt.gensalt(),
|
||||
)
|
||||
def get_password_hash(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def authenticate_user(email: str, password: str) -> bool | UserInternal:
|
||||
def authenticate_user(db: SessionDependency, email: str, password: str) -> bool | UserInternal:
|
||||
"""
|
||||
|
||||
:param email: email of the user
|
||||
:param password: password of the user
|
||||
:return: if authentication succeeds, returns the user object with added name and lastname, otherwise or if the user doesn't exist returns False
|
||||
"""
|
||||
user = database.users.get_user(email=email)
|
||||
user: UserInternal | None = db.exec(select(UserInternal).where(UserInternal.email == email)).first()
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
@@ -42,13 +42,15 @@ def authenticate_user(email: str, password: str) -> bool | UserInternal:
|
||||
@router.post("/token")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: SessionDependency,
|
||||
) -> Token:
|
||||
user = authenticate_user(form_data.username, form_data.password)
|
||||
user = authenticate_user(db,form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
# id needs to be converted because a UUID object isn't json serializable
|
||||
access_token = create_access_token(data={"sub": user.id.__str__()})
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@@ -1,43 +1,28 @@
|
||||
import logging
|
||||
from typing import Any, Generator, Annotated
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row
|
||||
from fastapi import Depends
|
||||
from sqlmodel import create_engine, SQLModel, Session
|
||||
|
||||
import config
|
||||
from config import DbConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
config: DbConfig = config.get_db_config()
|
||||
|
||||
db_url = "postgresql+psycopg" + "://" + config.user + ":" + config.password + "@" + config.host + ":" + str(
|
||||
config.port) + "/" + config.dbname
|
||||
|
||||
engine = create_engine(db_url, echo=True)
|
||||
|
||||
|
||||
class PgDatabase:
|
||||
"""PostgreSQL Database context manager using psycopg"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.driver = psycopg
|
||||
|
||||
def connect_to_database(self, config: DbConfig = DbConfig()):
|
||||
return self.driver.connect(
|
||||
autocommit=True,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dbname=config.dbname,
|
||||
row_factory=dict_row
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.connection = self.connect_to_database()
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exc_val, traceback):
|
||||
self.connection.close()
|
||||
def init_db() -> None:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
def init_db():
|
||||
log.info("Initializing database")
|
||||
def get_session() -> Generator[Session, Any, None]:
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
from database import tv, users
|
||||
users.init_table()
|
||||
tv.init_table()
|
||||
SessionDependency = Annotated[Session, Depends(get_session)]
|
||||
|
||||
log.info("Tables initialized successfully")
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from database import PgDatabase, log
|
||||
|
||||
# TODO: add NOT NULL and default values to DB
|
||||
|
||||
def init_table():
|
||||
with PgDatabase() as db:
|
||||
db.connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tv_show (
|
||||
id UUID PRIMARY KEY,
|
||||
external_id TEXT,
|
||||
metadata_provider TEXT,
|
||||
name TEXT,
|
||||
episode_count INTEGER,
|
||||
season_count INTEGER,
|
||||
UNIQUE (external_id, metadata_provider)
|
||||
|
||||
);""")
|
||||
log.info("tv_show Table initialized successfully")
|
||||
db.connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tv_season (
|
||||
show_id UUID REFERENCES tv_show(id),
|
||||
season_number INTEGER,
|
||||
episode_count INTEGER,
|
||||
CONSTRAINT PK_season PRIMARY KEY (show_id,season_number)
|
||||
);""")
|
||||
log.info("tv_seasonTable initialized successfully")
|
||||
db.connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tv_episode (
|
||||
season_number INTEGER,
|
||||
show_id uuid,
|
||||
episode_number INTEGER,
|
||||
title TEXT,
|
||||
CONSTRAINT PK_episode PRIMARY KEY (season_number,show_id,episode_number),
|
||||
FOREIGN KEY (season_number, show_id) REFERENCES tv_season(season_number,show_id)
|
||||
|
||||
);""")
|
||||
log.info("tv_episode Table initialized successfully")
|
||||
@@ -1,91 +1,19 @@
|
||||
from uuid import uuid4
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
import psycopg
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database import PgDatabase, log
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""
|
||||
User model
|
||||
"""
|
||||
name: str
|
||||
class UserBase(SQLModel):
|
||||
name: str = Field()
|
||||
lastname: str
|
||||
email: str
|
||||
email: str = Field(unique=True)
|
||||
|
||||
class UserPublic(UserBase):
|
||||
id: UUID = Field(primary_key=True, default_factory=uuid.uuid4)
|
||||
|
||||
class UserInternal(User):
|
||||
""""
|
||||
Internal user model, assumes the password is already hashed, when a new instance is created
|
||||
"""
|
||||
id: str = str(uuid4())
|
||||
class UserInternal(UserPublic, table=True):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
def create_user(user: UserInternal) -> bool:
|
||||
"""
|
||||
|
||||
:param user: user to create, password must already be hashed
|
||||
:return: True if user was created, False otherwise
|
||||
"""
|
||||
with PgDatabase() as db:
|
||||
try:
|
||||
db.connection.execute(
|
||||
"""
|
||||
INSERT INTO users (id, name, lastname, email, hashed_password)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(user.id, user.name, user.lastname, user.email, user.hashed_password)
|
||||
)
|
||||
except psycopg.errors.UniqueViolation as e:
|
||||
log.error(e)
|
||||
return False
|
||||
|
||||
log.info("User inserted successfully")
|
||||
log.debug(f"Inserted User: " + user.model_dump().__str__())
|
||||
return True
|
||||
|
||||
|
||||
def get_user(email: str = None, uid: str = None) -> UserInternal | None:
|
||||
"""
|
||||
either specify the email or uid to search for the user, if both parameters are provided the uid will be used
|
||||
|
||||
|
||||
:param email: the users email address
|
||||
:param uid: the users id
|
||||
:return: if user was found it a UserInternal object is returned, otherwise None
|
||||
"""
|
||||
with PgDatabase() as db:
|
||||
if email is not None and uid is None:
|
||||
result = db.connection.execute(
|
||||
"SELECT id, name, lastname, email, hashed_password FROM users WHERE email=%s",
|
||||
(email,)
|
||||
).fetchone()
|
||||
if uid is not None:
|
||||
result = db.connection.execute(
|
||||
"SELECT id, name, lastname, email, hashed_password FROM users WHERE id=%s",
|
||||
(uid,)
|
||||
).fetchone()
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
user = UserInternal(id=result["id"].__str__(), name=result["name"], lastname=result["lastname"],
|
||||
email=result["email"],
|
||||
hashed_password=result["hashed_password"])
|
||||
log.debug(f"Retrieved User successfully: {user.model_dump()} ")
|
||||
return user
|
||||
|
||||
|
||||
def init_table():
|
||||
with PgDatabase() as db:
|
||||
db.connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
lastname TEXT,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
hashed_password TEXT NOT NULL
|
||||
);
|
||||
""")
|
||||
log.info("users Table initialized successfully")
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
@@ -6,7 +6,8 @@ from fastapi import FastAPI
|
||||
|
||||
import config
|
||||
import database
|
||||
import tv.router
|
||||
import database.users
|
||||
#import tv.router
|
||||
from auth import password
|
||||
from routers import users
|
||||
|
||||
@@ -19,7 +20,7 @@ database.init_db()
|
||||
app = FastAPI(root_path="/api/v1")
|
||||
app.include_router(users.router, tags=["users"])
|
||||
app.include_router(password.router, tags=["authentication"])
|
||||
app.include_router(tv.router.router, tags=["tv"])
|
||||
#app.include_router(tv.router.router, tags=["tv"])
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
import database
|
||||
from auth import get_current_user
|
||||
from auth.password import get_password_hash
|
||||
from database.users import UserInternal, User
|
||||
from database import SessionDependency, get_session
|
||||
from database.users import UserInternal, UserCreate, UserPublic
|
||||
from routers import log
|
||||
|
||||
router = APIRouter(
|
||||
@@ -18,28 +19,26 @@ class Message(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class CreateUser(User):
|
||||
""""
|
||||
The Usermodel, but with an additional non-hashed password. attribute
|
||||
"""
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/", status_code=201, responses={
|
||||
409: {"model": Message, "description": "User with provided email already exists"},
|
||||
201: {"model": UserInternal, "description": "User created successfully"}
|
||||
201: {"model": UserPublic, "description": "User created successfully"}
|
||||
})
|
||||
async def create_user(
|
||||
user: CreateUser = Depends(CreateUser),
|
||||
db: SessionDependency,
|
||||
user: UserCreate = Depends(UserCreate),
|
||||
):
|
||||
internal_user = UserInternal(name=user.name, lastname=user.lastname, email=user.email,
|
||||
hashed_password=get_password_hash(user.password))
|
||||
if database.users.create_user(internal_user):
|
||||
log.info("Created new user", internal_user.model_dump())
|
||||
return internal_user
|
||||
else:
|
||||
log.warning("Failed to create new user, User with this email already exists,", internal_user.model_dump())
|
||||
db.add(internal_user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as e:
|
||||
log.debug(e)
|
||||
log.warning("Failed to create new user, User with this email already exists "+internal_user.model_dump().__str__())
|
||||
return JSONResponse(status_code=409, content={"message": "User with this email already exists"})
|
||||
log.info("Created new user "+internal_user.email)
|
||||
return UserPublic(**internal_user.model_dump())
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
@@ -51,6 +50,6 @@ async def read_users_me(
|
||||
|
||||
@router.get("/me/items")
|
||||
async def read_own_items(
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: UserInternal = Depends(get_current_user),
|
||||
):
|
||||
return [{"item_id": "Foo", "owner": current_user.name}]
|
||||
|
||||
@@ -1,90 +1,92 @@
|
||||
import logging
|
||||
import pprint
|
||||
from typing import Literal, List, Any
|
||||
from typing import List
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import requests
|
||||
import tmdbsimple as tmdb
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import TvConfig
|
||||
from database import PgDatabase
|
||||
import tmdbsimple as tmdb
|
||||
|
||||
|
||||
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
|
||||
# NOTE: use tmdbsimple for api calls
|
||||
|
||||
class Episode(BaseModel):
|
||||
class Episode(SQLModel):
|
||||
show_id: int = Field(foreign_key="show.id")
|
||||
season_number: int = Field(foreign_key="season.number")
|
||||
number: int
|
||||
title: str
|
||||
|
||||
|
||||
class Season(BaseModel):
|
||||
class Season(SQLModel, table=True):
|
||||
show_id: UUID = Field(foreign_key="show.id")
|
||||
number: int
|
||||
episodes: List[Episode]
|
||||
|
||||
def get_episode_count(self) -> int:
|
||||
return self.episodes.__len__()
|
||||
|
||||
|
||||
class Show(BaseModel):
|
||||
id: UUID = uuid4()
|
||||
class Show(SQLModel, table=True):
|
||||
id: UUID = Field(primary_key=True)
|
||||
external_id: int
|
||||
metadata_provider: str
|
||||
name: str
|
||||
seasons: List[Season] = []
|
||||
|
||||
def get_season_count(self) -> int:
|
||||
return self.seasons.__len__()
|
||||
# def get_season_count(self) -> int:
|
||||
# return self.seasons.__len__()
|
||||
|
||||
def get_episode_count(self) -> int:
|
||||
episode_count = 0
|
||||
for season in self.seasons:
|
||||
episode_count += season.get_episode_count()
|
||||
return episode_count
|
||||
# def get_episode_count(self) -> int:
|
||||
# episode_count = 0
|
||||
# for season in self.seasons:
|
||||
# episode_count += season.get_episode_count()
|
||||
# return episode_count
|
||||
|
||||
def save_show(self) -> None:
|
||||
with PgDatabase() as db:
|
||||
db.connection.execute("""
|
||||
INSERT INTO tv_show (
|
||||
id,
|
||||
external_id,
|
||||
metadata_provider,
|
||||
name,
|
||||
episode_count,
|
||||
season_count
|
||||
)VALUES(%s,%s,%s,%s,%s,%s);
|
||||
""",
|
||||
(self.id,
|
||||
self.external_id,
|
||||
self.metadata_provider,
|
||||
self.name,
|
||||
self.get_episode_count(),
|
||||
self.get_season_count(),
|
||||
)
|
||||
)
|
||||
log.info("added show: " + self.__str__())
|
||||
# def save_show(self) -> None:
|
||||
# with PgDatabase() as db:
|
||||
# db.connection.execute("""
|
||||
# INSERT INTO tv_show (
|
||||
# id,
|
||||
# external_id,
|
||||
# metadata_provider,
|
||||
# name,
|
||||
# episode_count,
|
||||
# season_count
|
||||
# )VALUES(%s,%s,%s,%s,%s,%s);
|
||||
# """,
|
||||
# (self.id,
|
||||
# self.external_id,
|
||||
# self.metadata_provider,
|
||||
# self.name,
|
||||
# self.get_episode_count(),
|
||||
# self.get_season_count(),
|
||||
# )
|
||||
# )
|
||||
# log.info("added show: " + self.__str__())
|
||||
|
||||
def get_data_from_tmdb(self) -> None:
|
||||
data = tmdb.TV(self.external_id).info()
|
||||
log.debug("data from tmdb: " + pprint.pformat(data))
|
||||
self.name = data["original_name"]
|
||||
self.metadata_provider = "tmdb"
|
||||
# def get_data_from_tmdb(self) -> None:
|
||||
# data = tmdb.TV(self.external_id).info()
|
||||
# log.debug("data from tmdb: " + pprint.pformat(data))
|
||||
# self.name = data["original_name"]
|
||||
# self.metadata_provider = "tmdb"
|
||||
|
||||
def add_season(self, season_number: int) -> None:
|
||||
data = tmdb.TV_Seasons(self.external_id, season_number).info()
|
||||
log.debug("data from tmdb: " + pprint.pformat(data))
|
||||
# def add_season(self, season_number: int) -> None:
|
||||
# data = tmdb.TV_Seasons(self.external_id, season_number).info()
|
||||
# log.debug("data from tmdb: " + pprint.pformat(data))
|
||||
|
||||
episodes: List[Episode] = []
|
||||
for episode in data["episodes"]:
|
||||
episodes.append(Episode(title=episode["name"],number=episode["episode_number"]))
|
||||
# episodes: List[Episode] = []
|
||||
# for episode in data["episodes"]:
|
||||
# episodes.append(Episode(title=episode["name"], number=episode["episode_number"]))
|
||||
|
||||
season = Season(number=season_number, episodes=episodes)
|
||||
# season = Season(number=season_number, episodes=episodes)
|
||||
|
||||
self.seasons.append(season)
|
||||
# self.seasons.append(season)
|
||||
|
||||
# def add_seasons(self, season_numbers: List[int]) -> None:
|
||||
# for season_number in season_numbers:
|
||||
# self.add_season(season_number)
|
||||
|
||||
def add_seasons(self, season_numbers: List[int]) -> None:
|
||||
for season_number in season_numbers:
|
||||
self.add_season(season_number)
|
||||
|
||||
def get_all_shows() -> List[Show]:
|
||||
with PgDatabase() as db:
|
||||
@@ -93,12 +95,14 @@ def get_all_shows() -> List[Show]:
|
||||
""").fetchall()
|
||||
return result
|
||||
|
||||
|
||||
def get_show(id: UUID) -> Show:
|
||||
with PgDatabase() as db:
|
||||
result = db.connection.execute("""
|
||||
SELECT * FROM tv_show WHERE id = %s
|
||||
""", (id,)).fetchone()
|
||||
return Show(**result)
|
||||
return Show(**result)
|
||||
|
||||
|
||||
config = TvConfig()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user