Files
MediaManager-maxdorninger/media_manager/tv/repository.py
2025-05-29 15:36:35 +02:00

292 lines
9.0 KiB
Python

from sqlalchemy import select, delete
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from media_manager.torrent.models import Torrent
from media_manager.torrent.schemas import TorrentId, Torrent as TorrentSchema
from media_manager.tv import log
from media_manager.tv.models import Season, Show, Episode, SeasonRequest, SeasonFile
from media_manager.tv.schemas import (
Season as SeasonSchema,
SeasonId,
Show as ShowSchema,
ShowId,
SeasonRequest as SeasonRequestSchema,
SeasonFile as SeasonFileSchema,
SeasonNumber,
SeasonRequestId,
RichSeasonRequest as RichSeasonRequestSchema,
)
def get_show(show_id: ShowId, db: Session) -> ShowSchema | None:
"""
Retrieve a show by its ID, including seasons and episodes.
:param show_id: The ID of the show to retrieve.
:param db: The database session.
:return: A ShowSchema object if found, otherwise None.
"""
stmt = (
select(Show)
.where(Show.id == show_id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
)
result = db.execute(stmt).unique().scalar_one_or_none()
if not result:
return None
return ShowSchema.model_validate(result)
def get_show_by_external_id(
external_id: int, db: Session, metadata_provider: str
) -> ShowSchema | None:
"""
Retrieve a show by its external ID, including nested seasons and episodes.
:param external_id: The ID of the show to retrieve.
:param metadata_provider: The metadata provider associated with the ID.
:param db: The database session.
:return: A ShowSchema object if found, otherwise None.
"""
stmt = (
select(Show)
.where(Show.external_id == external_id)
.where(Show.metadata_provider == metadata_provider)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
)
result = db.execute(stmt).unique().scalar_one_or_none()
if not result:
return None
return ShowSchema(**result.__dict__)
def get_shows(db: Session) -> list[ShowSchema]:
"""
Retrieve all shows from the database, including nested seasons and episodes.
:param db: The database session.
:return: A list of ShowSchema objects.
"""
stmt = select(Show)
results = db.execute(stmt).scalars().all()
return [ShowSchema.model_validate(show) for show in results]
def save_show(show: ShowSchema, db: Session) -> ShowSchema:
"""
Save a new show to the database, including its seasons and episodes.
:param show: The ShowSchema object to save.
:param db: The database session.
:return: The saved ShowSchema object.
:raises ValueError: If a show with the same primary key already exists.
"""
db_show = Show(
id=show.id,
external_id=show.external_id,
metadata_provider=show.metadata_provider,
name=show.name,
overview=show.overview,
year=show.year,
seasons=[
Season(
id=season.id,
show_id=show.id,
number=season.number,
external_id=season.external_id,
name=season.name,
overview=season.overview,
episodes=[
Episode(
id=episode.id,
season_id=season.id,
number=episode.number,
external_id=episode.external_id,
title=episode.title,
)
for episode in season.episodes
],
)
for season in show.seasons
],
)
db.add(db_show)
try:
db.commit()
db.refresh(db_show)
return ShowSchema.model_validate(db_show)
except IntegrityError:
db.rollback()
raise ValueError("Show with this primary key already exists.")
def delete_show(show_id: ShowId, db: Session) -> None:
"""
Delete a show by its ID.
:param show_id: The ID of the show to delete.
:param db: The database session.
:return: The deleted ShowSchema object if found, otherwise None.
"""
show = db.get(Show, show_id)
db.delete(show)
db.commit()
def get_season(season_id: SeasonId, db: Session) -> SeasonSchema:
"""
:param season_id: The ID of the season to get.
:param db: The database session.
:return: a Season object.
"""
return SeasonSchema.model_validate(db.get(Season, season_id))
def add_season_request(season_request: SeasonRequestSchema, db: Session) -> None:
"""
Adds a Season to the SeasonRequest table, which marks it as requested.
"""
log.debug(f"Adding season request {season_request.model_dump()}")
db_model = SeasonRequest(
id=season_request.id,
season_id=season_request.season_id,
wanted_quality=season_request.wanted_quality,
min_quality=season_request.min_quality,
requested_by_id=season_request.requested_by.id
if season_request.requested_by
else None,
authorized=season_request.authorized,
authorized_by_id=season_request.authorized_by.id
if season_request.authorized_by
else None,
)
db.add(db_model)
db.commit()
def delete_season_request(season_request_id: SeasonRequestId, db: Session) -> None:
"""
Removes a Season from the SeasonRequest table, which removes it from the 'requested' list.
"""
stmt = delete(SeasonRequest).where(SeasonRequest.id == season_request_id)
db.execute(stmt)
db.commit()
def get_season_by_number(
db: Session, season_number: int, show_id: ShowId
) -> SeasonSchema:
stmt = (
select(Season)
.where(Season.show_id == show_id)
.where(Season.number == season_number)
.options(joinedload(Season.episodes), joinedload(Season.show))
)
result = db.execute(stmt).unique().scalar_one_or_none()
return SeasonSchema.model_validate(result)
def get_season_requests(db: Session) -> list[RichSeasonRequestSchema]:
stmt = select(SeasonRequest).options(
joinedload(SeasonRequest.requested_by),
joinedload(SeasonRequest.authorized_by),
joinedload(SeasonRequest.season).joinedload(Season.show),
)
result = db.execute(stmt).scalars().unique().all()
return [
RichSeasonRequestSchema(
min_quality=x.min_quality,
wanted_quality=x.wanted_quality,
show=x.season.show,
season=x.season,
requested_by=x.requested_by,
authorized_by=x.authorized_by,
authorized=x.authorized,
id=x.id,
season_id=x.season.id,
)
for x in result
]
def add_season_file(db: Session, season_file: SeasonFileSchema) -> SeasonFileSchema:
db.add(SeasonFile(**season_file.model_dump()))
db.commit()
return season_file
def remove_season_files_by_torrent_id(db: Session, torrent_id: TorrentId):
stmt = delete(SeasonFile).where(SeasonFile.torrent_id == torrent_id)
db.execute(stmt)
def get_season_files_by_season_id(db: Session, season_id: SeasonId):
stmt = select(SeasonFile).where(SeasonFile.season_id == season_id)
result = db.execute(stmt).scalars().all()
return [SeasonFileSchema.model_validate(season_file) for season_file in result]
def get_torrents_by_show_id(db: Session, show_id: ShowId) -> list[TorrentSchema]:
stmt = (
select(Torrent)
.distinct()
.join(SeasonFile, SeasonFile.torrent_id == Torrent.id)
.join(Season, Season.id == SeasonFile.season_id)
.where(Season.show_id == show_id)
)
result = db.execute(stmt).scalars().unique().all()
return [TorrentSchema.model_validate(torrent) for torrent in result]
def get_all_shows_with_torrents(db: Session) -> list[ShowSchema]:
"""
Retrieve all shows that are associated with a torrent alphabetically from the database.
:param db: The database session.
:return: A list of ShowSchema objects.
"""
stmt = (
select(Show)
.distinct()
.join(Season, Show.id == Season.show_id)
.join(SeasonFile, Season.id == SeasonFile.season_id)
.join(Torrent, SeasonFile.torrent_id == Torrent.id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
.order_by(Show.name)
)
results = db.execute(stmt).scalars().unique().all()
return [ShowSchema.model_validate(show) for show in results]
def get_seasons_by_torrent_id(db: Session, torrent_id: TorrentId) -> list[SeasonNumber]:
stmt = (
select(Season.number)
.distinct()
.join(SeasonFile, SeasonFile.torrent_id == Torrent.id)
.join(Season, Season.id == SeasonFile.season_id)
.where(Torrent.id == torrent_id)
.select_from(Torrent)
)
result = db.execute(stmt).scalars().unique().all()
return [SeasonNumber(x) for x in result]
def get_season_request(
db: Session, season_request_id: SeasonRequestId
) -> SeasonRequestSchema:
return SeasonRequestSchema.model_validate(db.get(SeasonRequest, season_request_id))