Formatting issues

This commit is contained in:
Benjamin Harder
2025-10-01 18:43:38 +02:00
parent 1a4bd8f4be
commit 8d9a64798d
41 changed files with 725 additions and 387 deletions

23
main.py
View File

@@ -1,21 +1,24 @@
import asyncio
import signal
import types
import datetime
import signal
import sys
import types
from src.deletion_handler.deletion_handler import WatcherManager
from src.job_manager import JobManager
from src.settings.settings import Settings
from src.utils.log_setup import logger
from src.utils.startup import launch_steps
from src.deletion_handler.deletion_handler import WatcherManager
settings = Settings()
job_manager = JobManager(settings)
watch_manager = WatcherManager(settings)
def terminate(sigterm: signal.SIGTERM, frame: types.FrameType) -> None: # noqa: ARG001, pylint: disable=unused-argument
def terminate(
sigterm: signal.SIGTERM, # noqa: ARG001, pylint: disable=unused-argument
frame: types.FrameType, # noqa: ARG001, pylint: disable=unused-argument
) -> None:
"""Terminate cleanly. Needed for respecting 'docker stop'.
Args:
@@ -25,13 +28,18 @@ def terminate(sigterm: signal.SIGTERM, frame: types.FrameType) -> None: # noqa:
"""
logger.info(f"Termination signal received at {datetime.datetime.now()}.") # noqa: DTZ005
logger.info(
f"Termination signal received at {datetime.datetime.now()}."
) # noqa: DTZ005
watch_manager.stop()
sys.exit(0)
async def wait_next_run():
# Calculate next run time dynamically (to display)
next_run = datetime.datetime.now() + datetime.timedelta(minutes=settings.general.timer)
# Calculate next run time dynamically (to display)
next_run = datetime.datetime.now() + datetime.timedelta(
minutes=settings.general.timer
)
formatted_next_run = next_run.strftime("%Y-%m-%d %H:%M")
logger.verbose(f"*** Done - Next run at {formatted_next_run} ****")
@@ -39,6 +47,7 @@ async def wait_next_run():
# Wait for the next run
await asyncio.sleep(settings.general.timer * 60)
# Main function
async def main():
await launch_steps(settings)

View File

@@ -1,9 +1,10 @@
import asyncio
from pathlib import Path
from collections import defaultdict
from pathlib import Path
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
from src.utils.log_setup import logger
@@ -24,7 +25,6 @@ class DeletionHandler(FileSystemEventHandler):
deleted_file = event.src_path
asyncio.run_coroutine_threadsafe(self._queue_delete(deleted_file), self.loop)
async def _queue_delete(self, deleted_file):
async with self._lock:
self.deleted_files.add(deleted_file)
@@ -40,14 +40,18 @@ class DeletionHandler(FileSystemEventHandler):
async with self._lock:
# Copy and clear the deleted files set
files_to_process = self.deleted_files.copy()
logger.debug(f"deletion_handler.py/_process_deletes_after_delay: Deleted files: {' '.join(files_to_process)}")
logger.debug(
f"deletion_handler.py/_process_deletes_after_delay: Deleted files: {' '.join(files_to_process)}"
)
for handler in logger.handlers:
handler.flush()
self.deleted_files.clear()
# Extract parent folder paths, deduplicate them
deletions = self._group_deletions_by_folder(files_to_process)
logger.debug(f"deletion_handler.py/_process_deletes_after_delay: Folders with deletes: {' '.join(deletions.keys())}")
logger.debug(
f"deletion_handler.py/_process_deletes_after_delay: Folders with deletes: {' '.join(deletions.keys())}"
)
await self._handle_folders(deletions)
@@ -66,7 +70,7 @@ class DeletionHandler(FileSystemEventHandler):
logger.info(
f"Job 'detect_deletions' triggered media refresh on {self.arr.name} ({self.arr.base_url}): {refresh_item['title']}"
)
await self.arr.refresh_item(refresh_item['id'])
await self.arr.refresh_item(refresh_item["id"])
else:
logger.verbose(
f"Job 'detect_deletions' detected a deleted file, but couldn't find a corresponding media item on {self.arr.name} ({self.arr.base_url})"
@@ -80,6 +84,7 @@ class DeletionHandler(FileSystemEventHandler):
if self._process_task:
await self._process_task
class WatcherManager:
# Checks which folders are set up on arr and sets a watcher on them for deletes
def __init__(self, settings):
@@ -117,24 +122,25 @@ class WatcherManager:
f"Job 'detect_deletions' on {arr.name} ({arr.base_url}) does not have access to this path and will not monitor it: '{path}'"
)
logger.info(
'>>> 💡 Tip: Make sure that the paths in decluttarr and in your arr instance are identical.'
">>> 💡 Tip: Make sure that the paths in decluttarr and in your arr instance are identical."
)
if self.settings.envs.in_docker:
logger.info(
'>>> 💡 Tip: Make sure decluttarr and your arr instance have the same mount points'
">>> 💡 Tip: Make sure decluttarr and your arr instance have the same mount points"
)
return folders_to_watch
def set_watcher(self, arr, folder_to_watch):
"""Adds a file deletion watcher for the specified folder and arr instance, creating an event handler to process deletion events and an observer to monitor the filesystem; starts the observer and stores both the handler and observer for later management
"""
"""Adds a file deletion watcher for the specified folder and arr instance, creating an event handler to process deletion events and an observer to monitor the filesystem; starts the observer and stores both the handler and observer for later management"""
event_handler = DeletionHandler(arr, self.loop)
observer = Observer()
observer.schedule(event_handler, folder_to_watch, recursive=True)
observer.start()
self.handlers.append(event_handler)
logger.verbose(f"Job 'detect_deletions' started monitoring folder on {arr.name} ({arr.base_url}): {folder_to_watch}")
logger.verbose(
f"Job 'detect_deletions' started monitoring folder on {arr.name} ({arr.base_url}): {folder_to_watch}"
)
self.observers.append(observer)
def stop(self):

View File

@@ -102,7 +102,9 @@ class JobManager:
f"job_manager.py/_check_client_connection_status: Checking if {client.name} is connected"
)
if not await client.check_connected():
logger.warning(f">>> {client.name} is disconnected. Skipping queue cleaning on {self.arr.name}.")
logger.warning(
f">>> {client.name} is disconnected. Skipping queue cleaning on {self.arr.name}."
)
return False
return True

View File

@@ -10,9 +10,10 @@ class RemovalHandler:
async def remove_downloads(self, affected_downloads, blocklist):
for download_id in list(affected_downloads.keys()):
affected_download = affected_downloads[download_id]
handling_method = await self._get_handling_method(download_id, affected_download)
handling_method = await self._get_handling_method(
download_id, affected_download
)
if download_id in self.arr.tracker.deleted or handling_method == "skip":
del affected_downloads[download_id]
@@ -30,28 +31,36 @@ class RemovalHandler:
self.arr.tracker.deleted.append(download_id)
async def _remove_download(self, affected_download, download_id, blocklist):
queue_id = affected_download["queue_ids"][0]
logger.info(f"Job '{self.job_name}' triggered removal: {affected_download['title']}")
logger.info(
f"Job '{self.job_name}' triggered removal: {affected_download['title']}"
)
logger.debug(f"remove_handler.py/_remove_download: download_id={download_id}")
await self.arr.remove_queue_item(queue_id=queue_id, blocklist=blocklist)
async def _tag_as_obsolete(self, affected_download, download_id):
logger.info(f"Job '{self.job_name}' triggered obsolete-tagging: {affected_download['title']}")
logger.info(
f"Job '{self.job_name}' triggered obsolete-tagging: {affected_download['title']}"
)
for qbit in self.settings.download_clients.qbittorrent:
await qbit.set_tag(tags=[self.settings.general.obsolete_tag], hashes=[download_id])
await qbit.set_tag(
tags=[self.settings.general.obsolete_tag], hashes=[download_id]
)
async def _get_handling_method(self, download_id, affected_download):
if affected_download['protocol'] != 'torrent':
return "remove" # handling is only implemented for torrent
if affected_download["protocol"] != "torrent":
return "remove" # handling is only implemented for torrent
download_client_name = affected_download["downloadClient"]
_, download_client_type = self.settings.download_clients.get_download_client_by_name(download_client_name)
_, download_client_type = (
self.settings.download_clients.get_download_client_by_name(
download_client_name
)
)
if download_client_type != "qbittorrent":
return "remove" # handling is only implemented for qbit
return "remove" # handling is only implemented for qbit
if len(self.settings.download_clients.qbittorrent) == 0:
return "remove" # qbit not configured, thus can't tag

View File

