mirror of
https://github.com/ManiMatter/decluttarr.git
synced 2026-04-20 07:54:18 +02:00
Added sigterm handling to exit cleanly when running in Docker.
Fixed typos and various linting issues such as PEP violations. Added ruff and fixed common issues and linting issues.
This commit is contained in:
0
src/settings/__init__.py
Normal file
0
src/settings/__init__.py
Normal file
@@ -1,5 +1,6 @@
|
||||
import yaml
|
||||
|
||||
|
||||
def mask_sensitive_value(value, key, sensitive_attributes):
|
||||
"""Mask the value if it's in the sensitive attributes."""
|
||||
return "*****" if key in sensitive_attributes else value
|
||||
@@ -40,19 +41,19 @@ def clean_object(obj, sensitive_attributes, internal_attributes, hide_internal_a
|
||||
"""Clean an object (either a dict, class instance, or other types)."""
|
||||
if isinstance(obj, dict):
|
||||
return clean_dict(obj, sensitive_attributes, internal_attributes, hide_internal_attr)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
if hasattr(obj, "__dict__"):
|
||||
return clean_dict(vars(obj), sensitive_attributes, internal_attributes, hide_internal_attr)
|
||||
else:
|
||||
return mask_sensitive_value(obj, "", sensitive_attributes)
|
||||
return mask_sensitive_value(obj, "", sensitive_attributes)
|
||||
|
||||
|
||||
def get_config_as_yaml(
|
||||
data,
|
||||
sensitive_attributes=None,
|
||||
internal_attributes=None,
|
||||
*,
|
||||
hide_internal_attr=True,
|
||||
):
|
||||
"""Main function to process the configuration into YAML format."""
|
||||
"""Process the configuration into YAML format."""
|
||||
if sensitive_attributes is None:
|
||||
sensitive_attributes = set()
|
||||
if internal_attributes is None:
|
||||
@@ -67,7 +68,7 @@ def get_config_as_yaml(
|
||||
# Process list-based config
|
||||
if isinstance(obj, list):
|
||||
cleaned_list = clean_list(
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
|
||||
)
|
||||
if cleaned_list:
|
||||
config_output[key] = cleaned_list
|
||||
@@ -75,7 +76,7 @@ def get_config_as_yaml(
|
||||
# Process dict or class-like object config
|
||||
else:
|
||||
cleaned_obj = clean_object(
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr,
|
||||
)
|
||||
if cleaned_obj:
|
||||
config_output[key] = cleaned_obj
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.settings._download_clients_qBit import QbitClients
|
||||
from src.settings._download_clients_qbit import QbitClients
|
||||
|
||||
download_client_types = ["qbittorrent"]
|
||||
|
||||
|
||||
class DownloadClients:
|
||||
"""Represents all download clients."""
|
||||
|
||||
qbittorrent = None
|
||||
download_client_types = [
|
||||
"qbittorrent",
|
||||
]
|
||||
|
||||
def __init__(self, config, settings):
|
||||
self._set_qbit_clients(config, settings)
|
||||
self.check_unique_download_client_types()
|
||||
@@ -15,7 +17,7 @@ class DownloadClients:
|
||||
download_clients = config.get("download_clients", {})
|
||||
if isinstance(download_clients, dict):
|
||||
self.qbittorrent = QbitClients(config, settings)
|
||||
if not self.qbittorrent: # Unsets settings in general section needed for qbit (if no qbit is defined)
|
||||
if not self.qbittorrent: # Unsets settings in general section needed for qbit (if no qbit is defined)
|
||||
for key in [
|
||||
"private_tracker_handling",
|
||||
"public_tracker_handling",
|
||||
@@ -25,40 +27,42 @@ class DownloadClients:
|
||||
setattr(settings.general, key, None)
|
||||
|
||||
def config_as_yaml(self):
|
||||
"""Logs all download clients."""
|
||||
"""Log all download clients."""
|
||||
return get_config_as_yaml(
|
||||
{"qbittorrent": self.qbittorrent},
|
||||
sensitive_attributes={"username", "password", "cookie"},
|
||||
internal_attributes={ "api_url", "cookie", "settings", "min_version"},
|
||||
hide_internal_attr=True
|
||||
internal_attributes={"api_url", "cookie", "settings", "min_version"},
|
||||
hide_internal_attr=True,
|
||||
)
|
||||
|
||||
|
||||
def check_unique_download_client_types(self):
|
||||
"""Ensures that all download client names are unique.
|
||||
This is important since downloadClient in arr goes by name, and
|
||||
this is needed to link it to the right IP set up in the yaml config
|
||||
(which may be different to the one donfigured in arr)"""
|
||||
"""
|
||||
Ensure that all download client names are unique.
|
||||
|
||||
This is important since downloadClient in arr goes by name, and
|
||||
this is needed to link it to the right IP set up in the yaml config
|
||||
(which may be different to the one configured in arr)
|
||||
"""
|
||||
seen = set()
|
||||
for download_client_type in self.download_client_types:
|
||||
for download_client_type in download_client_types:
|
||||
download_clients = getattr(self, download_client_type, [])
|
||||
|
||||
# Check each client in the list
|
||||
for client in download_clients:
|
||||
name = getattr(client, "name", None)
|
||||
if name is None:
|
||||
raise ValueError(f'{download_client_type} client does not have a name ({client.base_url}).\nMake sure that the name corresponds with the name set in your *arr app for that download client.')
|
||||
error = f"{download_client_type} client does not have a name ({client.base_url}).\nMake sure that the name corresponds with the name set in your *arr app for that download client."
|
||||
raise ValueError(error)
|
||||
|
||||
if name.lower() in seen:
|
||||
raise ValueError(f"Download client names must be unique. Duplicate name found: '{name}'\nMake sure that the name corresponds with the name set in your *arr app for that download client.")
|
||||
else:
|
||||
seen.add(name.lower())
|
||||
error = f"Download client names must be unique. Duplicate name found: '{name}'\nMake sure that the name corresponds with the name set in your *arr app for that download client."
|
||||
raise ValueError(error)
|
||||
seen.add(name.lower())
|
||||
|
||||
def get_download_client_by_name(self, name: str):
|
||||
"""Retrieve the download client and its type by its name."""
|
||||
name_lower = name.lower()
|
||||
for download_client_type in self.download_client_types:
|
||||
for download_client_type in download_client_types:
|
||||
download_clients = getattr(self, download_client_type, [])
|
||||
|
||||
# Check each client in the list
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from packaging import version
|
||||
from src.utils.common import make_request, wait_and_exit
|
||||
|
||||
from src.settings._constants import ApiEndpoints, MinVersions
|
||||
from src.utils.common import make_request, wait_and_exit
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
|
||||
class QbitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QbitClients(list):
|
||||
"""Represents all qBittorrent clients"""
|
||||
"""Represents all qBittorrent clients."""
|
||||
|
||||
def __init__(self, config, settings):
|
||||
super().__init__()
|
||||
@@ -19,7 +21,7 @@ class QbitClients(list):
|
||||
|
||||
if not isinstance(qbit_config, list):
|
||||
logger.error(
|
||||
"Invalid config format for qbittorrent clients. Expected a list."
|
||||
"Invalid config format for qbittorrent clients. Expected a list.",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -30,29 +32,29 @@ class QbitClients(list):
|
||||
logger.error(f"Error parsing qbittorrent client config: {e}")
|
||||
|
||||
|
||||
|
||||
class QbitClient:
|
||||
"""Represents a single qBittorrent client."""
|
||||
|
||||
cookie: str = None
|
||||
cookie: dict[str, str] = None
|
||||
version: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings,
|
||||
base_url: str = None,
|
||||
username: str = None,
|
||||
password: str = None,
|
||||
name: str = None
|
||||
self,
|
||||
settings,
|
||||
base_url: str | None = None,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
self.settings = settings
|
||||
if not base_url:
|
||||
logger.error("Skipping qBittorrent client entry: 'base_url' is required.")
|
||||
raise ValueError("qBittorrent client must have a 'base_url'.")
|
||||
error = "qBittorrent client must have a 'base_url'."
|
||||
raise ValueError(error)
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_url = self.base_url + getattr(ApiEndpoints, "qbittorrent")
|
||||
self.min_version = getattr(MinVersions, "qbittorrent")
|
||||
self.api_url = self.base_url + ApiEndpoints.qbittorrent
|
||||
self.min_version = MinVersions.qbittorrent
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.name = name
|
||||
@@ -63,24 +65,28 @@ class QbitClient:
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
"""Remove attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
|
||||
|
||||
async def refresh_cookie(self):
|
||||
"""Refresh the qBittorrent session cookie."""
|
||||
|
||||
def _connection_error():
|
||||
error = "Login failed."
|
||||
raise ConnectionError(error)
|
||||
|
||||
try:
|
||||
endpoint = f"{self.api_url}/auth/login"
|
||||
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
|
||||
data = {"username": getattr(self, "username", ""), "password": getattr(self, "password", "")}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
response = await make_request(
|
||||
"post", endpoint, self.settings, data=data, headers=headers
|
||||
"post", endpoint, self.settings, data=data, headers=headers,
|
||||
)
|
||||
|
||||
if response.text == "Fails.":
|
||||
raise ConnectionError("Login failed.")
|
||||
_connection_error()
|
||||
|
||||
self.cookie = {"SID": response.cookies["SID"]}
|
||||
logger.debug("qBit cookie refreshed!")
|
||||
@@ -89,8 +95,6 @@ class QbitClient:
|
||||
self.cookie = {}
|
||||
raise QbitError(e) from e
|
||||
|
||||
|
||||
|
||||
async def fetch_version(self):
|
||||
"""Fetch the current qBittorrent version."""
|
||||
endpoint = f"{self.api_url}/app/version"
|
||||
@@ -98,24 +102,21 @@ class QbitClient:
|
||||
self.version = response.text[1:] # Remove the '_v' prefix
|
||||
logger.debug(f"qBit version for client qBittorrent: {self.version}")
|
||||
|
||||
|
||||
async def validate_version(self):
|
||||
"""Check if the qBittorrent version meets minimum and recommended requirements."""
|
||||
min_version = self.settings.min_versions.qbittorrent
|
||||
|
||||
if version.parse(self.version) < version.parse(min_version):
|
||||
logger.error(
|
||||
f"Please update qBittorrent to at least version {min_version}. Current version: {self.version}"
|
||||
)
|
||||
raise QbitError(
|
||||
f"qBittorrent version {self.version} is too old. Please update."
|
||||
f"Please update qBittorrent to at least version {min_version}. Current version: {self.version}",
|
||||
)
|
||||
error = f"qBittorrent version {self.version} is too old. Please update."
|
||||
raise QbitError(error)
|
||||
if version.parse(self.version) < version.parse("5.0.0"):
|
||||
logger.info(
|
||||
f"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead."
|
||||
"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead.",
|
||||
)
|
||||
|
||||
|
||||
async def create_tag(self):
|
||||
"""Create the protection tag in qBittorrent if it doesn't exist."""
|
||||
url = f"{self.api_url}/torrents/tags"
|
||||
@@ -134,34 +135,32 @@ class QbitClient:
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
if (
|
||||
self.settings.general.public_tracker_handling == "tag_as_obsolete"
|
||||
or self.settings.general.private_tracker_handling == "tag_as_obsolete"
|
||||
):
|
||||
if self.settings.general.obsolete_tag not in current_tags:
|
||||
logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}")
|
||||
if not self.settings.general.test_run:
|
||||
data = {"tags": self.settings.general.obsolete_tag}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/createTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
if ((self.settings.general.public_tracker_handling == "tag_as_obsolete"
|
||||
or self.settings.general.private_tracker_handling == "tag_as_obsolete")
|
||||
and self.settings.general.obsolete_tag not in current_tags):
|
||||
logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}")
|
||||
if not self.settings.general.test_run:
|
||||
data = {"tags": self.settings.general.obsolete_tag}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/createTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
async def set_unwanted_folder(self):
|
||||
"""Set the 'unwanted folder' setting in qBittorrent if needed."""
|
||||
if self.settings.jobs.remove_bad_files:
|
||||
endpoint = f"{self.api_url}/app/preferences"
|
||||
response = await make_request(
|
||||
"get", endpoint, self.settings, cookies=self.cookie
|
||||
"get", endpoint, self.settings, cookies=self.cookie,
|
||||
)
|
||||
qbit_settings = response.json()
|
||||
|
||||
if not qbit_settings.get("use_unwanted_folder"):
|
||||
logger.info(
|
||||
"Enabling 'Keep unselected files in .unwanted folder' in qBittorrent."
|
||||
"Enabling 'Keep unselected files in .unwanted folder' in qBittorrent.",
|
||||
)
|
||||
if not self.settings.general.test_run:
|
||||
data = {"json": '{"use_unwanted_folder": true}'}
|
||||
@@ -173,39 +172,32 @@ class QbitClient:
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
|
||||
async def check_qbit_reachability(self):
|
||||
"""Check if the qBittorrent URL is reachable."""
|
||||
try:
|
||||
endpoint = f"{self.api_url}/auth/login"
|
||||
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
|
||||
data = {"username": getattr(self, "username", ""), "password": getattr(self, "password", "")}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
await make_request(
|
||||
"post", endpoint, self.settings, data=data, headers=headers, log_error=False
|
||||
"post", endpoint, self.settings, data=data, headers=headers, log_error=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
tip = "💡 Tip: Did you specify the URL (and username/password if required) correctly?"
|
||||
logger.error(f"-- | qBittorrent\n❗️ {e}\n{tip}\n")
|
||||
wait_and_exit()
|
||||
|
||||
|
||||
async def check_qbit_connected(self):
|
||||
"""Check if the qBittorrent is connected to internet."""
|
||||
qbit_connection_status = ((
|
||||
await make_request(
|
||||
"get",
|
||||
self.api_url + "/sync/maindata",
|
||||
self.settings,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
).json())["server_state"]["connection_status"]
|
||||
if qbit_connection_status == "disconnected":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
await make_request(
|
||||
"get",
|
||||
self.api_url + "/sync/maindata",
|
||||
self.settings,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
).json())["server_state"]["connection_status"]
|
||||
return qbit_connection_status != "disconnected"
|
||||
|
||||
async def setup(self):
|
||||
"""Perform the qBittorrent setup by calling relevant managers."""
|
||||
@@ -228,9 +220,8 @@ class QbitClient:
|
||||
await self.create_tag()
|
||||
await self.set_unwanted_folder()
|
||||
|
||||
|
||||
async def get_protected_and_private(self):
|
||||
"""Fetches torrents from qBittorrent and checks for protected and private status."""
|
||||
"""Fetch torrents from qBittorrent and checks for protected and private status."""
|
||||
protected_downloads = []
|
||||
private_downloads = []
|
||||
|
||||
@@ -270,11 +261,12 @@ class QbitClient:
|
||||
|
||||
async def set_tag(self, tags, hashes):
|
||||
"""
|
||||
Sets tags to one or more torrents in qBittorrent.
|
||||
Set tags to one or more torrents in qBittorrent.
|
||||
|
||||
Args:
|
||||
tags (list): A list of tag names to be added.
|
||||
hashes (list): A list of torrent hashes to which the tags should be applied.
|
||||
|
||||
"""
|
||||
# Ensure hashes are provided as a string separated by '|'
|
||||
hashes_str = "|".join(hashes)
|
||||
@@ -285,7 +277,7 @@ class QbitClient:
|
||||
# Prepare the data for the request
|
||||
data = {
|
||||
"hashes": hashes_str,
|
||||
"tags": tags_str
|
||||
"tags": tags_str,
|
||||
}
|
||||
|
||||
# Perform the request to add the tag(s) to the torrents
|
||||
@@ -294,15 +286,13 @@ class QbitClient:
|
||||
self.api_url + "/torrents/addTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
|
||||
async def get_download_progress(self, download_id):
|
||||
items = await self.get_qbit_items(download_id)
|
||||
return items[0]["completed"]
|
||||
|
||||
|
||||
async def get_qbit_items(self, hashes=None):
|
||||
params = None
|
||||
if hashes:
|
||||
@@ -319,7 +309,6 @@ class QbitClient:
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_torrent_files(self, download_id):
|
||||
# this may not work if the wrong qbit
|
||||
response = await make_request(
|
||||
@@ -331,8 +320,8 @@ class QbitClient:
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def set_torrent_file_priority(self, download_id, file_id, priority = 0):
|
||||
data={
|
||||
async def set_torrent_file_priority(self, download_id, file_id, priority=0):
|
||||
data = {
|
||||
"hash": download_id.lower(),
|
||||
"id": file_id,
|
||||
"priority": priority,
|
||||
@@ -344,4 +333,3 @@ class QbitClient:
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import yaml
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
VALID_TRACKER_HANDLING = {"remove", "skip", "obsolete_tag"}
|
||||
|
||||
|
||||
class General:
|
||||
"""Represents general settings for the application."""
|
||||
VALID_TRACKER_HANDLING = {"remove", "skip", "obsolete_tag"}
|
||||
|
||||
log_level: str = "INFO"
|
||||
test_run: bool = False
|
||||
@@ -17,7 +18,6 @@ class General:
|
||||
obsolete_tag: str = None
|
||||
protected_tag: str = "Keep"
|
||||
|
||||
|
||||
def __init__(self, config):
|
||||
general_config = config.get("general", {})
|
||||
self.log_level = general_config.get("log_level", self.log_level.upper())
|
||||
@@ -32,31 +32,31 @@ class General:
|
||||
self.protected_tag = general_config.get("protected_tag", self.protected_tag)
|
||||
|
||||
# Validate tracker handling settings
|
||||
self.private_tracker_handling = self._validate_tracker_handling( self.private_tracker_handling, "private_tracker_handling" )
|
||||
self.public_tracker_handling = self._validate_tracker_handling( self.public_tracker_handling, "public_tracker_handling" )
|
||||
self.private_tracker_handling = self._validate_tracker_handling(self.private_tracker_handling, "private_tracker_handling")
|
||||
self.public_tracker_handling = self._validate_tracker_handling(self.public_tracker_handling, "public_tracker_handling")
|
||||
self.obsolete_tag = self._determine_obsolete_tag(self.obsolete_tag)
|
||||
|
||||
|
||||
validate_data_types(self)
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
"""Remove attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
|
||||
def _validate_tracker_handling(self, value, field_name):
|
||||
"""Validates tracker handling options. Defaults to 'remove' if invalid."""
|
||||
if value not in self.VALID_TRACKER_HANDLING:
|
||||
@staticmethod
|
||||
def _validate_tracker_handling(value, field_name) -> str:
|
||||
"""Validate tracker handling options. Defaults to 'remove' if invalid."""
|
||||
if value not in VALID_TRACKER_HANDLING:
|
||||
logger.error(
|
||||
f"Invalid value '{value}' for {field_name}. Defaulting to 'remove'."
|
||||
f"Invalid value '{value}' for {field_name}. Defaulting to 'remove'.",
|
||||
)
|
||||
return "remove"
|
||||
return value
|
||||
|
||||
def _determine_obsolete_tag(self, obsolete_tag):
|
||||
"""Defaults obsolete tag to "obsolete", only if none is provided and the tag is needed for handling """
|
||||
"""Set obsolete tag to "obsolete", only if none is provided and the tag is needed for handling."""
|
||||
if obsolete_tag is None and (
|
||||
self.private_tracker_handling == "obsolete_tag"
|
||||
or self.public_tracker_handling == "obsolete_tag"
|
||||
@@ -65,10 +65,7 @@ class General:
|
||||
return obsolete_tag
|
||||
|
||||
def config_as_yaml(self):
|
||||
"""Logs all general settings."""
|
||||
# yaml_output = yaml.dump(vars(self), indent=2, default_flow_style=False, sort_keys=False)
|
||||
# logger.info(f"General Settings:\n{yaml_output}")
|
||||
|
||||
"""Log all general settings."""
|
||||
return get_config_as_yaml(
|
||||
vars(self),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.settings._constants import (
|
||||
ApiEndpoints,
|
||||
MinVersions,
|
||||
FullQueueParameter,
|
||||
DetailItemKey,
|
||||
DetailItemSearchCommand,
|
||||
FullQueueParameter,
|
||||
MinVersions,
|
||||
)
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.utils.common import make_request, wait_and_exit
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
|
||||
class Tracker:
|
||||
@@ -52,9 +52,9 @@ class Instances:
|
||||
"""Return a list of arr instances matching the given arr_type."""
|
||||
return [arr for arr in self.arrs if arr.arr_type == arr_type]
|
||||
|
||||
def config_as_yaml(self, hide_internal_attr=True):
|
||||
"""Logs all configured Arr instances while masking sensitive attributes."""
|
||||
internal_attributes={
|
||||
def config_as_yaml(self, *, hide_internal_attr=True):
|
||||
"""Log all configured Arr instances while masking sensitive attributes."""
|
||||
internal_attributes = {
|
||||
"settings",
|
||||
"api_url",
|
||||
"min_version",
|
||||
@@ -65,7 +65,7 @@ class Instances:
|
||||
"detail_item_id_key",
|
||||
"detail_item_ids_key",
|
||||
"detail_item_search_command",
|
||||
}
|
||||
}
|
||||
|
||||
outputs = []
|
||||
for arr_type in ["sonarr", "radarr", "readarr", "lidarr", "whisparr"]:
|
||||
@@ -81,8 +81,6 @@ class Instances:
|
||||
|
||||
return "\n".join(outputs)
|
||||
|
||||
|
||||
|
||||
def check_any_arrs(self):
|
||||
"""Check if there are any ARR instances."""
|
||||
if not self.arrs:
|
||||
@@ -117,12 +115,11 @@ class ArrInstances(list):
|
||||
arr_type=arr_type,
|
||||
base_url=client_config["base_url"],
|
||||
api_key=client_config["api_key"],
|
||||
)
|
||||
),
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Missing required key {e} in {arr_type} client config."
|
||||
)
|
||||
error = f"Missing required key {e} in {arr_type} client config."
|
||||
logger.error(error)
|
||||
|
||||
|
||||
class ArrInstance:
|
||||
@@ -135,11 +132,13 @@ class ArrInstance:
|
||||
def __init__(self, settings, arr_type: str, base_url: str, api_key: str):
|
||||
if not base_url:
|
||||
logger.error(f"Skipping {arr_type} client entry: 'base_url' is required.")
|
||||
raise ValueError(f"{arr_type} client must have a 'base_url'.")
|
||||
error = f"{arr_type} client must have a 'base_url'."
|
||||
raise ValueError(error)
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"Skipping {arr_type} client entry: 'api_key' is required.")
|
||||
raise ValueError(f"{arr_type} client must have an 'api_key'.")
|
||||
error = f"{arr_type} client must have an 'api_key'."
|
||||
raise ValueError(error)
|
||||
|
||||
self.settings = settings
|
||||
self.arr_type = arr_type
|
||||
@@ -151,8 +150,8 @@ class ArrInstance:
|
||||
self.detail_item_key = getattr(DetailItemKey, arr_type)
|
||||
self.detail_item_id_key = self.detail_item_key + "Id"
|
||||
self.detail_item_ids_key = self.detail_item_key + "Ids"
|
||||
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
|
||||
|
||||
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
|
||||
|
||||
async def _check_ui_language(self):
|
||||
"""Check if the UI language is set to English."""
|
||||
endpoint = self.api_url + "/config/ui"
|
||||
@@ -162,25 +161,26 @@ class ArrInstance:
|
||||
if ui_language > 1: # Not English
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Decluttarr only works correctly if UI language is set to English (under Settings/UI in {self.name})"
|
||||
f"> Decluttarr only works correctly if UI language is set to English (under Settings/UI in {self.name})",
|
||||
)
|
||||
logger.error(
|
||||
"> Details: https://github.com/ManiMatter/decluttarr/issues/132)"
|
||||
"> Details: https://github.com/ManiMatter/decluttarr/issues/132)",
|
||||
)
|
||||
raise ArrError("Not English")
|
||||
error = "Not English"
|
||||
raise ArrError(error)
|
||||
|
||||
def _check_min_version(self, status):
|
||||
"""Check if ARR instance meets minimum version requirements."""
|
||||
self.version = status["version"]
|
||||
min_version = getattr(self.settings.min_versions, self.arr_type)
|
||||
|
||||
if min_version:
|
||||
if version.parse(self.version) < version.parse(min_version):
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Please update {self.name} ({self.base_url}) to at least version {min_version}. Current version: {self.version}"
|
||||
)
|
||||
raise ArrError("Not meeting minimum version requirements")
|
||||
if min_version and version.parse(self.version) < version.parse(min_version):
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Please update {self.name} ({self.base_url}) to at least version {min_version}. Current version: {self.version}",
|
||||
)
|
||||
error = f"Not meeting minimum version requirements: {min_version}"
|
||||
logger.error(error)
|
||||
|
||||
def _check_arr_type(self, status):
|
||||
"""Check if the ARR instance is of the correct type."""
|
||||
@@ -188,9 +188,10 @@ class ArrInstance:
|
||||
if actual_arr_type.lower() != self.arr_type:
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Your {self.name} ({self.base_url}) points to a {actual_arr_type} instance, rather than {self.arr_type}. Did you specify the wrong IP?"
|
||||
f"> Your {self.name} ({self.base_url}) points to a {actual_arr_type} instance, rather than {self.arr_type}. Did you specify the wrong IP?",
|
||||
)
|
||||
raise ArrError("Wrong Arr Type")
|
||||
error = "Wrong Arr Type"
|
||||
logger.error(error)
|
||||
|
||||
async def _check_reachability(self):
|
||||
"""Check if ARR instance is reachable."""
|
||||
@@ -198,14 +199,13 @@ class ArrInstance:
|
||||
endpoint = self.api_url + "/system/status"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
response = await make_request(
|
||||
"get", endpoint, self.settings, headers=headers, log_error=False
|
||||
"get", endpoint, self.settings, headers=headers, log_error=False,
|
||||
)
|
||||
status = response.json()
|
||||
return status
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
if isinstance(e, requests.exceptions.HTTPError):
|
||||
response = getattr(e, "response", None)
|
||||
if response is not None and response.status_code == 401:
|
||||
if response is not None and response.status_code == 401: # noqa: PLR2004
|
||||
tip = "💡 Tip: Have you configured the API_KEY correctly?"
|
||||
else:
|
||||
tip = f"💡 Tip: HTTP error occurred. Status: {getattr(response, 'status_code', 'unknown')}"
|
||||
@@ -218,7 +218,7 @@ class ArrInstance:
|
||||
raise ArrError(e) from e
|
||||
|
||||
async def setup(self):
|
||||
"""Checks on specific ARR instance"""
|
||||
"""Check on specific ARR instance."""
|
||||
try:
|
||||
status = await self._check_reachability()
|
||||
self.name = status.get("instanceName", self.arr_type)
|
||||
@@ -230,7 +230,7 @@ class ArrInstance:
|
||||
logger.info(f"OK | {self.name} ({self.base_url})")
|
||||
logger.debug(f"Current version of {self.name}: {self.version}")
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
if not isinstance(e, ArrError):
|
||||
logger.error(f"Unhandled error: {e}", exc_info=True)
|
||||
wait_and_exit()
|
||||
@@ -253,17 +253,19 @@ class ArrInstance:
|
||||
return client.get("implementation", None)
|
||||
return None
|
||||
|
||||
async def remove_queue_item(self, queue_id, blocklist=False):
|
||||
async def remove_queue_item(self, queue_id, *, blocklist=False):
|
||||
"""
|
||||
Remove a specific queue item from the queue by its qeue id.
|
||||
Remove a specific queue item from the queue by its queue id.
|
||||
|
||||
Sends a delete request to the API to remove the item.
|
||||
|
||||
Args:
|
||||
queue_id (str): The quueue ID of the queue item to be removed.
|
||||
queue_id (str): The queue ID of the queue item to be removed.
|
||||
blocklist (bool): Whether to add the item to the blocklist. Default is False.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if the removal was successful, False otherwise.
|
||||
|
||||
"""
|
||||
endpoint = f"{self.api_url}/queue/{queue_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
@@ -271,17 +273,14 @@ class ArrInstance:
|
||||
|
||||
# Send the request to remove the download from the queue
|
||||
response = await make_request(
|
||||
"delete", endpoint, self.settings, headers=headers, json=json_payload
|
||||
"delete", endpoint, self.settings, headers=headers, json=json_payload,
|
||||
)
|
||||
|
||||
# If the response is successful, return True, else return False
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return response.status_code == 200 # noqa: PLR2004
|
||||
|
||||
async def is_monitored(self, detail_id):
|
||||
"""Check if detail item (like a book, series, etc) is monitored."""
|
||||
"""Check if detail item (like a book, series, etc.) is monitored."""
|
||||
endpoint = f"{self.api_url}/{self.detail_item_key}/{detail_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
|
||||
class JobParams:
|
||||
@@ -33,7 +33,7 @@ class JobParams:
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
"""Remove attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
@@ -52,16 +52,16 @@ class JobDefaults:
|
||||
job_defaults_config = config.get("job_defaults", {})
|
||||
self.max_strikes = job_defaults_config.get("max_strikes", self.max_strikes)
|
||||
self.max_concurrent_searches = job_defaults_config.get(
|
||||
"max_concurrent_searches", self.max_concurrent_searches
|
||||
"max_concurrent_searches", self.max_concurrent_searches,
|
||||
)
|
||||
self.min_days_between_searches = job_defaults_config.get(
|
||||
"min_days_between_searches", self.min_days_between_searches
|
||||
"min_days_between_searches", self.min_days_between_searches,
|
||||
)
|
||||
validate_data_types(self)
|
||||
|
||||
|
||||
class Jobs:
|
||||
"""Represents all jobs explicitly"""
|
||||
"""Represent all jobs explicitly."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.job_defaults = JobDefaults(config)
|
||||
@@ -73,10 +73,10 @@ class Jobs:
|
||||
self.remove_bad_files = JobParams()
|
||||
self.remove_failed_downloads = JobParams()
|
||||
self.remove_failed_imports = JobParams(
|
||||
message_patterns=self.job_defaults.message_patterns
|
||||
message_patterns=self.job_defaults.message_patterns,
|
||||
)
|
||||
self.remove_metadata_missing = JobParams(
|
||||
max_strikes=self.job_defaults.max_strikes
|
||||
max_strikes=self.job_defaults.max_strikes,
|
||||
)
|
||||
self.remove_missing_files = JobParams()
|
||||
self.remove_orphans = JobParams()
|
||||
@@ -102,8 +102,7 @@ class Jobs:
|
||||
self._set_job_settings(job_name, config["jobs"][job_name])
|
||||
|
||||
def _set_job_settings(self, job_name, job_config):
|
||||
"""Sets per-job config settings"""
|
||||
|
||||
"""Set per-job config settings."""
|
||||
job = getattr(self, job_name, None)
|
||||
if (
|
||||
job_config is None
|
||||
@@ -128,8 +127,8 @@ class Jobs:
|
||||
|
||||
setattr(self, job_name, job)
|
||||
validate_data_types(
|
||||
job, self.job_defaults
|
||||
) # Validates and applies defauls from job_defaults
|
||||
job, self.job_defaults,
|
||||
) # Validates and applies defaults from job_defaults
|
||||
|
||||
def log_status(self):
|
||||
job_strings = []
|
||||
@@ -152,7 +151,7 @@ class Jobs:
|
||||
)
|
||||
|
||||
def list_job_status(self):
|
||||
"""Returns a string showing each job and whether it's enabled or not using emojis."""
|
||||
"""Return a string showing each job and whether it's enabled or not using emojis."""
|
||||
lines = []
|
||||
for name, obj in vars(self).items():
|
||||
if hasattr(obj, "enabled"):
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
CONFIG_MAPPING = {
|
||||
@@ -34,7 +37,8 @@ CONFIG_MAPPING = {
|
||||
|
||||
|
||||
def get_user_config(settings):
|
||||
"""Checks if data is read from enviornment variables, or from yaml file.
|
||||
"""
|
||||
Check if data is read from environment variables, or from yaml file.
|
||||
|
||||
Reads from environment variables if in docker, unless in docker-compose "USE_CONFIG_YAML" is set to true.
|
||||
Then the config file is read.
|
||||
@@ -53,7 +57,7 @@ def get_user_config(settings):
|
||||
|
||||
|
||||
def _parse_env_var(key: str) -> dict | list | str | int | None:
|
||||
"""Helper function to parse one setting input key"""
|
||||
"""Parse one setting input key."""
|
||||
raw_value = os.getenv(key)
|
||||
if raw_value is None:
|
||||
return None
|
||||
@@ -67,7 +71,7 @@ def _parse_env_var(key: str) -> dict | list | str | int | None:
|
||||
|
||||
|
||||
def _load_section(keys: list[str]) -> dict:
|
||||
"""Helper function to parse one section of expected config"""
|
||||
"""Parse one section of expected config."""
|
||||
section_config = {}
|
||||
for key in keys:
|
||||
parsed = _parse_env_var(key)
|
||||
@@ -76,14 +80,6 @@ def _load_section(keys: list[str]) -> dict:
|
||||
return section_config
|
||||
|
||||
|
||||
def _load_from_env() -> dict:
|
||||
"""Main function to load settings from env"""
|
||||
config = {}
|
||||
for section, keys in CONFIG_MAPPING.items():
|
||||
config[section] = _load_section(keys)
|
||||
return config
|
||||
|
||||
|
||||
def _load_from_env() -> dict:
|
||||
config = {}
|
||||
|
||||
@@ -100,7 +96,7 @@ def _load_from_env() -> dict:
|
||||
parsed_value = _lowercase(parsed_value)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(
|
||||
f"Failed to parse environment variable {key} as YAML:\n{e}"
|
||||
f"Failed to parse environment variable {key} as YAML:\n{e}",
|
||||
)
|
||||
parsed_value = {}
|
||||
section_config[key.lower()] = parsed_value
|
||||
@@ -111,28 +107,26 @@ def _load_from_env() -> dict:
|
||||
|
||||
|
||||
def _lowercase(data):
|
||||
"""Translates recevied keys (for instance setting-keys of jobs) to lower case"""
|
||||
"""Translate received keys (for instance setting-keys of jobs) to lower case."""
|
||||
if isinstance(data, dict):
|
||||
return {str(k).lower(): _lowercase(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
if isinstance(data, list):
|
||||
return [_lowercase(item) for item in data]
|
||||
else:
|
||||
# Leave strings and other types unchanged
|
||||
return data
|
||||
# Leave strings and other types unchanged
|
||||
return data
|
||||
|
||||
|
||||
def _config_file_exists(settings):
|
||||
config_path = settings.paths.config_file
|
||||
return os.path.exists(config_path)
|
||||
return Path(config_path).exists()
|
||||
|
||||
|
||||
def _load_from_yaml_file(settings):
|
||||
"""Reads config from YAML file and returns a dict."""
|
||||
"""Read config from YAML file and returns a dict."""
|
||||
config_path = settings.paths.config_file
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
config = yaml.safe_load(file) or {}
|
||||
return config
|
||||
with Path(config_path).open(encoding="utf-8") as file:
|
||||
return yaml.safe_load(file) or {}
|
||||
except yaml.YAMLError as e:
|
||||
logger.error("Error reading YAML file: %s", e)
|
||||
return {}
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
|
||||
|
||||
import inspect
|
||||
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
|
||||
def validate_data_types(cls, default_cls=None):
|
||||
"""Ensures all attributes match expected types dynamically.
|
||||
"""
|
||||
Ensure all attributes match expected types dynamically.
|
||||
|
||||
If default_cls is provided, the default key is taken from this class rather than the own class
|
||||
If the attribute doesn't exist in `default_cls`, fall back to `cls.__class__`.
|
||||
|
||||
"""
|
||||
|
||||
def _unhandled_conversion():
|
||||
error = f"Unhandled type conversion for '{attr}': {expected_type}"
|
||||
raise TypeError(error)
|
||||
|
||||
annotations = inspect.get_annotations(cls.__class__) # Extract type hints
|
||||
|
||||
for attr, expected_type in annotations.items():
|
||||
@@ -17,7 +24,7 @@ def validate_data_types(cls, default_cls=None):
|
||||
|
||||
value = getattr(cls, attr)
|
||||
default_source = default_cls if default_cls and hasattr(default_cls, attr) else cls.__class__
|
||||
default_value = getattr(default_source, attr, None)
|
||||
default_value = getattr(default_source, attr, None)
|
||||
|
||||
if value == default_value:
|
||||
continue
|
||||
@@ -37,22 +44,20 @@ def validate_data_types(cls, default_cls=None):
|
||||
elif expected_type is dict:
|
||||
value = convert_to_dict(value)
|
||||
else:
|
||||
raise TypeError(f"Unhandled type conversion for '{attr}': {expected_type}")
|
||||
except Exception as e:
|
||||
|
||||
_unhandled_conversion()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(
|
||||
f"❗️ Invalid type for '{attr}': Expected {expected_type.__name__}, but got {type(value).__name__}. "
|
||||
f"Error: {e}. Using default value: {default_value}"
|
||||
f"Error: {e}. Using default value: {default_value}",
|
||||
)
|
||||
value = default_value
|
||||
|
||||
setattr(cls, attr, value)
|
||||
|
||||
|
||||
|
||||
# --- Helper Functions ---
|
||||
def convert_to_bool(raw_value):
|
||||
"""Converts strings like 'yes', 'no', 'true', 'false' into boolean values."""
|
||||
"""Convert strings like 'yes', 'no', 'true', 'false' into boolean values."""
|
||||
if isinstance(raw_value, bool):
|
||||
return raw_value
|
||||
|
||||
@@ -64,28 +69,29 @@ def convert_to_bool(raw_value):
|
||||
|
||||
if raw_value in true_values:
|
||||
return True
|
||||
elif raw_value in false_values:
|
||||
if raw_value in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean value: '{raw_value}'")
|
||||
error = f"Invalid boolean value: '{raw_value}'"
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def convert_to_str(raw_value):
|
||||
"""Ensures a string and trims whitespace."""
|
||||
"""Ensure a string and trims whitespace."""
|
||||
if isinstance(raw_value, str):
|
||||
return raw_value.strip()
|
||||
return str(raw_value).strip()
|
||||
|
||||
|
||||
def convert_to_list(raw_value):
|
||||
"""Ensures a value is a list."""
|
||||
"""Ensure a value is a list."""
|
||||
if isinstance(raw_value, list):
|
||||
return [convert_to_str(item) for item in raw_value]
|
||||
return [convert_to_str(raw_value)] # Wrap single values in a list
|
||||
|
||||
|
||||
def convert_to_dict(raw_value):
|
||||
"""Ensures a value is a dictionary."""
|
||||
"""Ensure a value is a dictionary."""
|
||||
if isinstance(raw_value, dict):
|
||||
return {convert_to_str(k): v for k, v in raw_value.items()}
|
||||
raise TypeError(f"Expected dict but got {type(raw_value).__name__}")
|
||||
error = f"Expected dict but got {type(raw_value).__name__}"
|
||||
raise TypeError(error)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from src.utils.log_setup import configure_logging
|
||||
from src.settings._constants import Envs, MinVersions, Paths
|
||||
# from src.settings._migrate_legacy import migrate_legacy
|
||||
from src.settings._general import General
|
||||
from src.settings._jobs import Jobs
|
||||
from src.settings._download_clients import DownloadClients
|
||||
from src.settings._general import General
|
||||
from src.settings._instances import Instances
|
||||
from src.settings._jobs import Jobs
|
||||
from src.settings._user_config import get_user_config
|
||||
from src.utils.log_setup import configure_logging
|
||||
|
||||
|
||||
class Settings:
|
||||
|
||||
|
||||
min_versions = MinVersions()
|
||||
paths = Paths()
|
||||
|
||||
@@ -21,7 +21,6 @@ class Settings:
|
||||
self.instances = Instances(config, self)
|
||||
configure_logging(self)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
sections = [
|
||||
("ENVIRONMENT SETTINGS", "envs"),
|
||||
@@ -30,11 +29,8 @@ class Settings:
|
||||
("JOB SETTINGS", "jobs"),
|
||||
("INSTANCE SETTINGS", "instances"),
|
||||
("DOWNLOAD CLIENT SETTINGS", "download_clients"),
|
||||
]
|
||||
messages = []
|
||||
messages.append("🛠️ Decluttarr - Settings 🛠️")
|
||||
messages.append("-"*80)
|
||||
# messages.append("")
|
||||
]
|
||||
messages = ["🛠️ Decluttarr - Settings 🛠️", "-" * 80]
|
||||
for title, attr_name in sections:
|
||||
section = getattr(self, attr_name, None)
|
||||
section_content = section.config_as_yaml()
|
||||
@@ -44,11 +40,11 @@ class Settings:
|
||||
elif section_content != "{}":
|
||||
messages.append(self._format_section_title(title))
|
||||
messages.append(section_content)
|
||||
messages.append("") # Extra linebreak after section
|
||||
messages.append("") # Extra linebreak after section
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
def _format_section_title(self, name, border_length=50, symbol="="):
|
||||
@staticmethod
|
||||
def _format_section_title(name, border_length=50, symbol="=") -> str:
|
||||
"""Format section title with centered name and hash borders."""
|
||||
padding = max(border_length - len(name) - 2, 0) # 4 for spaces
|
||||
left_hashes = right_hashes = padding // 2
|
||||
|
||||
Reference in New Issue
Block a user