From 128ce84749cb32ef41f36a0fb91e29f857f0c01a Mon Sep 17 00:00:00 2001 From: Mike A Date: Sat, 10 Feb 2024 21:26:55 +0100 Subject: [PATCH] Make key fetcher stateful Prevents unnecessary HTTP session closing --- findmy/reports/account.py | 49 +++++++++---- findmy/reports/reports.py | 146 +++++++++++++++++++++++++------------- 2 files changed, 131 insertions(+), 64 deletions(-) diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 168a494..4b06b60 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -27,7 +27,7 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from findmy.util import HttpSession, decode_plist from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError -from .reports import KeyReport, fetch_reports +from .reports import LocationReport, LocationReportsFetcher from .state import LoginState, require_login_state from .twofactor import ( AsyncSecondFactorMethod, @@ -190,8 +190,8 @@ class BaseAppleAccount(ABC): self, keys: Sequence[KeyPair], date_from: datetime, - date_to: datetime, - ) -> dict[KeyPair, list[KeyReport]]: + date_to: datetime | None, + ) -> dict[KeyPair, list[LocationReport]]: """ Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`. @@ -204,7 +204,7 @@ class BaseAppleAccount(ABC): self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ) -> dict[KeyPair, list[KeyReport]]: + ) -> dict[KeyPair, list[LocationReport]]: """ Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours. @@ -251,6 +251,7 @@ class AsyncAppleAccount(BaseAppleAccount): self._account_info: _AccountInfo | None = None self._http: HttpSession = HttpSession() + self._reports: LocationReportsFetcher = LocationReportsFetcher(self) def _set_login_state( self, @@ -411,20 +412,38 @@ class AsyncAppleAccount(BaseAppleAccount): # AUTHENTICATED -> LOGGED_IN return await self._login_mobileme() + @require_login_state(LoginState.LOGGED_IN) + async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]: + """Make a request for location reports, returning raw data.""" + auth = ( + self._login_state_data["dsid"], + self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"], + ) + data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]} + r = await self._http.post( + "https://gateway.icloud.com/acsnservice/fetch", + auth=auth, + headers=await self.get_anisette_headers(), + json=data, + ) + resp = r.json() + if not r.ok or resp["statusCode"] != "200": + msg = f"Failed to fetch reports: {resp['statusCode']}" + raise UnhandledProtocolError(msg) + + return resp + @require_login_state(LoginState.LOGGED_IN) async def fetch_reports( self, keys: Sequence[KeyPair], date_from: datetime, - date_to: datetime, - ) -> dict[KeyPair, list[KeyReport]]: + date_to: datetime | None, + ) -> dict[KeyPair, list[LocationReport]]: """See `BaseAppleAccount.fetch_reports`.""" - anisette_headers = await self.get_anisette_headers() + date_to = date_to or datetime.now().astimezone() - return await fetch_reports( - self._login_state_data["dsid"], - self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"], - anisette_headers, + return await self._reports.fetch_reports( date_from, date_to, keys, @@ -435,7 +454,7 @@ class AsyncAppleAccount(BaseAppleAccount): self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ) -> dict[KeyPair, list[KeyReport]]: + ) -> dict[KeyPair, list[LocationReport]]: """See `BaseAppleAccount.fetch_last_reports`.""" end = datetime.now(tz=timezone.utc) start = end - timedelta(hours=hours) @@ -737,8 +756,8 @@ class AppleAccount(BaseAppleAccount): self, keys: Sequence[KeyPair], date_from: datetime, - date_to: datetime, - ) -> dict[KeyPair, list[KeyReport]]: + date_to: datetime | None, + ) -> dict[KeyPair, list[LocationReport]]: """See `AsyncAppleAccount.fetch_reports`.""" coro = self._asyncacc.fetch_reports(keys, date_from, date_to) return self._loop.run_until_complete(coro) @@ -747,7 +766,7 @@ class AppleAccount(BaseAppleAccount): self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ) -> dict[KeyPair, list[KeyReport]]: + ) -> dict[KeyPair, list[LocationReport]]: """See `AsyncAppleAccount.fetch_last_reports`.""" coro = self._asyncacc.fetch_last_reports(keys, hours) return self._loop.run_until_complete(coro) diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index b85b54b..e224bd7 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -5,22 +5,27 @@ import base64 import hashlib import struct from datetime import datetime, timezone -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, TypedDict, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from typing_extensions import Unpack -from findmy.util import HttpSession +from findmy.keys import KeyPair +from findmy.util.http import HttpSession if TYPE_CHECKING: - from findmy.keys import KeyPair + from .account import AsyncAppleAccount _session = HttpSession() -class ReportsError(RuntimeError): - """Raised when an error occurs while looking up reports.""" +class _FetcherConfig(TypedDict): + user_id: str + device_id: str + dsid: str + search_party_token: str def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes: @@ -46,7 +51,7 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes: return decryptor.update(enc_data) + decryptor.finalize() -class KeyReport: +class LocationReport: """Location report corresponding to a certain `KeyPair`.""" def __init__( # noqa: PLR0913 @@ -119,7 +124,7 @@ class KeyReport: publish_date: datetime, description: str, payload: bytes, - ) -> KeyReport: + ) -> LocationReport: """ Create a `KeyReport` from fields and a payload as reported by Apple. @@ -134,7 +139,7 @@ class KeyReport: confidence = int.from_bytes(data[8:9], "big") status = int.from_bytes(data[9:10], "big") - return KeyReport( + return cls( key, publish_date, timestamp, @@ -145,14 +150,14 @@ class KeyReport: status, ) - def __lt__(self, other: KeyReport) -> bool: + def __lt__(self, other: LocationReport) -> bool: """ Compare against another `KeyReport`. A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded timestamp is strictly less than the other report. """ - if isinstance(other, KeyReport): + if isinstance(other, LocationReport): return self.timestamp < other.timestamp return NotImplemented @@ -164,47 +169,90 @@ class KeyReport: ) -async def fetch_reports( # noqa: PLR0913 - dsid: str, - search_party_token: str, - anisette_headers: dict[str, str], - date_from: datetime, - date_to: datetime, - keys: Sequence[KeyPair], -) -> dict[KeyPair, list[KeyReport]]: - """Look up reports for given `KeyPair`s.""" - start_date = date_from.timestamp() * 1000 - end_date = date_to.timestamp() * 1000 - ids = [key.hashed_adv_key_b64 for key in keys] - data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]} +class LocationReportsFetcher: + """Fetcher class to retrieve location reports.""" - # TODO(malmeloo): do not create a new session every time - # https://github.com/malmeloo/FindMy.py/issues/3 - r = await _session.post( - "https://gateway.icloud.com/acsnservice/fetch", - auth=(dsid, search_party_token), - headers=anisette_headers, - json=data, - ) - resp = r.json() - if not r.ok or resp["statusCode"] != "200": - msg = f"Failed to fetch reports: {resp['statusCode']}" - raise ReportsError(msg) - await _session.close() + def __init__(self, account: AsyncAppleAccount) -> None: + """ + Initialize the fetcher. - reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys} - id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys} + :param account: Apple account. + """ + self._account: AsyncAppleAccount = account - for report in resp.get("results", []): - key = id_to_key[report["id"]] - date_published = datetime.fromtimestamp( - report.get("datePublished", 0) / 1000, - tz=timezone.utc, - ) - description = report.get("description", "") - payload = base64.b64decode(report["payload"]) + self._http: HttpSession = HttpSession() - r = KeyReport.from_payload(key, date_published, description, payload) - reports[key].append(r) + self._config: _FetcherConfig | None = None - return {key: sorted(reps) for key, reps in reports.items()} + def apply_config(self, **conf: Unpack[_FetcherConfig]) -> None: + """Configure internal variables necessary to make reports fetching calls.""" + self._config = conf + + @overload + async def fetch_reports( + self, + date_from: datetime, + date_to: datetime, + device: KeyPair, + ) -> list[LocationReport]: + ... + + @overload + async def fetch_reports( + self, + date_from: datetime, + date_to: datetime, + device: Sequence[KeyPair], + ) -> dict[KeyPair, list[LocationReport]]: + ... + + async def fetch_reports( + self, + date_from: datetime, + date_to: datetime, + device: KeyPair | Sequence[KeyPair], + ) -> list[LocationReport] | dict[KeyPair, list[LocationReport]]: + """ + Fetch location reports for a certain device. + + When ``device`` is a single :class:`.KeyPair`, this method will return + a list of location reports corresponding to that pair. + When ``device`` is a sequence of :class:`.KeyPair`s, it will return a dictionary + with the :class:`.KeyPair` as key, and a list of location reports as value. + """ + # single KeyPair + if isinstance(device, KeyPair): + return await self._fetch_reports(date_from, date_to, [device]) + + # sequence of KeyPairs + reports = await self._fetch_reports(date_from, date_to, device) + res: dict[KeyPair, list[LocationReport]] = {key: [] for key in device} + for report in reports: + res[report.key].append(report) + return res + + async def _fetch_reports( + self, + date_from: datetime, + date_to: datetime, + keys: Sequence[KeyPair], + ) -> list[LocationReport]: + start_date = int(date_from.timestamp() * 1000) + end_date = int(date_to.timestamp() * 1000) + ids = [key.hashed_adv_key_b64 for key in keys] + data = await self._account.fetch_raw_reports(start_date, end_date, ids) + + id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys} + reports: list[LocationReport] = [] + for report in data.get("results", []): + key = id_to_key[report["id"]] + date_published = datetime.fromtimestamp( + report.get("datePublished", 0) / 1000, + tz=timezone.utc, + ) + description = report.get("description", "") + payload = base64.b64decode(report["payload"]) + + reports.append(LocationReport.from_payload(key, date_published, description, payload)) + + return reports