diff --git a/examples/_login.py b/examples/_login.py index 6891b19..b602400 100644 --- a/examples/_login.py +++ b/examples/_login.py @@ -1,15 +1,13 @@ -# ruff: noqa: ASYNC230 +from __future__ import annotations from findmy.reports import ( AppleAccount, AsyncAppleAccount, - BaseAnisetteProvider, LoginState, SmsSecondFactorMethod, TrustedDeviceSecondFactorMethod, ) - -ACCOUNT_STORE = "account.json" +from findmy.reports.anisette import LocalAnisetteProvider, RemoteAnisetteProvider def _login_sync(account: AppleAccount) -> None: @@ -66,27 +64,45 @@ async def _login_async(account: AsyncAppleAccount) -> None: await method.submit(code) -def get_account_sync(anisette: BaseAnisetteProvider) -> AppleAccount: +def get_account_sync( + store_path: str, + anisette_url: str | None, + libs_path: str | None, +) -> AppleAccount: """Tries to restore a saved Apple account, or prompts the user for login otherwise. (sync)""" - acc = AppleAccount(anisette=anisette) - acc_store = "account.json" try: - acc.from_json(acc_store) + acc = AppleAccount.from_json(store_path, anisette_libs_path=libs_path) except FileNotFoundError: + ani = ( + LocalAnisetteProvider(libs_path=libs_path) + if anisette_url is None + else RemoteAnisetteProvider(anisette_url) + ) + acc = AppleAccount(ani) _login_sync(acc) - acc.to_json(acc_store) + + acc.to_json(store_path) return acc -async def get_account_async(anisette: BaseAnisetteProvider) -> AsyncAppleAccount: +async def get_account_async( + store_path: str, + anisette_url: str | None, + libs_path: str | None, +) -> AsyncAppleAccount: """Tries to restore a saved Apple account, or prompts the user for login otherwise. (async)""" - acc = AsyncAppleAccount(anisette=anisette) - acc_store = "account.json" try: - acc.from_json(acc_store) + acc = AsyncAppleAccount.from_json(store_path, anisette_libs_path=libs_path) except FileNotFoundError: + ani = ( + LocalAnisetteProvider(libs_path=libs_path) + if anisette_url is None + else RemoteAnisetteProvider(anisette_url) + ) + acc = AsyncAppleAccount(ani) await _login_async(acc) - acc.to_json(acc_store) + + acc.to_json(store_path) return acc diff --git a/examples/fetch_reports.py b/examples/fetch_reports.py index 3758417..21a6e2b 100644 --- a/examples/fetch_reports.py +++ b/examples/fetch_reports.py @@ -4,28 +4,47 @@ import sys from _login import get_account_sync from findmy import KeyPair -from findmy.reports import RemoteAnisetteProvider -# URL to (public or local) anisette server -ANISETTE_SERVER = "http://localhost:6969" +# Path where login session will be stored. +# This is necessary to avoid generating a new session every time we log in. +STORE_PATH = "account.json" + +# URL to LOCAL anisette server. Set to None to use built-in Anisette generator instead (recommended) +# IF YOU USE A PUBLIC SERVER, DO NOT COMPLAIN THAT YOU KEEP RUNNING INTO AUTHENTICATION ERRORS! +# If you change this value, make sure to remove the account store file. +ANISETTE_SERVER = None + +# Path where Anisette libraries will be stored. +# This is only relevant when using the built-in Anisette server. +# It can be omitted (set to None) to avoid saving to disk, +# but specifying a path is highly recommended to avoid downloading the bundle on every run. +ANISETTE_LIBS_PATH = "ani_libs.bin" logging.basicConfig(level=logging.INFO) def fetch_reports(priv_key: str) -> int: - key = KeyPair.from_b64(priv_key) - acc = get_account_sync( - RemoteAnisetteProvider(ANISETTE_SERVER), - ) + # Step 0: construct an account instance + # We use a helper for this to simplify interactive authentication + acc = get_account_sync(STORE_PATH, ANISETTE_SERVER, ANISETTE_LIBS_PATH) print(f"Logged in as: {acc.account_name} ({acc.first_name} {acc.last_name})") - # It's that simple! + # Step 1: construct a key object and get its location reports + key = KeyPair.from_b64(priv_key) reports = acc.fetch_last_reports(key) + + # Step 2: print the reports! for report in sorted(reports): print(report) - return 1 + # We can save the report to a file if we want + report.to_json("last_report.json") + + # Step 3: Make sure to save account state when you're done! + acc.to_json(STORE_PATH) + + return 0 if __name__ == "__main__": diff --git a/examples/fetch_reports_async.py b/examples/fetch_reports_async.py index d267a6d..169faa9 100644 --- a/examples/fetch_reports_async.py +++ b/examples/fetch_reports_async.py @@ -5,30 +5,49 @@ import sys from _login import get_account_async from findmy import KeyPair -from findmy.reports import RemoteAnisetteProvider -# URL to (public or local) anisette server -ANISETTE_SERVER = "http://localhost:6969" +# Path where login session will be stored. +# This is necessary to avoid generating a new session every time we log in. +STORE_PATH = "account.json" + +# URL to LOCAL anisette server. Set to None to use built-in Anisette generator instead (recommended) +# IF YOU USE A PUBLIC SERVER, DO NOT COMPLAIN THAT YOU KEEP RUNNING INTO AUTHENTICATION ERRORS! +# If you change this value, make sure to remove the account store file. +ANISETTE_SERVER = None + +# Path where Anisette libraries will be stored. +# This is only relevant when using the built-in Anisette server. +# It can be omitted (set to None) to avoid saving to disk, +# but specifying a path is highly recommended to avoid downloading the bundle on every run. +ANISETTE_LIBS_PATH = "ani_libs.bin" logging.basicConfig(level=logging.INFO) async def fetch_reports(priv_key: str) -> int: - key = KeyPair.from_b64(priv_key) - acc = await get_account_async( - RemoteAnisetteProvider(ANISETTE_SERVER), - ) + # Step 0: construct an account instance + # We use a helper for this to simplify interactive authentication + acc = await get_account_async(STORE_PATH, ANISETTE_SERVER, ANISETTE_LIBS_PATH) try: print(f"Logged in as: {acc.account_name} ({acc.first_name} {acc.last_name})") - # It's that simple! + # Step 1: construct a key object and get its location reports + key = KeyPair.from_b64(priv_key) reports = await acc.fetch_last_reports(key) + + # Step 2: print the reports! for report in sorted(reports): print(report) + + # We can save the report to a file if we want + report.to_json("last_report.json") finally: await acc.close() + # Make sure to save account state when you're done! + acc.to_json(STORE_PATH) + return 0 diff --git a/examples/real_airtag.py b/examples/real_airtag.py index 5eeb858..ff91378 100644 --- a/examples/real_airtag.py +++ b/examples/real_airtag.py @@ -11,10 +11,21 @@ from pathlib import Path from _login import get_account_sync from findmy import FindMyAccessory -from findmy.reports import RemoteAnisetteProvider -# URL to (public or local) anisette server -ANISETTE_SERVER = "http://localhost:6969" +# Path where login session will be stored. +# This is necessary to avoid generating a new session every time we log in. +STORE_PATH = "account.json" + +# URL to LOCAL anisette server. Set to None to use built-in Anisette generator instead (recommended) +# IF YOU USE A PUBLIC SERVER, DO NOT COMPLAIN THAT YOU KEEP RUNNING INTO AUTHENTICATION ERRORS! +# If you change this value, make sure to remove the account store file. +ANISETTE_SERVER = None + +# Path where Anisette libraries will be stored. +# This is only relevant when using the built-in Anisette server. +# It can be omitted (set to None) to avoid saving to disk, +# but specifying a path is highly recommended to avoid downloading the bundle on every run. +ANISETTE_LIBS_PATH = "ani_libs.bin" logging.basicConfig(level=logging.INFO) @@ -26,8 +37,7 @@ def main(plist_path: str) -> int: # Step 1: log into an Apple account print("Logging into account") - anisette = RemoteAnisetteProvider(ANISETTE_SERVER) - acc = get_account_sync(anisette) + acc = get_account_sync(STORE_PATH, ANISETTE_SERVER, ANISETTE_LIBS_PATH) # step 2: fetch reports! print("Fetching reports") @@ -39,6 +49,9 @@ def main(plist_path: str) -> int: for report in sorted(reports): print(f" - {report}") + # step 4: save current account state to disk + acc.to_json(STORE_PATH) + return 0 diff --git a/findmy/accessory.py b/findmy/accessory.py index d44ef8f..544170a 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -6,25 +6,40 @@ Accessories could be anything ranging from AirTags to iPhones. from __future__ import annotations -import json import logging import plistlib from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import IO, TYPE_CHECKING, overload +from typing import IO, TYPE_CHECKING, Literal, TypedDict, overload from typing_extensions import override +from findmy.util.abc import Serializable +from findmy.util.files import read_data_json, save_and_return_json + from .keys import KeyGenerator, KeyPair, KeyType from .util import crypto if TYPE_CHECKING: - from collections.abc import Generator, Mapping + from collections.abc import Generator logger = logging.getLogger(__name__) +class FindMyAccessoryMapping(TypedDict): + """JSON mapping representing state of a FindMyAccessory instance.""" + + type: Literal["accessory"] + master_key: str + skn: str + sks: str + paired_at: str + name: str | None + model: str | None + identifier: str | None + + class RollingKeyPairSource(ABC): """A class that generates rolling `KeyPair`s.""" @@ -67,7 +82,7 @@ class RollingKeyPairSource(ABC): return keys -class FindMyAccessory(RollingKeyPairSource): +class FindMyAccessory(RollingKeyPairSource, Serializable[FindMyAccessoryMapping]): """A findable Find My-accessory using official key rollover.""" def __init__( # noqa: PLR0913 @@ -242,9 +257,10 @@ class FindMyAccessory(RollingKeyPairSource): identifier=identifier, ) - def to_json(self, path: str | Path | None = None) -> dict[str, str | int | None]: - """Convert the accessory to a JSON-serializable dictionary.""" - d = { + @override + def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping: + res: FindMyAccessoryMapping = { + "type": "accessory", "master_key": self._primary_gen.master_key.hex(), "skn": self.skn.hex(), "sks": self.sks.hex(), @@ -253,23 +269,32 @@ class FindMyAccessory(RollingKeyPairSource): "model": self.model, "identifier": self.identifier, } - if path is not None: - Path(path).write_text(json.dumps(d, indent=4)) - return d + + return save_and_return_json(res, path) @classmethod - def from_json(cls, json_: str | Path | Mapping, /) -> FindMyAccessory: - """Create a FindMyAccessory from a JSON file.""" - data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_ - return cls( - master_key=bytes.fromhex(data["master_key"]), - skn=bytes.fromhex(data["skn"]), - sks=bytes.fromhex(data["sks"]), - paired_at=datetime.fromisoformat(data["paired_at"]), - name=data["name"], - model=data["model"], - identifier=data["identifier"], - ) + @override + def from_json( + cls, + val: str | Path | FindMyAccessoryMapping, + /, + ) -> FindMyAccessory: + val = read_data_json(val) + assert val["type"] == "accessory" + + try: + return cls( + master_key=bytes.fromhex(val["master_key"]), + skn=bytes.fromhex(val["skn"]), + sks=bytes.fromhex(val["sks"]), + paired_at=datetime.fromisoformat(val["paired_at"]), + name=val["name"], + model=val["model"], + identifier=val["identifier"], + ) + except KeyError as e: + msg = f"Failed to restore account data: {e}" + raise ValueError(msg) from None class AccessoryKeyGenerator(KeyGenerator[KeyPair]): diff --git a/findmy/keys.py b/findmy/keys.py index bc2644e..76b849f 100644 --- a/findmy/keys.py +++ b/findmy/keys.py @@ -7,15 +7,19 @@ import hashlib import secrets from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar, overload from cryptography.hazmat.primitives.asymmetric import ec from typing_extensions import override +from findmy.util.abc import Serializable +from findmy.util.files import read_data_json, save_and_return_json + from .util import crypto, parsers if TYPE_CHECKING: from collections.abc import Generator + from pathlib import Path class KeyType(Enum): @@ -26,6 +30,16 @@ class KeyType(Enum): SECONDARY = 2 +class KeyPairMapping(TypedDict): + """JSON mapping representing a KeyPair.""" + + type: Literal["keypair"] + + private_key: str + key_type: int + name: str | None + + class HasHashedPublicKey(ABC): """ ABC for anything that has a public, hashed FindMy-key. @@ -113,7 +127,7 @@ class HasPublicKey(HasHashedPublicKey, ABC): ) -class KeyPair(HasPublicKey): +class KeyPair(HasPublicKey, Serializable[KeyPairMapping]): """A private-public keypair for a trackable FindMy accessory.""" def __init__( @@ -182,6 +196,34 @@ class KeyPair(HasPublicKey): key_bytes = self._priv_key.public_key().public_numbers().x return int.to_bytes(key_bytes, 28, "big") + @override + def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping: + return save_and_return_json( + { + "type": "keypair", + "private_key": base64.b64encode(self.private_key_bytes).decode("ascii"), + "key_type": self._key_type.value, + "name": self.name, + }, + dst, + ) + + @classmethod + @override + def from_json(cls, val: str | Path | KeyPairMapping, /) -> KeyPair: + val = read_data_json(val) + assert val["type"] == "keypair" + + try: + return cls( + private_key=base64.b64decode(val["private_key"]), + key_type=KeyType(val["key_type"]), + name=val["name"], + ) + except KeyError as e: + msg = f"Failed to restore KeyPair data: {e}" + raise ValueError(msg) from None + def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes: """Do a Diffie-Hellman key exchange using another EC public key.""" return self._priv_key.exchange(ec.ECDH(), other_pub_key) diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 95da61b..6aef1d1 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -11,11 +11,11 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone from functools import wraps -from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, + Literal, TypedDict, TypeVar, cast, @@ -32,8 +32,10 @@ from findmy.errors import ( UnauthorizedError, UnhandledProtocolError, ) +from findmy.reports.anisette import AnisetteMapping, get_provider_from_mapping from findmy.util import crypto -from findmy.util.abc import Closable +from findmy.util.abc import Closable, Serializable +from findmy.util.files import read_data_json, save_and_return_json from findmy.util.http import HttpResponse, HttpSession, decode_plist from .reports import LocationReport, LocationReportsFetcher @@ -49,7 +51,8 @@ from .twofactor import ( ) if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Sequence + from pathlib import Path from findmy.accessory import RollingKeyPairSource from findmy.keys import HasHashedPublicKey @@ -70,6 +73,33 @@ class _AccountInfo(TypedDict): trusted_device_2fa: bool +class _AccountStateMappingIds(TypedDict): + uid: str + devid: str + + +class _AccountStateMappingAccount(TypedDict): + username: str | None + password: str | None + info: _AccountInfo | None + + +class _AccountStateMappingLoginState(TypedDict): + state: int + data: dict # TODO: make typed # noqa: TD002, TD003 + + +class AccountStateMapping(TypedDict): + """JSON mapping representing state of an Apple account instance.""" + + type: Literal["account"] + + ids: _AccountStateMappingIds + account: _AccountStateMappingAccount + login: _AccountStateMappingLoginState + anisette: AnisetteMapping + + _P = ParamSpec("_P") _R = TypeVar("_R") _A = TypeVar("_A", bound="BaseAppleAccount") @@ -111,7 +141,7 @@ def _extract_phone_numbers(html: str) -> list[dict]: return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", []) -class BaseAppleAccount(Closable, ABC): +class BaseAppleAccount(Closable, Serializable[AccountStateMapping], ABC): """Base class for an Apple account.""" @property @@ -151,33 +181,6 @@ class BaseAppleAccount(Closable, ABC): """ raise NotImplementedError - @abstractmethod - def to_json(self, path: str | Path | None = None) -> dict: - """ - Export the current state of the account as a JSON-serializable dictionary. - - If `path` is provided, the output will also be written to that file. - - The output of this method is guaranteed to be JSON-serializable, and passing - the return value of this function as an argument to `BaseAppleAccount.from_json` - will always result in an exact copy of the internal state as it was when exported. - - This method is especially useful to avoid having to keep going through the login flow. - """ - raise NotImplementedError - - @abstractmethod - def from_json(self, json_: str | Path | Mapping, /) -> None: - """ - Restore the state from a previous `BaseAppleAccount.to_json` export. - - If given a str or Path, it must point to a json file from `BaseAppleAccount.to_json`. - Otherwise it should be the Mapping itself. - - See `BaseAppleAccount.to_json` for more information. - """ - raise NotImplementedError - @abstractmethod def login(self, username: str, password: str) -> MaybeCoro[LoginState]: """Log in to an Apple account using a username and password.""" @@ -347,31 +350,33 @@ class AsyncAppleAccount(BaseAppleAccount): def __init__( self, - *, anisette: BaseAnisetteProvider, - user_id: str | None = None, - device_id: str | None = None, + *, + state_info: AccountStateMapping | None = None, ) -> None: """ Initialize the apple account. :param anisette: An instance of `AsyncAnisetteProvider`. - :param user_id: An optional user ID to use. Will be auto-generated if missing. - :param device_id: An optional device ID to use. Will be auto-generated if missing. """ super().__init__() self._anisette: BaseAnisetteProvider = anisette - self._uid: str = user_id or str(uuid.uuid4()) - self._devid: str = device_id or str(uuid.uuid4()) + self._uid: str = state_info["ids"]["uid"] if state_info else str(uuid.uuid4()) + self._devid: str = state_info["ids"]["devid"] if state_info else str(uuid.uuid4()) - self._username: str | None = None - self._password: str | None = None + # TODO: combine, user/pass should be "all or nothing" # noqa: TD002, TD003 + self._username: str | None = state_info["account"]["username"] if state_info else None + self._password: str | None = state_info["account"]["password"] if state_info else None - self._login_state: LoginState = LoginState.LOGGED_OUT - self._login_state_data: dict = {} + self._login_state: LoginState = ( + LoginState(state_info["login"]["state"]) if state_info else LoginState.LOGGED_OUT + ) + self._login_state_data: dict = state_info["login"]["data"] if state_info else {} - self._account_info: _AccountInfo | None = None + self._account_info: _AccountInfo | None = ( + state_info["account"]["info"] if state_info else None + ) self._http: HttpSession = HttpSession() self._reports: LocationReportsFetcher = LocationReportsFetcher(self) @@ -433,36 +438,39 @@ class AsyncAppleAccount(BaseAppleAccount): return self._account_info["last_name"] if self._account_info else None @override - def to_json(self, path: str | Path | None = None) -> dict: - result = { + def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping: + res: AccountStateMapping = { + "type": "account", "ids": {"uid": self._uid, "devid": self._devid}, "account": { "username": self._username, "password": self._password, "info": self._account_info, }, - "login_state": { + "login": { "state": self._login_state.value, "data": self._login_state_data, }, + "anisette": self._anisette.to_json(), } - if path is not None: - Path(path).write_text(json.dumps(result, indent=4)) - return result + return save_and_return_json(res, path) + + @classmethod @override - def from_json(self, json_: str | Path | Mapping, /) -> None: - data = json.loads(Path(json_).read_text()) if isinstance(json_, (str, Path)) else json_ + def from_json( + cls, + val: str | Path | AccountStateMapping, + /, + *, + anisette_libs_path: str | Path | None = None, + ) -> AsyncAppleAccount: + val = read_data_json(val) + assert val["type"] == "account" + try: - self._uid = data["ids"]["uid"] - self._devid = data["ids"]["devid"] - - self._username = data["account"]["username"] - self._password = data["account"]["password"] - self._account_info = data["account"]["info"] - - self._login_state = LoginState(data["login_state"]["state"]) - self._login_state_data = data["login_state"]["data"] + ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path) + return cls(ani_provider, state_info=val) except KeyError as e: msg = f"Failed to restore account data: {e}" raise ValueError(msg) from None @@ -976,13 +984,12 @@ class AppleAccount(BaseAppleAccount): def __init__( self, - *, anisette: BaseAnisetteProvider, - user_id: str | None = None, - device_id: str | None = None, + *, + state_info: AccountStateMapping | None = None, ) -> None: """See `AsyncAppleAccount.__init__`.""" - self._asyncacc = AsyncAppleAccount(anisette=anisette, user_id=user_id, device_id=device_id) + self._asyncacc = AsyncAppleAccount(anisette=anisette, state_info=state_info) try: self._evt_loop = asyncio.get_running_loop() @@ -1022,12 +1029,25 @@ class AppleAccount(BaseAppleAccount): return self._asyncacc.last_name @override - def to_json(self, path: str | Path | None = None) -> dict: - return self._asyncacc.to_json(path) + def to_json(self, dst: str | Path | None = None, /) -> AccountStateMapping: + return self._asyncacc.to_json(dst) + @classmethod @override - def from_json(self, json_: str | Path | Mapping, /) -> None: - return self._asyncacc.from_json(json_) + def from_json( + cls, + val: str | Path | AccountStateMapping, + /, + *, + anisette_libs_path: str | Path | None = None, + ) -> AppleAccount: + val = read_data_json(val) + try: + ani_provider = get_provider_from_mapping(val["anisette"], libs_path=anisette_libs_path) + return cls(ani_provider, state_info=val) + except KeyError as e: + msg = f"Failed to restore account data: {e}" + raise ValueError(msg) from None @override def login(self, username: str, password: str) -> LoginState: diff --git a/findmy/reports/anisette.py b/findmy/reports/anisette.py index 146a388..240533c 100644 --- a/findmy/reports/anisette.py +++ b/findmy/reports/anisette.py @@ -10,17 +10,49 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from io import BytesIO from pathlib import Path -from typing import BinaryIO +from typing import BinaryIO, Literal, TypedDict, Union from anisette import Anisette, AnisetteHeaders from typing_extensions import override from findmy.util.abc import Closable, Serializable +from findmy.util.files import read_data_json, save_and_return_json from findmy.util.http import HttpSession logger = logging.getLogger(__name__) +class RemoteAnisetteMapping(TypedDict): + """JSON mapping representing state of a remote Anisette provider.""" + + type: Literal["aniRemote"] + url: str + + +class LocalAnisetteMapping(TypedDict): + """JSON mapping representing state of a local Anisette provider.""" + + type: Literal["aniLocal"] + prov_data: str + + +AnisetteMapping = Union[RemoteAnisetteMapping, LocalAnisetteMapping] + + +def get_provider_from_mapping( + mapping: AnisetteMapping, + *, + libs_path: str | Path | None = None, +) -> RemoteAnisetteProvider | LocalAnisetteProvider: + """Get the correct Anisette provider instance from saved JSON data.""" + if mapping["type"] == "aniRemote": + return RemoteAnisetteProvider.from_json(mapping) + if mapping["type"] == "aniLocal": + return LocalAnisetteProvider.from_json(mapping, libs_path=libs_path) + msg = f"Unknown anisette type: {mapping['type']}" + raise ValueError(msg) + + class BaseAnisetteProvider(Closable, Serializable, ABC): """ Abstract base class for Anisette providers. @@ -156,7 +188,7 @@ class BaseAnisetteProvider(Closable, Serializable, ABC): return cpd -class RemoteAnisetteProvider(BaseAnisetteProvider): +class RemoteAnisetteProvider(BaseAnisetteProvider, Serializable[RemoteAnisetteMapping]): """Anisette provider. Fetches headers from a remote Anisette server.""" _ANISETTE_DATA_VALID_FOR = 30 @@ -174,20 +206,25 @@ class RemoteAnisetteProvider(BaseAnisetteProvider): self._closed = False @override - def serialize(self) -> dict: + def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping: """See `BaseAnisetteProvider.serialize`.""" - return { - "type": "aniRemote", - "url": self._server_url, - } + return save_and_return_json( + { + "type": "aniRemote", + "url": self._server_url, + }, + dst, + ) @classmethod @override - def deserialize(cls, data: dict) -> RemoteAnisetteProvider: + def from_json(cls, val: str | Path | RemoteAnisetteMapping) -> RemoteAnisetteProvider: """See `BaseAnisetteProvider.deserialize`.""" - assert data["type"] == "aniRemote" + val = read_data_json(val) - server_url = data["url"] + assert val["type"] == "aniRemote" + + server_url = val["url"] return cls(server_url) @@ -245,7 +282,7 @@ class RemoteAnisetteProvider(BaseAnisetteProvider): logger.warning("Error closing anisette HTTP session: %s", e) -class LocalAnisetteProvider(BaseAnisetteProvider): +class LocalAnisetteProvider(BaseAnisetteProvider, Serializable[LocalAnisetteMapping]): """Anisette provider. Generates headers without a remote server using the `anisette` library.""" def __init__( @@ -265,6 +302,7 @@ class LocalAnisetteProvider(BaseAnisetteProvider): "The Anisette engine will download libraries required for operation, " "this may take a few seconds...", ) + if libs_path is None: logger.info( "To speed up future local Anisette initializations, " "provide a filesystem path to load the libraries from.", @@ -289,24 +327,34 @@ class LocalAnisetteProvider(BaseAnisetteProvider): ) @override - def serialize(self) -> dict: + def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping: """See `BaseAnisetteProvider.serialize`.""" with BytesIO() as buf: self._ani.save_provisioning(buf) prov_data = base64.b64encode(buf.getvalue()).decode("utf-8") - return { - "type": "aniLocal", - "prov_data": prov_data, - } + return save_and_return_json( + { + "type": "aniLocal", + "prov_data": prov_data, + }, + dst, + ) @classmethod @override - def deserialize(cls, data: dict, libs_path: str | Path | None = None) -> LocalAnisetteProvider: + def from_json( + cls, + val: str | Path | LocalAnisetteMapping, + *, + libs_path: str | Path | None = None, + ) -> LocalAnisetteProvider: """See `BaseAnisetteProvider.deserialize`.""" - assert data["type"] == "aniLocal" + val = read_data_json(val) - state_blob = BytesIO(base64.b64decode(data["prov_data"])) + assert val["type"] == "aniLocal" + + state_blob = BytesIO(base64.b64decode(val["prov_data"])) return cls(state_blob=state_blob, libs_path=libs_path) diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index df532fc..68c97f3 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -8,7 +8,7 @@ import logging import struct from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -16,17 +16,42 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from typing_extensions import override from findmy.accessory import RollingKeyPairSource -from findmy.keys import HasHashedPublicKey, KeyPair +from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping +from findmy.util.abc import Serializable +from findmy.util.files import read_data_json, save_and_return_json if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path from .account import AsyncAppleAccount logger = logging.getLogger(__name__) -class LocationReport(HasHashedPublicKey): +class LocationReportEncryptedMapping(TypedDict): + """JSON mapping representing an encrypted location report.""" + + type: Literal["locReportEncrypted"] + + payload: str + hashed_adv_key: str + + +class LocationReportDecryptedMapping(TypedDict): + """JSON mapping representing a decrypted location report.""" + + type: Literal["locReportDecrypted"] + + payload: str + hashed_adv_key: str + key: KeyPairMapping + + +LocationReportMapping = Union[LocationReportEncryptedMapping, LocationReportDecryptedMapping] + + +class LocationReport(HasHashedPublicKey, Serializable[LocationReportMapping]): """Location report corresponding to a certain `HasHashedPublicKey`.""" def __init__( @@ -66,9 +91,13 @@ class LocationReport(HasHashedPublicKey): """Whether the report is currently decrypted.""" return self._decrypted_data is not None + def can_decrypt(self, key: KeyPair, /) -> bool: + """Whether the report can be decrypted using the given key.""" + return key.hashed_adv_key_bytes == self._hashed_adv_key + def decrypt(self, key: KeyPair) -> None: """Decrypt the report using its corresponding `KeyPair`.""" - if key.hashed_adv_key_bytes != self._hashed_adv_key: + if not self.can_decrypt(key): msg = "Cannot decrypt with this key!" raise ValueError(msg) @@ -163,6 +192,86 @@ class LocationReport(HasHashedPublicKey): status_bytes = self._decrypted_data[1][9:10] return int.from_bytes(status_bytes, "big") + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: Literal[True], + ) -> LocationReportEncryptedMapping: + pass + + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: Literal[False], + ) -> LocationReportDecryptedMapping: + pass + + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: None = None, + ) -> LocationReportMapping: + pass + + @override + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: bool | None = None, + ) -> LocationReportMapping: + if include_key is None: + include_key = self.is_decrypted + + if include_key: + return save_and_return_json( + { + "type": "locReportDecrypted", + "payload": base64.b64encode(self._payload).decode("utf-8"), + "hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"), + "key": self.key.to_json(), + }, + dst, + ) + return save_and_return_json( + { + "type": "locReportEncrypted", + "payload": base64.b64encode(self._payload).decode("utf-8"), + "hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"), + }, + dst, + ) + + @classmethod + @override + def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport: + val = read_data_json(val) + assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted" + + try: + report = cls( + payload=base64.b64decode(val["payload"]), + hashed_adv_key=base64.b64decode(val["hashed_adv_key"]), + ) + if val["type"] == "locReportDecrypted": + key = KeyPair.from_json(val["key"]) + report.decrypt(key) + except KeyError as e: + msg = f"Failed to restore account data: {e}" + raise ValueError(msg) from None + else: + return report + @override def __eq__(self, other: object) -> bool: """ diff --git a/findmy/util/abc.py b/findmy/util/abc.py index 101e11a..a88da3b 100644 --- a/findmy/util/abc.py +++ b/findmy/util/abc.py @@ -5,8 +5,13 @@ from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import TYPE_CHECKING, Generic, Self, TypeVar -logging.getLogger(__name__) +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) class Closable(ABC): @@ -38,16 +43,37 @@ class Closable(ABC): pass -class Serializable(ABC): +_T = TypeVar("_T", bound=Mapping) + + +class Serializable(Generic[_T], ABC): """ABC for serializable classes.""" @abstractmethod - def serialize(self) -> dict: - """Serialize the object to a JSON-serializable dictionary.""" + def to_json(self, dst: str | Path | None = None, /) -> _T: + """ + Export the current state of the object as a JSON-serializable dictionary. + + If an argument is provided, the output will also be written to that file. + + The output of this method is guaranteed to be JSON-serializable, and passing + the return value of this function as an argument to `Serializable.from_json` + will always result in an exact copy of the internal state as it was when exported. + + You are encouraged to save and load object states to and from disk whenever possible, + to prevent unnecessary API calls or otherwise unexpected behavior. + """ raise NotImplementedError @classmethod @abstractmethod - def deserialize(cls, data: dict) -> Serializable: - """Deserialize the object from a JSON-serializable dictionary.""" + def from_json(cls, val: str | Path | _T, /) -> Self: + """ + Restore state from a previous `Closable.to_json` export. + + If given a str or Path, it must point to a json file from `Serializable.to_json`. + Otherwise, it should be the Mapping itself. + + See `Serializable.to_json` for more information. + """ raise NotImplementedError diff --git a/findmy/util/files.py b/findmy/util/files.py new file mode 100644 index 0000000..e58bfd9 --- /dev/null +++ b/findmy/util/files.py @@ -0,0 +1,34 @@ +"""Utilities to simplify reading and writing data from and to files.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from pathlib import Path +from typing import TypeVar, cast + +T = TypeVar("T", bound=Mapping) + + +def save_and_return_json(data: T, dst: str | Path | None) -> T: + """Save and return a JSON-serializable data structure.""" + if dst is None: + return data + + if isinstance(dst, str): + dst = Path(dst) + + dst.write_text(json.dumps(data, indent=4)) + + return data + + +def read_data_json(val: str | Path | T) -> T: + """Read JSON data from a file if a path is passed, or return the argument itself.""" + if isinstance(val, str): + val = Path(val) + + if isinstance(val, Path): + val = cast("T", json.loads(val.read_text())) + + return val diff --git a/findmy/util/session.py b/findmy/util/session.py new file mode 100644 index 0000000..a545292 --- /dev/null +++ b/findmy/util/session.py @@ -0,0 +1,144 @@ +"""Logic related to serializable classes.""" + +from __future__ import annotations + +import random +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, Union + +from findmy.util.abc import Closable, Serializable + +if TYPE_CHECKING: + from pathlib import Path + from types import TracebackType + +_S = TypeVar("_S", bound=Serializable) +_SC = TypeVar("_SC", bound=Union[Serializable, Closable]) + + +class _BaseSessionManager(Generic[_SC]): + """Base class for session managers.""" + + def __init__(self) -> None: + self._sessions: dict[_SC, str | Path | None] = {} + + def _add(self, obj: _SC, path: str | Path | None) -> None: + self._sessions[obj] = path + + def remove(self, obj: _SC) -> None: + self._sessions.pop(obj, None) + + def save(self) -> None: + for obj, path in self._sessions.items(): + if isinstance(obj, Serializable): + obj.to_json(path) + + async def close(self) -> None: + for obj in self._sessions: + if isinstance(obj, Closable): + await obj.close() + + async def save_and_close(self) -> None: + for obj, path in self._sessions.items(): + if isinstance(obj, Serializable): + obj.to_json(path) + if isinstance(obj, Closable): + await obj.close() + + def get_random(self) -> _SC: + if not self._sessions: + msg = "No objects in the session manager." + raise ValueError(msg) + return random.choice(list(self._sessions.keys())) # noqa: S311 + + def __len__(self) -> int: + return len(self._sessions) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> None: + self.save() + + +class MixedSessionManager(_BaseSessionManager[Union[Serializable, Closable]]): + """Allows any Serializable or Closable object.""" + + def new( + self, + c_type: type[_SC], + path: str | Path | None = None, + /, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by instantiating it using its constructor.""" + obj = c_type(*args, **kwargs) + if isinstance(obj, Serializable) and path is not None: + obj.to_json(path) + self._add(obj, path) + return obj + + def add_from_json( + self, + c_type: type[_S], + path: str | Path, + /, + **kwargs: Any, # noqa: ANN401 + ) -> _S: + """Add an object to the manager by deserializing it from its JSON representation.""" + obj = c_type.from_json(path, **kwargs) + self._add(obj, path) + return obj + + def add(self, obj: Serializable | Closable, path: str | Path | None = None, /) -> None: + """Add an object to the session manager.""" + self._add(obj, path) + + +class UniformSessionManager(Generic[_SC], _BaseSessionManager[_SC]): + """Only allows a single type of Serializable object.""" + + def __init__(self, obj_type: type[_SC]) -> None: + """Create a new session manager.""" + super().__init__() + self._obj_type = obj_type + + def new( + self, + path: str | Path | None = None, + /, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by instantiating it using its constructor.""" + obj = self._obj_type(*args, **kwargs) + if isinstance(obj, Serializable) and path is not None: + obj.to_json(path) + self._add(obj, path) + return obj + + def add_from_json( + self, + path: str | Path, + /, + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by deserializing it from its JSON representation.""" + if not issubclass(self._obj_type, Serializable): + msg = "Can only add objects of type Serializable." + raise TypeError(msg) + obj = self._obj_type.from_json(path, **kwargs) + self._add(obj, path) + return obj + + def add(self, obj: _SC, path: str | Path | None = None, /) -> None: + """Add an object to the session manager.""" + if not isinstance(obj, self._obj_type): + msg = f"Object must be of type {self._obj_type.__name__}" + raise TypeError(msg) + self._add(obj, path)