@@ -25,38 +25,45 @@ class RemovalJob(ABC):
self.queue_manager = QueueManager(self.arr, self.settings)
self.max_strikes = getattr(self.job, "max_strikes", None)
if self.max_strikes:
self.strikes_handler = StrikesHandler(job_name=self.job_name, arr=self.arr, max_strikes=self.max_strikes)
self.strikes_handler = StrikesHandler(
job_name=self.job_name, arr=self.arr, max_strikes=self.max_strikes
)
async def run(self) -> int:
if not self.job.enabled:
return 0
logger.debug(f"removal_job.py/run: Launching job '{self.job_name}', and checking if any items in queue (queue_scope='{self.queue_scope}').")
self.queue = await self.queue_manager.get_queue_items(queue_scope=self.queue_scope)
logger.debug(
f"removal_job.py/run: Launching job '{self.job_name}', and checking if any items in queue (queue_scope='{self.queue_scope}')."
)
self.queue = await self.queue_manager.get_queue_items(
queue_scope=self.queue_scope
)
# Handle empty queue
if not self.queue:
return 0
self.affected_items = await self._find_affected_items()
self.affected_downloads = self.queue_manager.group_by_download_id(self.affected_items)
self.affected_downloads = self.queue_manager.group_by_download_id(
self.affected_items
)
# -- Checks --
self._ignore_protected()
if self.max_strikes:
self.affected_downloads = self.strikes_handler.filter_strike_exceeds(self.affected_downloads, self.queue)
self.affected_downloads = self.strikes_handler.filter_strike_exceeds(
self.affected_downloads, self.queue
)
# -- Removal --
await RemovalHandler(
arr=self.arr,
settings=self.settings,
job_name=self.job_name,
).remove_downloads(self.affected_downloads, self.blocklist)
arr=self.arr,
settings=self.settings,
job_name=self.job_name,
).remove_downloads(self.affected_downloads, self.blocklist)
return len(self.affected_downloads)
def _ignore_protected(self):
"""
Filter out downloads that are in the protected tracker.

View File

@@ -3,7 +3,6 @@ from pathlib import Path
from src.jobs.removal_job import RemovalJob
from src.utils.log_setup import logger
# fmt: off
STANDARD_EXTENSIONS = [
# Movies, TV Shows (Radarr, Sonarr, Whisparr)
@@ -61,7 +60,11 @@ class RemoveBadFiles(RemovalJob):
if not download_client_name:
continue
download_client, download_client_type = self.settings.download_clients.get_download_client_by_name(download_client_name)
download_client, download_client_type = (
self.settings.download_clients.get_download_client_by_name(
download_client_name
)
)
if not download_client or not download_client_type:
continue
@@ -69,14 +72,16 @@ class RemoveBadFiles(RemovalJob):
if download_client_type != "qbittorrent":
continue
result.setdefault(download_client, {
"download_client_type": download_client_type,
"download_ids": set(),
})["download_ids"].add(item["downloadId"])
result.setdefault(
download_client,
{
"download_client_type": download_client_type,
"download_ids": set(),
},
)["download_ids"].add(item["downloadId"])
return result
async def _handle_qbit(self, qbit_client, hashes):
"""Handle qBittorrent-specific logic for marking files as 'Do Not Download'."""
affected_items = []
@@ -86,7 +91,9 @@ class RemoveBadFiles(RemovalJob):
self.arr.tracker.extension_checked.append(qbit_item["hash"])
if qbit_item["hash"].upper() in self.arr.tracker.protected: # Do not stop files in protected torrents
if (
qbit_item["hash"].upper() in self.arr.tracker.protected
): # Do not stop files in protected torrents
continue
torrent_files = await self._get_active_files(qbit_client, qbit_item["hash"])
@@ -95,11 +102,15 @@ class RemoveBadFiles(RemovalJob):
if not stoppable_files:
continue
await self._mark_files_as_stopped(qbit_client, qbit_item["hash"], stoppable_files)
await self._mark_files_as_stopped(
qbit_client, qbit_item["hash"], stoppable_files
)
self._log_stopped_files(stoppable_files, qbit_item["name"])
if self._all_files_stopped(torrent_files, stoppable_files):
logger.verbose(">>> All files in this torrent have been marked as 'Do not Download'. Removing torrent.")
logger.verbose(
">>> All files in this torrent have been marked as 'Do not Download'. Removing torrent."
)
affected_items.extend(self._match_queue_items(qbit_item["hash"]))
return affected_items
@@ -113,28 +124,36 @@ class RemoveBadFiles(RemovalJob):
Additionally, each download should be checked at least once (for bad extensions), and thereafter only if availability drops to less than 100%
"""
return [
item for item in qbit_items
item
for item in qbit_items
if (
item.get("has_metadata")
and item["state"] in {"downloading", "forcedDL", "stalledDL"}
and (
item["hash"] not in self.arr.tracker.extension_checked
or item["availability"] < 1
)
item.get("has_metadata")
and item["state"] in {"downloading", "forcedDL", "stalledDL"}
and (
item["hash"] not in self.arr.tracker.extension_checked
or item["availability"] < 1
)
)
]
@staticmethod
async def _get_active_files(qbit_client, torrent_hash) -> list[dict]:
"""Return only files from the torrent that are still set to download, with file extension and name."""
files = await qbit_client.get_torrent_files(torrent_hash) # Await the async method
files = await qbit_client.get_torrent_files(
torrent_hash
) # Await the async method
return [
{
**f, # Include all original file properties
"file_name": Path(f["name"]).name, # Add proper filename (without folder)
"file_extension": Path(f["name"]).suffix, # Add file_extension (e.g., .mp3)
"file_name": Path(
f["name"]
).name, # Add proper filename (without folder)
"file_extension": Path(
f["name"]
).suffix, # Add file_extension (e.g., .mp3)
}
for f in files if f["priority"] > 0
for f in files
if f["priority"] > 0
]
def _log_stopped_files(self, stopped_files, torrent_name) -> None:
@@ -164,7 +183,9 @@ class RemoveBadFiles(RemovalJob):
# Check if the file has low availability
if self._is_complete_partial(file):
reasons.append(f"Low availability: {file['availability'] * 100:.1f}%")
reasons.append(
f"Low availability: {file['availability'] * 100:.1f}%"
)
# Only add to stoppable_files if there are reasons to stop the file
if reasons:
@@ -188,8 +209,8 @@ class RemoveBadFiles(RemovalJob):
file_size_mb = file.get("size", 0) / 1024 / 1024
return (
any(keyword.lower() in file_path for keyword in BAD_KEYWORDS)
and file_size_mb <= BAD_KEYWORD_LIMIT
any(keyword.lower() in file_path for keyword in BAD_KEYWORDS)
and file_size_mb <= BAD_KEYWORD_LIMIT
)
@staticmethod
@@ -206,11 +227,15 @@ class RemoveBadFiles(RemovalJob):
def _all_files_stopped(torrent_files, stoppable_files) -> bool:
"""Check if all files are either stopped (priority 0) or in the stoppable files list."""
stoppable_file_indexes = {file[0]["index"] for file in stoppable_files}
return all(f["priority"] == 0 or f["index"] in stoppable_file_indexes for f in torrent_files)
return all(
f["priority"] == 0 or f["index"] in stoppable_file_indexes
for f in torrent_files
)
def _match_queue_items(self, download_hash) -> list:
"""Find matching queue item(s) by downloadId (uppercase)."""
return [
item for item in self.queue
item
for item in self.queue
if item["downloadId"].upper() == download_hash.upper()
]

View File

@@ -26,7 +26,12 @@ class RemoveFailedImports(RemovalJob):
def _is_valid_item(item) -> bool:
"""Check if item has the necessary fields and is in a valid state."""
# Required fields that must be present in the item
required_fields = {"status", "trackedDownloadStatus", "trackedDownloadState", "statusMessages"}
required_fields = {
"status",
"trackedDownloadStatus",
"trackedDownloadState",
"statusMessages",
}
# Check if all required fields are present
if not all(field in item for field in required_fields):
@@ -38,7 +43,10 @@ class RemoveFailedImports(RemovalJob):
# Check if the tracked download state is one of the allowed states
# If all checks pass, the item is valid
return not (item["trackedDownloadState"] not in {"importPending", "importFailed", "importBlocked"})
return not (
item["trackedDownloadState"]
not in {"importPending", "importFailed", "importBlocked"}
)
def _prepare_removal_messages(self, item, patterns) -> list[str]:
"""Prepare removal messages, adding the tracked download state and matching messages."""
@@ -49,11 +57,10 @@ class RemoveFailedImports(RemovalJob):
removal_messages = [
f"↳ Tracked Download State: {item['trackedDownloadState']}",
f"↳ Status Messages:",
*[f" - {msg}" for msg in messages]
*[f" - {msg}" for msg in messages],
]
return removal_messages
@staticmethod
def _get_matching_messages(status_messages, patterns) -> list[str]:
"""Extract unique messages matching the provided patterns (or all messages if no pattern)."""
@@ -61,6 +68,7 @@ class RemoveFailedImports(RemovalJob):
msg
for status_message in status_messages
for msg in status_message.get("messages", [])
if not patterns or any(fnmatch.fnmatch(msg, pattern) for pattern in patterns)
if not patterns
or any(fnmatch.fnmatch(msg, pattern) for pattern in patterns)
]
return list(dict.fromkeys(messages))

View File

@@ -3,6 +3,7 @@ from src.utils.log_setup import logger
DISABLE_OVER_BANDWIDTH_USAGE = 0.8
class RemoveSlow(RemovalJob):
queue_scope = "normal"
blocklist = True
@@ -62,16 +63,23 @@ class RemoveSlow(RemovalJob):
def _checked_before(item, checked_ids):
download_id = item.get("downloadId", "None")
if download_id in checked_ids:
return True # One downloadId may occur in multiple items - only check once for all of them per iteration
return True # One downloadId may occur in multiple items - only check once for all of them per iteration
checked_ids.add(download_id)
return False
@staticmethod
def _missing_keys(item) -> bool:
required_keys = {"downloadId", "size", "sizeleft", "status", "protocol", "download_client", "download_client_type"}
required_keys = {
"downloadId",
"size",
"sizeleft",
"status",
"protocol",
"download_client",
"download_client_type",
}
return not required_keys.issubset(item)
@staticmethod
def _not_downloading(item) -> bool:
return item.get("status") != "downloading"
@@ -88,13 +96,16 @@ class RemoveSlow(RemovalJob):
download_progress = await self._get_download_progress(item, download_id)
previous_progress, increment, speed = self._compute_increment_and_speed(
download_id, download_progress,
download_id,
download_progress,
)
# For SABnzbd, use calculated speed from API data
if item["download_client_type"] == "sabnzbd":
try:
api_speed = await item["download_client"].get_item_download_speed(download_id)
api_speed = await item["download_client"].get_item_download_speed(
download_id
)
if api_speed is not None:
speed = api_speed
logger.debug(f"SABnzbd API speed for {item['title']}: {speed} KB/s")
@@ -103,19 +114,22 @@ class RemoveSlow(RemovalJob):
self.arr.tracker.download_progress[download_id] = download_progress
return download_progress, previous_progress, increment, speed
async def _get_download_progress(self, item, download_id):
# Grabs the progress from qbit or SABnzbd if possible, else calculates it based on progress (imprecise)
if item["download_client_type"] == "qbittorrent":
try:
progress = await item["download_client"].fetch_download_progress(download_id)
progress = await item["download_client"].fetch_download_progress(
download_id
)
if progress is not None:
return progress
except Exception: # noqa: BLE001
pass # fall back below
elif item["download_client_type"] == "sabnzbd":
try:
progress = await item["download_client"].fetch_download_progress(download_id)
progress = await item["download_client"].fetch_download_progress(
download_id
)
if progress is not None:
return progress
except Exception: # noqa: BLE001
@@ -152,11 +166,14 @@ class RemoveSlow(RemovalJob):
# Adds the download client to the queue item
for item in self.queue:
download_client_name = item["downloadClient"]
download_client, download_client_type = self.settings.download_clients.get_download_client_by_name(download_client_name)
download_client, download_client_type = (
self.settings.download_clients.get_download_client_by_name(
download_client_name
)
)
item["download_client"] = download_client
item["download_client_type"] = download_client_type
async def update_bandwidth_usage(self):
# Refreshes the current bandwidth usage for each client
processed_clients = set()

