work on getting the season number by asking an llm

This commit is contained in:
maxDorninger
2025-03-17 21:08:49 +01:00
parent 1d532b1f08
commit 0659e62ed2
7 changed files with 141 additions and 40 deletions

View File

@@ -1,33 +0,0 @@
services:
app:
build:
context: ..
dockerfile: .devcontainer/Dockerfile
# volumes:
# - ../..:/workspaces:cached
# Overrides default command so things don't shut down after the process ends.
command: sleep infinity
# Runs app on the same network as the database container, allows "forwardPorts" in devcontainer.json function.
network_mode: service:db
# Use "forwardPorts" in **devcontainer.json** to forward an app port locally.
# (Adding the "ports" property to this file will not forward from a Codespace.)
db:
image: postgres:latest
restart: unless-stopped
volumes:
- postgres-data:/var/lib/postgresql/data
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
# Add "forwardPorts": ["5432"] to **devcontainer.json** to forward PostgreSQL locally.
# (Adding the "ports" property to this file will not forward from a Codespace.)
volumes:
postgres-data:

3
.gitignore vendored
View File

@@ -2,4 +2,5 @@
venv
MediaManager.iml
MediaManager/res
MediaManager/res/.env
MediaManager/res/.env
docker-compose.yml

View File

@@ -0,0 +1,46 @@
import json
import logging
from collections import Counter
from typing import List
from ollama import ChatResponse
from ollama import chat
from pydantic import BaseModel
log = logging.getLogger(__name__)
class NFO(BaseModel):
season: int
def get_season(nfo: str) -> int | None:
responses: List[ChatResponse] = []
parsed_responses: List[int] = []
for i in range(0, 4):
responses.append(chat(
model='qwen2.5:0.5b',
format=NFO.model_json_schema(),
messages=[
{
'role': 'user',
'content':
"which season does a torrent with the following NFO contain, output a season number, the season number is an integer? output in json please" +
nfo
},
]))
for response in responses:
season_number: int
try:
season_number: int = json.loads(response.message.content)['season']
except Exception as e:
log.warning(f"failed to parse season number: {e}")
break
parsed_responses.append(season_number)
most_common = Counter(parsed_responses).most_common(1)
log.debug(f"extracted season number: {most_common} from nfo: {nfo}")
return most_common[0][0]

View File

@@ -0,0 +1,46 @@
import json
from datetime import datetime, timedelta
from ollama import ChatResponse
from ollama import chat
from pydantic import BaseModel
class NFO(BaseModel):
season: int
# or access fields directly from the response object
start_time = datetime.now() + timedelta(seconds=300)
i = 0
failed_prompts = 0
while start_time > datetime.now():
response: ChatResponse = chat(model='qwen2.5:0.5b',
format=NFO.model_json_schema()
, messages=[
{
'role': 'user',
'content':
"which season does a torrent with the following NFO contain? output the season number, which is an integer in json please\n" +
"The.Big.Bang.Theory.(2007).Season.9.S09.(1080p.BluRay.x265.HEVC.10bit.AAC.5.1.Vyndros)"
},
])
i += 1
print("prompt #", i)
print("remaining time: ", start_time - datetime.now())
try:
json2 = json.loads(response.message.content)
print(json2)
except Exception as e:
print("prompt failed", e)
print(response.message.content)
failed_prompts += 1
if json2['season'] != 9:
failed_prompts += 1
print("prompts: ", i, " total time: 120s")
print("failed prompts: ", failed_prompts)
print("average time per prompt: ", 300 / i)
print("average time per successful prompt: ", 300 / (i - failed_prompts))
print("ratio successful/failed prompts: ", failed_prompts / (i - failed_prompts))

View File

