Merge pull request #148 from malmeloo/feat/better-serialization

Make more objects `Serializable`
This commit is contained in:
Mike Almeloo
2025-08-03 21:46:05 +02:00
committed by GitHub
12 changed files with 672 additions and 157 deletions

View File

@@ -6,25 +6,40 @@ Accessories could be anything ranging from AirTags to iPhones.
from __future__ import annotations
import json
import logging
import plistlib
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import IO, TYPE_CHECKING, overload
from typing import IO, TYPE_CHECKING, Literal, TypedDict, overload
from typing_extensions import override
from findmy.util.abc import Serializable
from findmy.util.files import read_data_json, save_and_return_json
from .keys import KeyGenerator, KeyPair, KeyType
from .util import crypto
if TYPE_CHECKING:
from collections.abc import Generator, Mapping
from collections.abc import Generator
logger = logging.getLogger(__name__)
class FindMyAccessoryMapping(TypedDict):
"""JSON mapping representing state of a FindMyAccessory instance."""
type: Literal["accessory"]
master_key: str
skn: str
sks: str
paired_at: str
name: str | None
model: str | None
identifier: str | None
class RollingKeyPairSource(ABC):
"""A class that generates rolling `KeyPair`s."""
@@ -67,7 +82,7 @@ class RollingKeyPairSource(ABC):
return keys
class FindMyAccessory(RollingKeyPairSource):
class FindMyAccessory(RollingKeyPairSource, Serializable[FindMyAccessoryMapping]):
"""A findable Find My-accessory using official key rollover."""
def __init__( # noqa: PLR0913
@@ -242,9 +257,10 @@ class FindMyAccessory(RollingKeyPairSource):
identifier=identifier,
)
def to_json(self, path: str | Path | None = None) -> dict[str, str | int | None]:
"""Convert the accessory to a JSON-serializable dictionary."""
d = {
@override
def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping:
res: FindMyAccessoryMapping = {
"type": "accessory",
"master_key": self._primary_gen.master_key.hex(),
"skn": self.skn.hex(),
"sks": self.sks.hex(),
@@ -253,23 +269,32 @@ class FindMyAccessory(RollingKeyPairSource):
"model": self.model,
"identifier": self.identifier,
}
if path is not None:
Path(path).write_text(json.dumps(d, indent=4))
return d
return save_and_return_json(res, path)
@classmethod
def from_json(cls, json_: str | Path | Mapping, /) -> FindMyAccessory:
"""Create a FindMyAccessory from a JSON file."""
data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_
return cls(
master_key=bytes.fromhex(data["master_key"]),
skn=bytes.fromhex(data["skn"]),
sks=bytes.fromhex(data["sks"]),
paired_at=datetime.fromisoformat(data["paired_at"]),
name=data["name"],
model=data["model"],
identifier=data["identifier"],
)
@override
def from_json(
cls,
val: str | Path | FindMyAccessoryMapping,
/,
) -> FindMyAccessory:
val = read_data_json(val)
assert val["type"] == "accessory"
try:
return cls(
master_key=bytes.fromhex(val["master_key"]),
skn=bytes.fromhex(val["skn"]),
sks=bytes.fromhex(val["sks"]),
paired_at=datetime.fromisoformat(val["paired_at"]),
name=val["name"],
model=val["model"],
identifier=val["identifier"],
)
except KeyError as e:
msg = f"Failed to restore account data: {e}"
raise ValueError(msg) from None
class AccessoryKeyGenerator(KeyGenerator[KeyPair]):

View File

@@ -7,15 +7,19 @@ import hashlib
import secrets
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, overload
from cryptography.hazmat.primitives.asymmetric import ec
from typing_extensions import override
from findmy.util.abc import Serializable
from findmy.util.files import read_data_json, save_and_return_json
from .util import crypto, parsers
if TYPE_CHECKING:
from collections.abc import Generator
from pathlib import Path
class KeyType(Enum):
@@ -26,6 +30,16 @@ class KeyType(Enum):
SECONDARY = 2
class KeyPairMapping(TypedDict):
"""JSON mapping representing a KeyPair."""
type: Literal["keypair"]
private_key: str
key_type: int
name: str | None
class HasHashedPublicKey(ABC):
"""
ABC for anything that has a public, hashed FindMy-key.
@@ -113,7 +127,7 @@ class HasPublicKey(HasHashedPublicKey, ABC):
)
class KeyPair(HasPublicKey):
class KeyPair(HasPublicKey, Serializable[KeyPairMapping]):
"""A private-public keypair for a trackable FindMy accessory."""
def __init__(
@@ -182,6 +196,34 @@ class KeyPair(HasPublicKey):
key_bytes = self._priv_key.public_key().public_numbers().x
return int.to_bytes(key_bytes, 28, "big")
@override
def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping:
return save_and_return_json(
{
"type": "keypair",
"private_key": base64.b64encode(self.private_key_bytes).decode("ascii"),
"key_type": self._key_type.value,
"name": self.name,
},
dst,
)
@classmethod
@override
def from_json(cls, val: str | Path | KeyPairMapping, /) -> KeyPair:
val = read_data_json(val)
assert val["type"] == "keypair"
try:
return cls(
private_key=base64.b64decode(val["private_key"]),
key_type=KeyType(val["key_type"]),
name=val["name"],
)
except KeyError as e:
msg = f"Failed to restore KeyPair data: {e}"
raise ValueError(msg) from None
def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes:
"""Do a Diffie-Hellman key exchange using another EC public key."""
return self._priv_key.exchange(ec.ECDH(), other_pub_key)

View File

@@ -11,11 +11,11 @@ import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from functools import wraps
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypedDict,
TypeVar,
cast,
@@ -32,8 +32,10 @@ from findmy.errors import (
UnauthorizedError,
UnhandledProtocolError,
)
from findmy.reports.anisette import AnisetteMapping, get_provider_from_mapping
from findmy.util import crypto
from findmy.util.abc import Closable
from findmy.util.abc import Closable, Serializable
from findmy.util.files import read_data_json, save_and_return_json
from findmy.util.http import HttpResponse, HttpSession, decode_plist
from .reports import LocationReport, LocationReportsFetcher
@@ -49,7 +51,8 @@ from .twofactor import (
)
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from pathlib import Path
from findmy.accessory import RollingKeyPairSource
from findmy.keys import HasHashedPublicKey
@@ -70,6 +73,33 @@ class _AccountInfo(TypedDict):
trusted_device_2fa: bool
class _AccountStateMappingIds(TypedDict):
uid: str
devid: str
class _AccountStateMappingAccount(TypedDict):
username: str | None
password: str | None
info: _AccountInfo | None
class _AccountStateMappingLoginState(TypedDict):
state: int
data: dict # TODO: make typed # noqa: TD002, TD003
class AccountStateMapping(TypedDict):
"""JSON mapping representing state of an Apple account instance."""
type: Literal["account"]
ids: _AccountStateMappingIds
account: _AccountStateMappingAccount
login: _AccountStateMappingLoginState
anisette: AnisetteMapping
_P = ParamSpec("_P")
_R = TypeVar("_R")
_A = TypeVar("_A", bound="BaseAppleAccount")
@@ -111,7 +141,7 @@ def _extract_phone_numbers(html: str) -> list[dict]:
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
class BaseAppleAccount(Closable, ABC):
class BaseAppleAccount(Closable, Serializable[AccountStateMapping], ABC):
"""Base class for an Apple account."""
@property
@@ -151,33 +181,6 @@ class BaseAppleAccount(Closable, ABC):
"""
raise NotImplementedError
@abstractmethod
def to_json(self, path: str | Path | None = None) -> dict:
"""
Export the current state of the account as a JSON-serializable dictionary.
If `path` is provided, the output will also be written to that file.
The output of this method is guaranteed to be JSON-serializable, and passing
the return value of this function as an argument to `BaseAppleAccount.from_json`
will always result in an exact copy of the internal state as it was when exported.
This method is especially useful to avoid having to keep going through the login flow.
"""
raise NotImplementedError
@abstractmethod
def from_json(self, json_: str | Path | Mapping, /) -> None:
"""
Restore the state from a previous `BaseAppleAccount.to_json` export.
If given a str or Path, it must point to a json file from `BaseAppleAccount.to_json`.
Otherwise it should be the Mapping itself.
See `BaseAppleAccount.to_json` for more information.
"""
raise NotImplementedError
@abstractmethod
def login(self, username: str, password: str) -> MaybeCoro[LoginState]:
"""Log in to an Apple account using a username and password."""
@@ -347,31 +350,33 @@ class AsyncAppleAccount(BaseAppleAccount):
def __init__(
self,
*,
anisette: BaseAnisetteProvider,
user_id: str | None = None,
device_id: str | None = None,
*,
state_info: AccountStateMapping | None = None,
) -> None:
"""
Initialize the apple account.
:param anisette: An instance of `AsyncAnisetteProvider`.
:param user_id: An optional user ID to use. Will be auto-generated if missing.
:param device_id: An optional device ID to use. Will be auto-generated if missing.
"""
super().__init__()
self._anisette: BaseAnisetteProvider = anisette
self._uid: str = user_id or str(uuid.uuid4())
self._devid: str = device_id or str(uuid.uuid4())
self._uid: str = state_info["ids"]["uid"] if state_info else str(uuid.uuid4())
self._devid: str = state_info["ids"]["devid"] if state_info else str(uuid.uuid4())
self._username: str | None = None
self._password: str | None = None
# TODO: combine, user/pass should be "all or nothing" # noqa: TD002, TD003
self._username: str | None = state_info["account"]["username"] if state_info else None
self._password: str | None = state_info["account"]["password"] if state_info else None
self._login_state: LoginState = LoginState.LOGGED_OUT
self._login_state_data: dict = {}
self._login_state: LoginState = (
LoginState(state_info["login"]["state"]) if state_info else LoginState.LOGGED_OUT
)
self._login_state_data: dict = state_info["login"]["data"] if state_info else {}
self._account_info: _AccountInfo | None = None
self._account_info: _AccountInfo | None = (
state_info["account"]["info"] if state_info else None
)
self._http: HttpSession = HttpSession()
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
@@ -433,36 +438,39 @@ class AsyncAppleAccount(BaseAppleAccount):
return self._account_info["last_name"] if self._account_info else None
@override
def to_json(self, path: str | Path | None = None) -> dict:
result = {
def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping:
res: AccountStateMapping = {
"type": "account",
"ids": {"uid": self._uid, "devid": self._devid},
"account": {
"username": self._username,
"password": self._password,
"info": self._account_info,
},
"login_state": {
"login": {
"state": self._login_state.value,
"data": self._login_state_data,
},
"anisette": self._anisette.to_json(),
}
if path is not None:
Path(path).write_text(json.dumps(result, indent=4))
return result
return save_and_return_json(res, path)
@classmethod
@override
def from_json(self, json_: str | Path | Mapping, /) -> None:
data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_
def from_json(
cls,
val: str | Path | AccountStateMapping,
/,
*,
anisette_libs_path: str | Path | None = None,
) -> AsyncAppleAccount:
val = read_data_json(val)
assert val["type"] == "account"
try:
self._uid = data["ids"]["uid"]
self._devid = data["ids"]["devid"]
self._username = data["account"]["username"]
self._password = data["account"]["password"]
self._account_info = data["account"]["info"]
self._login_state = LoginState(data["login_state"]["state"])
self._login_state_data = data["login_state"]["data"]
ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path)
return cls(ani_provider, state_info=val)
except KeyError as e:
msg = f"Failed to restore account data: {e}"
raise ValueError(msg) from None
@@ -976,13 +984,12 @@ class AppleAccount(BaseAppleAccount):
def __init__(
self,
*,
anisette: BaseAnisetteProvider,
user_id: str | None = None,
device_id: str | None = None,
*,
state_info: AccountStateMapping | None = None,
) -> None:
"""See `AsyncAppleAccount.__init__`."""
self._asyncacc = AsyncAppleAccount(anisette=anisette, user_id=user_id, device_id=device_id)
self._asyncacc = AsyncAppleAccount(anisette=anisette, state_info=state_info)
try:
self._evt_loop = asyncio.get_running_loop()
@@ -1022,12 +1029,25 @@ class AppleAccount(BaseAppleAccount):
return self._asyncacc.last_name
@override
def to_json(self, path: str | Path | None = None) -> dict:
return self._asyncacc.to_json(path)
def to_json(self, dst: str | Path | None = None, /) -> AccountStateMapping:
return self._asyncacc.to_json(dst)
@classmethod
@override
def from_json(self, json_: str | Path | Mapping, /) -> None:
return self._asyncacc.from_json(json_)
def from_json(
cls,
val: str | Path | AccountStateMapping,
/,
*,
anisette_libs_path: str | Path | None = None,
) -> AppleAccount:
val = read_data_json(val)
try:
ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path)
return cls(ani_provider, state_info=val)
except KeyError as e:
msg = f"Failed to restore account data: {e}"
raise ValueError(msg) from None
@override
def login(self, username: str, password: str) -> LoginState:

View File

@@ -10,17 +10,49 @@ from abc import ABC, abstractmethod
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path
from typing import BinaryIO
from typing import BinaryIO, Literal, TypedDict, Union
from anisette import Anisette, AnisetteHeaders
from typing_extensions import override
from findmy.util.abc import Closable, Serializable
from findmy.util.files import read_data_json, save_and_return_json
from findmy.util.http import HttpSession
logger = logging.getLogger(__name__)
class RemoteAnisetteMapping(TypedDict):
"""JSON mapping representing state of a remote Anisette provider."""
type: Literal["aniRemote"]
url: str
class LocalAnisetteMapping(TypedDict):
"""JSON mapping representing state of a local Anisette provider."""
type: Literal["aniLocal"]
prov_data: str
AnisetteMapping = Union[RemoteAnisetteMapping, LocalAnisetteMapping]
def get_provider_from_mapping(
mapping: AnisetteMapping,
*,
libs_path: str | Path | None = None,
) -> RemoteAnisetteProvider | LocalAnisetteProvider:
"""Get the correct Anisette provider instance from saved JSON data."""
if mapping["type"] == "aniRemote":
return RemoteAnisetteProvider.from_json(mapping)
if mapping["type"] == "aniLocal":
return LocalAnisetteProvider.from_json(mapping, libs_path=libs_path)
msg = f"Unknown anisette type: {mapping['type']}"
raise ValueError(msg)
class BaseAnisetteProvider(Closable, Serializable, ABC):
"""
Abstract base class for Anisette providers.
@@ -156,7 +188,7 @@ class BaseAnisetteProvider(Closable, Serializable, ABC):
return cpd
class RemoteAnisetteProvider(BaseAnisetteProvider):
class RemoteAnisetteProvider(BaseAnisetteProvider, Serializable[RemoteAnisetteMapping]):
"""Anisette provider. Fetches headers from a remote Anisette server."""
_ANISETTE_DATA_VALID_FOR = 30
@@ -174,20 +206,25 @@ class RemoteAnisetteProvider(BaseAnisetteProvider):
self._closed = False
@override
def serialize(self) -> dict:
def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping:
"""See `BaseAnisetteProvider.serialize`."""
return {
"type": "aniRemote",
"url": self._server_url,
}
return save_and_return_json(
{
"type": "aniRemote",
"url": self._server_url,
},
dst,
)
@classmethod
@override
def deserialize(cls, data: dict) -> RemoteAnisetteProvider:
def from_json(cls, val: str | Path | RemoteAnisetteMapping) -> RemoteAnisetteProvider:
"""See `BaseAnisetteProvider.deserialize`."""
assert data["type"] == "aniRemote"
val = read_data_json(val)
server_url = data["url"]
assert val["type"] == "aniRemote"
server_url = val["url"]
return cls(server_url)
@@ -245,7 +282,7 @@ class RemoteAnisetteProvider(BaseAnisetteProvider):
logger.warning("Error closing anisette HTTP session: %s", e)
class LocalAnisetteProvider(BaseAnisetteProvider):
class LocalAnisetteProvider(BaseAnisetteProvider, Serializable[LocalAnisetteMapping]):
"""Anisette provider. Generates headers without a remote server using the `anisette` library."""
def __init__(
@@ -265,6 +302,7 @@ class LocalAnisetteProvider(BaseAnisetteProvider):
"The Anisette engine will download libraries required for operation, "
"this may take a few seconds...",
)
if libs_path is None:
logger.info(
"To speed up future local Anisette initializations, "
"provide a filesystem path to load the libraries from.",
@@ -289,24 +327,34 @@ class LocalAnisetteProvider(BaseAnisetteProvider):
)
@override
def serialize(self) -> dict:
def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping:
"""See `BaseAnisetteProvider.serialize`."""
with BytesIO() as buf:
self._ani.save_provisioning(buf)
prov_data = base64.b64encode(buf.getvalue()).decode("utf-8")
return {
"type": "aniLocal",
"prov_data": prov_data,
}
return save_and_return_json(
{
"type": "aniLocal",
"prov_data": prov_data,
},
dst,
)
@classmethod
@override
def deserialize(cls, data: dict, libs_path: str | Path | None = None) -> LocalAnisetteProvider:
def from_json(
cls,
val: str | Path | LocalAnisetteMapping,
*,
libs_path: str | Path | None = None,
) -> LocalAnisetteProvider:
"""See `BaseAnisetteProvider.deserialize`."""
assert data["type"] == "aniLocal"
val = read_data_json(val)
state_blob = BytesIO(base64.b64decode(data["prov_data"]))
assert val["type"] == "aniLocal"
state_blob = BytesIO(base64.b64decode(val["prov_data"]))
return cls(state_blob=state_blob, libs_path=libs_path)

View File

@@ -8,7 +8,7 @@ import logging
import struct
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, cast, overload
from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
@@ -16,17 +16,42 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from typing_extensions import override
from findmy.accessory import RollingKeyPairSource
from findmy.keys import HasHashedPublicKey, KeyPair
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping
from findmy.util.abc import Serializable
from findmy.util.files import read_data_json, save_and_return_json
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from .account import AsyncAppleAccount
logger = logging.getLogger(__name__)
class LocationReport(HasHashedPublicKey):
class LocationReportEncryptedMapping(TypedDict):
"""JSON mapping representing an encrypted location report."""
type: Literal["locReportEncrypted"]
payload: str
hashed_adv_key: str
class LocationReportDecryptedMapping(TypedDict):
"""JSON mapping representing a decrypted location report."""
type: Literal["locReportDecrypted"]
payload: str
hashed_adv_key: str
key: KeyPairMapping
LocationReportMapping = Union[LocationReportEncryptedMapping, LocationReportDecryptedMapping]
class LocationReport(HasHashedPublicKey, Serializable[LocationReportMapping]):
"""Location report corresponding to a certain `HasHashedPublicKey`."""
def __init__(
@@ -66,9 +91,13 @@ class LocationReport(HasHashedPublicKey):
"""Whether the report is currently decrypted."""
return self._decrypted_data is not None
def can_decrypt(self, key: KeyPair, /) -> bool:
"""Whether the report can be decrypted using the given key."""
return key.hashed_adv_key_bytes == self._hashed_adv_key
def decrypt(self, key: KeyPair) -> None:
"""Decrypt the report using its corresponding `KeyPair`."""
if key.hashed_adv_key_bytes != self._hashed_adv_key:
if not self.can_decrypt(key):
msg = "Cannot decrypt with this key!"
raise ValueError(msg)
@@ -163,6 +192,86 @@ class LocationReport(HasHashedPublicKey):
status_bytes = self._decrypted_data[1][9:10]
return int.from_bytes(status_bytes, "big")
@overload
def to_json(
self,
dst: str | Path | None = None,
/,
*,
include_key: Literal[True],
) -> LocationReportEncryptedMapping:
pass
@overload
def to_json(
self,
dst: str | Path | None = None,
/,
*,
include_key: Literal[False],
) -> LocationReportDecryptedMapping:
pass
@overload
def to_json(
self,
dst: str | Path | None = None,
/,
*,
include_key: None = None,
) -> LocationReportMapping:
pass
@override
def to_json(
self,
dst: str | Path | None = None,
/,
*,
include_key: bool | None = None,
) -> LocationReportMapping:
if include_key is None:
include_key = self.is_decrypted
if include_key:
return save_and_return_json(
{
"type": "locReportDecrypted",
"payload": base64.b64encode(self._payload).decode("utf-8"),
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
"key": self.key.to_json(),
},
dst,
)
return save_and_return_json(
{
"type": "locReportEncrypted",
"payload": base64.b64encode(self._payload).decode("utf-8"),
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
},
dst,
)
@classmethod
@override
def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport:
val = read_data_json(val)
assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted"
try:
report = cls(
payload=base64.b64decode(val["payload"]),
hashed_adv_key=base64.b64decode(val["hashed_adv_key"]),
)
if val["type"] == "locReportDecrypted":
key = KeyPair.from_json(val["key"])
report.decrypt(key)
except KeyError as e:
msg = f"Failed to restore account data: {e}"
raise ValueError(msg) from None
else:
return report
@override
def __eq__(self, other: object) -> bool:
"""

View File

@@ -5,8 +5,13 @@ from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Generic, Self, TypeVar
logging.getLogger(__name__)
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
class Closable(ABC):
@@ -38,16 +43,37 @@ class Closable(ABC):
pass
class Serializable(ABC):
_T = TypeVar("_T", bound=Mapping)
class Serializable(Generic[_T], ABC):
"""ABC for serializable classes."""
@abstractmethod
def serialize(self) -> dict:
"""Serialize the object to a JSON-serializable dictionary."""
def to_json(self, dst: str | Path | None = None, /) -> _T:
"""
Export the current state of the object as a JSON-serializable dictionary.
If an argument is provided, the output will also be written to that file.
The output of this method is guaranteed to be JSON-serializable, and passing
the return value of this function as an argument to `Serializable.from_json`
will always result in an exact copy of the internal state as it was when exported.
You are encouraged to save and load object states to and from disk whenever possible,
to prevent unnecessary API calls or otherwise unexpected behavior.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def deserialize(cls, data: dict) -> Serializable:
"""Deserialize the object from a JSON-serializable dictionary."""
def from_json(cls, val: str | Path | _T, /) -> Self:
"""
Restore state from a previous `Closable.to_json` export.
If given a str or Path, it must point to a json file from `Serializable.to_json`.
Otherwise, it should be the Mapping itself.
See `Serializable.to_json` for more information.
"""
raise NotImplementedError

34
findmy/util/files.py Normal file
View File

@@ -0,0 +1,34 @@
"""Utilities to simplify reading and writing data from and to files."""
from __future__ import annotations
import json
from collections.abc import Mapping
from pathlib import Path
from typing import TypeVar, cast
T = TypeVar("T", bound=Mapping)
def save_and_return_json(data: T, dst: str | Path | None) -> T:
"""Save and return a JSON-serializable data structure."""
if dst is None:
return data
if isinstance(dst, str):
dst = Path(dst)
dst.write_text(json.dumps(data, indent=4))
return data
def read_data_json(val: str | Path | T) -> T:
"""Read JSON data from a file if a path is passed, or return the argument itself."""
if isinstance(val, str):
val = Path(val)
if isinstance(val, Path):
val = cast("T", json.loads(val.read_text()))
return val

144
findmy/util/session.py Normal file
View File

@@ -0,0 +1,144 @@
"""Logic related to serializable classes."""
from __future__ import annotations
import random
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, Union
from findmy.util.abc import Closable, Serializable
if TYPE_CHECKING:
from pathlib import Path
from types import TracebackType
_S = TypeVar("_S", bound=Serializable)
_SC = TypeVar("_SC", bound=Union[Serializable, Closable])
class _BaseSessionManager(Generic[_SC]):
"""Base class for session managers."""
def __init__(self) -> None:
self._sessions: dict[_SC, str | Path | None] = {}
def _add(self, obj: _SC, path: str | Path | None) -> None:
self._sessions[obj] = path
def remove(self, obj: _SC) -> None:
self._sessions.pop(obj, None)
def save(self) -> None:
for obj, path in self._sessions.items():
if isinstance(obj, Serializable):
obj.to_json(path)
async def close(self) -> None:
for obj in self._sessions:
if isinstance(obj, Closable):
await obj.close()
async def save_and_close(self) -> None:
for obj, path in self._sessions.items():
if isinstance(obj, Serializable):
obj.to_json(path)
if isinstance(obj, Closable):
await obj.close()
def get_random(self) -> _SC:
if not self._sessions:
msg = "No objects in the session manager."
raise ValueError(msg)
return random.choice(list(self._sessions.keys())) # noqa: S311
def __len__(self) -> int:
return len(self._sessions)
def __enter__(self) -> Self:
return self
def __exit__(
self,
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> None:
self.save()
class MixedSessionManager(_BaseSessionManager[Union[Serializable, Closable]]):
"""Allows any Serializable or Closable object."""
def new(
self,
c_type: type[_SC],
path: str | Path | None = None,
/,
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> _SC:
"""Add an object to the manager by instantiating it using its constructor."""
obj = c_type(*args, **kwargs)
if isinstance(obj, Serializable) and path is not None:
obj.to_json(path)
self._add(obj, path)
return obj
def add_from_json(
self,
c_type: type[_S],
path: str | Path,
/,
**kwargs: Any, # noqa: ANN401
) -> _S:
"""Add an object to the manager by deserializing it from its JSON representation."""
obj = c_type.from_json(path, **kwargs)
self._add(obj, path)
return obj
def add(self, obj: Serializable | Closable, path: str | Path | None = None, /) -> None:
"""Add an object to the session manager."""
self._add(obj, path)
class UniformSessionManager(Generic[_SC], _BaseSessionManager[_SC]):
"""Only allows a single type of Serializable object."""
def __init__(self, obj_type: type[_SC]) -> None:
"""Create a new session manager."""
super().__init__()
self._obj_type = obj_type
def new(
self,
path: str | Path | None = None,
/,
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> _SC:
"""Add an object to the manager by instantiating it using its constructor."""
obj = self._obj_type(*args, **kwargs)
if isinstance(obj, Serializable) and path is not None:
obj.to_json(path)
self._add(obj, path)
return obj
def add_from_json(
self,
path: str | Path,
/,
**kwargs: Any, # noqa: ANN401
) -> _SC:
"""Add an object to the manager by deserializing it from its JSON representation."""
if not issubclass(self._obj_type, Serializable):
msg = "Can only add objects of type Serializable."
raise TypeError(msg)
obj = self._obj_type.from_json(path, **kwargs)
self._add(obj, path)
return obj
def add(self, obj: _SC, path: str | Path | None = None, /) -> None:
"""Add an object to the session manager."""
if not isinstance(obj, self._obj_type):
msg = f"Object must be of type {self._obj_type.__name__}"
raise TypeError(msg)
self._add(obj, path)