View File

@@ -84,7 +84,7 @@ class SearchHandler:
return items[: self.job.max_concurrent_searches]
def _filter_already_downloading(self, wanted_items, queue):
queue_ids = {q['detail_item_id'] for q in queue}
queue_ids = {q["detail_item_id"] for q in queue}
return [item for item in wanted_items if item["id"] not in queue_ids]
async def _trigger_search(self, items):

View File

@@ -1,6 +1,8 @@
import logging
from src.utils.log_setup import logger
class StrikesHandler:
def __init__(self, job_name, arr, max_strikes):
self.job_name = job_name
@@ -9,7 +11,9 @@ class StrikesHandler:
self.tracker.defective.setdefault(job_name, {})
def filter_strike_exceeds(self, affected_downloads, queue):
recovered, removed_from_queue, paused = self._recover_downloads(affected_downloads, queue)
recovered, removed_from_queue, paused = self._recover_downloads(
affected_downloads, queue
)
strike_exceeds = self._apply_strikes_and_filter(affected_downloads)
if logger.isEnabledFor(logging.DEBUG):
self.log_change(recovered, removed_from_queue, paused, strike_exceeds)
@@ -23,14 +27,21 @@ class StrikesHandler:
if entry:
entry["tracking_paused"] = True
entry["pause_reason"] = reason
logger.debug("strikes_handler.py/StrikesHandler/pause_entry: Paused tracking for %s due to: %s", download_id, reason)
logger.debug(
"strikes_handler.py/StrikesHandler/pause_entry: Paused tracking for %s due to: %s",
download_id,
reason,
)
def unpause_entry(self, download_id):
entry = self.get_entry(download_id)
if entry:
entry.pop("tracking_paused", None)
entry.pop("pause_reason", None)
logger.debug("strikes_handler.py/StrikesHandler/unpause_entry: Unpaused tracking for %s", download_id)
logger.debug(
"strikes_handler.py/StrikesHandler/unpause_entry: Unpaused tracking for %s",
download_id,
)
# pylint: disable=too-many-locals, too-many-branches
def log_change(self, recovered, removed_from_queue, paused, strike_exceeds):
@@ -57,7 +68,9 @@ class StrikesHandler:
strikes = entry.get("strikes")
if d_id in paused:
reason = entry.get("pause_reason", "unknown reason")
paused_entries.append(f"'{d_id}' [{strikes}/{self.max_strikes}, {reason}]")
paused_entries.append(
f"'{d_id}' [{strikes}/{self.max_strikes}, {reason}]"
)
elif d_id in strike_exceeds:
strike_exceeded.append(f"'{d_id}' [{strikes}/{self.max_strikes}]")
elif strikes == 1:
@@ -71,27 +84,37 @@ class StrikesHandler:
for d_id in removed_from_queue:
removed_entries.append(d_id)
log_lines = [f"strikes_handler.py/log_change/defective tracker '{self.job_name}':"]
log_lines = [
f"strikes_handler.py/log_change/defective tracker '{self.job_name}':"
]
if added:
log_lines.append(f"Added ({len(added)}): {', '.join(added)}")
if incremented:
log_lines.append(f"Incremented ({len(incremented)}) [strikes]: {', '.join(incremented)}")
log_lines.append(
f"Incremented ({len(incremented)}) [strikes]: {', '.join(incremented)}"
)
if paused_entries:
log_lines.append(f"Tracking Paused ({len(paused_entries)}) [strikes, reason]: {', '.join(paused_entries)}")
log_lines.append(
f"Tracking Paused ({len(paused_entries)}) [strikes, reason]: {', '.join(paused_entries)}"
)
if removed_entries:
log_lines.append(f"Removed from queue ({len(removed_entries)}): {', '.join(removed_entries)}")
log_lines.append(
f"Removed from queue ({len(removed_entries)}): {', '.join(removed_entries)}"
)
if recovered_entries:
log_lines.append(f"Recovered ({len(recovered_entries)}): {', '.join(recovered_entries)}")
log_lines.append(
f"Recovered ({len(recovered_entries)}): {', '.join(recovered_entries)}"
)
if strike_exceeded:
log_lines.append(f"Strikes Exceeded ({len(strike_exceeded)}): {', '.join(strike_exceeded)}")
log_lines.append(
f"Strikes Exceeded ({len(strike_exceeded)}): {', '.join(strike_exceeded)}"
)
logger.debug("\n".join(log_lines))
return added, incremented, paused, recovered, strike_exceeds, removed_from_queue
def _recover_downloads(self, affected_downloads, queue):
"""
Identifies downloads that were previously tracked and are now no longer affected as recovered.
@@ -130,7 +153,9 @@ class StrikesHandler:
log_level = logger.verbose
removed_from_queue.append(d_id)
log_level(f">>> Job '{self.job_name,}' no longer flagging download (download {recovery_reason}): {entry['title']}")
log_level(
f">>> Job '{self.job_name,}' no longer flagging download (download {recovery_reason}): {entry['title']}"
)
del job_tracker[d_id]
return recovered, removed_from_queue, paused
@@ -154,7 +179,6 @@ class StrikesHandler:
entry["strikes"] += 1
return entry["strikes"]
def _log_strike_status(self, title, strikes, strikes_left):
# -1 is the first time no strikes are remaining and thus removal will be triggered
# Since the removal itself sparks an appropriate message, we don't need to show the message again here on info-level
@@ -172,7 +196,7 @@ class StrikesHandler:
title,
)
if strikes_left <= -2: # noqa: PLR2004
if strikes_left <= -2: # noqa: PLR2004
logger.info(
'>>> 💡 Tip: Since this download should already have been removed in a previous iteration but keeps coming back, this indicates the blocking of the torrent does not work correctly. Consider turning on the option "Reject Blocklisted Torrent Hashes While Grabbing" on the indexer in the *arr app: %s',
title,

View File

@@ -18,8 +18,7 @@ def filter_internal_attributes(data, internal_attributes, hide_internal_attr):
def clean_dict(data, sensitive_attributes, internal_attributes, hide_internal_attr):
"""Clean a dictionary by masking sensitive attributes and filtering internal ones."""
cleaned = {
k: mask_sensitive_value(v, k, sensitive_attributes)
for k, v in data.items()
k: mask_sensitive_value(v, k, sensitive_attributes) for k, v in data.items()
}
return filter_internal_attributes(cleaned, internal_attributes, hide_internal_attr)
@@ -29,9 +28,20 @@ def clean_list(obj, sensitive_attributes, internal_attributes, hide_internal_att
cleaned_list = []
for entry in obj:
if isinstance(entry, dict):
cleaned_list.append(clean_dict(entry, sensitive_attributes, internal_attributes, hide_internal_attr))
cleaned_list.append(
clean_dict(
entry, sensitive_attributes, internal_attributes, hide_internal_attr
)
)
elif hasattr(entry, "__dict__"):
cleaned_list.append(clean_dict(vars(entry), sensitive_attributes, internal_attributes, hide_internal_attr))
cleaned_list.append(
clean_dict(
vars(entry),
sensitive_attributes,
internal_attributes,
hide_internal_attr,
)
)
else:
cleaned_list.append(entry)
return cleaned_list
@@ -40,9 +50,13 @@ def clean_list(obj, sensitive_attributes, internal_attributes, hide_internal_att
def clean_object(obj, sensitive_attributes, internal_attributes, hide_internal_attr):
"""Clean an object (either a dict, class instance, or other types)."""
if isinstance(obj, dict):
return clean_dict(obj, sensitive_attributes, internal_attributes, hide_internal_attr)
return clean_dict(
obj, sensitive_attributes, internal_attributes, hide_internal_attr
)
if hasattr(obj, "__dict__"):
return clean_dict(vars(obj), sensitive_attributes, internal_attributes, hide_internal_attr)
return clean_dict(
vars(obj), sensitive_attributes, internal_attributes, hide_internal_attr
)
return mask_sensitive_value(obj, "", sensitive_attributes)
@@ -68,7 +82,10 @@ def get_config_as_yaml(
# Process list-based config
if isinstance(obj, list):
cleaned_list = clean_list(
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
obj,
sensitive_attributes,
internal_attributes,
hide_internal_attr,
)
if cleaned_list:
config_output[key] = cleaned_list
@@ -76,9 +93,14 @@ def get_config_as_yaml(
# Process dict or class-like object config
else:
cleaned_obj = clean_object(
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
obj,
sensitive_attributes,
internal_attributes,
hide_internal_attr,
)
if cleaned_obj:
config_output[key] = cleaned_obj
return yaml.dump(config_output, indent=2, default_flow_style=False, sort_keys=False).strip()
return yaml.dump(
config_output, indent=2, default_flow_style=False, sort_keys=False
).strip()

View File

@@ -80,7 +80,6 @@ class DownloadClients:
raise ValueError(error)
seen.add(name.lower())
def get_download_client_by_name(
self, name: str, download_client_type: str | None = None
):
@@ -116,7 +115,6 @@ class DownloadClients:
download_client_type = mapping.get(arr_download_client_implementation)
return download_client_type
def list_download_clients(self) -> dict[str, list[str]]:
"""
Return a dict mapping download_client_type to list of client names

View File

@@ -1,7 +1,7 @@
from packaging import version
from src.settings._constants import ApiEndpoints, MinVersions
from src.utils.common import make_request, wait_and_exit, extract_json_from_response
from src.utils.common import extract_json_from_response, make_request, wait_and_exit
from src.utils.log_setup import logger

View File

