Added sigterm handling to exit cleanly when running in Docker.

Fixed typos and various linting issues such as PEP violations.
Added ruff and fixed common issues and linting issues.
This commit is contained in:
NaruZosa
2025-05-25 16:54:51 +10:00
parent 22bdc9ab43
commit 2e6973bea4
56 changed files with 704 additions and 658 deletions

0
src/settings/__init__.py Normal file
View File

View File

@@ -1,5 +1,6 @@
import yaml
def mask_sensitive_value(value, key, sensitive_attributes):
"""Mask the value if it's in the sensitive attributes."""
return "*****" if key in sensitive_attributes else value
@@ -40,19 +41,19 @@ def clean_object(obj, sensitive_attributes, internal_attributes, hide_internal_a
"""Clean an object (either a dict, class instance, or other types)."""
if isinstance(obj, dict):
return clean_dict(obj, sensitive_attributes, internal_attributes, hide_internal_attr)
elif hasattr(obj, "__dict__"):
if hasattr(obj, "__dict__"):
return clean_dict(vars(obj), sensitive_attributes, internal_attributes, hide_internal_attr)
else:
return mask_sensitive_value(obj, "", sensitive_attributes)
return mask_sensitive_value(obj, "", sensitive_attributes)
def get_config_as_yaml(
data,
sensitive_attributes=None,
internal_attributes=None,
*,
hide_internal_attr=True,
):
"""Main function to process the configuration into YAML format."""
"""Process the configuration into YAML format."""
if sensitive_attributes is None:
sensitive_attributes = set()
if internal_attributes is None:
@@ -67,7 +68,7 @@ def get_config_as_yaml(
# Process list-based config
if isinstance(obj, list):
cleaned_list = clean_list(
obj, sensitive_attributes, internal_attributes, hide_internal_attr
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
)
if cleaned_list:
config_output[key] = cleaned_list
@@ -75,7 +76,7 @@ def get_config_as_yaml(
# Process dict or class-like object config
else:
cleaned_obj = clean_object(
obj, sensitive_attributes, internal_attributes, hide_internal_attr
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
)
if cleaned_obj:
config_output[key] = cleaned_obj

View File

@@ -1,4 +1,5 @@
import os
from src.settings._config_as_yaml import get_config_as_yaml

View File

@@ -1,12 +1,14 @@
from src.settings._config_as_yaml import get_config_as_yaml
from src.settings._download_clients_qBit import QbitClients
from src.settings._download_clients_qbit import QbitClients
download_client_types = ["qbittorrent"]
class DownloadClients:
"""Represents all download clients."""
qbittorrent = None
download_client_types = [
"qbittorrent",
]
def __init__(self, config, settings):
self._set_qbit_clients(config, settings)
self.check_unique_download_client_types()
@@ -15,7 +17,7 @@ class DownloadClients:
download_clients = config.get("download_clients", {})
if isinstance(download_clients, dict):
self.qbittorrent = QbitClients(config, settings)
if not self.qbittorrent: # Unsets settings in general section needed for qbit (if no qbit is defined)
if not self.qbittorrent: # Unsets settings in general section needed for qbit (if no qbit is defined)
for key in [
"private_tracker_handling",
"public_tracker_handling",
@@ -25,40 +27,42 @@ class DownloadClients:
setattr(settings.general, key, None)
def config_as_yaml(self):
"""Logs all download clients."""
"""Log all download clients."""
return get_config_as_yaml(
{"qbittorrent": self.qbittorrent},
sensitive_attributes={"username", "password", "cookie"},
internal_attributes={ "api_url", "cookie", "settings", "min_version"},
hide_internal_attr=True
internal_attributes={"api_url", "cookie", "settings", "min_version"},
hide_internal_attr=True,
)
def check_unique_download_client_types(self):
"""Ensures that all download client names are unique.
This is important since downloadClient in arr goes by name, and
this is needed to link it to the right IP set up in the yaml config
(which may be different to the one donfigured in arr)"""
"""
Ensure that all download client names are unique.
This is important since downloadClient in arr goes by name, and
this is needed to link it to the right IP set up in the yaml config
(which may be different to the one configured in arr)
"""
seen = set()
for download_client_type in self.download_client_types:
for download_client_type in download_client_types:
download_clients = getattr(self, download_client_type, [])
# Check each client in the list
for client in download_clients:
name = getattr(client, "name", None)
if name is None:
raise ValueError(f'{download_client_type} client does not have a name ({client.base_url}).\nMake sure that the name corresponds with the name set in your *arr app for that download client.')
error = f"{download_client_type} client does not have a name ({client.base_url}).\nMake sure that the name corresponds with the name set in your *arr app for that download client."
raise ValueError(error)
if name.lower() in seen:
raise ValueError(f"Download client names must be unique. Duplicate name found: '{name}'\nMake sure that the name corresponds with the name set in your *arr app for that download client.")
else:
seen.add(name.lower())
error = f"Download client names must be unique. Duplicate name found: '{name}'\nMake sure that the name corresponds with the name set in your *arr app for that download client."
raise ValueError(error)
seen.add(name.lower())
def get_download_client_by_name(self, name: str):
"""Retrieve the download client and its type by its name."""
name_lower = name.lower()
for download_client_type in self.download_client_types:
for download_client_type in download_client_types:
download_clients = getattr(self, download_client_type, [])
# Check each client in the list

View File

@@ -1,14 +1,16 @@
from packaging import version
from src.utils.common import make_request, wait_and_exit
from src.settings._constants import ApiEndpoints, MinVersions
from src.utils.common import make_request, wait_and_exit
from src.utils.log_setup import logger
class QbitError(Exception):
pass
class QbitClients(list):
"""Represents all qBittorrent clients"""
"""Represents all qBittorrent clients."""
def __init__(self, config, settings):
super().__init__()
@@ -19,7 +21,7 @@ class QbitClients(list):
if not isinstance(qbit_config, list):
logger.error(
"Invalid config format for qbittorrent clients. Expected a list."
"Invalid config format for qbittorrent clients. Expected a list.",
)
return
@@ -30,29 +32,29 @@ class QbitClients(list):
logger.error(f"Error parsing qbittorrent client config: {e}")
class QbitClient:
"""Represents a single qBittorrent client."""
cookie: str = None
cookie: dict[str, str] = None
version: str = None
def __init__(
self,
settings,
base_url: str = None,
username: str = None,
password: str = None,
name: str = None
self,
settings,
base_url: str | None = None,
username: str | None = None,
password: str | None = None,
name: str | None = None,
):
self.settings = settings
if not base_url:
logger.error("Skipping qBittorrent client entry: 'base_url' is required.")
raise ValueError("qBittorrent client must have a 'base_url'.")
error = "qBittorrent client must have a 'base_url'."
raise ValueError(error)
self.base_url = base_url.rstrip("/")
self.api_url = self.base_url + getattr(ApiEndpoints, "qbittorrent")
self.min_version = getattr(MinVersions, "qbittorrent")
self.api_url = self.base_url + ApiEndpoints.qbittorrent
self.min_version = MinVersions.qbittorrent
self.username = username
self.password = password
self.name = name
@@ -63,24 +65,28 @@ class QbitClient:
self._remove_none_attributes()
def _remove_none_attributes(self):
"""Removes attributes that are None to keep the object clean."""
"""Remove attributes that are None to keep the object clean."""
for attr in list(vars(self)):
if getattr(self, attr) is None:
delattr(self, attr)
async def refresh_cookie(self):
"""Refresh the qBittorrent session cookie."""
def _connection_error():
error = "Login failed."
raise ConnectionError(error)
try:
endpoint = f"{self.api_url}/auth/login"
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
data = {"username": getattr(self, "username", ""), "password": getattr(self, "password", "")}
headers = {"content-type": "application/x-www-form-urlencoded"}
response = await make_request(
"post", endpoint, self.settings, data=data, headers=headers
"post", endpoint, self.settings, data=data, headers=headers,
)
if response.text == "Fails.":
raise ConnectionError("Login failed.")
_connection_error()
self.cookie = {"SID": response.cookies["SID"]}
logger.debug("qBit cookie refreshed!")
@@ -89,8 +95,6 @@ class QbitClient:
self.cookie = {}
raise QbitError(e) from e
async def fetch_version(self):
"""Fetch the current qBittorrent version."""
endpoint = f"{self.api_url}/app/version"
@@ -98,24 +102,21 @@ class QbitClient:
self.version = response.text[1:] # Remove the '_v' prefix
logger.debug(f"qBit version for client qBittorrent: {self.version}")
async def validate_version(self):
"""Check if the qBittorrent version meets minimum and recommended requirements."""
min_version = self.settings.min_versions.qbittorrent
if version.parse(self.version) < version.parse(min_version):
logger.error(
f"Please update qBittorrent to at least version {min_version}. Current version: {self.version}"
)
raise QbitError(
f"qBittorrent version {self.version} is too old. Please update."
f"Please update qBittorrent to at least version {min_version}. Current version: {self.version}",
)
error = f"qBittorrent version {self.version} is too old. Please update."
raise QbitError(error)
if version.parse(self.version) < version.parse("5.0.0"):
logger.info(
f"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead."
"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead.",
)
async def create_tag(self):
"""Create the protection tag in qBittorrent if it doesn't exist."""
url = f"{self.api_url}/torrents/tags"
@@ -134,34 +135,32 @@ class QbitClient:
cookies=self.cookie,
)
if (
self.settings.general.public_tracker_handling == "tag_as_obsolete"
or self.settings.general.private_tracker_handling == "tag_as_obsolete"
):
if self.settings.general.obsolete_tag not in current_tags:
logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}")
if not self.settings.general.test_run:
data = {"tags": self.settings.general.obsolete_tag}
await make_request(
"post",
self.api_url + "/torrents/createTags",
self.settings,
data=data,
cookies=self.cookie,
)
if ((self.settings.general.public_tracker_handling == "tag_as_obsolete"
or self.settings.general.private_tracker_handling == "tag_as_obsolete")
and self.settings.general.obsolete_tag not in current_tags):
logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}")
if not self.settings.general.test_run:
data = {"tags": self.settings.general.obsolete_tag}
await make_request(
"post",
self.api_url + "/torrents/createTags",
self.settings,
data=data,
cookies=self.cookie,
)
async def set_unwanted_folder(self):
"""Set the 'unwanted folder' setting in qBittorrent if needed."""
if self.settings.jobs.remove_bad_files:
endpoint = f"{self.api_url}/app/preferences"
response = await make_request(
"get", endpoint, self.settings, cookies=self.cookie
"get", endpoint, self.settings, cookies=self.cookie,
)
qbit_settings = response.json()
if not qbit_settings.get("use_unwanted_folder"):
logger.info(
"Enabling 'Keep unselected files in .unwanted folder' in qBittorrent."
"Enabling 'Keep unselected files in .unwanted folder' in qBittorrent.",
)
if not self.settings.general.test_run:
data = {"json": '{"use_unwanted_folder": true}'}
@@ -173,39 +172,32 @@ class QbitClient:
cookies=self.cookie,
)
async def check_qbit_reachability(self):
"""Check if the qBittorrent URL is reachable."""
try:
endpoint = f"{self.api_url}/auth/login"
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
data = {"username": getattr(self, "username", ""), "password": getattr(self, "password", "")}
headers = {"content-type": "application/x-www-form-urlencoded"}
await make_request(
"post", endpoint, self.settings, data=data, headers=headers, log_error=False
"post", endpoint, self.settings, data=data, headers=headers, log_error=False,
)
except Exception as e:
except Exception as e: # noqa: BLE001
tip = "💡 Tip: Did you specify the URL (and username/password if required) correctly?"
logger.error(f"-- | qBittorrent\n❗️ {e}\n{tip}\n")
wait_and_exit()
async def check_qbit_connected(self):
"""Check if the qBittorrent is connected to internet."""
qbit_connection_status = ((
await make_request(
"get",
self.api_url + "/sync/maindata",
self.settings,
cookies=self.cookie,
)
).json())["server_state"]["connection_status"]
if qbit_connection_status == "disconnected":
return False
else:
return True
await make_request(
"get",
self.api_url + "/sync/maindata",
self.settings,
cookies=self.cookie,
)
).json())["server_state"]["connection_status"]
return qbit_connection_status != "disconnected"
async def setup(self):
"""Perform the qBittorrent setup by calling relevant managers."""
@@ -228,9 +220,8 @@ class QbitClient:
await self.create_tag()
await self.set_unwanted_folder()
async def get_protected_and_private(self):
"""Fetches torrents from qBittorrent and checks for protected and private status."""
"""Fetch torrents from qBittorrent and checks for protected and private status."""
protected_downloads = []
private_downloads = []
@@ -270,11 +261,12 @@ class QbitClient:
async def set_tag(self, tags, hashes):
"""
Sets tags to one or more torrents in qBittorrent.
Set tags to one or more torrents in qBittorrent.
Args:
tags (list): A list of tag names to be added.
hashes (list): A list of torrent hashes to which the tags should be applied.
"""
# Ensure hashes are provided as a string separated by '|'
hashes_str = "|".join(hashes)
@@ -285,7 +277,7 @@ class QbitClient:
# Prepare the data for the request
data = {
"hashes": hashes_str,
"tags": tags_str
"tags": tags_str,
}
# Perform the request to add the tag(s) to the torrents
@@ -294,15 +286,13 @@ class QbitClient:
self.api_url + "/torrents/addTags",
self.settings,
data=data,
cookies=self.cookie,
cookies=self.cookie,
)
async def get_download_progress(self, download_id):
items = await self.get_qbit_items(download_id)
return items[0]["completed"]
async def get_qbit_items(self, hashes=None):
params = None
if hashes:
@@ -319,7 +309,6 @@ class QbitClient:
)
return response.json()
async def get_torrent_files(self, download_id):
# this may not work if the wrong qbit
response = await make_request(
@@ -331,8 +320,8 @@ class QbitClient:
)
return response.json()
async def set_torrent_file_priority(self, download_id, file_id, priority = 0):
data={
async def set_torrent_file_priority(self, download_id, file_id, priority=0):
data = {
"hash": download_id.lower(),
"id": file_id,
"priority": priority,
@@ -344,4 +333,3 @@ class QbitClient:
data=data,
cookies=self.cookie,
)

View File

@@ -1,11 +1,12 @@
import yaml
from src.utils.log_setup import logger
from src.settings._validate_data_types import validate_data_types
from src.settings._config_as_yaml import get_config_as_yaml
from src.settings._validate_data_types import validate_data_types
from src.utils.log_setup import logger
VALID_TRACKER_HANDLING = {"remove", "skip", "obsolete_tag"}
class General:
"""Represents general settings for the application."""
VALID_TRACKER_HANDLING = {"remove", "skip", "obsolete_tag"}
log_level: str = "INFO"
test_run: bool = False
@@ -17,7 +18,6 @@ class General:
obsolete_tag: str = None
protected_tag: str = "Keep"
def __init__(self, config):
general_config = config.get("general", {})
self.log_level = general_config.get("log_level", self.log_level.upper())
@@ -32,31 +32,31 @@ class General:
self.protected_tag = general_config.get("protected_tag", self.protected_tag)
# Validate tracker handling settings
self.private_tracker_handling = self._validate_tracker_handling( self.private_tracker_handling, "private_tracker_handling" )
self.public_tracker_handling = self._validate_tracker_handling( self.public_tracker_handling, "public_tracker_handling" )
self.private_tracker_handling = self._validate_tracker_handling(self.private_tracker_handling, "private_tracker_handling")
self.public_tracker_handling = self._validate_tracker_handling(self.public_tracker_handling, "public_tracker_handling")
self.obsolete_tag = self._determine_obsolete_tag(self.obsolete_tag)
validate_data_types(self)
self._remove_none_attributes()
def _remove_none_attributes(self):
"""Removes attributes that are None to keep the object clean."""
"""Remove attributes that are None to keep the object clean."""
for attr in list(vars(self)):
if getattr(self, attr) is None:
delattr(self, attr)
def _validate_tracker_handling(self, value, field_name):
"""Validates tracker handling options. Defaults to 'remove' if invalid."""
if value not in self.VALID_TRACKER_HANDLING:
@staticmethod
def _validate_tracker_handling(value, field_name) -> str:
"""Validate tracker handling options. Defaults to 'remove' if invalid."""
if value not in VALID_TRACKER_HANDLING:
logger.error(
f"Invalid value '{value}' for {field_name}. Defaulting to 'remove'."
f"Invalid value '{value}' for {field_name}. Defaulting to 'remove'.",
)
return "remove"
return value
def _determine_obsolete_tag(self, obsolete_tag):
"""Defaults obsolete tag to "obsolete", only if none is provided and the tag is needed for handling """
"""Set obsolete tag to "obsolete", only if none is provided and the tag is needed for handling."""
if obsolete_tag is None and (
self.private_tracker_handling == "obsolete_tag"
or self.public_tracker_handling == "obsolete_tag"
@@ -65,10 +65,7 @@ class General:
return obsolete_tag
def config_as_yaml(self):
"""Logs all general settings."""
# yaml_output = yaml.dump(vars(self), indent=2, default_flow_style=False, sort_keys=False)
# logger.info(f"General Settings:\n{yaml_output}")
"""Log all general settings."""
return get_config_as_yaml(
vars(self),
)
)

View File

@@ -1,16 +1,16 @@
import requests
from packaging import version
from src.utils.log_setup import logger
from src.settings._config_as_yaml import get_config_as_yaml
from src.settings._constants import (
ApiEndpoints,
MinVersions,
FullQueueParameter,
DetailItemKey,
DetailItemSearchCommand,
FullQueueParameter,
MinVersions,
)
from src.settings._config_as_yaml import get_config_as_yaml
from src.utils.common import make_request, wait_and_exit
from src.utils.log_setup import logger
class Tracker:
@@ -52,9 +52,9 @@ class Instances:
"""Return a list of arr instances matching the given arr_type."""
return [arr for arr in self.arrs if arr.arr_type == arr_type]
def config_as_yaml(self, hide_internal_attr=True):
"""Logs all configured Arr instances while masking sensitive attributes."""
internal_attributes={
def config_as_yaml(self, *, hide_internal_attr=True):
"""Log all configured Arr instances while masking sensitive attributes."""
internal_attributes = {
"settings",
"api_url",
"min_version",
@@ -65,7 +65,7 @@ class Instances:
"detail_item_id_key",
"detail_item_ids_key",
"detail_item_search_command",
}
}
outputs = []
for arr_type in ["sonarr", "radarr", "readarr", "lidarr", "whisparr"]:
@@ -81,8 +81,6 @@ class Instances:
return "\n".join(outputs)
def check_any_arrs(self):
"""Check if there are any ARR instances."""
if not self.arrs:
@@ -117,12 +115,11 @@ class ArrInstances(list):
arr_type=arr_type,
base_url=client_config["base_url"],
api_key=client_config["api_key"],
)
),
)
except KeyError as e:
logger.error(
f"Missing required key {e} in {arr_type} client config."
)
error = f"Missing required key {e} in {arr_type} client config."
logger.error(error)
class ArrInstance:
@@ -135,11 +132,13 @@ class ArrInstance:
def __init__(self, settings, arr_type: str, base_url: str, api_key: str):
if not base_url:
logger.error(f"Skipping {arr_type} client entry: 'base_url' is required.")
raise ValueError(f"{arr_type} client must have a 'base_url'.")
error = f"{arr_type} client must have a 'base_url'."
raise ValueError(error)
if not api_key:
logger.error(f"Skipping {arr_type} client entry: 'api_key' is required.")
raise ValueError(f"{arr_type} client must have an 'api_key'.")
error = f"{arr_type} client must have an 'api_key'."
raise ValueError(error)
self.settings = settings
self.arr_type = arr_type
@@ -151,8 +150,8 @@ class ArrInstance:
self.detail_item_key = getattr(DetailItemKey, arr_type)
self.detail_item_id_key = self.detail_item_key + "Id"
self.detail_item_ids_key = self.detail_item_key + "Ids"
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
async def _check_ui_language(self):
"""Check if the UI language is set to English."""
endpoint = self.api_url + "/config/ui"
@@ -162,25 +161,26 @@ class ArrInstance:
if ui_language > 1: # Not English
logger.error("!! %s Error: !!", self.name)
logger.error(
f"> Decluttarr only works correctly if UI language is set to English (under Settings/UI in {self.name})"
f"> Decluttarr only works correctly if UI language is set to English (under Settings/UI in {self.name})",
)
logger.error(
"> Details: https://github.com/ManiMatter/decluttarr/issues/132)"
"> Details: https://github.com/ManiMatter/decluttarr/issues/132)",
)
raise ArrError("Not English")
error = "Not English"
raise ArrError(error)
def _check_min_version(self, status):
"""Check if ARR instance meets minimum version requirements."""
self.version = status["version"]
min_version = getattr(self.settings.min_versions, self.arr_type)
if min_version:
if version.parse(self.version) < version.parse(min_version):
logger.error("!! %s Error: !!", self.name)
logger.error(
f"> Please update {self.name} ({self.base_url}) to at least version {min_version}. Current version: {self.version}"
)
raise ArrError("Not meeting minimum version requirements")
if min_version and version.parse(self.version) < version.parse(min_version):
logger.error("!! %s Error: !!", self.name)
logger.error(
f"> Please update {self.name} ({self.base_url}) to at least version {min_version}. Current version: {self.version}",
)
error = f"Not meeting minimum version requirements: {min_version}"
logger.error(error)
def _check_arr_type(self, status):
"""Check if the ARR instance is of the correct type."""
@@ -188,9 +188,10 @@ class ArrInstance:
if actual_arr_type.lower() != self.arr_type:
logger.error("!! %s Error: !!", self.name)
logger.error(
f"> Your {self.name} ({self.base_url}) points to a {actual_arr_type} instance, rather than {self.arr_type}. Did you specify the wrong IP?"
f"> Your {self.name} ({self.base_url}) points to a {actual_arr_type} instance, rather than {self.arr_type}. Did you specify the wrong IP?",
)
raise ArrError("Wrong Arr Type")
error = "Wrong Arr Type"
logger.error(error)
async def _check_reachability(self):
"""Check if ARR instance is reachable."""
@@ -198,14 +199,13 @@ class ArrInstance:
endpoint = self.api_url + "/system/status"
headers = {"X-Api-Key": self.api_key}
response = await make_request(
"get", endpoint, self.settings, headers=headers, log_error=False
"get", endpoint, self.settings, headers=headers, log_error=False,
)
status = response.json()
return status
return response.json()
except Exception as e:
if isinstance(e, requests.exceptions.HTTPError):
response = getattr(e, "response", None)
if response is not None and response.status_code == 401:
if response is not None and response.status_code == 401: # noqa: PLR2004
tip = "💡 Tip: Have you configured the API_KEY correctly?"
else:
tip = f"💡 Tip: HTTP error occurred. Status: {getattr(response, 'status_code', 'unknown')}"
@@ -218,7 +218,7 @@ class ArrInstance:
raise ArrError(e) from e
async def setup(self):
"""Checks on specific ARR instance"""
"""Check on specific ARR instance."""
try:
status = await self._check_reachability()
self.name = status.get("instanceName", self.arr_type)
@@ -230,7 +230,7 @@ class ArrInstance:
logger.info(f"OK | {self.name} ({self.base_url})")
logger.debug(f"Current version of {self.name}: {self.version}")
except Exception as e:
except Exception as e: # noqa: BLE001
if not isinstance(e, ArrError):
logger.error(f"Unhandled error: {e}", exc_info=True)
wait_and_exit()
@@ -253,17 +253,19 @@ class ArrInstance:
return client.get("implementation", None)
return None
async def remove_queue_item(self, queue_id, blocklist=False):
async def remove_queue_item(self, queue_id, *, blocklist=False):
"""
Remove a specific queue item from the queue by its qeue id.
Remove a specific queue item from the queue by its queue id.
Sends a delete request to the API to remove the item.
Args:
queue_id (str): The quueue ID of the queue item to be removed.
queue_id (str): The queue ID of the queue item to be removed.
blocklist (bool): Whether to add the item to the blocklist. Default is False.
Returns:
bool: Returns True if the removal was successful, False otherwise.
"""
endpoint = f"{self.api_url}/queue/{queue_id}"
headers = {"X-Api-Key": self.api_key}
@@ -271,17 +273,14 @@ class ArrInstance:
# Send the request to remove the download from the queue
response = await make_request(
"delete", endpoint, self.settings, headers=headers, json=json_payload
"delete", endpoint, self.settings, headers=headers, json=json_payload,
)
# If the response is successful, return True, else return False
if response.status_code == 200:
return True
else:
return False
return response.status_code == 200 # noqa: PLR2004
async def is_monitored(self, detail_id):
"""Check if detail item (like a book, series, etc) is monitored."""
"""Check if detail item (like a book, series, etc.) is monitored."""
endpoint = f"{self.api_url}/{self.detail_item_key}/{detail_id}"
headers = {"X-Api-Key": self.api_key}

