diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index df532fc..cd3dd59 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -8,7 +8,7 @@ import logging import struct from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -16,17 +16,42 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from typing_extensions import override from findmy.accessory import RollingKeyPairSource -from findmy.keys import HasHashedPublicKey, KeyPair +from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping +from findmy.util.abc import Serializable +from findmy.util.files import read_data_json, save_and_return_json if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path from .account import AsyncAppleAccount logger = logging.getLogger(__name__) -class LocationReport(HasHashedPublicKey): +class LocationReportEncryptedMapping(TypedDict): + """JSON mapping representing an encrypted location report.""" + + type: Literal["locReportEncrypted"] + + payload: str + hashed_adv_key: str + + +class LocationReportDecryptedMapping(TypedDict): + """JSON mapping representing a decrypted location report.""" + + type: Literal["locReportDecrypted"] + + payload: str + hashed_adv_key: str + key: KeyPairMapping + + +LocationReportMapping = Union[LocationReportEncryptedMapping, LocationReportDecryptedMapping] + + +class LocationReport(HasHashedPublicKey, Serializable): """Location report corresponding to a certain `HasHashedPublicKey`.""" def __init__( @@ -66,9 +91,13 @@ class LocationReport(HasHashedPublicKey): """Whether the report is currently decrypted.""" return self._decrypted_data is not None + def can_decrypt(self, key: KeyPair, /) -> bool: + """Whether the report can be decrypted using the given key.""" + return key.hashed_adv_key_bytes == self._hashed_adv_key + def decrypt(self, key: KeyPair) -> None: """Decrypt the report using its corresponding `KeyPair`.""" - if key.hashed_adv_key_bytes != self._hashed_adv_key: + if not self.can_decrypt(key): msg = "Cannot decrypt with this key!" raise ValueError(msg) @@ -163,6 +192,86 @@ class LocationReport(HasHashedPublicKey): status_bytes = self._decrypted_data[1][9:10] return int.from_bytes(status_bytes, "big") + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: Literal[True], + ) -> LocationReportEncryptedMapping: + pass + + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: Literal[False], + ) -> LocationReportDecryptedMapping: + pass + + @overload + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: None = None, + ) -> LocationReportMapping: + pass + + @override + def to_json( + self, + dst: str | Path | None = None, + /, + *, + include_key: bool | None = None, + ) -> LocationReportMapping: + if include_key is None: + include_key = self.is_decrypted + + if include_key: + return save_and_return_json( + { + "type": "locReportDecrypted", + "payload": base64.b64encode(self._payload).decode("utf-8"), + "hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"), + "key": self.key.to_json(), + }, + dst, + ) + return save_and_return_json( + { + "type": "locReportEncrypted", + "payload": base64.b64encode(self._payload).decode("utf-8"), + "hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"), + }, + dst, + ) + + @classmethod + @override + def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport: + val = read_data_json(val) + assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted" + + try: + report = cls( + payload=base64.b64decode(val["payload"]), + hashed_adv_key=base64.b64decode(val["hashed_adv_key"]), + ) + if val["type"] == "locReportDecrypted": + key = KeyPair.from_json(val["key"]) + report.decrypt(key) + except KeyError as e: + msg = f"Failed to restore account data: {e}" + raise ValueError(msg) from None + else: + return report + @override def __eq__(self, other: object) -> bool: """