@@ -76,15 +76,11 @@ class SabnzbdClient:
async def fetch_version(self):
"""Fetch the current SABnzbd version."""
logger.debug("_download_clients_sabnzbd.py/fetch_version: Getting SABnzbd Version")
params = {
"mode": "version",
"apikey": self.api_key,
"output": "json"
}
response = await make_request(
"get", self.api_url, self.settings, params=params
logger.debug(
"_download_clients_sabnzbd.py/fetch_version: Getting SABnzbd Version"
)
params = {"mode": "version", "apikey": self.api_key, "output": "json"}
response = await make_request("get", self.api_url, self.settings, params=params)
response_data = response.json()
self.version = response_data.get("version", "unknown")
logger.debug(
@@ -108,11 +104,7 @@ class SabnzbdClient:
logger.debug(
"_download_clients_sabnzbd.py/check_sabnzbd_reachability: Checking if SABnzbd is reachable"
)
params = {
"mode": "version",
"apikey": self.api_key,
"output": "json"
}
params = {"mode": "version", "apikey": self.api_key, "output": "json"}
await make_request(
"get",
self.api_url,
@@ -132,11 +124,7 @@ class SabnzbdClient:
logger.debug(
"_download_clients_sabnzbd.py/check_connected: Checking if SABnzbd is connected"
)
params = {
"mode": "status",
"apikey": self.api_key,
"output": "json"
}
params = {"mode": "status", "apikey": self.api_key, "output": "json"}
response = await make_request(
"get",
self.api_url,
@@ -164,12 +152,10 @@ class SabnzbdClient:
async def get_queue_items(self):
"""Fetch queue items from SABnzbd."""
logger.debug("_download_clients_sabnzbd.py/get_queue_items: Getting queue items")
params = {
"mode": "queue",
"apikey": self.api_key,
"output": "json"
}
logger.debug(
"_download_clients_sabnzbd.py/get_queue_items: Getting queue items"
)
params = {"mode": "queue", "apikey": self.api_key, "output": "json"}
response = await make_request(
"get",
self.api_url,
@@ -181,12 +167,10 @@ class SabnzbdClient:
async def get_history_items(self):
"""Fetch history items from SABnzbd."""
logger.debug("_download_clients_sabnzbd.py/get_history_items: Getting history items")
params = {
"mode": "history",
"apikey": self.api_key,
"output": "json"
}
logger.debug(
"_download_clients_sabnzbd.py/get_history_items: Getting history items"
)
params = {"mode": "history", "apikey": self.api_key, "output": "json"}
response = await make_request(
"get",
self.api_url,
@@ -198,13 +182,15 @@ class SabnzbdClient:
async def remove_download(self, nzo_id: str):
"""Remove a download from SABnzbd queue."""
logger.debug(f"_download_clients_sabnzbd.py/remove_download: Removing download {nzo_id}")
logger.debug(
f"_download_clients_sabnzbd.py/remove_download: Removing download {nzo_id}"
)
params = {
"mode": "queue",
"name": "delete",
"value": nzo_id,
"apikey": self.api_key,
"output": "json"
"output": "json",
}
await make_request(
"get",
@@ -215,13 +201,15 @@ class SabnzbdClient:
async def pause_download(self, nzo_id: str):
"""Pause a download in SABnzbd queue."""
logger.debug(f"_download_clients_sabnzbd.py/pause_download: Pausing download {nzo_id}")
logger.debug(
f"_download_clients_sabnzbd.py/pause_download: Pausing download {nzo_id}"
)
params = {
"mode": "queue",
"name": "pause",
"value": nzo_id,
"apikey": self.api_key,
"output": "json"
"output": "json",
}
await make_request(
"get",
@@ -232,13 +220,15 @@ class SabnzbdClient:
async def resume_download(self, nzo_id: str):
"""Resume a download in SABnzbd queue."""
logger.debug(f"_download_clients_sabnzbd.py/resume_download: Resuming download {nzo_id}")
logger.debug(
f"_download_clients_sabnzbd.py/resume_download: Resuming download {nzo_id}"
)
params = {
"mode": "queue",
"name": "resume",
"value": nzo_id,
"apikey": self.api_key,
"output": "json"
"output": "json",
}
await make_request(
"get",
@@ -249,12 +239,14 @@ class SabnzbdClient:
async def retry_download(self, nzo_id: str):
"""Retry a failed download from SABnzbd history."""
logger.debug(f"_download_clients_sabnzbd.py/retry_download: Retrying download {nzo_id}")
logger.debug(
f"_download_clients_sabnzbd.py/retry_download: Retrying download {nzo_id}"
)
params = {
"mode": "retry",
"value": nzo_id,
"apikey": self.api_key,
"output": "json"
"output": "json",
}
await make_request(
"get",
@@ -311,11 +303,7 @@ class SabnzbdClient:
async def get_download_speed(self):
"""Get current download speed from SABnzbd status."""
params = {
"mode": "status",
"apikey": self.api_key,
"output": "json"
}
params = {"mode": "status", "apikey": self.api_key, "output": "json"}
response = await make_request(
"get",
self.api_url,

View File

@@ -23,17 +23,29 @@ class General:
self.log_level = general_config.get("log_level", self.log_level.upper())
self.test_run = general_config.get("test_run", self.test_run)
self.timer = general_config.get("timer", self.timer)
self.ssl_verification = general_config.get("ssl_verification", self.ssl_verification)
self.ignored_download_clients = general_config.get("ignored_download_clients", self.ignored_download_clients)
self.ssl_verification = general_config.get(
"ssl_verification", self.ssl_verification
)
self.ignored_download_clients = general_config.get(
"ignored_download_clients", self.ignored_download_clients
)
self.private_tracker_handling = general_config.get("private_tracker_handling", self.private_tracker_handling)
self.public_tracker_handling = general_config.get("public_tracker_handling", self.public_tracker_handling)
self.private_tracker_handling = general_config.get(
"private_tracker_handling", self.private_tracker_handling
)
self.public_tracker_handling = general_config.get(
"public_tracker_handling", self.public_tracker_handling
)
self.obsolete_tag = general_config.get("obsolete_tag", self.obsolete_tag)
self.protected_tag = general_config.get("protected_tag", self.protected_tag)
# Validate tracker handling settings
self.private_tracker_handling = self._validate_tracker_handling(self.private_tracker_handling, "private_tracker_handling")
self.public_tracker_handling = self._validate_tracker_handling(self.public_tracker_handling, "public_tracker_handling")
self.private_tracker_handling = self._validate_tracker_handling(
self.private_tracker_handling, "private_tracker_handling"
)
self.public_tracker_handling = self._validate_tracker_handling(
self.public_tracker_handling, "public_tracker_handling"
)
self.obsolete_tag = self._determine_obsolete_tag(self.obsolete_tag)
validate_data_types(self)

View File

@@ -1,4 +1,5 @@
from pathlib import Path
import requests
from packaging import version
@@ -9,10 +10,10 @@ from src.settings._constants import (
DetailItemSearchCommand,
FullQueueParameter,
MinVersions,
RefreshItemKey,
RefreshItemCommand,
RefreshItemKey,
)
from src.utils.common import make_request, wait_and_exit, extract_json_from_response
from src.utils.common import extract_json_from_response, make_request, wait_and_exit
from src.utils.log_setup import logger
@@ -154,7 +155,7 @@ class ArrInstance:
self.detail_item_id_key = self.detail_item_key + "Id"
self.detail_item_ids_key = self.detail_item_key + "Ids"
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
if self.arr_type in ('radarr','sonarr'):
if self.arr_type in ("radarr", "sonarr"):
self.refresh_item_key = getattr(RefreshItemKey, arr_type)
self.refresh_item_id_key = self.refresh_item_key + "Id"
self.refresh_item_command = getattr(RefreshItemCommand, arr_type)

View File

@@ -109,7 +109,6 @@ class Jobs:
)
self.detect_deletions = JobParams()
def _set_job_configs(self, config):
# Populate jobs from YAML config
for job_name in self.__dict__:

View File

@@ -1,5 +1,6 @@
import os
from pathlib import Path
import yaml
from src.utils.log_setup import logger

View File

@@ -23,7 +23,9 @@ def validate_data_types(cls, default_cls=None):
continue
value = getattr(cls, attr)
default_source = default_cls if default_cls and hasattr(default_cls, attr) else cls.__class__
default_source = (
default_cls if default_cls and hasattr(default_cls, attr) else cls.__class__
)
default_value = getattr(default_source, attr, None)
if value == default_value:

View File

@@ -1,8 +1,9 @@
import asyncio
import copy
import logging
import sys
import time
import logging
import copy
import requests
from src.utils.log_setup import logger
@@ -23,7 +24,11 @@ def sanitize_kwargs(data):
if isinstance(data, dict):
redacted = {}
for key, value in data.items():
if key.lower() in {"username", "password", "x-api-key", "apikey", "cookies"} and value:
if (
key.lower()
in {"username", "password", "x-api-key", "apikey", "cookies"}
and value
):
redacted[key] = "[**redacted**]"
else:
redacted[key] = sanitize_kwargs(value)
@@ -34,7 +39,13 @@ def sanitize_kwargs(data):
async def make_request(
method: str, endpoint: str, settings, timeout: int = 15, *, log_error=True, **kwargs,
method: str,
endpoint: str,
settings,
timeout: int = 15,
*,
log_error=True,
**kwargs,
) -> requests.Response:
"""
A utility function to make HTTP requests (GET, POST, DELETE, PUT).
@@ -50,7 +61,6 @@ async def make_request(
)
return DummyResponse(text="Test run - no actual call made", status_code=200)
try:
if logger.isEnabledFor(logging.DEBUG):
sanitized_kwargs = sanitize_kwargs(copy.deepcopy(kwargs))

View File

@@ -33,7 +33,9 @@ logger = logging.getLogger(__name__)
def set_handler_format(log_handler, *, long_format=True):
if long_format:
target_format = logging.Formatter("%(asctime)s | %(levelname)-7s | %(message)s", "%Y-%m-%d %H:%M:%S")
target_format = logging.Formatter(
"%(asctime)s | %(levelname)-7s | %(message)s", "%Y-%m-%d %H:%M:%S"
)
else:
target_format = logging.Formatter("%(levelname)-7s | %(message)s")
log_handler.setFormatter(target_format)
@@ -56,7 +58,9 @@ def configure_logging(settings):
Path(log_dir).mkdir(exist_ok=True, parents=True)
# File handler
file_handler = RotatingFileHandler(log_file, maxBytes=50 * 1024 * 1024, backupCount=2)
file_handler = RotatingFileHandler(
log_file, maxBytes=50 * 1024 * 1024, backupCount=2
)
set_handler_format(file_handler, long_format=True)
logger.addHandler(file_handler)

View File

@@ -1,6 +1,7 @@
import logging
from typing import Union
from src.utils.common import make_request, extract_json_from_response
from src.utils.common import extract_json_from_response, make_request
from src.utils.log_setup import logger
@@ -31,8 +32,12 @@ class QueueManager:
full_queue = await self._get_queue(full_queue=True)
normal_queue = await self._get_queue(full_queue=False)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"queue_manager.py/get_queue_items (full) to determine orphans: Current queue ({len(full_queue)} items) = {self.format_queue(full_queue)}")
logger.debug(f"queue_manager.py/get_queue_items (normal) to determine orphans: Current queue ({len(normal_queue)} items) = {self.format_queue(normal_queue)}")
logger.debug(
f"queue_manager.py/get_queue_items (full) to determine orphans: Current queue ({len(full_queue)} items) = {self.format_queue(full_queue)}"
)
logger.debug(
f"queue_manager.py/get_queue_items (normal) to determine orphans: Current queue ({len(normal_queue)} items) = {self.format_queue(normal_queue)}"
)
queue_items = [fq for fq in full_queue if fq not in normal_queue]
elif queue_scope == "full":
queue_items = await self._get_queue(full_queue=True)
@@ -40,7 +45,9 @@ class QueueManager:
error = f"Invalid queue_scope: {queue_scope}"
raise ValueError(error)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"queue_manager.py/get_queue_items ({queue_scope}): Current queue ({len(queue_items)} items) = {self.format_queue(queue_items)}")
logger.debug(
f"queue_manager.py/get_queue_items ({queue_scope}): Current queue ({len(queue_items)} items) = {self.format_queue(queue_items)}"
)
return queue_items
async def _get_queue(self, *, full_queue=False):
@@ -75,7 +82,9 @@ class QueueManager:
async def _get_total_records_count(self, full_queue):
# Get the total number of records from the queue using `arr.full_queue_parameter`
params = {self.arr.full_queue_parameter: full_queue}
logger.debug("queue_manager.py/_get_total_records_count: Getting Total Records Count")
logger.debug(
"queue_manager.py/_get_total_records_count: Getting Total Records Count"
)
total_records = await self.fetch_queue_field(params, key="totalRecords")
return total_records
@@ -88,11 +97,12 @@ class QueueManager:
if full_queue:
params |= {self.arr.full_queue_parameter: full_queue}
logger.debug(f"queue_manager.py/_get_arr_records: Getting queue records ({total_records_count} items)")
logger.debug(
f"queue_manager.py/_get_arr_records: Getting queue records ({total_records_count} items)"
)
records = await self.fetch_queue_field(params, key="records")
return records
async def fetch_queue_field(self, params, key: str | None = None):
# Gets the response of the /queue endpoint and extracts a specific field from the json response
response = await make_request(
@@ -104,7 +114,6 @@ class QueueManager:
)
return extract_json_from_response(response, key=key)
def _filter_out_ignored_statuses(
self, queue, ignored_statuses=("delay", "downloadClientUnavailable")
):
@@ -202,7 +211,6 @@ class QueueManager:
return grouped_dict
@staticmethod
def filter_queue(
queue: list[dict],

View File

@@ -4,45 +4,53 @@ from src.utils.log_setup import logger
def show_welcome(settings):
messages = ["🎉🎉🎉 Decluttarr - Application Started! 🎉🎉🎉",
"-" * 80,
"⭐️ Like this app?",
"Thanks for giving it a ⭐️ on GitHub!",
"https://github.com/ManiMatter/decluttarr/"]
messages = [
"🎉🎉🎉 Decluttarr - Application Started! 🎉🎉🎉",
"-" * 80,
"⭐️ Like this app?",
"Thanks for giving it a ⭐️ on GitHub!",
"https://github.com/ManiMatter/decluttarr/",
]
# Show welcome message
# Show info level tip
if settings.general.log_level == "INFO":
messages.extend([
"",
"💡 Tip: More logs?",
"If you want to know more about what's going on, switch log level to 'VERBOSE'",
])
messages.extend(
[
"",
"💡 Tip: More logs?",
"If you want to know more about what's going on, switch log level to 'VERBOSE'",
]
)
# Show bug report tip
messages.extend([
"",
"🐛 Found a bug?",
"Before reporting bugs on GitHub, please:",
"1) Check the readme on github",
"2) Check open and closed issues on github",
"3) Switch your logs to 'DEBUG' level",
"4) Turn off any features other than the one(s) causing it",
"5) Provide the full logs via pastebin on your GitHub issue",
"Once submitted, thanks for being responsive and helping debug / re-test",
])
messages.extend(
[
"",
"🐛 Found a bug?",
"Before reporting bugs on GitHub, please:",
"1) Check the readme on github",
"2) Check open and closed issues on github",
"3) Switch your logs to 'DEBUG' level",
"4) Turn off any features other than the one(s) causing it",
"5) Provide the full logs via pastebin on your GitHub issue",
"Once submitted, thanks for being responsive and helping debug / re-test",
]
)
# Show test mode tip
if settings.general.test_run:
messages.extend([
"",
"=================== IMPORTANT ====================",
" ⚠️ ⚠️ ⚠️ TEST MODE IS ACTIVE ⚠️ ⚠️ ⚠️",
"Decluttarr won't actually do anything for you...",
"You can change this via the setting 'test_run'",
"==================================================",
])
messages.extend(
[
"",
"=================== IMPORTANT ====================",
" ⚠️ ⚠️ ⚠️ TEST MODE IS ACTIVE ⚠️ ⚠️ ⚠️",
"Decluttarr won't actually do anything for you...",
"You can change this via the setting 'test_run'",
"==================================================",
]
)
messages.append("")
# Log all messages at once

View File

@@ -1,4 +1,4 @@
from src.utils.common import make_request, extract_json_from_response
from src.utils.common import extract_json_from_response, make_request
class WantedManager:
@@ -16,7 +16,9 @@ class WantedManager:
return await self._get_arr_records(missing_or_cutoff, total_records_count)
async def _get_total_records_count(self, missing_or_cutoff: str) -> int:
total_records = await self.fetch_wanted_field(missing_or_cutoff, key="totalRecords")
total_records = await self.fetch_wanted_field(
missing_or_cutoff, key="totalRecords"
)
return total_records
async def _get_arr_records(self, missing_or_cutoff, total_records_count):
@@ -27,10 +29,14 @@ class WantedManager:
sort_key = f"{self.arr.detail_item_key}s.lastSearchTime"
params = {"page": "1", "pageSize": total_records_count, "sortKey": sort_key}
records = await self.fetch_wanted_field(missing_or_cutoff, params=params, key="records")
records = await self.fetch_wanted_field(
missing_or_cutoff, params=params, key="records"
)
return records
async def fetch_wanted_field(self, missing_or_cutoff: str, params: dict | None = None, key: str | None = None):
async def fetch_wanted_field(
self, missing_or_cutoff: str, params: dict | None = None, key: str | None = None
):
# Gets the response of the /queue endpoint and extracts a specific field from the json response
response = await make_request(
method="GET",

View File

@@ -2,9 +2,10 @@
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.deletion_handler.deletion_handler import WatcherManager, DeletionHandler
from src.deletion_handler.deletion_handler import DeletionHandler, WatcherManager
@pytest.mark.asyncio
@@ -103,6 +104,7 @@ def test_group_deletions_by_folder():
# Also check no extra keys
assert set(deletions.keys()) == set(expected.keys())
@pytest.mark.asyncio
async def test_process_deletes_after_delay_clears_deleted_files(monkeypatch):
"""Tests that _process_deletes_after_delay clears deleted files and correctly processes their parent folders asynchronously."""
@@ -149,8 +151,6 @@ async def test_process_deletes_after_delay_clears_deleted_files(monkeypatch):
assert set(arr.called) == expected_folders
@pytest.mark.asyncio
async def test_file_deletion_triggers_handler_with_watchermanager(tmp_path):
"""Tests that when a file is deleted in a watched directory,

View File

@@ -1,4 +1,5 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.jobs.removal_handler import RemovalHandler
@@ -32,7 +33,10 @@ async def test_get_handling_method(
settings.download_clients.qbittorrent = ["dummy"] if qbittorrent_configured else []
# Simulate (client_name, client_type) return
settings.download_clients.get_download_client_by_name.return_value = ("client_name", client_type)
settings.download_clients.get_download_client_by_name.return_value = (
"client_name",
client_type,
)
settings.general.private_tracker_handling = "private_handling"
settings.general.public_tracker_handling = "public_handling"
@@ -44,7 +48,7 @@ async def test_get_handling_method(
"protocol": protocol,
}
result = await handler._get_handling_method( # pylint: disable=W0212
result = await handler._get_handling_method( # pylint: disable=W0212
"A", affected_download
)
assert result == expected

View File

@@ -1,9 +1,11 @@
from pathlib import Path
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.jobs.remove_bad_files import RemoveBadFiles
# Fixture for arr mock
@pytest.fixture(name="removal_job")
def fixture_removal_job():
@@ -18,12 +20,12 @@ def fixture_removal_job():
@pytest.mark.parametrize(
"file_name, expected_result, keep_archives",
[
("file.mp4", False, False), # Good extension
("file.mkv", False, False), # Good extension
("file.avi", False, False), # Good extension
("file.mp4", False, False), # Good extension
("file.mkv", False, False), # Good extension
("file.avi", False, False), # Good extension
("file.exe", True, False), # Bad extension
("file.jpg", True, False), # Bad extension
("file.zip", True, False), # Archive - Don't keep archives
("file.zip", True, False), # Archive - Don't keep archives
("file.zip", False, True), # Archive - Keep archives
],
)
@@ -44,16 +46,56 @@ def test_is_bad_extension(removal_job, file_name, expected_result, keep_archives
@pytest.mark.parametrize(
("name", "size_bytes", "expected_result"),
[
("My.Movie.2024.2160/Subfolder/sample.mkv", 100 * 1024, True), # 100 KB, 'sample' keyword in filename
("My.Movie.2024.2160/Subfolder/Sample.mkv", 100 * 1024, True), # 100 KB, case-insensitive match
("My.Movie.2024.2160/Subfolder/sample movie.mkv", 100 * 1024, True), # 100 KB, 'sample' keyword with space
("My.Movie.2024.2160/Subfolder/samplemovie.mkv", 100 * 1024, True), # 100 KB, 'sample' keyword concatenated
("My.Movie.2024.2160/Subfolder/Movie sample.mkv", 100 * 1024, True), # 100 KB, 'sample' keyword at end
("My.Movie.2024.2160/Sample/Movie.mkv", 100 * 1024, True), # 100 KB, 'sample' keyword in folder name
("My.Movie.2024.2160/sample/Movie.mkv", 100 * 1024, True), # 100 KB, lowercase folder name
("My.Movie.2024.2160/Samples/Movie.mkv", 100 * 1024, True), # 100 KB, plural form in folder name
("My.Movie.2024.2160/Big Samples/Movie.mkv", 700 * 1024 * 1024, False), # 700 MB, large file, should NOT be flagged
("My.Movie.2024.2160/Some Folder/Movie.mkv", 100 * 1024, False), # 100 KB, no 'sample' keyword, should not flag
(
"My.Movie.2024.2160/Subfolder/sample.mkv",
100 * 1024,
True,
), # 100 KB, 'sample' keyword in filename
(
"My.Movie.2024.2160/Subfolder/Sample.mkv",
100 * 1024,
True,
), # 100 KB, case-insensitive match
(
"My.Movie.2024.2160/Subfolder/sample movie.mkv",
100 * 1024,
True,
), # 100 KB, 'sample' keyword with space
(
"My.Movie.2024.2160/Subfolder/samplemovie.mkv",
100 * 1024,
True,
), # 100 KB, 'sample' keyword concatenated
(
"My.Movie.2024.2160/Subfolder/Movie sample.mkv",
100 * 1024,
True,
), # 100 KB, 'sample' keyword at end
(
"My.Movie.2024.2160/Sample/Movie.mkv",
100 * 1024,
True,
), # 100 KB, 'sample' keyword in folder name
(
"My.Movie.2024.2160/sample/Movie.mkv",
100 * 1024,
True,
), # 100 KB, lowercase folder name
(
"My.Movie.2024.2160/Samples/Movie.mkv",
100 * 1024,
True,
), # 100 KB, plural form in folder name
(
"My.Movie.2024.2160/Big Samples/Movie.mkv",
700 * 1024 * 1024,
False,
), # 700 MB, large file, should NOT be flagged
(
"My.Movie.2024.2160/Some Folder/Movie.mkv",
100 * 1024,
False,
), # 100 KB, no 'sample' keyword, should not flag
],
)
def test_contains_bad_keyword(removal_job, name, size_bytes, expected_result):
@@ -66,8 +108,6 @@ def test_contains_bad_keyword(removal_job, name, size_bytes, expected_result):
assert result == expected_result
@pytest.mark.parametrize(
("file", "is_incomplete_partial"),
[
@@ -179,7 +219,7 @@ async def test_get_items_to_process(qbit_item, expected_processed, removal_job):
removal_job.arr.tracker.extension_checked = {"checked-hash"}
# Act
processed_items = removal_job._get_items_to_process( # pylint: disable=W0212
processed_items = removal_job._get_items_to_process( # pylint: disable=W0212
[qbit_item]
)
@@ -320,9 +360,14 @@ def fixture_torrent_files():
],
)
def test_all_files_stopped(
removal_job, torrent_files, stoppable_indexes, all_files_stopped,
removal_job,
torrent_files,
stoppable_indexes,
all_files_stopped,
):
# Create stoppable_files using only the index for each file and a dummy reason
stoppable_files = [({"index": idx}, "some reason") for idx in stoppable_indexes]
result = removal_job._all_files_stopped(torrent_files, stoppable_files) # pylint: disable=W0212
result = removal_job._all_files_stopped( # pylint: disable=W0212
torrent_files, stoppable_files
)
assert result == all_files_stopped

View File

@@ -1,7 +1,7 @@
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_failed_downloads import RemoveFailedDownloads
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
# Test to check if items with "failed" status are included in affected items with parameterized data
@@ -12,14 +12,20 @@ from src.jobs.remove_failed_downloads import RemoveFailedDownloads
(
[
{"downloadId": "1", "status": "failed"}, # Item with failed status
{"downloadId": "2", "status": "completed"}, # Item with completed status
{
"downloadId": "2",
"status": "completed",
}, # Item with completed status
{"downloadId": "3"}, # No status field
],
["1"], # Only the failed item should be affected
),
(
[
{"downloadId": "1", "status": "completed"}, # Item with completed status
{
"downloadId": "1",
"status": "completed",
}, # Item with completed status
{"downloadId": "2", "status": "completed"},
{"downloadId": "3", "status": "completed"},
],

View File

@@ -1,8 +1,9 @@
from unittest.mock import MagicMock
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_failed_imports import RemoveFailedImports
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
@pytest.mark.asyncio
@@ -141,7 +142,9 @@ async def test_find_affected_items_with_patterns(
removal_job.job.message_patterns = patterns
# Act and Assert (Shared)
affected_items = await shared_test_affected_items(removal_job, expected_download_ids)
affected_items = await shared_test_affected_items(
removal_job, expected_download_ids
)
# Check if the correct downloadIds are in the affected items
affected_download_ids = [item["downloadId"] for item in affected_items]

View File

@@ -1,8 +1,7 @@
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_metadata_missing import RemoveMetadataMissing
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
# Test to check if items with the specific error message are included in affected items with parameterized data

View File

@@ -1,7 +1,7 @@
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_missing_files import RemoveMissingFiles
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
@pytest.mark.asyncio
@@ -10,15 +10,31 @@ from src.jobs.remove_missing_files import RemoveMissingFiles
[
(
[ # valid failed torrent (warning + matching errorMessage)
{"downloadId": "1", "status": "warning", "errorMessage": "DownloadClientQbittorrentTorrentStateMissingFiles"},
{"downloadId": "2", "status": "warning", "errorMessage": "The download is missing files"},
{"downloadId": "3", "status": "warning", "errorMessage": "qBittorrent is reporting missing files"},
{
"downloadId": "1",
"status": "warning",
"errorMessage": "DownloadClientQbittorrentTorrentStateMissingFiles",
},
{
"downloadId": "2",
"status": "warning",
"errorMessage": "The download is missing files",
},
{
"downloadId": "3",
"status": "warning",
"errorMessage": "qBittorrent is reporting missing files",
},
],
["1", "2", "3"],
),
(
[ # wrong status for errorMessage, should be ignored
{"downloadId": "1", "status": "failed", "errorMessage": "The download is missing files"},
{
"downloadId": "1",
"status": "failed",
"errorMessage": "The download is missing files",
},
],
[],
),
@@ -28,7 +44,11 @@ from src.jobs.remove_missing_files import RemoveMissingFiles
"downloadId": "1",
"status": "completed",
"statusMessages": [
{"messages": ["No files found are eligible for import in /some/path"]},
{
"messages": [
"No files found are eligible for import in /some/path"
]
},
],
},
{
@@ -54,11 +74,17 @@ from src.jobs.remove_missing_files import RemoveMissingFiles
),
(
[ # Mixed: one matching warning + one matching statusMessage
{"downloadId": "1", "status": "warning", "errorMessage": "The download is missing files"},
{
"downloadId": "1",
"status": "warning",
"errorMessage": "The download is missing files",
},
{
"downloadId": "2",
"status": "completed",
"statusMessages": [{"messages": ["No files found are eligible for import in foo"]}],
"statusMessages": [
{"messages": ["No files found are eligible for import in foo"]}
],
},
{"downloadId": "3", "status": "completed"},
],

View File

@@ -1,7 +1,7 @@
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_orphans import RemoveOrphans
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
@pytest.mark.asyncio

View File

@@ -1,8 +1,9 @@
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from tests.jobs.utils import shared_fix_affected_items
from src.jobs.remove_slow import RemoveSlow
from tests.jobs.utils import shared_fix_affected_items
# pylint: disable=W0212
@@ -83,8 +84,6 @@ def test_not_downloading(item, expected_result):
assert result == expected_result
@pytest.mark.parametrize(
("item", "expected_result"),
[
@@ -238,7 +237,9 @@ async def test_get_progress_stats(
(4, "other_client", 0.9, False), # different client type
],
)
def test_high_bandwidth_usage(download_id, download_client_type, bandwidth_usage, expected):
def test_high_bandwidth_usage(
download_id, download_client_type, bandwidth_usage, expected
):
"""
Test RemoveSlow._high_bandwidth_usage method.
@@ -288,7 +289,6 @@ async def test_add_download_client_to_queue_items_simple():
assert item["download_client_type"] == download_client_type
@pytest.mark.asyncio
async def test_update_bandwidth_usage_calls_once_per_client():
"""
@@ -314,7 +314,10 @@ async def test_update_bandwidth_usage_calls_once_per_client():
"download_client_type": "qbittorrent",
}, # duplicate client
{"download_client": qb_client2, "download_client_type": "qbittorrent"},
{"download_client": sabnzbd_client, "download_client_type": "sabnzbd"}, # SABnzbd client
{
"download_client": sabnzbd_client,
"download_client_type": "sabnzbd",
}, # SABnzbd client
{"download_client": other_client, "download_client_type": "other"},
]
@@ -324,7 +327,10 @@ async def test_update_bandwidth_usage_calls_once_per_client():
qb_client1.set_bandwidth_usage.assert_awaited_once()
qb_client2.set_bandwidth_usage.assert_awaited_once()
# Verify SABnzbd and other client methods were not called (no bandwidth tracking for them)
assert not hasattr(sabnzbd_client, 'set_bandwidth_usage') or not sabnzbd_client.set_bandwidth_usage.called
assert (
not hasattr(sabnzbd_client, "set_bandwidth_usage")
or not sabnzbd_client.set_bandwidth_usage.called
)
other_client.set_bandwidth_usage.assert_not_awaited()
@@ -334,23 +340,16 @@ async def test_update_bandwidth_usage_calls_once_per_client():
[
# Already checked downloadId -> skip (simulate by repeating downloadId)
({"downloadId": "checked_before"}, False),
# Keys not present -> skip
({"downloadId": "keys_missing"}, False),
# Not Downloading -> skip
({"downloadId": "not_downloading"}, False),
# Completed but stuck -> skip
({"downloadId": "completed_but_stuck"}, False),
# High bandwidth usage -> skip
({"downloadId": "high_bandwidth"}, False),
# Not slow -> skip
({"downloadId": "not_slow"}, False),
# None of above, hence truly slow
({"downloadId": "good"}, True),
],
@@ -366,11 +365,20 @@ async def test_find_affected_items_simple(queue_item, should_be_affected):
removal_job._get_progress_stats = AsyncMock(return_value=(1000, 900, 100, 10))
# Setup checks to pass except in for the designated tests
removal_job._checked_before = lambda item, checked_ids: item.get("downloadId") == "checked_before"
removal_job._checked_before = (
lambda item, checked_ids: item.get("downloadId") == "checked_before"
)
removal_job._missing_keys = lambda item: item.get("downloadId") == "keys_missing"
removal_job._not_downloading = lambda item: item.get("downloadId") == "not_downloading"
removal_job._is_completed_but_stuck = lambda item: item.get("downloadId") == "completed_but_stuck"
removal_job._high_bandwidth_usage = lambda download_client, download_client_type=None: queue_item.get("downloadId") == "high_bandwidth"
removal_job._not_downloading = (
lambda item: item.get("downloadId") == "not_downloading"
)
removal_job._is_completed_but_stuck = (
lambda item: item.get("downloadId") == "completed_but_stuck"
)
removal_job._high_bandwidth_usage = (
lambda download_client, download_client_type=None: queue_item.get("downloadId")
== "high_bandwidth"
)
removal_job._not_slow = lambda speed: queue_item.get("downloadId") == "not_slow"
# Run the method under test
@@ -380,4 +388,6 @@ async def test_find_affected_items_simple(queue_item, should_be_affected):
assert affected_items, f"Item {queue_item.get('downloadId')} should be affected"
assert affected_items[0]["downloadId"] == queue_item["downloadId"]
else:
assert not affected_items, f"Item {queue_item.get('downloadId')} should NOT be affected"
assert (
not affected_items
), f"Item {queue_item.get('downloadId')} should NOT be affected"

View File

@@ -1,7 +1,7 @@
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_stalled import RemoveStalled
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
# Test to check if items with the specific error message are included in affected items with parameterized data
@@ -11,31 +11,75 @@ from src.jobs.remove_stalled import RemoveStalled
[
(
[
{"downloadId": "1", "status": "warning", "errorMessage": "The download is stalled with no connections"}, # Valid item
{"downloadId": "2", "status": "completed", "errorMessage": "The download is stalled with no connections"}, # Wrong status
{"downloadId": "3", "status": "warning", "errorMessage": "Some other error"}, # Incorrect errorMessage
{
"downloadId": "1",
"status": "warning",
"errorMessage": "The download is stalled with no connections",
}, # Valid item
{
"downloadId": "2",
"status": "completed",
"errorMessage": "The download is stalled with no connections",
}, # Wrong status
{
"downloadId": "3",
"status": "warning",
"errorMessage": "Some other error",
}, # Incorrect errorMessage
],
["1"], # Only the item with "warning" status and the correct errorMessage should be affected
[
"1"
], # Only the item with "warning" status and the correct errorMessage should be affected
),
(
[
{"downloadId": "1", "status": "warning", "errorMessage": "Some other error"}, # Incorrect errorMessage
{"downloadId": "2", "status": "completed", "errorMessage": "The download is stalled with no connections"}, # Wrong status
{"downloadId": "3", "status": "warning", "errorMessage": "The download is stalled with no connections"}, # Correct item
{
"downloadId": "1",
"status": "warning",
"errorMessage": "Some other error",
}, # Incorrect errorMessage
{
"downloadId": "2",
"status": "completed",
"errorMessage": "The download is stalled with no connections",
}, # Wrong status
{
"downloadId": "3",
"status": "warning",
"errorMessage": "The download is stalled with no connections",
}, # Correct item
],
["3"], # Only the item with "warning" status and the correct errorMessage should be affected
[
"3"
], # Only the item with "warning" status and the correct errorMessage should be affected
),
(
[
{"downloadId": "1", "status": "warning", "errorMessage": "The download is stalled with no connections"}, # Valid item
{"downloadId": "2", "status": "warning", "errorMessage": "The download is stalled with no connections"}, # Another valid item
{
"downloadId": "1",
"status": "warning",
"errorMessage": "The download is stalled with no connections",
}, # Valid item
{
"downloadId": "2",
"status": "warning",
"errorMessage": "The download is stalled with no connections",
}, # Another valid item
],
["1", "2"], # Both items match the condition
),
(
[
{"downloadId": "1", "status": "completed", "errorMessage": "The download is stalled with no connections"}, # Wrong status
{"downloadId": "2", "status": "warning", "errorMessage": "Some other error"}, # Incorrect errorMessage
{
"downloadId": "1",
"status": "completed",
"errorMessage": "The download is stalled with no connections",
}, # Wrong status
{
"downloadId": "2",
"status": "warning",
"errorMessage": "Some other error",
}, # Incorrect errorMessage
],
[], # No items match the condition
),

View File

@@ -1,8 +1,9 @@
from unittest.mock import AsyncMock
import pytest
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
from src.jobs.remove_unmonitored import RemoveUnmonitored
from tests.jobs.utils import shared_fix_affected_items, shared_test_affected_items
@pytest.mark.asyncio
@@ -59,12 +60,12 @@ from src.jobs.remove_unmonitored import RemoveUnmonitored
[
{"downloadId": "1", "detail_item_id": 101},
{"downloadId": "2", "detail_item_id": 102},
{"downloadId": "3", "detail_item_id": None}
{"downloadId": "3", "detail_item_id": None},
],
{101: True, 102: False},
["2"]
["2"],
),
]
],
)
async def test_find_affected_items(queue_data, monitored_ids, expected_download_ids):
# Arrange

View File

@@ -2,8 +2,10 @@ import logging
from unittest.mock import MagicMock
import pytest
from src.jobs.strikes_handler import StrikesHandler
# pylint: disable=W0212
# pylint: disable=too-many-locals
@pytest.mark.parametrize(
@@ -15,21 +17,17 @@ from src.jobs.strikes_handler import StrikesHandler
"expected_in_tracker",
"expected_in_paused",
"expected_in_recovered",
"expected_in_removed_from_queue"
"expected_in_removed_from_queue",
),
[
# Not tracked previously, in queue, not affected → ignore
("HASH1", False, True, False, False, False, False, False),
# Previously tracked, no longer in queue and not affected → recover with reason "no longer in queue"
("HASH2", True, False, False, False, False, False, True),
# Previously tracked, still in queue but no longer affected → recover with reason "has recovered"
("HASH3", True, True, False, False, False, True, False),
# Previously tracked, still in queue and still affected → remain tracked, no pause, no recover
("HASH4", True, True, True, True, False, False, False),
# Previously tracked, still in queue, not affected but tracking paused → remain tracked in paused, no recover
("HASH5", True, True, False, True, True, False, False),
],
@@ -42,7 +40,7 @@ def test_recover_downloads(
expected_in_tracker,
expected_in_paused,
expected_in_recovered,
expected_in_removed_from_queue
expected_in_removed_from_queue,
):
# Setup mock tracker with or without the download
strikes = 1 if already_in_tracker else None
@@ -56,9 +54,13 @@ def test_recover_downloads(
tracker = MagicMock()
tracker.defective = {
"remove_stalled": {
download_id: defective_entry,
} if already_in_tracker else {}
"remove_stalled": (
{
download_id: defective_entry,
}
if already_in_tracker
else {}
)
}
arr = MagicMock()
@@ -75,23 +77,32 @@ def test_recover_downloads(
queue.append({"downloadId": download_id})
# Unpack all three returned values from _recover_downloads
recovered, removed_from_queue, paused = handler._recover_downloads(affected_downloads, queue=queue) # pylint: disable=W0212
recovered, removed_from_queue, paused = handler._recover_downloads(
affected_downloads, queue=queue
) # pylint: disable=W0212
is_in_tracker = download_id in tracker.defective["remove_stalled"]
assert is_in_tracker == expected_in_tracker, f"{download_id} tracker presence mismatch"
assert (
is_in_tracker == expected_in_tracker
), f"{download_id} tracker presence mismatch"
is_in_paused = download_id in paused
assert is_in_paused == expected_in_paused, f"{download_id} paused presence mismatch"
is_in_recovered = download_id in recovered
assert is_in_recovered == expected_in_recovered, f"{download_id} recovered presence mismatch"
assert (
is_in_recovered == expected_in_recovered
), f"{download_id} recovered presence mismatch"
is_in_recovered = download_id in recovered
assert is_in_recovered == expected_in_recovered, f"{download_id} recovered presence mismatch"
assert (
is_in_recovered == expected_in_recovered
), f"{download_id} recovered presence mismatch"
is_in_removed = download_id in removed_from_queue
assert is_in_removed == expected_in_removed_from_queue, f"{download_id} removed_from_queue presence mismatch"
assert (
is_in_removed == expected_in_removed_from_queue
), f"{download_id} removed_from_queue presence mismatch"
@pytest.mark.parametrize(
@@ -121,14 +132,13 @@ def test_apply_strikes_and_filter(
"HASH1": {"title": "dummy"},
}
result = handler._apply_strikes_and_filter(
affected_downloads
)
result = handler._apply_strikes_and_filter(affected_downloads)
if expected_in_affected_downloads:
assert "HASH1" in result
else:
assert "HASH1" not in result
def test_log_change_logs_expected_strike_changes(caplog):
handler = StrikesHandler(job_name="remove_stalled", arr=MagicMock(), max_strikes=3)
handler.tracker = MagicMock()
@@ -152,7 +162,14 @@ def test_log_change_logs_expected_strike_changes(caplog):
log_messages = "\n".join(record.message for record in caplog.records)
# Check category keywords exist
for keyword in ["Added", "Incremented", "Tracking Paused", "Removed from queue", "Recovered", "Strikes Exceeded"]:
for keyword in [
"Added",
"Incremented",
"Tracking Paused",
"Removed from queue",
"Recovered",
"Strikes Exceeded",
]:
assert keyword in log_messages
# Check actual IDs appear somewhere in the logged messages

View File

@@ -1,22 +1,25 @@
from unittest.mock import MagicMock
def shared_fix_affected_items(removal_class, queue_data=None):
# Arrange
removal_job = removal_class(arr=MagicMock(), settings=MagicMock(),job_name="test")
removal_job = removal_class(arr=MagicMock(), settings=MagicMock(), job_name="test")
if queue_data:
removal_job.queue = queue_data
return removal_job
async def shared_test_affected_items(removal_job, expected_download_ids):
# Act
affected_items = await removal_job._find_affected_items() # pylint: disable=W0212
affected_items = await removal_job._find_affected_items() # pylint: disable=W0212
# Assert
assert isinstance(affected_items, list)
# Assert that the affected items match the expected download IDs
affected_download_ids = [item["downloadId"] for item in affected_items]
assert sorted(affected_download_ids) == sorted(expected_download_ids), \
f"Expected affected items with downloadIds {expected_download_ids}, got {affected_download_ids}"
assert sorted(affected_download_ids) == sorted(
expected_download_ids
), f"Expected affected items with downloadIds {expected_download_ids}, got {affected_download_ids}"
return affected_items

View File

@@ -1,5 +1,7 @@
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.settings._instances import ArrInstance

View File

@@ -1,4 +1,5 @@
from unittest.mock import Mock, AsyncMock
from unittest.mock import AsyncMock, Mock
import pytest
from src.settings._download_clients_sabnzbd import SabnzbdClient, SabnzbdClients
@@ -11,9 +12,7 @@ class TestSabnzbdClient:
settings.min_versions = Mock()
settings.min_versions.sabnzbd = "4.0.0"
client = SabnzbdClient(
settings=settings,
base_url="http://sabnzbd:8080",
api_key="test_api_key"
settings=settings, base_url="http://sabnzbd:8080", api_key="test_api_key"
)
assert client.base_url == "http://sabnzbd:8080"
assert client.api_url == "http://sabnzbd:8080/api"
@@ -29,7 +28,7 @@ class TestSabnzbdClient:
settings=settings,
base_url="http://sabnzbd:8080/",
api_key="test_api_key",
name="Custom SABnzbd"
name="Custom SABnzbd",
)
assert client.base_url == "http://sabnzbd:8080"
assert client.api_url == "http://sabnzbd:8080/api"
@@ -55,23 +54,14 @@ class TestSabnzbdClient:
settings.min_versions = Mock()
settings.min_versions.sabnzbd = "4.0.0"
client = SabnzbdClient(
settings=settings,
base_url="http://sabnzbd:8080",
api_key="test_api_key"
settings=settings, base_url="http://sabnzbd:8080", api_key="test_api_key"
)
# Mock the get_queue_items method
client.get_queue_items = AsyncMock(return_value=[
{
"nzo_id": "test_id_1",
"mb": "1000",
"mbleft": "200"
},
{
"nzo_id": "test_id_2",
"mb": "2000",
"mbleft": "1000"
}
]
client.get_queue_items = AsyncMock(
return_value=[
{"nzo_id": "test_id_1", "mb": "1000", "mbleft": "200"},
{"nzo_id": "test_id_2", "mb": "2000", "mbleft": "1000"},
]
)
# Test getting progress for existing download
progress = await client.fetch_download_progress("test_id_1")
@@ -95,15 +85,12 @@ class TestSabnzbdClients:
config = {
"download_clients": {
"sabnzbd": [
{"base_url": "http://sabnzbd1:8080", "api_key": "api_key_1"},
{
"base_url": "http://sabnzbd1:8080",
"api_key": "api_key_1"
},
{
"base_url": "http://sabnzbd2:8080",
"base_url": "http://sabnzbd2:8080",
"api_key": "api_key_2",
"name": "SABnzbd 2"
}
"name": "SABnzbd 2",
},
]
}
}
@@ -121,11 +108,7 @@ class TestSabnzbdClients:
def test_init_invalid_config_format(self, caplog):
"""Test SabnzbdClients initialization with invalid config format."""
config = {
"download_clients": {
"sabnzbd": "not_a_list"
}
}
config = {"download_clients": {"sabnzbd": "not_a_list"}}
settings = Mock()
clients = SabnzbdClients(config, settings)
assert len(clients) == 0

View File

@@ -1,4 +1,5 @@
"""Test loading the user configuration from environment variables."""
import os
import textwrap
from unittest.mock import patch
@@ -16,47 +17,61 @@ TIMER_VALUE = "10"
SSL_VERIFICATION_VALUE = "true"
# List
ignored_download_clients_yaml = textwrap.dedent("""
ignored_download_clients_yaml = textwrap.dedent(
"""
- emulerr
- napster
""").strip()
"""
).strip()
# Job: No settings
remove_bad_files_yaml = "" # pylint: disable=C0103; empty string represents flag enabled with no config
remove_bad_files_yaml = ( # pylint: disable=C0103; empty string represents flag enabled with no config
""
)
# Job: One Setting
remove_slow_yaml = textwrap.dedent("""
remove_slow_yaml = textwrap.dedent(
"""
- max_strikes: 3
""").strip()
"""
).strip()
# Job: Multiple Setting
remove_stalled_yaml = textwrap.dedent("""
remove_stalled_yaml = textwrap.dedent(
"""
- min_speed: 100
- max_strikes: 3
- some_bool_upper: TRUE
- some_bool_lower: false
- some_bool_sentence: False
""").strip()
"""
).strip()
# Arr Instances
radarr_yaml = textwrap.dedent("""
radarr_yaml = textwrap.dedent(
"""
- base_url: "http://radarr:7878"
api_key: "radarr1_key"
""").strip()
"""
).strip()
sonarr_yaml = textwrap.dedent("""
sonarr_yaml = textwrap.dedent(
"""
- base_url: "sonarr_1_api_key"
api_key: "sonarr1_api_url"
- base_url: "sonarr_2_api_key"
api_key: "sonarr2_api_url"
""").strip()
"""
).strip()
# Qbit Instances
qbit_yaml = textwrap.dedent("""
qbit_yaml = textwrap.dedent(
"""
- base_url: "http://qbittorrent:8080"
username: "qbit_username1"
password: "qbit_password1"
""").strip()
"""
).strip()
@pytest.fixture(name="env_vars")
@@ -87,19 +102,28 @@ sonarr_expected = yaml.safe_load(sonarr_yaml)
qbit_expected = yaml.safe_load(qbit_yaml)
@pytest.mark.parametrize(("section", "key", "expected"), [
("general", "log_level", LOG_LEVEL_VALUE),
("general", "timer", int(TIMER_VALUE)),
("general", "ssl_verification", True),
("general", "ignored_download_clients", remove_ignored_download_clients_expected),
("jobs", "remove_bad_files", remove_bad_files_expected),
("jobs", "remove_slow", remove_slow_expected),
("jobs", "remove_stalled", remove_stalled_expected),
("instances", "radarr", radarr_expected),
("instances", "sonarr", sonarr_expected),
("download_clients", "qbittorrent", qbit_expected),
])
def test_env_loading_parametrized(env_vars, section, key, expected): # pylint: disable=unused-argument # noqa: ARG001
@pytest.mark.parametrize(
("section", "key", "expected"),
[
("general", "log_level", LOG_LEVEL_VALUE),
("general", "timer", int(TIMER_VALUE)),
("general", "ssl_verification", True),
(
"general",
"ignored_download_clients",
remove_ignored_download_clients_expected,
),
("jobs", "remove_bad_files", remove_bad_files_expected),
("jobs", "remove_slow", remove_slow_expected),
("jobs", "remove_stalled", remove_stalled_expected),
("instances", "radarr", radarr_expected),
("instances", "sonarr", sonarr_expected),
("download_clients", "qbittorrent", qbit_expected),
],
)
def test_env_loading_parametrized(
env_vars, section, key, expected
): # pylint: disable=unused-argument # noqa: ARG001
config = _load_from_env()
assert section in config
assert key in config[section]

View File

@@ -1,4 +1,5 @@
from unittest.mock import Mock
import pytest
from src.utils.queue_manager import QueueManager
@@ -11,11 +12,13 @@ def fixture_mock_queue_manager():
mock_settings = Mock()
return QueueManager(arr=mock_arr, settings=mock_settings)
# ---------- Tests ----------
def test_format_queue_empty(mock_queue_manager):
result = mock_queue_manager.format_queue([])
assert result == "empty"
def test_format_queue_single_item(mock_queue_manager):
queue_items = [
{
@@ -37,6 +40,7 @@ def test_format_queue_single_item(mock_queue_manager):
result = mock_queue_manager.format_queue(queue_items)
assert result == expected
def test_format_queue_multiple_same_download_id(mock_queue_manager):
queue_items = [
{
@@ -52,7 +56,7 @@ def test_format_queue_multiple_same_download_id(mock_queue_manager):
"protocol": "usenet",
"status": "downloading",
"id": 2,
}
},
]
expected = {
"xyz789": {
@@ -65,6 +69,7 @@ def test_format_queue_multiple_same_download_id(mock_queue_manager):
result = mock_queue_manager.format_queue(queue_items)
assert result == expected
def test_format_queue_multiple_different_download_ids(mock_queue_manager):
queue_items = [
{
@@ -80,21 +85,21 @@ def test_format_queue_multiple_different_download_ids(mock_queue_manager):
"protocol": "usenet",
"status": "completed",
"id": 20,
}
},
]
expected = {
'aaa111': {
'queue_ids': [10],
'title': 'Example Download Title A',
'protocol': 'torrent',
'status': 'queued'
"aaa111": {
"queue_ids": [10],
"title": "Example Download Title A",
"protocol": "torrent",
"status": "queued",
},
"bbb222": {
"queue_ids": [20],
"title": "Example Download Title B",
"protocol": "usenet",
"status": "completed",
},
'bbb222': {
'queue_ids': [20],
'title': 'Example Download Title B',
'protocol': 'usenet',
'status': 'completed'
}
}
result = mock_queue_manager.format_queue(queue_items)
assert result == expected