View File

@@ -1,6 +1,6 @@
from src.utils.log_setup import logger
from src.settings._validate_data_types import validate_data_types
from src.settings._config_as_yaml import get_config_as_yaml
from src.settings._validate_data_types import validate_data_types
from src.utils.log_setup import logger
class JobParams:
@@ -33,7 +33,7 @@ class JobParams:
self._remove_none_attributes()
def _remove_none_attributes(self):
"""Removes attributes that are None to keep the object clean."""
"""Remove attributes that are None to keep the object clean."""
for attr in list(vars(self)):
if getattr(self, attr) is None:
delattr(self, attr)
@@ -52,16 +52,16 @@ class JobDefaults:
job_defaults_config = config.get("job_defaults", {})
self.max_strikes = job_defaults_config.get("max_strikes", self.max_strikes)
self.max_concurrent_searches = job_defaults_config.get(
"max_concurrent_searches", self.max_concurrent_searches
"max_concurrent_searches", self.max_concurrent_searches,
)
self.min_days_between_searches = job_defaults_config.get(
"min_days_between_searches", self.min_days_between_searches
"min_days_between_searches", self.min_days_between_searches,
)
validate_data_types(self)
class Jobs:
"""Represents all jobs explicitly"""
"""Represent all jobs explicitly."""
def __init__(self, config):
self.job_defaults = JobDefaults(config)
@@ -73,10 +73,10 @@ class Jobs:
self.remove_bad_files = JobParams()
self.remove_failed_downloads = JobParams()
self.remove_failed_imports = JobParams(
message_patterns=self.job_defaults.message_patterns
message_patterns=self.job_defaults.message_patterns,
)
self.remove_metadata_missing = JobParams(
max_strikes=self.job_defaults.max_strikes
max_strikes=self.job_defaults.max_strikes,
)
self.remove_missing_files = JobParams()
self.remove_orphans = JobParams()
@@ -102,8 +102,7 @@ class Jobs:
self._set_job_settings(job_name, config["jobs"][job_name])
def _set_job_settings(self, job_name, job_config):
"""Sets per-job config settings"""
"""Set per-job config settings."""
job = getattr(self, job_name, None)
if (
job_config is None
@@ -128,8 +127,8 @@ class Jobs:
setattr(self, job_name, job)
validate_data_types(
job, self.job_defaults
) # Validates and applies defauls from job_defaults
job, self.job_defaults,
) # Validates and applies defaults from job_defaults
def log_status(self):
job_strings = []
@@ -152,7 +151,7 @@ class Jobs:
)
def list_job_status(self):
"""Returns a string showing each job and whether it's enabled or not using emojis."""
"""Return a string showing each job and whether it's enabled or not using emojis."""
lines = []
for name, obj in vars(self).items():
if hasattr(obj, "enabled"):

