mirror of
https://github.com/ManiMatter/decluttarr.git
synced 2026-04-20 02:54:20 +02:00
Code Rewrite to support multi instances
This commit is contained in:
83
src/settings/_config_as_yaml.py
Normal file
83
src/settings/_config_as_yaml.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import yaml
|
||||
|
||||
def mask_sensitive_value(value, key, sensitive_attributes):
|
||||
"""Mask the value if it's in the sensitive attributes."""
|
||||
return "*****" if key in sensitive_attributes else value
|
||||
|
||||
|
||||
def filter_internal_attributes(data, internal_attributes, hide_internal_attr):
|
||||
"""Filter out internal attributes based on the hide_internal_attr flag."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if not (hide_internal_attr and k in internal_attributes)
|
||||
}
|
||||
|
||||
|
||||
def clean_dict(data, sensitive_attributes, internal_attributes, hide_internal_attr):
|
||||
"""Clean a dictionary by masking sensitive attributes and filtering internal ones."""
|
||||
cleaned = {
|
||||
k: mask_sensitive_value(v, k, sensitive_attributes)
|
||||
for k, v in data.items()
|
||||
}
|
||||
return filter_internal_attributes(cleaned, internal_attributes, hide_internal_attr)
|
||||
|
||||
|
||||
def clean_list(obj, sensitive_attributes, internal_attributes, hide_internal_attr):
|
||||
"""Clean a list of dicts or class instances."""
|
||||
cleaned_list = []
|
||||
for entry in obj:
|
||||
if isinstance(entry, dict):
|
||||
cleaned_list.append(clean_dict(entry, sensitive_attributes, internal_attributes, hide_internal_attr))
|
||||
elif hasattr(entry, "__dict__"):
|
||||
cleaned_list.append(clean_dict(vars(entry), sensitive_attributes, internal_attributes, hide_internal_attr))
|
||||
else:
|
||||
cleaned_list.append(entry)
|
||||
return cleaned_list
|
||||
|
||||
|
||||
def clean_object(obj, sensitive_attributes, internal_attributes, hide_internal_attr):
|
||||
"""Clean an object (either a dict, class instance, or other types)."""
|
||||
if isinstance(obj, dict):
|
||||
return clean_dict(obj, sensitive_attributes, internal_attributes, hide_internal_attr)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return clean_dict(vars(obj), sensitive_attributes, internal_attributes, hide_internal_attr)
|
||||
else:
|
||||
return mask_sensitive_value(obj, "", sensitive_attributes)
|
||||
|
||||
|
||||
def get_config_as_yaml(
|
||||
data,
|
||||
sensitive_attributes=None,
|
||||
internal_attributes=None,
|
||||
hide_internal_attr=True,
|
||||
):
|
||||
"""Main function to process the configuration into YAML format."""
|
||||
if sensitive_attributes is None:
|
||||
sensitive_attributes = set()
|
||||
if internal_attributes is None:
|
||||
internal_attributes = set()
|
||||
|
||||
config_output = {}
|
||||
|
||||
for key, obj in data.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
|
||||
# Process list-based config
|
||||
if isinstance(obj, list):
|
||||
cleaned_list = clean_list(
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr
|
||||
)
|
||||
if cleaned_list:
|
||||
config_output[key] = cleaned_list
|
||||
|
||||
# Process dict or class-like object config
|
||||
else:
|
||||
cleaned_obj = clean_object(
|
||||
obj, sensitive_attributes, internal_attributes, hide_internal_attr
|
||||
)
|
||||
if cleaned_obj:
|
||||
config_output[key] = cleaned_obj
|
||||
|
||||
return yaml.dump(config_output, indent=2, default_flow_style=False, sort_keys=False)
|
||||
61
src/settings/_constants.py
Normal file
61
src/settings/_constants.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
|
||||
|
||||
class Envs:
|
||||
def __init__(self):
|
||||
self.in_docker = os.environ.get("IN_DOCKER", "").lower() == "true"
|
||||
self.image_tag = os.environ.get("IMAGE_TAG") or "Local"
|
||||
self.short_commit_id = os.environ.get("SHORT_COMMIT_ID") or "n/a"
|
||||
self.use_config_yaml = False # Overwritten later if config file exists
|
||||
|
||||
def config_as_yaml(self):
|
||||
return get_config_as_yaml(self.__dict__)
|
||||
|
||||
|
||||
class Paths:
|
||||
logs = "./temp/log.txt"
|
||||
tracker = "./temp/tracker.txt"
|
||||
config_file = "./config/config.yaml"
|
||||
|
||||
|
||||
class ApiEndpoints:
|
||||
radarr = "/api/v3"
|
||||
sonarr = "/api/v3"
|
||||
lidarr = "/api/v1"
|
||||
readarr = "/api/v1"
|
||||
whisparr = "/api/v3"
|
||||
qbittorrent = "/api/v2"
|
||||
|
||||
|
||||
class MinVersions:
|
||||
radarr = "5.10.3.9171"
|
||||
sonarr = "4.0.9.2332"
|
||||
lidarr = "2.11.1.4621"
|
||||
readarr = "0.4.15.2787"
|
||||
whisparr = "2.0.0.548"
|
||||
qbittorrent = "4.3.0"
|
||||
|
||||
|
||||
class FullQueueParameter:
|
||||
radarr = "includeUnknownMovieItems"
|
||||
sonarr = "includeUnknownSeriesItems"
|
||||
lidarr = "includeUnknownArtistItems"
|
||||
readarr = "includeUnknownAuthorItems"
|
||||
whisparr = "includeUnknownSeriesItems"
|
||||
|
||||
|
||||
class DetailItemKey:
|
||||
radarr = "movie"
|
||||
sonarr = "episode"
|
||||
lidarr = "album"
|
||||
readarr = "book"
|
||||
whisparr = "episode"
|
||||
|
||||
|
||||
class DetailItemSearchCommand:
|
||||
radarr = "MoviesSearch"
|
||||
sonarr = "EpisodeSearch"
|
||||
lidarr = "BookSearch"
|
||||
readarr = "BookSearch"
|
||||
whisparr = None
|
||||
69
src/settings/_download_clients.py
Normal file
69
src/settings/_download_clients.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.settings._download_clients_qBit import QbitClients
|
||||
|
||||
class DownloadClients:
|
||||
"""Represents all download clients."""
|
||||
qbittorrent = None
|
||||
download_client_types = [
|
||||
"qbittorrent",
|
||||
]
|
||||
def __init__(self, config, settings):
|
||||
self._set_qbit_clients(config, settings)
|
||||
self.check_unique_download_client_types()
|
||||
|
||||
def _set_qbit_clients(self, config, settings):
|
||||
download_clients = config.get("download_clients", {})
|
||||
if isinstance(download_clients, dict):
|
||||
self.qbittorrent = QbitClients(config, settings)
|
||||
if not self.qbittorrent: # Unsets settings in general section needed for qbit (if no qbit is defined)
|
||||
for key in [
|
||||
"private_tracker_handling",
|
||||
"public_tracker_handling",
|
||||
"obsolete_tag",
|
||||
"protected_tag",
|
||||
]:
|
||||
setattr(settings.general, key, None)
|
||||
|
||||
def config_as_yaml(self):
|
||||
"""Logs all download clients."""
|
||||
return get_config_as_yaml(
|
||||
{"qbittorrent": self.qbittorrent},
|
||||
sensitive_attributes={"username", "password", "cookie"},
|
||||
internal_attributes={ "api_url", "cookie", "settings", "min_version"},
|
||||
hide_internal_attr=True
|
||||
)
|
||||
|
||||
|
||||
def check_unique_download_client_types(self):
|
||||
"""Ensures that all download client names are unique.
|
||||
This is important since downloadClient in arr goes by name, and
|
||||
this is needed to link it to the right IP set up in the yaml config
|
||||
(which may be different to the one donfigured in arr)"""
|
||||
|
||||
seen = set()
|
||||
for download_client_type in self.download_client_types:
|
||||
download_clients = getattr(self, download_client_type, [])
|
||||
|
||||
# Check each client in the list
|
||||
for client in download_clients:
|
||||
name = getattr(client, "name", None)
|
||||
if name is None:
|
||||
raise ValueError(f'{download_client_type} client does not have a name ({client.base_url}).\nMake sure that the name corresponds with the name set in your *arr app for that download client.')
|
||||
|
||||
if name.lower() in seen:
|
||||
raise ValueError(f"Download client names must be unique. Duplicate name found: '{name}'\nMake sure that the name corresponds with the name set in your *arr app for that download client.")
|
||||
else:
|
||||
seen.add(name.lower())
|
||||
|
||||
def get_download_client_by_name(self, name: str):
|
||||
"""Retrieve the download client and its type by its name."""
|
||||
name_lower = name.lower()
|
||||
for download_client_type in self.download_client_types:
|
||||
download_clients = getattr(self, download_client_type, [])
|
||||
|
||||
# Check each client in the list
|
||||
for client in download_clients:
|
||||
if client.name.lower() == name_lower:
|
||||
return client, download_client_type
|
||||
|
||||
return None, None
|
||||
347
src/settings/_download_clients_qBit.py
Normal file
347
src/settings/_download_clients_qBit.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from packaging import version
|
||||
from src.utils.common import make_request, wait_and_exit
|
||||
from src.settings._constants import ApiEndpoints, MinVersions
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
|
||||
class QbitError(Exception):
|
||||
pass
|
||||
|
||||
class QbitClients(list):
|
||||
"""Represents all qBittorrent clients"""
|
||||
|
||||
def __init__(self, config, settings):
|
||||
super().__init__()
|
||||
self._set_qbit_clients(config, settings)
|
||||
|
||||
def _set_qbit_clients(self, config, settings):
|
||||
qbit_config = config.get("download_clients", {}).get("qbittorrent", [])
|
||||
|
||||
if not isinstance(qbit_config, list):
|
||||
logger.error(
|
||||
"Invalid config format for qbittorrent clients. Expected a list."
|
||||
)
|
||||
return
|
||||
|
||||
for client_config in qbit_config:
|
||||
try:
|
||||
self.append(QbitClient(settings, **client_config))
|
||||
except TypeError as e:
|
||||
logger.error(f"Error parsing qbittorrent client config: {e}")
|
||||
|
||||
|
||||
|
||||
class QbitClient:
|
||||
"""Represents a single qBittorrent client."""
|
||||
|
||||
cookie: str = None
|
||||
version: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings,
|
||||
base_url: str = None,
|
||||
username: str = None,
|
||||
password: str = None,
|
||||
name: str = None
|
||||
):
|
||||
self.settings = settings
|
||||
if not base_url:
|
||||
logger.error("Skipping qBittorrent client entry: 'base_url' is required.")
|
||||
raise ValueError("qBittorrent client must have a 'base_url'.")
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_url = self.base_url + getattr(ApiEndpoints, "qbittorrent")
|
||||
self.min_version = getattr(MinVersions, "qbittorrent")
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.name = name
|
||||
if not self.name:
|
||||
logger.verbose("No name provided for qbittorrent client, assuming 'qBitorrent'. If the name used in your *arr is different, please correct either the name in your *arr, or set the name in your config")
|
||||
self.name = "qBittorrent"
|
||||
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
|
||||
|
||||
async def refresh_cookie(self):
|
||||
"""Refresh the qBittorrent session cookie."""
|
||||
try:
|
||||
endpoint = f"{self.api_url}/auth/login"
|
||||
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
response = await make_request(
|
||||
"post", endpoint, self.settings, data=data, headers=headers
|
||||
)
|
||||
|
||||
if response.text == "Fails.":
|
||||
raise ConnectionError("Login failed.")
|
||||
|
||||
self.cookie = {"SID": response.cookies["SID"]}
|
||||
logger.debug("qBit cookie refreshed!")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing qBit cookie: {e}")
|
||||
self.cookie = {}
|
||||
raise QbitError(e) from e
|
||||
|
||||
|
||||
|
||||
async def fetch_version(self):
|
||||
"""Fetch the current qBittorrent version."""
|
||||
endpoint = f"{self.api_url}/app/version"
|
||||
response = await make_request("get", endpoint, self.settings, cookies=self.cookie)
|
||||
self.version = response.text[1:] # Remove the '_v' prefix
|
||||
logger.debug(f"qBit version for client qBittorrent: {self.version}")
|
||||
|
||||
|
||||
async def validate_version(self):
|
||||
"""Check if the qBittorrent version meets minimum and recommended requirements."""
|
||||
min_version = self.settings.min_versions.qbittorrent
|
||||
|
||||
if version.parse(self.version) < version.parse(min_version):
|
||||
logger.error(
|
||||
f"Please update qBittorrent to at least version {min_version}. Current version: {self.version}"
|
||||
)
|
||||
raise QbitError(
|
||||
f"qBittorrent version {self.version} is too old. Please update."
|
||||
)
|
||||
if version.parse(self.version) < version.parse("5.0.0"):
|
||||
logger.info(
|
||||
f"[Tip!] Consider upgrading to qBittorrent v5.0.0 or newer to reduce network overhead."
|
||||
)
|
||||
|
||||
|
||||
async def create_tag(self):
|
||||
"""Create the protection tag in qBittorrent if it doesn't exist."""
|
||||
url = f"{self.api_url}/torrents/tags"
|
||||
response = await make_request("get", url, self.settings, cookies=self.cookie)
|
||||
|
||||
current_tags = response.json()
|
||||
if self.settings.general.protected_tag not in current_tags:
|
||||
logger.verbose(f"Creating protection tag: {self.settings.general.protected_tag}")
|
||||
if not self.settings.general.test_run:
|
||||
data = {"tags": self.settings.general.protected_tag}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/createTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
if (
|
||||
self.settings.general.public_tracker_handling == "tag_as_obsolete"
|
||||
or self.settings.general.private_tracker_handling == "tag_as_obsolete"
|
||||
):
|
||||
if self.settings.general.obsolete_tag not in current_tags:
|
||||
logger.verbose(f"Creating obsolete tag: {self.settings.general.obsolete_tag}")
|
||||
if not self.settings.general.test_run:
|
||||
data = {"tags": self.settings.general.obsolete_tag}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/createTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
async def set_unwanted_folder(self):
|
||||
"""Set the 'unwanted folder' setting in qBittorrent if needed."""
|
||||
if self.settings.jobs.remove_bad_files:
|
||||
endpoint = f"{self.api_url}/app/preferences"
|
||||
response = await make_request(
|
||||
"get", endpoint, self.settings, cookies=self.cookie
|
||||
)
|
||||
qbit_settings = response.json()
|
||||
|
||||
if not qbit_settings.get("use_unwanted_folder"):
|
||||
logger.info(
|
||||
"Enabling 'Keep unselected files in .unwanted folder' in qBittorrent."
|
||||
)
|
||||
if not self.settings.general.test_run:
|
||||
data = {"json": '{"use_unwanted_folder": true}'}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/app/setPreferences",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
|
||||
async def check_qbit_reachability(self):
|
||||
"""Check if the qBittorrent URL is reachable."""
|
||||
try:
|
||||
endpoint = f"{self.api_url}/auth/login"
|
||||
data = {"username": getattr(self, 'username', ''), "password": getattr(self, 'password', '')}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
await make_request(
|
||||
"post", endpoint, self.settings, data=data, headers=headers, log_error=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
tip = "💡 Tip: Did you specify the URL (and username/password if required) correctly?"
|
||||
logger.error(f"-- | qBittorrent\n❗️ {e}\n{tip}\n")
|
||||
wait_and_exit()
|
||||
|
||||
|
||||
async def check_qbit_connected(self):
|
||||
"""Check if the qBittorrent is connected to internet."""
|
||||
qbit_connection_status = ((
|
||||
await make_request(
|
||||
"get",
|
||||
self.api_url + "/sync/maindata",
|
||||
self.settings,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
).json())["server_state"]["connection_status"]
|
||||
if qbit_connection_status == "disconnected":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
|
||||
async def setup(self):
|
||||
"""Perform the qBittorrent setup by calling relevant managers."""
|
||||
# Check reachabilty
|
||||
await self.check_qbit_reachability()
|
||||
|
||||
# Refresh the qBittorrent cookie first
|
||||
await self.refresh_cookie()
|
||||
|
||||
try:
|
||||
# Fetch version and validate it
|
||||
await self.fetch_version()
|
||||
await self.validate_version()
|
||||
logger.info(f"OK | qBittorrent ({self.base_url})")
|
||||
except QbitError as e:
|
||||
logger.error(f"qBittorrent version check failed: {e}")
|
||||
wait_and_exit() # Exit if version check fails
|
||||
|
||||
# Continue with other setup tasks regardless of version check result
|
||||
await self.create_tag()
|
||||
await self.set_unwanted_folder()
|
||||
|
||||
|
||||
async def get_protected_and_private(self):
|
||||
"""Fetches torrents from qBittorrent and checks for protected and private status."""
|
||||
protected_downloads = []
|
||||
private_downloads = []
|
||||
|
||||
# Fetch all torrents
|
||||
qbit_items = await self.get_qbit_items()
|
||||
|
||||
for qbit_item in qbit_items:
|
||||
# Fetch protected torrents (by tag)
|
||||
if self.settings.general.protected_tag in qbit_item.get("tags", []):
|
||||
protected_downloads.append(qbit_item["hash"].upper())
|
||||
|
||||
# Fetch private torrents
|
||||
if not (self.settings.general.private_tracker_handling == "remove" or self.settings.general.public_tracker_handling == "remove"):
|
||||
if version.parse(self.version) >= version.parse("5.0.0"):
|
||||
if qbit_item.get("private"):
|
||||
private_downloads.append(qbit_item["hash"].upper())
|
||||
else:
|
||||
qbit_item_props = await make_request(
|
||||
"get",
|
||||
self.api_url + "/torrents/properties",
|
||||
self.settings,
|
||||
params={"hash": qbit_item["hash"]},
|
||||
cookies=self.cookie,
|
||||
)
|
||||
if not qbit_item_props:
|
||||
logger.error(
|
||||
"Torrent %s not found on qBittorrent - potentially removed while checking if private. "
|
||||
"Consider upgrading qBit to v5.0.4 or newer to avoid this problem.",
|
||||
qbit_item["hash"],
|
||||
)
|
||||
continue
|
||||
if qbit_item_props.get("is_private", False):
|
||||
private_downloads.append(qbit_item["hash"].upper())
|
||||
qbit_item["private"] = qbit_item_props.get("is_private", None)
|
||||
|
||||
return protected_downloads, private_downloads
|
||||
|
||||
async def set_tag(self, tags, hashes):
|
||||
"""
|
||||
Sets tags to one or more torrents in qBittorrent.
|
||||
|
||||
Args:
|
||||
tags (list): A list of tag names to be added.
|
||||
hashes (list): A list of torrent hashes to which the tags should be applied.
|
||||
"""
|
||||
# Ensure hashes are provided as a string separated by '|'
|
||||
hashes_str = "|".join(hashes)
|
||||
|
||||
# Ensure tags are provided as a string separated by ',' (comma)
|
||||
tags_str = ",".join(tags)
|
||||
|
||||
# Prepare the data for the request
|
||||
data = {
|
||||
"hashes": hashes_str,
|
||||
"tags": tags_str
|
||||
}
|
||||
|
||||
# Perform the request to add the tag(s) to the torrents
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/addTags",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
|
||||
async def get_download_progress(self, download_id):
|
||||
items = await self.get_qbit_items(download_id)
|
||||
return items[0]["completed"]
|
||||
|
||||
|
||||
async def get_qbit_items(self, hashes=None):
|
||||
params = None
|
||||
if hashes:
|
||||
if isinstance(hashes, str):
|
||||
hashes = [hashes]
|
||||
params = {"hashes": "|".join(hashes).lower()} # Join and make lowercase
|
||||
|
||||
response = await make_request(
|
||||
method="get",
|
||||
endpoint=self.api_url + "/torrents/info",
|
||||
settings=self.settings,
|
||||
params=params,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_torrent_files(self, download_id):
|
||||
# this may not work if the wrong qbit
|
||||
response = await make_request(
|
||||
method="get",
|
||||
endpoint=self.api_url + "/torrents/files",
|
||||
settings=self.settings,
|
||||
params={"hash": download_id.lower()},
|
||||
cookies=self.cookie,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def set_torrent_file_priority(self, download_id, file_id, priority = 0):
|
||||
data={
|
||||
"hash": download_id.lower(),
|
||||
"id": file_id,
|
||||
"priority": priority,
|
||||
}
|
||||
await make_request(
|
||||
"post",
|
||||
self.api_url + "/torrents/filePrio",
|
||||
self.settings,
|
||||
data=data,
|
||||
cookies=self.cookie,
|
||||
)
|
||||
|
||||
74
src/settings/_general.py
Normal file
74
src/settings/_general.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import yaml
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
|
||||
class General:
|
||||
"""Represents general settings for the application."""
|
||||
VALID_TRACKER_HANDLING = {"remove", "skip", "obsolete_tag"}
|
||||
|
||||
log_level: str = "INFO"
|
||||
test_run: bool = False
|
||||
ssl_verification: bool = True
|
||||
timer: float = 10.0
|
||||
ignored_download_clients: list = []
|
||||
private_tracker_handling: str = "remove"
|
||||
public_tracker_handling: str = "remove"
|
||||
obsolete_tag: str = None
|
||||
protected_tag: str = "Keep"
|
||||
|
||||
|
||||
def __init__(self, config):
|
||||
general_config = config.get("general", {})
|
||||
self.log_level = general_config.get("log_level", self.log_level.upper())
|
||||
self.test_run = general_config.get("test_run", self.test_run)
|
||||
self.timer = general_config.get("timer", self.timer)
|
||||
self.ssl_verification = general_config.get("ssl_verification", self.ssl_verification)
|
||||
self.ignored_download_clients = general_config.get("ignored_download_clients", self.ignored_download_clients)
|
||||
|
||||
self.private_tracker_handling = general_config.get("private_tracker_handling", self.private_tracker_handling)
|
||||
self.public_tracker_handling = general_config.get("public_tracker_handling", self.public_tracker_handling)
|
||||
self.obsolete_tag = general_config.get("obsolete_tag", self.obsolete_tag)
|
||||
self.protected_tag = general_config.get("protected_tag", self.protected_tag)
|
||||
|
||||
# Validate tracker handling settings
|
||||
self.private_tracker_handling = self._validate_tracker_handling( self.private_tracker_handling, "private_tracker_handling" )
|
||||
self.public_tracker_handling = self._validate_tracker_handling( self.public_tracker_handling, "public_tracker_handling" )
|
||||
self.obsolete_tag = self._determine_obsolete_tag(self.obsolete_tag)
|
||||
|
||||
|
||||
validate_data_types(self)
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
|
||||
def _validate_tracker_handling(self, value, field_name):
|
||||
"""Validates tracker handling options. Defaults to 'remove' if invalid."""
|
||||
if value not in self.VALID_TRACKER_HANDLING:
|
||||
logger.error(
|
||||
f"Invalid value '{value}' for {field_name}. Defaulting to 'remove'."
|
||||
)
|
||||
return "remove"
|
||||
return value
|
||||
|
||||
def _determine_obsolete_tag(self, obsolete_tag):
|
||||
"""Defaults obsolete tag to "obsolete", only if none is provided and the tag is needed for handling """
|
||||
if obsolete_tag is None and (
|
||||
self.private_tracker_handling == "obsolete_tag"
|
||||
or self.public_tracker_handling == "obsolete_tag"
|
||||
):
|
||||
return "Obsolete"
|
||||
return obsolete_tag
|
||||
|
||||
def config_as_yaml(self):
|
||||
"""Logs all general settings."""
|
||||
# yaml_output = yaml.dump(vars(self), indent=2, default_flow_style=False, sort_keys=False)
|
||||
# logger.info(f"General Settings:\n{yaml_output}")
|
||||
|
||||
return get_config_as_yaml(
|
||||
vars(self),
|
||||
)
|
||||
296
src/settings/_instances.py
Normal file
296
src/settings/_instances.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._constants import (
|
||||
ApiEndpoints,
|
||||
MinVersions,
|
||||
FullQueueParameter,
|
||||
DetailItemKey,
|
||||
DetailItemSearchCommand,
|
||||
)
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
from src.utils.common import make_request, wait_and_exit
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self):
|
||||
self.protected = []
|
||||
self.private = []
|
||||
self.defective = {}
|
||||
self.download_progress = {}
|
||||
self.deleted = []
|
||||
self.extension_checked = []
|
||||
|
||||
async def refresh_private_and_protected(self, settings):
|
||||
protected_downloads = []
|
||||
private_downloads = []
|
||||
|
||||
for qbit in settings.download_clients.qbittorrent:
|
||||
protected, private = await qbit.get_protected_and_private()
|
||||
protected_downloads.extend(protected)
|
||||
private_downloads.extend(private)
|
||||
|
||||
self.protected = protected_downloads
|
||||
self.private = private_downloads
|
||||
|
||||
|
||||
class ArrError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Instances:
|
||||
"""Represents all Arr instances."""
|
||||
|
||||
def __init__(self, config, settings):
|
||||
self.arrs = ArrInstances(config, settings)
|
||||
if not self.arrs:
|
||||
logger.error("No valid Arr instances found in the config.")
|
||||
wait_and_exit()
|
||||
|
||||
def get_by_arr_type(self, arr_type):
|
||||
"""Return a list of arr instances matching the given arr_type."""
|
||||
return [arr for arr in self.arrs if arr.arr_type == arr_type]
|
||||
|
||||
def config_as_yaml(self, hide_internal_attr=True):
|
||||
"""Logs all configured Arr instances while masking sensitive attributes."""
|
||||
internal_attributes={
|
||||
"settings",
|
||||
"api_url",
|
||||
"min_version",
|
||||
"arr_type",
|
||||
"full_queue_parameter",
|
||||
"monitored_item",
|
||||
"detail_item_key",
|
||||
"detail_item_id_key",
|
||||
"detail_item_ids_key",
|
||||
"detail_item_search_command",
|
||||
}
|
||||
|
||||
outputs = []
|
||||
for arr_type in ["sonarr", "radarr", "readarr", "lidarr", "whisparr"]:
|
||||
arrs = self.get_by_arr_type(arr_type)
|
||||
if arrs:
|
||||
output = get_config_as_yaml(
|
||||
{arr_type.capitalize(): arrs},
|
||||
sensitive_attributes={"api_key"},
|
||||
internal_attributes=internal_attributes,
|
||||
hide_internal_attr=hide_internal_attr,
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
return "\n".join(outputs)
|
||||
|
||||
|
||||
|
||||
def check_any_arrs(self):
|
||||
"""Check if there are any ARR instances."""
|
||||
if not self.arrs:
|
||||
logger.warning("No ARR instances found.")
|
||||
wait_and_exit()
|
||||
|
||||
|
||||
class ArrInstances(list):
|
||||
"""Represents all Arr clients (Sonarr, Radarr, etc.)."""
|
||||
|
||||
def __init__(self, config, settings):
|
||||
super().__init__()
|
||||
self._load_clients(config, settings)
|
||||
|
||||
def _load_clients(self, config, settings):
|
||||
instances_config = config.get("instances", {})
|
||||
|
||||
if not isinstance(instances_config, dict):
|
||||
logger.error("Invalid format for 'instances'. Expected a dictionary.")
|
||||
return
|
||||
|
||||
for arr_type, clients in instances_config.items():
|
||||
if not isinstance(clients, list):
|
||||
logger.error(f"Invalid config format for {arr_type}. Expected a list.")
|
||||
continue
|
||||
|
||||
for client_config in clients:
|
||||
try:
|
||||
self.append(
|
||||
ArrInstance(
|
||||
settings,
|
||||
arr_type=arr_type,
|
||||
base_url=client_config["base_url"],
|
||||
api_key=client_config["api_key"],
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Missing required key {e} in {arr_type} client config."
|
||||
)
|
||||
|
||||
|
||||
class ArrInstance:
|
||||
"""Represents an individual Arr instance (Sonarr, Radarr, etc.)."""
|
||||
|
||||
version: str = None
|
||||
name: str = None
|
||||
tracker = Tracker()
|
||||
|
||||
def __init__(self, settings, arr_type: str, base_url: str, api_key: str):
|
||||
if not base_url:
|
||||
logger.error(f"Skipping {arr_type} client entry: 'base_url' is required.")
|
||||
raise ValueError(f"{arr_type} client must have a 'base_url'.")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"Skipping {arr_type} client entry: 'api_key' is required.")
|
||||
raise ValueError(f"{arr_type} client must have an 'api_key'.")
|
||||
|
||||
self.settings = settings
|
||||
self.arr_type = arr_type
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.api_url = self.base_url + getattr(ApiEndpoints, arr_type)
|
||||
self.min_version = getattr(MinVersions, arr_type)
|
||||
self.full_queue_parameter = getattr(FullQueueParameter, arr_type)
|
||||
self.detail_item_key = getattr(DetailItemKey, arr_type)
|
||||
self.detail_item_id_key = self.detail_item_key + "Id"
|
||||
self.detail_item_ids_key = self.detail_item_key + "Ids"
|
||||
self.detail_item_search_command = getattr(DetailItemSearchCommand, arr_type)
|
||||
|
||||
async def _check_ui_language(self):
|
||||
"""Check if the UI language is set to English."""
|
||||
endpoint = self.api_url + "/config/ui"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
response = await make_request("get", endpoint, self.settings, headers=headers)
|
||||
ui_language = (response.json())["uiLanguage"]
|
||||
if ui_language > 1: # Not English
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Decluttarr only works correctly if UI language is set to English (under Settings/UI in {self.name})"
|
||||
)
|
||||
logger.error(
|
||||
"> Details: https://github.com/ManiMatter/decluttarr/issues/132)"
|
||||
)
|
||||
raise ArrError("Not English")
|
||||
|
||||
def _check_min_version(self, status):
|
||||
"""Check if ARR instance meets minimum version requirements."""
|
||||
self.version = status["version"]
|
||||
min_version = getattr(self.settings.min_versions, self.arr_type)
|
||||
|
||||
if min_version:
|
||||
if version.parse(self.version) < version.parse(min_version):
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Please update {self.name} ({self.base_url}) to at least version {min_version}. Current version: {self.version}"
|
||||
)
|
||||
raise ArrError("Not meeting minimum version requirements")
|
||||
|
||||
def _check_arr_type(self, status):
|
||||
"""Check if the ARR instance is of the correct type."""
|
||||
actual_arr_type = status["appName"]
|
||||
if actual_arr_type.lower() != self.arr_type:
|
||||
logger.error("!! %s Error: !!", self.name)
|
||||
logger.error(
|
||||
f"> Your {self.name} ({self.base_url}) points to a {actual_arr_type} instance, rather than {self.arr_type}. Did you specify the wrong IP?"
|
||||
)
|
||||
raise ArrError("Wrong Arr Type")
|
||||
|
||||
async def _check_reachability(self):
|
||||
"""Check if ARR instance is reachable."""
|
||||
try:
|
||||
endpoint = self.api_url + "/system/status"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
response = await make_request(
|
||||
"get", endpoint, self.settings, headers=headers, log_error=False
|
||||
)
|
||||
status = response.json()
|
||||
return status
|
||||
except Exception as e:
|
||||
if isinstance(e, requests.exceptions.HTTPError):
|
||||
response = getattr(e, "response", None)
|
||||
if response is not None and response.status_code == 401:
|
||||
tip = "💡 Tip: Have you configured the API_KEY correctly?"
|
||||
else:
|
||||
tip = f"💡 Tip: HTTP error occurred. Status: {getattr(response, 'status_code', 'unknown')}"
|
||||
elif isinstance(e, requests.exceptions.RequestException):
|
||||
tip = "💡 Tip: Have you configured the URL correctly?"
|
||||
else:
|
||||
tip = ""
|
||||
|
||||
logger.error(f"-- | {self.arr_type} ({self.base_url})\n❗️ {e}\n{tip}\n")
|
||||
raise ArrError(e) from e
|
||||
|
||||
async def setup(self):
|
||||
"""Checks on specific ARR instance"""
|
||||
try:
|
||||
status = await self._check_reachability()
|
||||
self.name = status.get("instanceName", self.arr_type)
|
||||
self._check_arr_type(status)
|
||||
self._check_min_version(status)
|
||||
await self._check_ui_language()
|
||||
|
||||
# Display result
|
||||
logger.info(f"OK | {self.name} ({self.base_url})")
|
||||
logger.debug(f"Current version of {self.name}: {self.version}")
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, ArrError):
|
||||
logger.error(f"Unhandled error: {e}", exc_info=True)
|
||||
wait_and_exit()
|
||||
|
||||
async def get_download_client_implementation(self, download_client_name):
|
||||
"""Fetch download client information and return the implementation value."""
|
||||
endpoint = self.api_url + "/downloadclient"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
# Fetch the download client list from the API
|
||||
response = await make_request("get", endpoint, self.settings, headers=headers)
|
||||
|
||||
# Check if the response is a list
|
||||
download_clients = response.json()
|
||||
|
||||
# Find the client where the name matches client_name
|
||||
for client in download_clients:
|
||||
if client.get("name") == download_client_name:
|
||||
# Return the implementation value if found
|
||||
return client.get("implementation", None)
|
||||
return None
|
||||
|
||||
async def remove_queue_item(self, queue_id, blocklist=False):
|
||||
"""
|
||||
Remove a specific queue item from the queue by its qeue id.
|
||||
Sends a delete request to the API to remove the item.
|
||||
|
||||
Args:
|
||||
queue_id (str): The quueue ID of the queue item to be removed.
|
||||
blocklist (bool): Whether to add the item to the blocklist. Default is False.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if the removal was successful, False otherwise.
|
||||
"""
|
||||
endpoint = f"{self.api_url}/queue/{queue_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
json_payload = {"removeFromClient": True, "blocklist": blocklist}
|
||||
|
||||
# Send the request to remove the download from the queue
|
||||
response = await make_request(
|
||||
"delete", endpoint, self.settings, headers=headers, json=json_payload
|
||||
)
|
||||
|
||||
# If the response is successful, return True, else return False
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def is_monitored(self, detail_id):
|
||||
"""Check if detail item (like a book, series, etc) is monitored."""
|
||||
endpoint = f"{self.api_url}/{self.detail_item_key}/{detail_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
response = await make_request("get", endpoint, self.settings, headers=headers)
|
||||
return response.json()["monitored"]
|
||||
|
||||
async def get_series(self):
|
||||
"""Fetch download client information and return the implementation value."""
|
||||
endpoint = self.api_url + "/series"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
response = await make_request("get", endpoint, self.settings, headers=headers)
|
||||
return response.json()
|
||||
161
src/settings/_jobs.py
Normal file
161
src/settings/_jobs.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from src.utils.log_setup import logger
|
||||
from src.settings._validate_data_types import validate_data_types
|
||||
from src.settings._config_as_yaml import get_config_as_yaml
|
||||
|
||||
|
||||
class JobParams:
|
||||
"""Represents individual job settings, with an 'enabled' flag and optional parameters."""
|
||||
|
||||
enabled: bool = False
|
||||
message_patterns: list
|
||||
max_strikes: int
|
||||
min_speed: int
|
||||
max_concurrent_searches: int
|
||||
min_days_between_searches: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled=None,
|
||||
message_patterns=None,
|
||||
max_strikes=None,
|
||||
min_speed=None,
|
||||
max_concurrent_searches=None,
|
||||
min_days_between_searches=None,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.message_patterns = message_patterns
|
||||
self.max_strikes = max_strikes
|
||||
self.min_speed = min_speed
|
||||
self.max_concurrent_searches = max_concurrent_searches
|
||||
self.min_days_between_searches = min_days_between_searches
|
||||
|
||||
# Remove attributes that are None to keep the object clean
|
||||
self._remove_none_attributes()
|
||||
|
||||
def _remove_none_attributes(self):
|
||||
"""Removes attributes that are None to keep the object clean."""
|
||||
for attr in list(vars(self)):
|
||||
if getattr(self, attr) is None:
|
||||
delattr(self, attr)
|
||||
|
||||
|
||||
class JobDefaults:
|
||||
"""Represents default job settings."""
|
||||
|
||||
max_strikes: int = 3
|
||||
max_concurrent_searches: int = 3
|
||||
min_days_between_searches: int = 7
|
||||
min_speed: int = 100
|
||||
message_patterns = ["*"]
|
||||
|
||||
def __init__(self, config):
|
||||
job_defaults_config = config.get("job_defaults", {})
|
||||
self.max_strikes = job_defaults_config.get("max_strikes", self.max_strikes)
|
||||
self.max_concurrent_searches = job_defaults_config.get(
|
||||
"max_concurrent_searches", self.max_concurrent_searches
|
||||
)
|
||||
self.min_days_between_searches = job_defaults_config.get(
|
||||
"min_days_between_searches", self.min_days_between_searches
|
||||
)
|
||||
validate_data_types(self)
|
||||
|
||||
|
||||
class Jobs:
|
||||
"""Represents all jobs explicitly"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.job_defaults = JobDefaults(config)
|
||||
self._set_job_defaults()
|
||||
self._set_job_configs(config)
|
||||
del self.job_defaults
|
||||
|
||||
def _set_job_defaults(self):
|
||||
self.remove_bad_files = JobParams()
|
||||
self.remove_failed_downloads = JobParams()
|
||||
self.remove_failed_imports = JobParams(
|
||||
message_patterns=self.job_defaults.message_patterns
|
||||
)
|
||||
self.remove_metadata_missing = JobParams(
|
||||
max_strikes=self.job_defaults.max_strikes
|
||||
)
|
||||
self.remove_missing_files = JobParams()
|
||||
self.remove_orphans = JobParams()
|
||||
self.remove_slow = JobParams(
|
||||
max_strikes=self.job_defaults.max_strikes,
|
||||
min_speed=self.job_defaults.min_speed,
|
||||
)
|
||||
self.remove_stalled = JobParams(max_strikes=self.job_defaults.max_strikes)
|
||||
self.remove_unmonitored = JobParams()
|
||||
self.search_unmet_cutoff_content = JobParams(
|
||||
max_concurrent_searches=self.job_defaults.max_concurrent_searches,
|
||||
min_days_between_searches=self.job_defaults.min_days_between_searches,
|
||||
)
|
||||
self.search_missing_content = JobParams(
|
||||
max_concurrent_searches=self.job_defaults.max_concurrent_searches,
|
||||
min_days_between_searches=self.job_defaults.min_days_between_searches,
|
||||
)
|
||||
|
||||
def _set_job_configs(self, config):
|
||||
# Populate jobs from YAML config
|
||||
for job_name in self.__dict__:
|
||||
if job_name != "job_defaults" and job_name in config.get("jobs", {}):
|
||||
self._set_job_settings(job_name, config["jobs"][job_name])
|
||||
|
||||
def _set_job_settings(self, job_name, job_config):
|
||||
"""Sets per-job config settings"""
|
||||
|
||||
job = getattr(self, job_name, None)
|
||||
if (
|
||||
job_config is None
|
||||
): # this triggers only when reading from yaml-file. for docker-compose, empty configs are not loaded, thus the entire job would not be parsed
|
||||
job.enabled = True
|
||||
elif isinstance(job_config, bool):
|
||||
if job:
|
||||
job.enabled = job_config
|
||||
else:
|
||||
job = JobParams(enabled=job_config)
|
||||
elif isinstance(job_config, dict):
|
||||
job_config.setdefault("enabled", True)
|
||||
|
||||
if job:
|
||||
for key, value in job_config.items():
|
||||
setattr(job, key, value)
|
||||
else:
|
||||
job = JobParams(**job_config)
|
||||
|
||||
else:
|
||||
job = JobParams(enabled=False)
|
||||
|
||||
setattr(self, job_name, job)
|
||||
validate_data_types(
|
||||
job, self.job_defaults
|
||||
) # Validates and applies defauls from job_defaults
|
||||
|
||||
def log_status(self):
|
||||
job_strings = []
|
||||
for job_name, job_obj in self.__dict__.items():
|
||||
if isinstance(job_obj, JobParams):
|
||||
job_strings.append(f"{job_name}: {job_obj.enabled}")
|
||||
status = "\n".join(job_strings)
|
||||
logger.info(status)
|
||||
|
||||
def config_as_yaml(self):
|
||||
filtered = {
|
||||
k: v
|
||||
for k, v in vars(self).items()
|
||||
if not hasattr(v, "enabled") or v.enabled
|
||||
}
|
||||
return get_config_as_yaml(
|
||||
filtered,
|
||||
internal_attributes={"enabled"},
|
||||
hide_internal_attr=True,
|
||||
)
|
||||
|
||||
def list_job_status(self):
|
||||
"""Returns a string showing each job and whether it's enabled or not using emojis."""
|
||||
lines = []
|
||||
for name, obj in vars(self).items():
|
||||
if hasattr(obj, "enabled"):
|
||||
status = "🟢" if obj.enabled else "⚪️"
|
||||
lines.append(f"{status} {name}")
|
||||
return "\n".join(lines)
|
||||
138
src/settings/_user_config.py
Normal file
138
src/settings/_user_config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import yaml
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
CONFIG_MAPPING = {
|
||||
"general": [
|
||||
"LOG_LEVEL",
|
||||
"TEST_RUN",
|
||||
"TIMER",
|
||||
"SSL_VERIFICATION",
|
||||
"IGNORED_DOWNLOAD_CLIENTS",
|
||||
],
|
||||
"job_defaults": [
|
||||
"MAX_STRIKES",
|
||||
"MIN_DAYS_BETWEEN_SEARCHES",
|
||||
"MAX_CONCURRENT_SEARCHES",
|
||||
],
|
||||
"jobs": [
|
||||
"REMOVE_BAD_FILES",
|
||||
"REMOVE_FAILED_DOWNLOADS",
|
||||
"REMOVE_FAILED_IMPORTS",
|
||||
"REMOVE_METADATA_MISSING",
|
||||
"REMOVE_MISSING_FILES",
|
||||
"REMOVE_ORPHANS",
|
||||
"REMOVE_SLOW",
|
||||
"REMOVE_STALLED",
|
||||
"REMOVE_UNMONITORED",
|
||||
"SEARCH_UNMET_CUTOFF_CONTENT",
|
||||
"SEARCH_MISSING_CONTENT",
|
||||
],
|
||||
"instances": ["SONARR", "RADARR", "READARR", "LIDARR", "WHISPARR"],
|
||||
"download_clients": ["QBITTORRENT"],
|
||||
}
|
||||
|
||||
|
||||
def get_user_config(settings):
|
||||
"""Checks if data is read from enviornment variables, or from yaml file.
|
||||
|
||||
Reads from environment variables if in docker, unless in docker-compose "USE_CONFIG_YAML" is set to true.
|
||||
Then the config file is read.
|
||||
"""
|
||||
config = {}
|
||||
if _config_file_exists(settings):
|
||||
config = _load_from_yaml_file(settings)
|
||||
settings.envs.use_config_yaml = True
|
||||
elif settings.envs.in_docker:
|
||||
config = _load_from_env()
|
||||
# Ensure all top-level keys exist, even if empty
|
||||
for section in CONFIG_MAPPING:
|
||||
if config.get(section) is None:
|
||||
config[section] = {}
|
||||
return config
|
||||
|
||||
|
||||
def _parse_env_var(key: str) -> dict | list | str | int | None:
|
||||
"""Helper function to parse one setting input key"""
|
||||
raw_value = os.getenv(key)
|
||||
if raw_value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(raw_value)
|
||||
return _lowercase(parsed)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Failed to parse environment variable {key} as YAML:\n{e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _load_section(keys: list[str]) -> dict:
|
||||
"""Helper function to parse one section of expected config"""
|
||||
section_config = {}
|
||||
for key in keys:
|
||||
parsed = _parse_env_var(key)
|
||||
if parsed is not None:
|
||||
section_config[key.lower()] = parsed
|
||||
return section_config
|
||||
|
||||
|
||||
def _load_from_env() -> dict:
|
||||
"""Main function to load settings from env"""
|
||||
config = {}
|
||||
for section, keys in CONFIG_MAPPING.items():
|
||||
config[section] = _load_section(keys)
|
||||
return config
|
||||
|
||||
|
||||
def _load_from_env() -> dict:
|
||||
config = {}
|
||||
|
||||
for section, keys in CONFIG_MAPPING.items():
|
||||
section_config = {}
|
||||
|
||||
for key in keys:
|
||||
raw_value = os.getenv(key)
|
||||
if raw_value is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
parsed_value = yaml.safe_load(raw_value)
|
||||
parsed_value = _lowercase(parsed_value)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(
|
||||
f"Failed to parse environment variable {key} as YAML:\n{e}"
|
||||
)
|
||||
parsed_value = {}
|
||||
section_config[key.lower()] = parsed_value
|
||||
|
||||
config[section] = section_config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _lowercase(data):
|
||||
"""Translates recevied keys (for instance setting-keys of jobs) to lower case"""
|
||||
if isinstance(data, dict):
|
||||
return {str(k).lower(): _lowercase(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_lowercase(item) for item in data]
|
||||
else:
|
||||
# Leave strings and other types unchanged
|
||||
return data
|
||||
|
||||
|
||||
def _config_file_exists(settings):
|
||||
config_path = settings.paths.config_file
|
||||
return os.path.exists(config_path)
|
||||
|
||||
|
||||
def _load_from_yaml_file(settings):
|
||||
"""Reads config from YAML file and returns a dict."""
|
||||
config_path = settings.paths.config_file
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
config = yaml.safe_load(file) or {}
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
logger.error("Error reading YAML file: %s", e)
|
||||
return {}
|
||||
91
src/settings/_validate_data_types.py
Normal file
91
src/settings/_validate_data_types.py
Normal file
@@ -0,0 +1,91 @@
|
||||
|
||||
|
||||
import inspect
|
||||
from src.utils.log_setup import logger
|
||||
|
||||
def validate_data_types(cls, default_cls=None):
|
||||
"""Ensures all attributes match expected types dynamically.
|
||||
If default_cls is provided, the default key is taken from this class rather than the own class
|
||||
If the attribute doesn't exist in `default_cls`, fall back to `cls.__class__`.
|
||||
|
||||
"""
|
||||
annotations = inspect.get_annotations(cls.__class__) # Extract type hints
|
||||
|
||||
for attr, expected_type in annotations.items():
|
||||
if not hasattr(cls, attr): # Skip if attribute is missing
|
||||
continue
|
||||
|
||||
value = getattr(cls, attr)
|
||||
default_source = default_cls if default_cls and hasattr(default_cls, attr) else cls.__class__
|
||||
default_value = getattr(default_source, attr, None)
|
||||
|
||||
if value == default_value:
|
||||
continue
|
||||
|
||||
if not isinstance(value, expected_type):
|
||||
try:
|
||||
if expected_type is bool:
|
||||
value = convert_to_bool(value)
|
||||
elif expected_type is int:
|
||||
value = int(value)
|
||||
elif expected_type is float:
|
||||
value = float(value)
|
||||
elif expected_type is str:
|
||||
value = convert_to_str(value)
|
||||
elif expected_type is list:
|
||||
value = convert_to_list(value)
|
||||
elif expected_type is dict:
|
||||
value = convert_to_dict(value)
|
||||
else:
|
||||
raise TypeError(f"Unhandled type conversion for '{attr}': {expected_type}")
|
||||
except Exception as e:
|
||||
|
||||
logger.error(
|
||||
f"❗️ Invalid type for '{attr}': Expected {expected_type.__name__}, but got {type(value).__name__}. "
|
||||
f"Error: {e}. Using default value: {default_value}"
|
||||
)
|
||||
value = default_value
|
||||
|
||||
setattr(cls, attr, value)
|
||||
|
||||
|
||||
|
||||
# --- Helper Functions ---
|
||||
def convert_to_bool(raw_value):
|
||||
"""Converts strings like 'yes', 'no', 'true', 'false' into boolean values."""
|
||||
if isinstance(raw_value, bool):
|
||||
return raw_value
|
||||
|
||||
true_values = {"1", "yes", "true", "on"}
|
||||
false_values = {"0", "no", "false", "off"}
|
||||
|
||||
if isinstance(raw_value, str):
|
||||
raw_value = raw_value.strip().lower()
|
||||
|
||||
if raw_value in true_values:
|
||||
return True
|
||||
elif raw_value in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean value: '{raw_value}'")
|
||||
|
||||
|
||||
def convert_to_str(raw_value):
|
||||
"""Ensures a string and trims whitespace."""
|
||||
if isinstance(raw_value, str):
|
||||
return raw_value.strip()
|
||||
return str(raw_value).strip()
|
||||
|
||||
|
||||
def convert_to_list(raw_value):
|
||||
"""Ensures a value is a list."""
|
||||
if isinstance(raw_value, list):
|
||||
return [convert_to_str(item) for item in raw_value]
|
||||
return [convert_to_str(raw_value)] # Wrap single values in a list
|
||||
|
||||
|
||||
def convert_to_dict(raw_value):
|
||||
"""Ensures a value is a dictionary."""
|
||||
if isinstance(raw_value, dict):
|
||||
return {convert_to_str(k): v for k, v in raw_value.items()}
|
||||
raise TypeError(f"Expected dict but got {type(raw_value).__name__}")
|
||||
60
src/settings/settings.py
Normal file
60
src/settings/settings.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from src.utils.log_setup import configure_logging
|
||||
from src.settings._constants import Envs, MinVersions, Paths
|
||||
# from src.settings._migrate_legacy import migrate_legacy
|
||||
from src.settings._general import General
|
||||
from src.settings._jobs import Jobs
|
||||
from src.settings._download_clients import DownloadClients
|
||||
from src.settings._instances import Instances
|
||||
from src.settings._user_config import get_user_config
|
||||
|
||||
class Settings:
|
||||
|
||||
min_versions = MinVersions()
|
||||
paths = Paths()
|
||||
|
||||
def __init__(self):
|
||||
self.envs = Envs()
|
||||
config = get_user_config(self)
|
||||
self.general = General(config)
|
||||
self.jobs = Jobs(config)
|
||||
self.download_clients = DownloadClients(config, self)
|
||||
self.instances = Instances(config, self)
|
||||
configure_logging(self)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
sections = [
|
||||
("ENVIRONMENT SETTINGS", "envs"),
|
||||
("GENERAL SETTINGS", "general"),
|
||||
("ACTIVE JOBS", "jobs"),
|
||||
("JOB SETTINGS", "jobs"),
|
||||
("INSTANCE SETTINGS", "instances"),
|
||||
("DOWNLOAD CLIENT SETTINGS", "download_clients"),
|
||||
]
|
||||
messages = []
|
||||
messages.append("🛠️ Decluttarr - Settings 🛠️")
|
||||
messages.append("-"*80)
|
||||
messages.append("")
|
||||
for title, attr_name in sections:
|
||||
section = getattr(self, attr_name, None)
|
||||
section_content = section.config_as_yaml()
|
||||
if title == "ACTIVE JOBS":
|
||||
messages.append(self._format_section_title(title))
|
||||
messages.append(self.jobs.list_job_status() + "\n")
|
||||
elif section_content != "{}\n":
|
||||
messages.append(self._format_section_title(title))
|
||||
messages.append(section_content + "\n")
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
def _format_section_title(self, name, border_length=50, symbol="="):
|
||||
"""Format section title with centered name and hash borders."""
|
||||
padding = max(border_length - len(name) - 2, 0) # 4 for spaces
|
||||
left_hashes = right_hashes = padding // 2
|
||||
if padding % 2 != 0:
|
||||
right_hashes += 1
|
||||
return f"{symbol * left_hashes} {name} {symbol * right_hashes}\n"
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user