diff --git a/src/job_manager.py b/src/job_manager.py index e0dd29b..bd058a6 100644 --- a/src/job_manager.py +++ b/src/job_manager.py @@ -58,6 +58,7 @@ class JobManager: await SearchHandler(self.arr, self.settings).handle_search("cutoff") async def _queue_has_items(self): + logger.debug(f"job_manager.py/_queue_has_items (Before any removal jobs): Checking if any items in full queue") queue_manager = QueueManager(self.arr, self.settings) full_queue = await queue_manager.get_queue_items("full") if full_queue: @@ -72,6 +73,7 @@ class JobManager: async def _qbit_connected(self): for qbit in self.settings.download_clients.qbittorrent: + logger.debug(f"job_manager.py/_queue_has_items (Before any removal jobs): Checking if qbit is connected to the internet") # Check if any client is disconnected if not await qbit.check_qbit_connected(): logger.warning( diff --git a/src/jobs/removal_job.py b/src/jobs/removal_job.py index 3fd588c..531cad9 100644 --- a/src/jobs/removal_job.py +++ b/src/jobs/removal_job.py @@ -13,6 +13,7 @@ class RemovalJob(ABC): affected_downloads = None job = None max_strikes = None + queue = [] # Default class attributes (can be overridden in subclasses) def __init__(self, arr, settings, job_name): @@ -28,12 +29,15 @@ class RemovalJob(ABC): async def run(self): if not self.job.enabled: return 0 - if await self.is_queue_empty(self.job_name, self.queue_scope): + logger.debug(f"removal_job.py/run: Launching job '{self.job_name}', and checking if any items in {self.queue_scope} queue.") + self.queue = await self.queue_manager.get_queue_items(queue_scope=self.queue_scope) + + # Handle empty queue + if not self.queue: if self.max_strikes: self.strikes_handler.all_recovered() return 0 - - logger.debug(f"removal_job.py: Running job '{self.job_name}'") + self.affected_items = await self._find_affected_items() self.affected_downloads = self.queue_manager.group_by_download_id(self.affected_items) @@ -53,21 +57,6 @@ class RemovalJob(ABC): return len(self.affected_downloads) - - - async def is_queue_empty(self, job_name, queue_scope="normal"): - # Check if queue empty - queue_items = await self.queue_manager.get_queue_items(queue_scope) - logger.debug( - f"{job_name}/queue IN: %s", - self.queue_manager.format_queue(queue_items), - ) - # Early exit if no queue - if not queue_items: - return True - return False - - def _ignore_protected(self): """ Filters out downloads that are in the protected tracker. diff --git a/src/jobs/remove_bad_files.py b/src/jobs/remove_bad_files.py index e3facbd..2bf4e94 100644 --- a/src/jobs/remove_bad_files.py +++ b/src/jobs/remove_bad_files.py @@ -28,10 +28,8 @@ class RemoveBadFiles(RemovalJob): # fmt: on async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") - # Get in-scope download IDs - result = self._group_download_ids_by_client(queue) + result = self._group_download_ids_by_client() affected_items = [] for download_client, info in result.items(): @@ -39,17 +37,17 @@ class RemoveBadFiles(RemovalJob): download_ids = info["download_ids"] if download_client_type == "qbittorrent": - client_items = await self._handle_qbit(download_client, download_ids, queue) + client_items = await self._handle_qbit(download_client, download_ids) affected_items.extend(client_items) return affected_items - def _group_download_ids_by_client(self, queue): + def _group_download_ids_by_client(self): """Group all relevant download IDs by download client. Limited to qbittorrent currently, as no other download clients implemented""" result = {} - for item in queue: + for item in self.queue: download_client_name = item.get("downloadClient") if not download_client_name: continue @@ -70,7 +68,7 @@ class RemoveBadFiles(RemovalJob): return result - async def _handle_qbit(self, qbit_client, hashes, queue): + async def _handle_qbit(self, qbit_client, hashes): """Handle qBittorrent-specific logic for marking files as 'Do Not Download'.""" affected_items = [] qbit_items = await qbit_client.get_qbit_items(hashes=hashes) @@ -89,7 +87,7 @@ class RemoveBadFiles(RemovalJob): if self._all_files_stopped(torrent_files, stoppable_files): logger.verbose(">>> All files in this torrent have been marked as 'Do not Download'. Removing torrent.") - affected_items.extend(self._match_queue_items(queue, qbit_item["hash"])) + affected_items.extend(self._match_queue_items(qbit_item["hash"])) return affected_items @@ -134,10 +132,10 @@ class RemoveBadFiles(RemovalJob): """Check if no files remain with download priority.""" return all(f["priority"] == 0 for f in torrent_files) - def _match_queue_items(self, queue, download_hash): + def _match_queue_items(self, download_hash): """Find matching queue item(s) by downloadId (uppercase).""" return [ - item for item in queue + item for item in self.queue if item["downloadId"] == download_hash.upper() ] @@ -208,10 +206,10 @@ class RemoveBadFiles(RemovalJob): stoppable_file_indexes= {file[0]["index"] for file in stoppable_files} return all(f["priority"] == 0 or f["index"] in stoppable_file_indexes for f in torrent_files) - def _match_queue_items(self, queue, download_hash): + def _match_queue_items(self, download_hash): """Find matching queue item(s) by downloadId (uppercase).""" return [ - item for item in queue + item for item in self.queue if item["downloadId"].upper() == download_hash.upper() ] diff --git a/src/jobs/remove_failed_downloads.py b/src/jobs/remove_failed_downloads.py index 0ebb2af..e757627 100644 --- a/src/jobs/remove_failed_downloads.py +++ b/src/jobs/remove_failed_downloads.py @@ -5,10 +5,9 @@ class RemoveFailedDownloads(RemovalJob): blocklist = False async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") affected_items = [] - for item in queue: + for item in self.queue: if "status" in item: if item["status"] == "failed": affected_items.append(item) diff --git a/src/jobs/remove_failed_imports.py b/src/jobs/remove_failed_imports.py index 2d3e300..ce97860 100644 --- a/src/jobs/remove_failed_imports.py +++ b/src/jobs/remove_failed_imports.py @@ -6,11 +6,10 @@ class RemoveFailedImports(RemovalJob): blocklist = True async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") affected_items = [] patterns = self.job.message_patterns - for item in queue: + for item in self.queue: if not self._is_valid_item(item): continue diff --git a/src/jobs/remove_metadata_missing.py b/src/jobs/remove_metadata_missing.py index 93f29d4..7be4427 100644 --- a/src/jobs/remove_metadata_missing.py +++ b/src/jobs/remove_metadata_missing.py @@ -6,10 +6,9 @@ class RemoveMetadataMissing(RemovalJob): blocklist = True async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") affected_items = [] - for item in queue: + for item in self.queue: if "errorMessage" in item and "status" in item: if ( item["status"] == "queued" diff --git a/src/jobs/remove_missing_files.py b/src/jobs/remove_missing_files.py index ef59c72..5d819f5 100644 --- a/src/jobs/remove_missing_files.py +++ b/src/jobs/remove_missing_files.py @@ -5,10 +5,9 @@ class RemoveMissingFiles(RemovalJob): blocklist = False async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") affected_items = [] - for item in queue: + for item in self.queue: if self._is_failed_torrent(item) or self._is_bad_nzb(item): affected_items.append(item) diff --git a/src/jobs/remove_orphans.py b/src/jobs/remove_orphans.py index ca21203..c7e3c6a 100644 --- a/src/jobs/remove_orphans.py +++ b/src/jobs/remove_orphans.py @@ -1,11 +1,10 @@ from src.jobs.removal_job import RemovalJob class RemoveOrphans(RemovalJob): - queue_scope = "full" + queue_scope = "orphans" blocklist = False async def _find_affected_items(self): - affected_items = await self.queue_manager.get_queue_items(queue_scope="orphans") - return affected_items + return self.queue diff --git a/src/jobs/remove_slow.py b/src/jobs/remove_slow.py index 2b60765..f7fc0e8 100644 --- a/src/jobs/remove_slow.py +++ b/src/jobs/remove_slow.py @@ -7,11 +7,10 @@ class RemoveSlow(RemovalJob): blocklist = True async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope=self.queue_scope) affected_items = [] checked_ids = set() - for item in queue: + for item in self.queue: if not self._is_valid_item(item): continue diff --git a/src/jobs/remove_stalled.py b/src/jobs/remove_stalled.py index e9bbdd1..13e68ae 100644 --- a/src/jobs/remove_stalled.py +++ b/src/jobs/remove_stalled.py @@ -6,9 +6,8 @@ class RemoveStalled(RemovalJob): blocklist = True async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") affected_items = [] - for item in queue: + for item in self.queue: if "errorMessage" in item and "status" in item: if ( item["status"] == "warning" diff --git a/src/jobs/remove_unmonitored.py b/src/jobs/remove_unmonitored.py index 4f6d0d2..5c8f09b 100644 --- a/src/jobs/remove_unmonitored.py +++ b/src/jobs/remove_unmonitored.py @@ -5,18 +5,16 @@ class RemoveUnmonitored(RemovalJob): blocklist = False async def _find_affected_items(self): - queue = await self.queue_manager.get_queue_items(queue_scope="normal") - # First pass: Check if items are monitored monitored_download_ids = [] - for item in queue: + for item in self.queue: detail_item_id = item["detail_item_id"] if await self.arr.is_monitored(detail_item_id): monitored_download_ids.append(item["downloadId"]) # Second pass: Append queue items none that depends on download id is monitored affected_items = [] - for queue_item in queue: + for queue_item in self.queue: if queue_item["downloadId"] not in monitored_download_ids: affected_items.append( queue_item diff --git a/src/jobs/search_handler.py b/src/jobs/search_handler.py index 5f0dd87..c56e840 100644 --- a/src/jobs/search_handler.py +++ b/src/jobs/search_handler.py @@ -17,10 +17,12 @@ class SearchHandler: logger.debug(f"search_handler.py: Running '{search_type}' search") self._initialize_job(search_type) + logger.debug(f"search_handler.py/handle_search: Getting the list of wanted items ({search_type})") wanted_items = await self._get_initial_wanted_items(search_type) if not wanted_items: return - + + logger.debug(f"search_handler.py/handle_search: Getting list of queue items to only search for items that are not already downloading.") queue = await QueueManager(self.arr, self.settings).get_queue_items( queue_scope="normal" ) @@ -29,6 +31,7 @@ class SearchHandler: return await self._log_items(wanted_items, search_type) + logger.debug(f"search_handler.py/handle_search: Triggering search for wanted items ({search_type})") await self._trigger_search(wanted_items) def _initialize_job(self, search_type): @@ -101,6 +104,7 @@ class SearchHandler: logger.verbose(f">>> - {title}") elif self.arr.arr_type == "sonarr": + logger.debug("search_handler.py/_log_items: Getting series information for better display in output") series = await self.arr.get_series() series_title = next( (s["title"] for s in series if s["id"] == item.get("seriesId")), diff --git a/src/settings/_download_clients_qBit.py b/src/settings/_download_clients_qBit.py index c337ccc..9be6565 100644 --- a/src/settings/_download_clients_qBit.py +++ b/src/settings/_download_clients_qBit.py @@ -72,6 +72,7 @@ class QbitClient: async def refresh_cookie(self): """Refresh the qBittorrent session cookie.""" try: + logger.debug("_download_clients_qBit.py/refresh_cookie: Refreshing qBit cookie") endpoint = f"{self.api_url}/auth/login" data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')} headers = {"content-type": "application/x-www-form-urlencoded"} @@ -83,7 +84,6 @@ class QbitClient: raise ConnectionError("Login failed.") self.cookie = {"SID": response.cookies["SID"]} - logger.debug("qBit cookie refreshed!") except Exception as e: logger.error(f"Error refreshing qBit cookie: {e}") self.cookie = {} @@ -93,10 +93,11 @@ class QbitClient: async def fetch_version(self): """Fetch the current qBittorrent version.""" + logger.debug("_download_clients_qBit.py/fetch_version: Getting qBit Version") endpoint = f"{self.api_url}/app/version" response = await make_request("get", endpoint, self.settings, cookies=self.cookie) self.version = response.text[1:] # Remove the '_v' prefix - logger.debug(f"qBit version for client qBittorrent: {self.version}") + logger.debug(f"_download_clients_qBit.py/fetch_version: qBit version={self.version}") async def validate_version(self): @@ -115,16 +116,16 @@ class QbitClient: f"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead." ) - - async def create_tag(self): - """Create the protection tag in qBittorrent if it doesn't exist.""" + async def create_tag(self, tag: str): + """Ensure a tag exists in qBittorrent; create it if it doesn't.""" + logger.debug("_download_clients_qBit.py/create_tag: Checking if tag '{tag}' exists (and creating it if not)") url = f"{self.api_url}/torrents/tags" response = await make_request("get", url, self.settings, cookies=self.cookie) - current_tags = response.json() - if self.settings.general.protected_tag not in current_tags: - logger.verbose(f"Creating protection tag: {self.settings.general.protected_tag}") - data = {"tags": self.settings.general.protected_tag} + + if tag not in current_tags: + logger.verbose(f"Creating tag: {tag}") + data = {"tags": tag} await make_request( "post", self.api_url + "/torrents/createTags", @@ -133,24 +134,20 @@ class QbitClient: cookies=self.cookie, ) + async def create_required_tags(self): + """Ensure protection and obsolete tags exist in qBittorrent if needed.""" + await self.create_tag(self.settings.general.protected_tag) + if ( self.settings.general.public_tracker_handling == "tag_as_obsolete" or self.settings.general.private_tracker_handling == "tag_as_obsolete" ): - if self.settings.general.obsolete_tag not in current_tags: - logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}") - data = {"tags": self.settings.general.obsolete_tag} - await make_request( - "post", - self.api_url + "/torrents/createTags", - self.settings, - data=data, - cookies=self.cookie, - ) + await self.create_tag(self.settings.general.obsolete_tag) async def set_unwanted_folder(self): """Set the 'unwanted folder' setting in qBittorrent if needed.""" if self.settings.jobs.remove_bad_files: + logger.debug("_download_clients_qBit.py/set_unwanted_folder: Checking preferences and setting use_unwanted_folder if not already set") endpoint = f"{self.api_url}/app/preferences" response = await make_request( "get", endpoint, self.settings, cookies=self.cookie @@ -174,6 +171,7 @@ class QbitClient: async def check_qbit_reachability(self): """Check if the qBittorrent URL is reachable.""" try: + logger.debug("_download_clients_qBit.py/check_qbit_reachability: Checking if qbit is reachable") endpoint = f"{self.api_url}/auth/login" data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')} headers = {"content-type": "application/x-www-form-urlencoded"} @@ -189,6 +187,7 @@ class QbitClient: async def check_qbit_connected(self): """Check if the qBittorrent is connected to internet.""" + logger.debug("_download_clients_qBit.py/check_qbit_reachability: Checking if qbit is connected to the internet") qbit_connection_status = (( await make_request( "get", @@ -222,7 +221,7 @@ class QbitClient: wait_and_exit() # Exit if version check fails # Continue with other setup tasks regardless of version check result - await self.create_tag() + await self.create_required_tags() await self.set_unwanted_folder() @@ -232,6 +231,7 @@ class QbitClient: private_downloads = [] # Fetch all torrents + logger.debug("_download_clients_qBit/get_protected_and_private: Checking if torrents have protected tag") qbit_items = await self.get_qbit_items() for qbit_item in qbit_items: @@ -245,6 +245,7 @@ class QbitClient: if qbit_item.get("private"): private_downloads.append(qbit_item["hash"].upper()) else: + logger.debug("_download_clients_qBit/get_protected_and_private: Checking if torrents are private (only done for old qbit versions)") qbit_item_props = await make_request( "get", self.api_url + "/torrents/properties", @@ -279,6 +280,8 @@ class QbitClient: # Ensure tags are provided as a string separated by ',' (comma) tags_str = ",".join(tags) + logger.debug("_download_clients_qBit/set_tag: Setting tag(s) {tags_str} to {hashes_str}") + # Prepare the data for the request data = { "hashes": hashes_str, @@ -319,6 +322,7 @@ class QbitClient: async def get_torrent_files(self, download_id): # this may not work if the wrong qbit + logger.debug("_download_clients_qBit/get_torrent_files: Getting torrent files") response = await make_request( method="get", endpoint=self.api_url + "/torrents/files", @@ -329,6 +333,7 @@ class QbitClient: return response.json() async def set_torrent_file_priority(self, download_id, file_id, priority = 0): + logger.debug("_download_clients_qBit/set_torrent_file_priority: Setting download priority for torrent file") data={ "hash": download_id.lower(), "id": file_id, diff --git a/src/settings/_instances.py b/src/settings/_instances.py index c526b0c..d5a3066 100644 --- a/src/settings/_instances.py +++ b/src/settings/_instances.py @@ -195,6 +195,7 @@ class ArrInstance: async def _check_reachability(self): """Check if ARR instance is reachable.""" try: + logger.debug("_instances.py/_check_reachability: Checking if arr instance is reachable") endpoint = self.api_url + "/system/status" headers = {"X-Api-Key": self.api_key} response = await make_request( @@ -237,6 +238,7 @@ class ArrInstance: async def get_download_client_implementation(self, download_client_name): """Fetch download client information and return the implementation value.""" + logger.debug("_instances.py/get_download_client_implementation: Checking type of download client type by download client name") endpoint = self.api_url + "/downloadclient" headers = {"X-Api-Key": self.api_key} @@ -265,6 +267,7 @@ class ArrInstance: Returns: bool: Returns True if the removal was successful, False otherwise. """ + logger.debug(f"_instances.py/remove_queue_item: Removing queue item, blocklist: {blocklist}") endpoint = f"{self.api_url}/queue/{queue_id}" headers = {"X-Api-Key": self.api_key} json_payload = {"removeFromClient": True, "blocklist": blocklist} @@ -282,6 +285,7 @@ class ArrInstance: async def is_monitored(self, detail_id): """Check if detail item (like a book, series, etc) is monitored.""" + logger.debug(f"_instances.py/is_monitored: Checking if item is monitored") endpoint = f"{self.api_url}/{self.detail_item_key}/{detail_id}" headers = {"X-Api-Key": self.api_key} diff --git a/tests/jobs/test_remove_bad_files.py b/tests/jobs/test_remove_bad_files.py index 54c8208..d30eb22 100644 --- a/tests/jobs/test_remove_bad_files.py +++ b/tests/jobs/test_remove_bad_files.py @@ -1,8 +1,7 @@ from unittest.mock import MagicMock, AsyncMock +import os import pytest from src.jobs.remove_bad_files import RemoveBadFiles -from tests.jobs.test_utils import removal_job_fix -import os # Fixture for arr mock @pytest.fixture(name="arr") @@ -24,7 +23,7 @@ def fixture_qbit_client(): @pytest.fixture(name="removal_job") def fixture_removal_job(arr): - removal_job = removal_job_fix(RemoveBadFiles) + removal_job = RemoveBadFiles(arr=arr, settings=MagicMock(), job_name="test") removal_job.arr = arr removal_job.job = MagicMock() removal_job.job.keep_archives = False @@ -193,9 +192,9 @@ async def test_get_items_to_process(qbit_item, expected_processed, removal_job, arr.tracker.extension_checked = {"checked-hash"} # Act - processed_items = removal_job._get_items_to_process( + processed_items = removal_job._get_items_to_process( # pylint: disable=W0212 [qbit_item] - ) # pylint: disable=W0212 + ) # Extract the hash from the processed items processed_hashes = [item["hash"] for item in processed_items] diff --git a/tests/jobs/test_remove_failed_downloads.py b/tests/jobs/test_remove_failed_downloads.py index 9bc2b42..38e10a1 100644 --- a/tests/jobs/test_remove_failed_downloads.py +++ b/tests/jobs/test_remove_failed_downloads.py @@ -1,6 +1,6 @@ +from unittest.mock import MagicMock import pytest from src.jobs.remove_failed_downloads import RemoveFailedDownloads -from tests.jobs.test_utils import removal_job_fix # Test to check if items with "failed" status are included in affected items with parameterized data @pytest.mark.asyncio @@ -34,7 +34,8 @@ from tests.jobs.test_utils import removal_job_fix ) async def test_find_affected_items(queue_data, expected_download_ids): # Arrange - removal_job = removal_job_fix(RemoveFailedDownloads, queue_data=queue_data) + removal_job = RemoveFailedDownloads(arr=MagicMock(), settings=MagicMock(), job_name="test") + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_remove_failed_imports.py b/tests/jobs/test_remove_failed_imports.py index c4efe76..b834584 100644 --- a/tests/jobs/test_remove_failed_imports.py +++ b/tests/jobs/test_remove_failed_imports.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock import pytest from src.jobs.remove_failed_imports import RemoveFailedImports -from tests.jobs.test_utils import removal_job_fix @pytest.mark.asyncio @pytest.mark.parametrize( @@ -60,7 +59,7 @@ from tests.jobs.test_utils import removal_job_fix ) async def test_is_valid_item(item, expected_result): #Fix - removal_job = removal_job_fix(RemoveFailedImports) + removal_job = RemoveFailedImports(arr=MagicMock(), settings=MagicMock(), job_name="test") # Act result = removal_job._is_valid_item(item) # pylint: disable=W0212 @@ -113,7 +112,8 @@ def fixture_queue_data(): ) async def test_find_affected_items_with_patterns(queue_data, patterns, expected_download_ids, removal_messages_expected): # Arrange - removal_job = removal_job_fix(RemoveFailedImports, queue_data=queue_data) + removal_job = RemoveFailedImports(arr=MagicMock(), settings=MagicMock(),job_name="test") + removal_job.queue = queue_data # Mock the job settings for message patterns removal_job.job = MagicMock() diff --git a/tests/jobs/test_remove_metadata_missing.py b/tests/jobs/test_remove_metadata_missing.py index 6d098e3..4c737ee 100644 --- a/tests/jobs/test_remove_metadata_missing.py +++ b/tests/jobs/test_remove_metadata_missing.py @@ -1,6 +1,6 @@ import pytest +from unittest.mock import MagicMock from src.jobs.remove_metadata_missing import RemoveMetadataMissing -from tests.jobs.test_utils import removal_job_fix # Test to check if items with the specific error message are included in affected items with parameterized data @pytest.mark.asyncio @@ -41,7 +41,8 @@ from tests.jobs.test_utils import removal_job_fix ) async def test_find_affected_items(queue_data, expected_download_ids): # Arrange - removal_job = removal_job_fix(RemoveMetadataMissing, queue_data=queue_data) + removal_job = RemoveMetadataMissing(arr=MagicMock(), settings=MagicMock(),job_name="test") + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_remove_missing_files.py b/tests/jobs/test_remove_missing_files.py index a08aae1..b053ed5 100644 --- a/tests/jobs/test_remove_missing_files.py +++ b/tests/jobs/test_remove_missing_files.py @@ -1,6 +1,6 @@ import pytest +from unittest.mock import MagicMock from src.jobs.remove_missing_files import RemoveMissingFiles -from tests.jobs.test_utils import removal_job_fix @pytest.mark.asyncio @pytest.mark.parametrize( @@ -66,7 +66,8 @@ from tests.jobs.test_utils import removal_job_fix ) async def test_find_affected_items(queue_data, expected_download_ids): # Arrange - removal_job = removal_job_fix(RemoveMissingFiles, queue_data=queue_data) + removal_job = RemoveMissingFiles(arr=MagicMock(), settings=MagicMock(),job_name="test") + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_remove_orphans.py b/tests/jobs/test_remove_orphans.py index 487c81e..4215db2 100644 --- a/tests/jobs/test_remove_orphans.py +++ b/tests/jobs/test_remove_orphans.py @@ -35,6 +35,7 @@ def fixture_queue_data(): async def test_find_affected_items_returns_queue(queue_data): # Fix removal_job = removal_job_fix(RemoveOrphans, queue_data=queue_data) + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_remove_slow.py b/tests/jobs/test_remove_slow.py index ba62e90..00c7fd7 100644 --- a/tests/jobs/test_remove_slow.py +++ b/tests/jobs/test_remove_slow.py @@ -1,7 +1,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest from src.jobs.remove_slow import RemoveSlow -from tests.jobs.test_utils import removal_job_fix @pytest.mark.asyncio @@ -57,13 +56,18 @@ from tests.jobs.test_utils import removal_job_fix ], ) async def test_is_valid_item(item, expected_result): - removal_job = removal_job_fix(RemoveSlow) + # Arrange + removal_job = RemoveSlow(arr=MagicMock(), settings=MagicMock(),job_name="test") + + # Act result = removal_job._is_valid_item(item) # pylint: disable=W0212 + + # Assert assert result == expected_result -@pytest.fixture(name="slow_queue_data") -def fixture_slow_queue_data(): +@pytest.fixture(name="queue_data") +def fixture_queue_data(): return [ { "downloadId": "usenet", @@ -129,9 +133,11 @@ def fixture_arr(): ], ) async def test_find_affected_items_with_varied_speeds( - slow_queue_data, min_speed, expected_ids, arr + queue_data, min_speed, expected_ids, arr ): - removal_job = removal_job_fix(RemoveSlow, queue_data=slow_queue_data) + # Arrange + removal_job = RemoveSlow(arr=MagicMock(), settings=MagicMock(),job_name="test") + removal_job.queue = queue_data # Set up job and timer removal_job.job = MagicMock() @@ -140,9 +146,9 @@ async def test_find_affected_items_with_varied_speeds( removal_job.settings.general.timer = 1 # 1 minute for speed calculation removal_job.arr = arr # Inject the mocked arr object removal_job._is_valid_item = MagicMock( return_value=True ) # Mock the _is_valid_item method to always return True # pylint: disable=W0212 - + # Inject size and sizeleft into each item in the queue - for item in slow_queue_data: + for item in queue_data: item["size"] = item["total_size"] * 1000000 # Inject total size as 'size' item["sizeleft"] = ( item["size"] - item["progress_now"] * 1000000 ) # Calculate sizeleft item["status"] = "downloading" @@ -151,13 +157,12 @@ async def test_find_affected_items_with_varied_speeds( # Mock the download progress in `arr.tracker.download_progress` removal_job.arr.tracker.download_progress = { item["downloadId"]: item["progress_previous"] * 1000000 - for item in slow_queue_data + for item in queue_data } - # Call the method we're testing + + # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 - - # Extract case identifiers of affected items affected_ids = [item["downloadId"] for item in affected_items] # Assert that the affected cases match the expected ones diff --git a/tests/jobs/test_remove_stalled.py b/tests/jobs/test_remove_stalled.py index c2c7539..ede6f70 100644 --- a/tests/jobs/test_remove_stalled.py +++ b/tests/jobs/test_remove_stalled.py @@ -1,6 +1,7 @@ import pytest from src.jobs.remove_stalled import RemoveStalled from tests.jobs.test_utils import removal_job_fix +from unittest.mock import AsyncMock # Test to check if items with the specific error message are included in affected items with parameterized data @pytest.mark.asyncio @@ -41,7 +42,8 @@ from tests.jobs.test_utils import removal_job_fix ) async def test_find_affected_items(queue_data, expected_download_ids): # Arrange - removal_job = removal_job_fix(RemoveStalled, queue_data=queue_data) + removal_job = RemoveStalled(arr=AsyncMock(), settings=AsyncMock(),job_name="test") + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_remove_unmonitored.py b/tests/jobs/test_remove_unmonitored.py index 9246b9f..b633911 100644 --- a/tests/jobs/test_remove_unmonitored.py +++ b/tests/jobs/test_remove_unmonitored.py @@ -1,13 +1,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest from src.jobs.remove_unmonitored import RemoveUnmonitored -from tests.jobs.test_utils import removal_job_fix - -@pytest.fixture(name="arr") -def fixture_arr(): - mock = MagicMock() - mock.is_monitored = AsyncMock() - return mock @pytest.mark.asyncio @pytest.mark.parametrize( @@ -60,15 +53,13 @@ def fixture_arr(): ), ] ) -async def test_find_affected_items(queue_data, monitored_ids, expected_download_ids, arr): - # Patch arr mock with side_effect - async def mock_is_monitored(detail_item_id): - return monitored_ids[detail_item_id] - - arr.is_monitored = AsyncMock(side_effect=mock_is_monitored) +async def test_find_affected_items(queue_data, monitored_ids, expected_download_ids): # Arrange - removal_job = removal_job_fix(RemoveUnmonitored, queue_data=queue_data) - removal_job.arr = arr # Inject the mocked arr object + arr = MagicMock() + arr.is_monitored = AsyncMock(side_effect=lambda id_: monitored_ids[id_]) + + removal_job = RemoveUnmonitored(arr=arr, settings=MagicMock(), job_name="test") + removal_job.queue = queue_data # Act affected_items = await removal_job._find_affected_items() # pylint: disable=W0212 diff --git a/tests/jobs/test_strikes_handler.py b/tests/jobs/test_strikes_handler.py index d1be5c9..d8a062a 100644 --- a/tests/jobs/test_strikes_handler.py +++ b/tests/jobs/test_strikes_handler.py @@ -1,5 +1,5 @@ -import pytest from unittest.mock import MagicMock +import pytest from src.jobs.strikes_handler import StrikesHandler @pytest.mark.parametrize(