@@ -4,7 +4,7 @@ from uuid import UUID
import tmdbsimple as tmdb
from sqlalchemy import UniqueConstraint, ForeignKeyConstraint
from sqlmodel import Field, SQLModel
from sqlmodel import Field, SQLModel, Relationship
from config import TvConfig
@@ -16,17 +16,22 @@ class Show(SQLModel, table=True):
name: str
overview: str
seasons: list["Season"] = Relationship(back_populates="show", cascade_delete=True)
class Season(SQLModel, table=True):
show_id: UUID = Field(foreign_key="show.id", primary_key=True, default_factory=uuid.uuid4)
show_id: UUID = Field(foreign_key="show.id", primary_key=True, default_factory=uuid.uuid4, ondelete="CASCADE")
number: int = Field(primary_key=True)
requested: bool = Field(default=False)
external_id: int
name: str
overview: str
show: Show = Relationship(back_populates="seasons")
episodes: list["Episode"] = Relationship(back_populates="season", cascade_delete=True)
class Episode(SQLModel, table=True):
__table_args__ = (
ForeignKeyConstraint(['show_id', 'season_number'], ['season.show_id', 'season.number']),
ForeignKeyConstraint(['show_id', 'season_number'], ['season.show_id', 'season.number'], ondelete="CASCADE"),
)
show_id: UUID = Field(primary_key=True)
season_number: int = Field( primary_key=True)
@@ -34,6 +39,8 @@ class Episode(SQLModel, table=True):
external_id: int
title: str
season: Season = Relationship(back_populates="episodes")
config = TvConfig()
log = logging.getLogger(__name__)

View File

@@ -3,7 +3,6 @@ from typing import List
from uuid import UUID
import psycopg.errors
import sqlalchemy
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from sqlmodel import select
@@ -67,12 +66,15 @@ def add_show(db: SessionDependency, show_id: int, metadata_provider: str = "tmdb
return show
@router.delete("/{show_id}", status_code=status.HTTP_200_OK)
def delete_show(db: SessionDependency, show_id: UUID):
db.delete(db.get(Show, show_id))
db.commit()
@router.patch("/{show_id}/{season}", status_code=status.HTTP_200_OK, dependencies=[Depends(auth.get_current_user)], response_model=Show)
@router.patch("/{show_id}/{season}", status_code=status.HTTP_200_OK, dependencies=[Depends(auth.get_current_user)],
response_model=Show)
def add_season(db: SessionDependency, show_id: UUID, season: int):
"""
adds requested flag to a season
@@ -84,7 +86,9 @@ def add_season(db: SessionDependency, show_id: UUID, season: int):
db.refresh(season)
return season
@router.delete("/{show_id}/{season}", status_code=status.HTTP_200_OK, dependencies=[Depends(auth.get_current_user)], response_model=Show)
@router.delete("/{show_id}/{season}", status_code=status.HTTP_200_OK, dependencies=[Depends(auth.get_current_user)],
response_model=Show)
def delete_season(db: SessionDependency, show_id: UUID, season: int):
"""
removes requested flag from a season
@@ -96,6 +100,7 @@ def delete_season(db: SessionDependency, show_id: UUID, season: int):
db.refresh(season)
return season
@router.get("/show", dependencies=[Depends(auth.get_current_user)], response_model=List[Show])
def get_shows(db: SessionDependency):
return db.exec(select(Show)).unique().fetchall()

29
docker-compose.yml Normal file
View File

@@ -0,0 +1,29 @@
services:
db:
image: postgres:latest
restart: unless-stopped
volumes:
- .\MediaManager\res\postgres:/var/lib/postgresql/data
environment:
POSTGRES_USER: MediaManager
POSTGRES_DB: MediaManager
POSTGRES_PASSWORD: MediaManager
ports:
- "5432:5432"
prowlarr:
image: lscr.io/linuxserver/prowlarr:latest
container_name: prowlarr
environment:
- PUID=1000
- PGID=1000
- TZ=Etc/UTC
volumes:
- .\MediaManager\res\prowlarr:/config
ports:
- "9696:9696"
ollama:
image: ollama/ollama
volumes:
- .\MediaManager\res\ollama:/root/.ollama
ports:
- "11434:11434"