mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 21:53:57 +02:00
Merge pull request #148 from malmeloo/feat/better-serialization
Make more objects `Serializable`
This commit is contained in:
@@ -6,25 +6,40 @@ Accessories could be anything ranging from AirTags to iPhones.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import plistlib
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import IO, TYPE_CHECKING, overload
|
||||
from typing import IO, TYPE_CHECKING, Literal, TypedDict, overload
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.util.abc import Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
|
||||
from .keys import KeyGenerator, KeyPair, KeyType
|
||||
from .util import crypto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindMyAccessoryMapping(TypedDict):
|
||||
"""JSON mapping representing state of a FindMyAccessory instance."""
|
||||
|
||||
type: Literal["accessory"]
|
||||
master_key: str
|
||||
skn: str
|
||||
sks: str
|
||||
paired_at: str
|
||||
name: str | None
|
||||
model: str | None
|
||||
identifier: str | None
|
||||
|
||||
|
||||
class RollingKeyPairSource(ABC):
|
||||
"""A class that generates rolling `KeyPair`s."""
|
||||
|
||||
@@ -67,7 +82,7 @@ class RollingKeyPairSource(ABC):
|
||||
return keys
|
||||
|
||||
|
||||
class FindMyAccessory(RollingKeyPairSource):
|
||||
class FindMyAccessory(RollingKeyPairSource, Serializable[FindMyAccessoryMapping]):
|
||||
"""A findable Find My-accessory using official key rollover."""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
@@ -242,9 +257,10 @@ class FindMyAccessory(RollingKeyPairSource):
|
||||
identifier=identifier,
|
||||
)
|
||||
|
||||
def to_json(self, path: str | Path | None = None) -> dict[str, str | int | None]:
|
||||
"""Convert the accessory to a JSON-serializable dictionary."""
|
||||
d = {
|
||||
@override
|
||||
def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping:
|
||||
res: FindMyAccessoryMapping = {
|
||||
"type": "accessory",
|
||||
"master_key": self._primary_gen.master_key.hex(),
|
||||
"skn": self.skn.hex(),
|
||||
"sks": self.sks.hex(),
|
||||
@@ -253,23 +269,32 @@ class FindMyAccessory(RollingKeyPairSource):
|
||||
"model": self.model,
|
||||
"identifier": self.identifier,
|
||||
}
|
||||
if path is not None:
|
||||
Path(path).write_text(json.dumps(d, indent=4))
|
||||
return d
|
||||
|
||||
return save_and_return_json(res, path)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_: str | Path | Mapping, /) -> FindMyAccessory:
|
||||
"""Create a FindMyAccessory from a JSON file."""
|
||||
data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_
|
||||
return cls(
|
||||
master_key=bytes.fromhex(data["master_key"]),
|
||||
skn=bytes.fromhex(data["skn"]),
|
||||
sks=bytes.fromhex(data["sks"]),
|
||||
paired_at=datetime.fromisoformat(data["paired_at"]),
|
||||
name=data["name"],
|
||||
model=data["model"],
|
||||
identifier=data["identifier"],
|
||||
)
|
||||
@override
|
||||
def from_json(
|
||||
cls,
|
||||
val: str | Path | FindMyAccessoryMapping,
|
||||
/,
|
||||
) -> FindMyAccessory:
|
||||
val = read_data_json(val)
|
||||
assert val["type"] == "accessory"
|
||||
|
||||
try:
|
||||
return cls(
|
||||
master_key=bytes.fromhex(val["master_key"]),
|
||||
skn=bytes.fromhex(val["skn"]),
|
||||
sks=bytes.fromhex(val["sks"]),
|
||||
paired_at=datetime.fromisoformat(val["paired_at"]),
|
||||
name=val["name"],
|
||||
model=val["model"],
|
||||
identifier=val["identifier"],
|
||||
)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
|
||||
|
||||
class AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
|
||||
@@ -7,15 +7,19 @@ import hashlib
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, overload
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.util.abc import Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
|
||||
from .util import crypto, parsers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class KeyType(Enum):
|
||||
@@ -26,6 +30,16 @@ class KeyType(Enum):
|
||||
SECONDARY = 2
|
||||
|
||||
|
||||
class KeyPairMapping(TypedDict):
|
||||
"""JSON mapping representing a KeyPair."""
|
||||
|
||||
type: Literal["keypair"]
|
||||
|
||||
private_key: str
|
||||
key_type: int
|
||||
name: str | None
|
||||
|
||||
|
||||
class HasHashedPublicKey(ABC):
|
||||
"""
|
||||
ABC for anything that has a public, hashed FindMy-key.
|
||||
@@ -113,7 +127,7 @@ class HasPublicKey(HasHashedPublicKey, ABC):
|
||||
)
|
||||
|
||||
|
||||
class KeyPair(HasPublicKey):
|
||||
class KeyPair(HasPublicKey, Serializable[KeyPairMapping]):
|
||||
"""A private-public keypair for a trackable FindMy accessory."""
|
||||
|
||||
def __init__(
|
||||
@@ -182,6 +196,34 @@ class KeyPair(HasPublicKey):
|
||||
key_bytes = self._priv_key.public_key().public_numbers().x
|
||||
return int.to_bytes(key_bytes, 28, "big")
|
||||
|
||||
@override
|
||||
def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping:
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "keypair",
|
||||
"private_key": base64.b64encode(self.private_key_bytes).decode("ascii"),
|
||||
"key_type": self._key_type.value,
|
||||
"name": self.name,
|
||||
},
|
||||
dst,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(cls, val: str | Path | KeyPairMapping, /) -> KeyPair:
|
||||
val = read_data_json(val)
|
||||
assert val["type"] == "keypair"
|
||||
|
||||
try:
|
||||
return cls(
|
||||
private_key=base64.b64decode(val["private_key"]),
|
||||
key_type=KeyType(val["key_type"]),
|
||||
name=val["name"],
|
||||
)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore KeyPair data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes:
|
||||
"""Do a Diffie-Hellman key exchange using another EC public key."""
|
||||
return self._priv_key.exchange(ec.ECDH(), other_pub_key)
|
||||
|
||||
@@ -11,11 +11,11 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
@@ -32,8 +32,10 @@ from findmy.errors import (
|
||||
UnauthorizedError,
|
||||
UnhandledProtocolError,
|
||||
)
|
||||
from findmy.reports.anisette import AnisetteMapping, get_provider_from_mapping
|
||||
from findmy.util import crypto
|
||||
from findmy.util.abc import Closable
|
||||
from findmy.util.abc import Closable, Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
from findmy.util.http import HttpResponse, HttpSession, decode_plist
|
||||
|
||||
from .reports import LocationReport, LocationReportsFetcher
|
||||
@@ -49,7 +51,8 @@ from .twofactor import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from findmy.accessory import RollingKeyPairSource
|
||||
from findmy.keys import HasHashedPublicKey
|
||||
@@ -70,6 +73,33 @@ class _AccountInfo(TypedDict):
|
||||
trusted_device_2fa: bool
|
||||
|
||||
|
||||
class _AccountStateMappingIds(TypedDict):
|
||||
uid: str
|
||||
devid: str
|
||||
|
||||
|
||||
class _AccountStateMappingAccount(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
info: _AccountInfo | None
|
||||
|
||||
|
||||
class _AccountStateMappingLoginState(TypedDict):
|
||||
state: int
|
||||
data: dict # TODO: make typed # noqa: TD002, TD003
|
||||
|
||||
|
||||
class AccountStateMapping(TypedDict):
|
||||
"""JSON mapping representing state of an Apple account instance."""
|
||||
|
||||
type: Literal["account"]
|
||||
|
||||
ids: _AccountStateMappingIds
|
||||
account: _AccountStateMappingAccount
|
||||
login: _AccountStateMappingLoginState
|
||||
anisette: AnisetteMapping
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_A = TypeVar("_A", bound="BaseAppleAccount")
|
||||
@@ -111,7 +141,7 @@ def _extract_phone_numbers(html: str) -> list[dict]:
|
||||
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
|
||||
|
||||
|
||||
class BaseAppleAccount(Closable, ABC):
|
||||
class BaseAppleAccount(Closable, Serializable[AccountStateMapping], ABC):
|
||||
"""Base class for an Apple account."""
|
||||
|
||||
@property
|
||||
@@ -151,33 +181,6 @@ class BaseAppleAccount(Closable, ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_json(self, path: str | Path | None = None) -> dict:
|
||||
"""
|
||||
Export the current state of the account as a JSON-serializable dictionary.
|
||||
|
||||
If `path` is provided, the output will also be written to that file.
|
||||
|
||||
The output of this method is guaranteed to be JSON-serializable, and passing
|
||||
the return value of this function as an argument to `BaseAppleAccount.from_json`
|
||||
will always result in an exact copy of the internal state as it was when exported.
|
||||
|
||||
This method is especially useful to avoid having to keep going through the login flow.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def from_json(self, json_: str | Path | Mapping, /) -> None:
|
||||
"""
|
||||
Restore the state from a previous `BaseAppleAccount.to_json` export.
|
||||
|
||||
If given a str or Path, it must point to a json file from `BaseAppleAccount.to_json`.
|
||||
Otherwise it should be the Mapping itself.
|
||||
|
||||
See `BaseAppleAccount.to_json` for more information.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def login(self, username: str, password: str) -> MaybeCoro[LoginState]:
|
||||
"""Log in to an Apple account using a username and password."""
|
||||
@@ -347,31 +350,33 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
anisette: BaseAnisetteProvider,
|
||||
user_id: str | None = None,
|
||||
device_id: str | None = None,
|
||||
*,
|
||||
state_info: AccountStateMapping | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the apple account.
|
||||
|
||||
:param anisette: An instance of `AsyncAnisetteProvider`.
|
||||
:param user_id: An optional user ID to use. Will be auto-generated if missing.
|
||||
:param device_id: An optional device ID to use. Will be auto-generated if missing.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._anisette: BaseAnisetteProvider = anisette
|
||||
self._uid: str = user_id or str(uuid.uuid4())
|
||||
self._devid: str = device_id or str(uuid.uuid4())
|
||||
self._uid: str = state_info["ids"]["uid"] if state_info else str(uuid.uuid4())
|
||||
self._devid: str = state_info["ids"]["devid"] if state_info else str(uuid.uuid4())
|
||||
|
||||
self._username: str | None = None
|
||||
self._password: str | None = None
|
||||
# TODO: combine, user/pass should be "all or nothing" # noqa: TD002, TD003
|
||||
self._username: str | None = state_info["account"]["username"] if state_info else None
|
||||
self._password: str | None = state_info["account"]["password"] if state_info else None
|
||||
|
||||
self._login_state: LoginState = LoginState.LOGGED_OUT
|
||||
self._login_state_data: dict = {}
|
||||
self._login_state: LoginState = (
|
||||
LoginState(state_info["login"]["state"]) if state_info else LoginState.LOGGED_OUT
|
||||
)
|
||||
self._login_state_data: dict = state_info["login"]["data"] if state_info else {}
|
||||
|
||||
self._account_info: _AccountInfo | None = None
|
||||
self._account_info: _AccountInfo | None = (
|
||||
state_info["account"]["info"] if state_info else None
|
||||
)
|
||||
|
||||
self._http: HttpSession = HttpSession()
|
||||
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
|
||||
@@ -433,36 +438,39 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
return self._account_info["last_name"] if self._account_info else None
|
||||
|
||||
@override
|
||||
def to_json(self, path: str | Path | None = None) -> dict:
|
||||
result = {
|
||||
def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping:
|
||||
res: AccountStateMapping = {
|
||||
"type": "account",
|
||||
"ids": {"uid": self._uid, "devid": self._devid},
|
||||
"account": {
|
||||
"username": self._username,
|
||||
"password": self._password,
|
||||
"info": self._account_info,
|
||||
},
|
||||
"login_state": {
|
||||
"login": {
|
||||
"state": self._login_state.value,
|
||||
"data": self._login_state_data,
|
||||
},
|
||||
"anisette": self._anisette.to_json(),
|
||||
}
|
||||
if path is not None:
|
||||
Path(path).write_text(json.dumps(result, indent=4))
|
||||
return result
|
||||
|
||||
return save_and_return_json(res, path)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(self, json_: str | Path | Mapping, /) -> None:
|
||||
data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_
|
||||
def from_json(
|
||||
cls,
|
||||
val: str | Path | AccountStateMapping,
|
||||
/,
|
||||
*,
|
||||
anisette_libs_path: str | Path | None = None,
|
||||
) -> AsyncAppleAccount:
|
||||
val = read_data_json(val)
|
||||
assert val["type"] == "account"
|
||||
|
||||
try:
|
||||
self._uid = data["ids"]["uid"]
|
||||
self._devid = data["ids"]["devid"]
|
||||
|
||||
self._username = data["account"]["username"]
|
||||
self._password = data["account"]["password"]
|
||||
self._account_info = data["account"]["info"]
|
||||
|
||||
self._login_state = LoginState(data["login_state"]["state"])
|
||||
self._login_state_data = data["login_state"]["data"]
|
||||
ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path)
|
||||
return cls(ani_provider, state_info=val)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
@@ -976,13 +984,12 @@ class AppleAccount(BaseAppleAccount):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
anisette: BaseAnisetteProvider,
|
||||
user_id: str | None = None,
|
||||
device_id: str | None = None,
|
||||
*,
|
||||
state_info: AccountStateMapping | None = None,
|
||||
) -> None:
|
||||
"""See `AsyncAppleAccount.__init__`."""
|
||||
self._asyncacc = AsyncAppleAccount(anisette=anisette, user_id=user_id, device_id=device_id)
|
||||
self._asyncacc = AsyncAppleAccount(anisette=anisette, state_info=state_info)
|
||||
|
||||
try:
|
||||
self._evt_loop = asyncio.get_running_loop()
|
||||
@@ -1022,12 +1029,25 @@ class AppleAccount(BaseAppleAccount):
|
||||
return self._asyncacc.last_name
|
||||
|
||||
@override
|
||||
def to_json(self, path: str | Path | None = None) -> dict:
|
||||
return self._asyncacc.to_json(path)
|
||||
def to_json(self, dst: str | Path | None = None, /) -> AccountStateMapping:
|
||||
return self._asyncacc.to_json(dst)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(self, json_: str | Path | Mapping, /) -> None:
|
||||
return self._asyncacc.from_json(json_)
|
||||
def from_json(
|
||||
cls,
|
||||
val: str | Path | AccountStateMapping,
|
||||
/,
|
||||
*,
|
||||
anisette_libs_path: str | Path | None = None,
|
||||
) -> AppleAccount:
|
||||
val = read_data_json(val)
|
||||
try:
|
||||
ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path)
|
||||
return cls(ani_provider, state_info=val)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
|
||||
@override
|
||||
def login(self, username: str, password: str) -> LoginState:
|
||||
|
||||
@@ -10,17 +10,49 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
from typing import BinaryIO, Literal, TypedDict, Union
|
||||
|
||||
from anisette import Anisette, AnisetteHeaders
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.util.abc import Closable, Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
from findmy.util.http import HttpSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteAnisetteMapping(TypedDict):
|
||||
"""JSON mapping representing state of a remote Anisette provider."""
|
||||
|
||||
type: Literal["aniRemote"]
|
||||
url: str
|
||||
|
||||
|
||||
class LocalAnisetteMapping(TypedDict):
|
||||
"""JSON mapping representing state of a local Anisette provider."""
|
||||
|
||||
type: Literal["aniLocal"]
|
||||
prov_data: str
|
||||
|
||||
|
||||
AnisetteMapping = Union[RemoteAnisetteMapping, LocalAnisetteMapping]
|
||||
|
||||
|
||||
def get_provider_from_mapping(
|
||||
mapping: AnisetteMapping,
|
||||
*,
|
||||
libs_path: str | Path | None = None,
|
||||
) -> RemoteAnisetteProvider | LocalAnisetteProvider:
|
||||
"""Get the correct Anisette provider instance from saved JSON data."""
|
||||
if mapping["type"] == "aniRemote":
|
||||
return RemoteAnisetteProvider.from_json(mapping)
|
||||
if mapping["type"] == "aniLocal":
|
||||
return LocalAnisetteProvider.from_json(mapping, libs_path=libs_path)
|
||||
msg = f"Unknown anisette type: {mapping['type']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class BaseAnisetteProvider(Closable, Serializable, ABC):
|
||||
"""
|
||||
Abstract base class for Anisette providers.
|
||||
@@ -156,7 +188,7 @@ class BaseAnisetteProvider(Closable, Serializable, ABC):
|
||||
return cpd
|
||||
|
||||
|
||||
class RemoteAnisetteProvider(BaseAnisetteProvider):
|
||||
class RemoteAnisetteProvider(BaseAnisetteProvider, Serializable[RemoteAnisetteMapping]):
|
||||
"""Anisette provider. Fetches headers from a remote Anisette server."""
|
||||
|
||||
_ANISETTE_DATA_VALID_FOR = 30
|
||||
@@ -174,20 +206,25 @@ class RemoteAnisetteProvider(BaseAnisetteProvider):
|
||||
self._closed = False
|
||||
|
||||
@override
|
||||
def serialize(self) -> dict:
|
||||
def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping:
|
||||
"""See `BaseAnisetteProvider.serialize`."""
|
||||
return {
|
||||
"type": "aniRemote",
|
||||
"url": self._server_url,
|
||||
}
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "aniRemote",
|
||||
"url": self._server_url,
|
||||
},
|
||||
dst,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def deserialize(cls, data: dict) -> RemoteAnisetteProvider:
|
||||
def from_json(cls, val: str | Path | RemoteAnisetteMapping) -> RemoteAnisetteProvider:
|
||||
"""See `BaseAnisetteProvider.deserialize`."""
|
||||
assert data["type"] == "aniRemote"
|
||||
val = read_data_json(val)
|
||||
|
||||
server_url = data["url"]
|
||||
assert val["type"] == "aniRemote"
|
||||
|
||||
server_url = val["url"]
|
||||
|
||||
return cls(server_url)
|
||||
|
||||
@@ -245,7 +282,7 @@ class RemoteAnisetteProvider(BaseAnisetteProvider):
|
||||
logger.warning("Error closing anisette HTTP session: %s", e)
|
||||
|
||||
|
||||
class LocalAnisetteProvider(BaseAnisetteProvider):
|
||||
class LocalAnisetteProvider(BaseAnisetteProvider, Serializable[LocalAnisetteMapping]):
|
||||
"""Anisette provider. Generates headers without a remote server using the `anisette` library."""
|
||||
|
||||
def __init__(
|
||||
@@ -265,6 +302,7 @@ class LocalAnisetteProvider(BaseAnisetteProvider):
|
||||
"The Anisette engine will download libraries required for operation, "
|
||||
"this may take a few seconds...",
|
||||
)
|
||||
if libs_path is None:
|
||||
logger.info(
|
||||
"To speed up future local Anisette initializations, "
|
||||
"provide a filesystem path to load the libraries from.",
|
||||
@@ -289,24 +327,34 @@ class LocalAnisetteProvider(BaseAnisetteProvider):
|
||||
)
|
||||
|
||||
@override
|
||||
def serialize(self) -> dict:
|
||||
def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping:
|
||||
"""See `BaseAnisetteProvider.serialize`."""
|
||||
with BytesIO() as buf:
|
||||
self._ani.save_provisioning(buf)
|
||||
prov_data = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"type": "aniLocal",
|
||||
"prov_data": prov_data,
|
||||
}
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "aniLocal",
|
||||
"prov_data": prov_data,
|
||||
},
|
||||
dst,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def deserialize(cls, data: dict, libs_path: str | Path | None = None) -> LocalAnisetteProvider:
|
||||
def from_json(
|
||||
cls,
|
||||
val: str | Path | LocalAnisetteMapping,
|
||||
*,
|
||||
libs_path: str | Path | None = None,
|
||||
) -> LocalAnisetteProvider:
|
||||
"""See `BaseAnisetteProvider.deserialize`."""
|
||||
assert data["type"] == "aniLocal"
|
||||
val = read_data_json(val)
|
||||
|
||||
state_blob = BytesIO(base64.b64decode(data["prov_data"]))
|
||||
assert val["type"] == "aniLocal"
|
||||
|
||||
state_blob = BytesIO(base64.b64decode(val["prov_data"]))
|
||||
|
||||
return cls(state_blob=state_blob, libs_path=libs_path)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, cast, overload
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
@@ -16,17 +16,42 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.accessory import RollingKeyPairSource
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping
|
||||
from findmy.util.abc import Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from .account import AsyncAppleAccount
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocationReport(HasHashedPublicKey):
|
||||
class LocationReportEncryptedMapping(TypedDict):
|
||||
"""JSON mapping representing an encrypted location report."""
|
||||
|
||||
type: Literal["locReportEncrypted"]
|
||||
|
||||
payload: str
|
||||
hashed_adv_key: str
|
||||
|
||||
|
||||
class LocationReportDecryptedMapping(TypedDict):
|
||||
"""JSON mapping representing a decrypted location report."""
|
||||
|
||||
type: Literal["locReportDecrypted"]
|
||||
|
||||
payload: str
|
||||
hashed_adv_key: str
|
||||
key: KeyPairMapping
|
||||
|
||||
|
||||
LocationReportMapping = Union[LocationReportEncryptedMapping, LocationReportDecryptedMapping]
|
||||
|
||||
|
||||
class LocationReport(HasHashedPublicKey, Serializable[LocationReportMapping]):
|
||||
"""Location report corresponding to a certain `HasHashedPublicKey`."""
|
||||
|
||||
def __init__(
|
||||
@@ -66,9 +91,13 @@ class LocationReport(HasHashedPublicKey):
|
||||
"""Whether the report is currently decrypted."""
|
||||
return self._decrypted_data is not None
|
||||
|
||||
def can_decrypt(self, key: KeyPair, /) -> bool:
|
||||
"""Whether the report can be decrypted using the given key."""
|
||||
return key.hashed_adv_key_bytes == self._hashed_adv_key
|
||||
|
||||
def decrypt(self, key: KeyPair) -> None:
|
||||
"""Decrypt the report using its corresponding `KeyPair`."""
|
||||
if key.hashed_adv_key_bytes != self._hashed_adv_key:
|
||||
if not self.can_decrypt(key):
|
||||
msg = "Cannot decrypt with this key!"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -163,6 +192,86 @@ class LocationReport(HasHashedPublicKey):
|
||||
status_bytes = self._decrypted_data[1][9:10]
|
||||
return int.from_bytes(status_bytes, "big")
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: Literal[True],
|
||||
) -> LocationReportEncryptedMapping:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: Literal[False],
|
||||
) -> LocationReportDecryptedMapping:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: None = None,
|
||||
) -> LocationReportMapping:
|
||||
pass
|
||||
|
||||
@override
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: bool | None = None,
|
||||
) -> LocationReportMapping:
|
||||
if include_key is None:
|
||||
include_key = self.is_decrypted
|
||||
|
||||
if include_key:
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "locReportDecrypted",
|
||||
"payload": base64.b64encode(self._payload).decode("utf-8"),
|
||||
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
|
||||
"key": self.key.to_json(),
|
||||
},
|
||||
dst,
|
||||
)
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "locReportEncrypted",
|
||||
"payload": base64.b64encode(self._payload).decode("utf-8"),
|
||||
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
|
||||
},
|
||||
dst,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport:
|
||||
val = read_data_json(val)
|
||||
assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted"
|
||||
|
||||
try:
|
||||
report = cls(
|
||||
payload=base64.b64decode(val["payload"]),
|
||||
hashed_adv_key=base64.b64decode(val["hashed_adv_key"]),
|
||||
)
|
||||
if val["type"] == "locReportDecrypted":
|
||||
key = KeyPair.from_json(val["key"])
|
||||
report.decrypt(key)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
else:
|
||||
return report
|
||||
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
|
||||
@@ -5,8 +5,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Generic, Self, TypeVar
|
||||
|
||||
logging.getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Closable(ABC):
|
||||
@@ -38,16 +43,37 @@ class Closable(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
_T = TypeVar("_T", bound=Mapping)
|
||||
|
||||
|
||||
class Serializable(Generic[_T], ABC):
|
||||
"""ABC for serializable classes."""
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self) -> dict:
|
||||
"""Serialize the object to a JSON-serializable dictionary."""
|
||||
def to_json(self, dst: str | Path | None = None, /) -> _T:
|
||||
"""
|
||||
Export the current state of the object as a JSON-serializable dictionary.
|
||||
|
||||
If an argument is provided, the output will also be written to that file.
|
||||
|
||||
The output of this method is guaranteed to be JSON-serializable, and passing
|
||||
the return value of this function as an argument to `Serializable.from_json`
|
||||
will always result in an exact copy of the internal state as it was when exported.
|
||||
|
||||
You are encouraged to save and load object states to and from disk whenever possible,
|
||||
to prevent unnecessary API calls or otherwise unexpected behavior.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def deserialize(cls, data: dict) -> Serializable:
|
||||
"""Deserialize the object from a JSON-serializable dictionary."""
|
||||
def from_json(cls, val: str | Path | _T, /) -> Self:
|
||||
"""
|
||||
Restore state from a previous `Closable.to_json` export.
|
||||
|
||||
If given a str or Path, it must point to a json file from `Serializable.to_json`.
|
||||
Otherwise, it should be the Mapping itself.
|
||||
|
||||
See `Serializable.to_json` for more information.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
34
findmy/util/files.py
Normal file
34
findmy/util/files.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Utilities to simplify reading and writing data from and to files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, cast
|
||||
|
||||
T = TypeVar("T", bound=Mapping)
|
||||
|
||||
|
||||
def save_and_return_json(data: T, dst: str | Path | None) -> T:
|
||||
"""Save and return a JSON-serializable data structure."""
|
||||
if dst is None:
|
||||
return data
|
||||
|
||||
if isinstance(dst, str):
|
||||
dst = Path(dst)
|
||||
|
||||
dst.write_text(json.dumps(data, indent=4))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def read_data_json(val: str | Path | T) -> T:
|
||||
"""Read JSON data from a file if a path is passed, or return the argument itself."""
|
||||
if isinstance(val, str):
|
||||
val = Path(val)
|
||||
|
||||
if isinstance(val, Path):
|
||||
val = cast("T", json.loads(val.read_text()))
|
||||
|
||||
return val
|
||||
144
findmy/util/session.py
Normal file
144
findmy/util/session.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Logic related to serializable classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, Union
|
||||
|
||||
from findmy.util.abc import Closable, Serializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
|
||||
_S = TypeVar("_S", bound=Serializable)
|
||||
_SC = TypeVar("_SC", bound=Union[Serializable, Closable])
|
||||
|
||||
|
||||
class _BaseSessionManager(Generic[_SC]):
|
||||
"""Base class for session managers."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[_SC, str | Path | None] = {}
|
||||
|
||||
def _add(self, obj: _SC, path: str | Path | None) -> None:
|
||||
self._sessions[obj] = path
|
||||
|
||||
def remove(self, obj: _SC) -> None:
|
||||
self._sessions.pop(obj, None)
|
||||
|
||||
def save(self) -> None:
|
||||
for obj, path in self._sessions.items():
|
||||
if isinstance(obj, Serializable):
|
||||
obj.to_json(path)
|
||||
|
||||
async def close(self) -> None:
|
||||
for obj in self._sessions:
|
||||
if isinstance(obj, Closable):
|
||||
await obj.close()
|
||||
|
||||
async def save_and_close(self) -> None:
|
||||
for obj, path in self._sessions.items():
|
||||
if isinstance(obj, Serializable):
|
||||
obj.to_json(path)
|
||||
if isinstance(obj, Closable):
|
||||
await obj.close()
|
||||
|
||||
def get_random(self) -> _SC:
|
||||
if not self._sessions:
|
||||
msg = "No objects in the session manager."
|
||||
raise ValueError(msg)
|
||||
return random.choice(list(self._sessions.keys())) # noqa: S311
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._sessions)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
_exc_type: type[BaseException] | None,
|
||||
_exc_val: BaseException | None,
|
||||
_exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.save()
|
||||
|
||||
|
||||
class MixedSessionManager(_BaseSessionManager[Union[Serializable, Closable]]):
|
||||
"""Allows any Serializable or Closable object."""
|
||||
|
||||
def new(
|
||||
self,
|
||||
c_type: type[_SC],
|
||||
path: str | Path | None = None,
|
||||
/,
|
||||
*args: Any, # noqa: ANN401
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> _SC:
|
||||
"""Add an object to the manager by instantiating it using its constructor."""
|
||||
obj = c_type(*args, **kwargs)
|
||||
if isinstance(obj, Serializable) and path is not None:
|
||||
obj.to_json(path)
|
||||
self._add(obj, path)
|
||||
return obj
|
||||
|
||||
def add_from_json(
|
||||
self,
|
||||
c_type: type[_S],
|
||||
path: str | Path,
|
||||
/,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> _S:
|
||||
"""Add an object to the manager by deserializing it from its JSON representation."""
|
||||
obj = c_type.from_json(path, **kwargs)
|
||||
self._add(obj, path)
|
||||
return obj
|
||||
|
||||
def add(self, obj: Serializable | Closable, path: str | Path | None = None, /) -> None:
|
||||
"""Add an object to the session manager."""
|
||||
self._add(obj, path)
|
||||
|
||||
|
||||
class UniformSessionManager(Generic[_SC], _BaseSessionManager[_SC]):
|
||||
"""Only allows a single type of Serializable object."""
|
||||
|
||||
def __init__(self, obj_type: type[_SC]) -> None:
|
||||
"""Create a new session manager."""
|
||||
super().__init__()
|
||||
self._obj_type = obj_type
|
||||
|
||||
def new(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
/,
|
||||
*args: Any, # noqa: ANN401
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> _SC:
|
||||
"""Add an object to the manager by instantiating it using its constructor."""
|
||||
obj = self._obj_type(*args, **kwargs)
|
||||
if isinstance(obj, Serializable) and path is not None:
|
||||
obj.to_json(path)
|
||||
self._add(obj, path)
|
||||
return obj
|
||||
|
||||
def add_from_json(
|
||||
self,
|
||||
path: str | Path,
|
||||
/,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> _SC:
|
||||
"""Add an object to the manager by deserializing it from its JSON representation."""
|
||||
if not issubclass(self._obj_type, Serializable):
|
||||
msg = "Can only add objects of type Serializable."
|
||||
raise TypeError(msg)
|
||||
obj = self._obj_type.from_json(path, **kwargs)
|
||||
self._add(obj, path)
|
||||
return obj
|
||||
|
||||
def add(self, obj: _SC, path: str | Path | None = None, /) -> None:
|
||||
"""Add an object to the session manager."""
|
||||
if not isinstance(obj, self._obj_type):
|
||||
msg = f"Object must be of type {self._obj_type.__name__}"
|
||||
raise TypeError(msg)
|
||||
self._add(obj, path)
|
||||
Reference in New Issue
Block a user