View File

@@ -1,5 +1,8 @@
import os
from pathlib import Path
import yaml
from src.utils.log_setup import logger
CONFIG_MAPPING = {
@@ -34,7 +37,8 @@ CONFIG_MAPPING = {
def get_user_config(settings):
"""Checks if data is read from enviornment variables, or from yaml file.
"""
Check if data is read from environment variables, or from yaml file.
Reads from environment variables if in docker, unless in docker-compose "USE_CONFIG_YAML" is set to true.
Then the config file is read.
@@ -53,7 +57,7 @@ def get_user_config(settings):
def _parse_env_var(key: str) -> dict | list | str | int | None:
"""Helper function to parse one setting input key"""
"""Parse one setting input key."""
raw_value = os.getenv(key)
if raw_value is None:
return None
@@ -67,7 +71,7 @@ def _parse_env_var(key: str) -> dict | list | str | int | None:
def _load_section(keys: list[str]) -> dict:
"""Helper function to parse one section of expected config"""
"""Parse one section of expected config."""
section_config = {}
for key in keys:
parsed = _parse_env_var(key)
@@ -76,14 +80,6 @@ def _load_section(keys: list[str]) -> dict:
return section_config
def _load_from_env() -> dict:
"""Main function to load settings from env"""
config = {}
for section, keys in CONFIG_MAPPING.items():
config[section] = _load_section(keys)
return config
def _load_from_env() -> dict:
config = {}
@@ -100,7 +96,7 @@ def _load_from_env() -> dict:
parsed_value = _lowercase(parsed_value)
except yaml.YAMLError as e:
logger.error(
f"Failed to parse environment variable {key} as YAML:\n{e}"
f"Failed to parse environment variable {key} as YAML:\n{e}",
)
parsed_value = {}
section_config[key.lower()] = parsed_value
@@ -111,28 +107,26 @@ def _load_from_env() -> dict:
def _lowercase(data):
"""Translates recevied keys (for instance setting-keys of jobs) to lower case"""
"""Translate received keys (for instance setting-keys of jobs) to lower case."""
if isinstance(data, dict):
return {str(k).lower(): _lowercase(v) for k, v in data.items()}
elif isinstance(data, list):
if isinstance(data, list):
return [_lowercase(item) for item in data]
else:
# Leave strings and other types unchanged
return data
# Leave strings and other types unchanged
return data
def _config_file_exists(settings):
config_path = settings.paths.config_file
return os.path.exists(config_path)
return Path(config_path).exists()
def _load_from_yaml_file(settings):
"""Reads config from YAML file and returns a dict."""
"""Read config from YAML file and returns a dict."""
config_path = settings.paths.config_file
try:
with open(config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file) or {}
return config
with Path(config_path).open(encoding="utf-8") as file:
return yaml.safe_load(file) or {}
except yaml.YAMLError as e:
logger.error("Error reading YAML file: %s", e)
return {}

View File

@@ -1,14 +1,21 @@
import inspect
from src.utils.log_setup import logger
def validate_data_types(cls, default_cls=None):
"""Ensures all attributes match expected types dynamically.
"""
Ensure all attributes match expected types dynamically.
If default_cls is provided, the default key is taken from this class rather than the own class
If the attribute doesn't exist in `default_cls`, fall back to `cls.__class__`.
"""
def _unhandled_conversion():
error = f"Unhandled type conversion for '{attr}': {expected_type}"
raise TypeError(error)
annotations = inspect.get_annotations(cls.__class__) # Extract type hints
for attr, expected_type in annotations.items():
@@ -17,7 +24,7 @@ def validate_data_types(cls, default_cls=None):
value = getattr(cls, attr)
default_source = default_cls if default_cls and hasattr(default_cls, attr) else cls.__class__
default_value = getattr(default_source, attr, None)
default_value = getattr(default_source, attr, None)
if value == default_value:
continue
@@ -37,22 +44,20 @@ def validate_data_types(cls, default_cls=None):
elif expected_type is dict:
value = convert_to_dict(value)
else:
raise TypeError(f"Unhandled type conversion for '{attr}': {expected_type}")
except Exception as e:
_unhandled_conversion()
except Exception as e: # noqa: BLE001
logger.error(
f"❗️ Invalid type for '{attr}': Expected {expected_type.__name__}, but got {type(value).__name__}. "
f"Error: {e}. Using default value: {default_value}"
f"Error: {e}. Using default value: {default_value}",
)
value = default_value
setattr(cls, attr, value)
# --- Helper Functions ---
def convert_to_bool(raw_value):
"""Converts strings like 'yes', 'no', 'true', 'false' into boolean values."""
"""Convert strings like 'yes', 'no', 'true', 'false' into boolean values."""
if isinstance(raw_value, bool):
return raw_value
@@ -64,28 +69,29 @@ def convert_to_bool(raw_value):
if raw_value in true_values:
return True
elif raw_value in false_values:
if raw_value in false_values:
return False
else:
raise ValueError(f"Invalid boolean value: '{raw_value}'")
error = f"Invalid boolean value: '{raw_value}'"
raise ValueError(error)
def convert_to_str(raw_value):
"""Ensures a string and trims whitespace."""
"""Ensure a string and trims whitespace."""
if isinstance(raw_value, str):
return raw_value.strip()
return str(raw_value).strip()
def convert_to_list(raw_value):
"""Ensures a value is a list."""
"""Ensure a value is a list."""
if isinstance(raw_value, list):
return [convert_to_str(item) for item in raw_value]
return [convert_to_str(raw_value)] # Wrap single values in a list
def convert_to_dict(raw_value):
"""Ensures a value is a dictionary."""
"""Ensure a value is a dictionary."""
if isinstance(raw_value, dict):
return {convert_to_str(k): v for k, v in raw_value.items()}
raise TypeError(f"Expected dict but got {type(raw_value).__name__}")
error = f"Expected dict but got {type(raw_value).__name__}"
raise TypeError(error)

View File

@@ -1,14 +1,14 @@
from src.utils.log_setup import configure_logging
from src.settings._constants import Envs, MinVersions, Paths
# from src.settings._migrate_legacy import migrate_legacy
from src.settings._general import General
from src.settings._jobs import Jobs
from src.settings._download_clients import DownloadClients
from src.settings._general import General
from src.settings._instances import Instances
from src.settings._jobs import Jobs
from src.settings._user_config import get_user_config
from src.utils.log_setup import configure_logging
class Settings:
min_versions = MinVersions()
paths = Paths()
@@ -21,7 +21,6 @@ class Settings:
self.instances = Instances(config, self)
configure_logging(self)
def __repr__(self):
sections = [
("ENVIRONMENT SETTINGS", "envs"),
@@ -30,11 +29,8 @@ class Settings:
("JOB SETTINGS", "jobs"),
("INSTANCE SETTINGS", "instances"),
("DOWNLOAD CLIENT SETTINGS", "download_clients"),
]
messages = []
messages.append("🛠️ Decluttarr - Settings 🛠️")
messages.append("-"*80)
# messages.append("")
]
messages = ["🛠️ Decluttarr - Settings 🛠️", "-" * 80]
for title, attr_name in sections:
section = getattr(self, attr_name, None)
section_content = section.config_as_yaml()
@@ -44,11 +40,11 @@ class Settings:
elif section_content != "{}":
messages.append(self._format_section_title(title))
messages.append(section_content)
messages.append("") # Extra linebreak after section
messages.append("") # Extra linebreak after section
return "\n".join(messages)
def _format_section_title(self, name, border_length=50, symbol="="):
@staticmethod
def _format_section_title(name, border_length=50, symbol="=") -> str:
"""Format section title with centered name and hash borders."""
padding = max(border_length - len(name) - 2, 0) # 4 for spaces
left_hashes = right_hashes = padding // 2