mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 21:53:57 +02:00
feat: make LocationReport serializable
This commit is contained in:
@@ -8,7 +8,7 @@ import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, cast, overload
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
@@ -16,17 +16,42 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.accessory import RollingKeyPairSource
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping
|
||||
from findmy.util.abc import Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from .account import AsyncAppleAccount
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocationReport(HasHashedPublicKey):
|
||||
class LocationReportEncryptedMapping(TypedDict):
|
||||
"""JSON mapping representing an encrypted location report."""
|
||||
|
||||
type: Literal["locReportEncrypted"]
|
||||
|
||||
payload: str
|
||||
hashed_adv_key: str
|
||||
|
||||
|
||||
class LocationReportDecryptedMapping(TypedDict):
|
||||
"""JSON mapping representing a decrypted location report."""
|
||||
|
||||
type: Literal["locReportDecrypted"]
|
||||
|
||||
payload: str
|
||||
hashed_adv_key: str
|
||||
key: KeyPairMapping
|
||||
|
||||
|
||||
LocationReportMapping = Union[LocationReportEncryptedMapping, LocationReportDecryptedMapping]
|
||||
|
||||
|
||||
class LocationReport(HasHashedPublicKey, Serializable):
|
||||
"""Location report corresponding to a certain `HasHashedPublicKey`."""
|
||||
|
||||
def __init__(
|
||||
@@ -66,9 +91,13 @@ class LocationReport(HasHashedPublicKey):
|
||||
"""Whether the report is currently decrypted."""
|
||||
return self._decrypted_data is not None
|
||||
|
||||
def can_decrypt(self, key: KeyPair, /) -> bool:
|
||||
"""Whether the report can be decrypted using the given key."""
|
||||
return key.hashed_adv_key_bytes == self._hashed_adv_key
|
||||
|
||||
def decrypt(self, key: KeyPair) -> None:
|
||||
"""Decrypt the report using its corresponding `KeyPair`."""
|
||||
if key.hashed_adv_key_bytes != self._hashed_adv_key:
|
||||
if not self.can_decrypt(key):
|
||||
msg = "Cannot decrypt with this key!"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -163,6 +192,86 @@ class LocationReport(HasHashedPublicKey):
|
||||
status_bytes = self._decrypted_data[1][9:10]
|
||||
return int.from_bytes(status_bytes, "big")
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: Literal[True],
|
||||
) -> LocationReportEncryptedMapping:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: Literal[False],
|
||||
) -> LocationReportDecryptedMapping:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: None = None,
|
||||
) -> LocationReportMapping:
|
||||
pass
|
||||
|
||||
@override
|
||||
def to_json(
|
||||
self,
|
||||
dst: str | Path | None = None,
|
||||
/,
|
||||
*,
|
||||
include_key: bool | None = None,
|
||||
) -> LocationReportMapping:
|
||||
if include_key is None:
|
||||
include_key = self.is_decrypted
|
||||
|
||||
if include_key:
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "locReportDecrypted",
|
||||
"payload": base64.b64encode(self._payload).decode("utf-8"),
|
||||
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
|
||||
"key": self.key.to_json(),
|
||||
},
|
||||
dst,
|
||||
)
|
||||
return save_and_return_json(
|
||||
{
|
||||
"type": "locReportEncrypted",
|
||||
"payload": base64.b64encode(self._payload).decode("utf-8"),
|
||||
"hashed_adv_key": base64.b64encode(self._hashed_adv_key).decode("utf-8"),
|
||||
},
|
||||
dst,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport:
|
||||
val = read_data_json(val)
|
||||
assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted"
|
||||
|
||||
try:
|
||||
report = cls(
|
||||
payload=base64.b64decode(val["payload"]),
|
||||
hashed_adv_key=base64.b64decode(val["hashed_adv_key"]),
|
||||
)
|
||||
if val["type"] == "locReportDecrypted":
|
||||
key = KeyPair.from_json(val["key"])
|
||||
report.decrypt(key)
|
||||
except KeyError as e:
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ValueError(msg) from None
|
||||
else:
|
||||
return report
|
||||
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user