diff --git a/README.md b/README.md index 7af0842..8e54249 100644 --- a/README.md +++ b/README.md @@ -78,9 +78,11 @@ before committing it. There are several other cool projects based on this library! Some of them have been listed below, make sure to check them out as well. -* [OfflineFindRecovery](https://github.com/hajekj/OfflineFindRecovery) - Set of scripts to be able to precisely locate your lost MacBook via Apple's Offline Find through Bluetooth Low Energy. -* [SwiftFindMy](https://github.com/airy10/SwiftFindMy) - Swift port of FindMy.py -* [FindMy Home Assistant Integration](github.com/krmax44/homeassistant-findmy) +* [OfflineFindRecovery](https://github.com/hajekj/OfflineFindRecovery) - Set of scripts to precisely locate your lost MacBook. +* [SwiftFindMy](https://github.com/airy10/SwiftFindMy) - Swift port of FindMy.py. +* [FindMy Home Assistant (1)](https://github.com/malmeloo/hass-FindMy) - Home Assistant integration made by the author of FindMy.py. +* [FindMy Home Assistant (2)](github.com/krmax44/homeassistant-findmy) - Home Assistant integration made by [krmax44](https://github.com/krmax44). +* [OpenTagViewer](https://github.com/parawanderer/OpenTagViewer) - Android App to locate your AirTags. ## Credits diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 4c3b10d..264d01c 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -249,13 +249,28 @@ class BaseAppleAccount(Closable, ABC): date_to: datetime | None, ) -> MaybeCoro[list[LocationReport]]: ... + @overload + def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ... + @abstractmethod def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]: + ) -> MaybeCoro[ + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ]: """ Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`. @@ -287,12 +302,27 @@ class BaseAppleAccount(Closable, ABC): hours: int = 7 * 24, ) -> MaybeCoro[list[LocationReport]]: ... + @overload @abstractmethod def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]: + ) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ... + + @abstractmethod + def fetch_last_reports( + self, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> MaybeCoro[ + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ]: """ Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours. @@ -642,14 +672,29 @@ class AsyncAppleAccount(BaseAppleAccount): date_to: datetime | None, ) -> list[LocationReport]: ... + @overload + async def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @require_login_state(LoginState.LOGGED_IN) @override async def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `BaseAppleAccount.fetch_reports`.""" date_to = date_to or datetime.now().astimezone() @@ -680,13 +725,27 @@ class AsyncAppleAccount(BaseAppleAccount): hours: int = 7 * 24, ) -> list[LocationReport]: ... + @overload + async def fetch_last_reports( + self, + keys: Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @require_login_state(LoginState.LOGGED_IN) @override async def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `BaseAppleAccount.fetch_last_reports`.""" end = datetime.now(tz=timezone.utc) start = end - timedelta(hours=hours) @@ -1042,13 +1101,28 @@ class AppleAccount(BaseAppleAccount): date_to: datetime | None, ) -> list[LocationReport]: ... + @overload + def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @override def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `AsyncAppleAccount.fetch_reports`.""" coro = self._asyncacc.fetch_reports(keys, date_from, date_to) return self._evt_loop.run_until_complete(coro) @@ -1074,12 +1148,26 @@ class AppleAccount(BaseAppleAccount): hours: int = 7 * 24, ) -> list[LocationReport]: ... + @overload + def fetch_last_reports( + self, + keys: Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @override def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `AsyncAppleAccount.fetch_last_reports`.""" coro = self._asyncacc.fetch_last_reports(keys, hours) return self._evt_loop.run_until_complete(coro) diff --git a/findmy/reports/anisette.py b/findmy/reports/anisette.py index bcf2bf1..d56f52d 100644 --- a/findmy/reports/anisette.py +++ b/findmy/reports/anisette.py @@ -205,7 +205,7 @@ class RemoteAnisetteProvider(BaseAnisetteProvider): if self._anisette_data is None or time.time() >= self._anisette_data_expires_at: logging.info("Fetching anisette data from %s", self._server_url) - r = await self._http.get(self._server_url) + r = await self._http.get(self._server_url, auto_retry=True) self._anisette_data = r.json() self._anisette_data_expires_at = time.time() + self._ANISETTE_DATA_VALID_FOR diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index 76af893..e910545 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -6,8 +6,9 @@ import base64 import hashlib import logging import struct +from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, cast, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -260,12 +261,27 @@ class LocationReportsFetcher: device: RollingKeyPairSource, ) -> list[LocationReport]: ... + @overload async def fetch_reports( self, date_from: datetime, date_to: datetime, - device: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + device: Sequence[RollingKeyPairSource], + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + + async def fetch_reports( + self, + date_from: datetime, + date_to: datetime, + device: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """ Fetch location reports for a certain device. @@ -276,45 +292,71 @@ class LocationReportsFetcher: When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of location reports corresponding to that source. """ - # single key + key_devs: ( + dict[HasHashedPublicKey, HasHashedPublicKey] + | dict[HasHashedPublicKey, RollingKeyPairSource] + ) = {} if isinstance(device, HasHashedPublicKey): - return await self._fetch_reports(date_from, date_to, [device]) - - # key generator - # add 12h margin to the generator - if isinstance(device, RollingKeyPairSource): - keys = list( - device.keys_between( + # single key + key_devs = {device: device} + elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device): + # multiple static keys + device = cast(list[HasHashedPublicKey], device) + key_devs = {key: key for key in device} + elif isinstance(device, RollingKeyPairSource): + # key generator + # add 12h margin to the generator + key_devs = { + key: device + for key in device.keys_between( date_from - timedelta(hours=12), date_to + timedelta(hours=12), - ), - ) + ) + } + elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device): + # multiple key generators + # add 12h margin to each generator + device = cast(list[RollingKeyPairSource], device) + key_devs = { + key: dev + for dev in device + for key in dev.keys_between( + date_from - timedelta(hours=12), + date_to + timedelta(hours=12), + ) + } else: - keys = device + msg = "Unknown device type: %s" + raise ValueError(msg, type(device)) # sequence of keys (fetch 256 max at a time) - reports: list[LocationReport] = [] + key_reports: dict[HasHashedPublicKey, list[LocationReport]] = {} + keys = list(key_devs.keys()) for key_offset in range(0, len(keys), 256): - chunk = keys[key_offset : key_offset + 256] - reports.extend(await self._fetch_reports(date_from, date_to, chunk)) + chunk_keys = keys[key_offset : key_offset + 256] + chunk_reports = await self._fetch_reports(date_from, date_to, chunk_keys) + key_reports |= chunk_reports - if isinstance(device, RollingKeyPairSource): - return reports + # combine (key -> list[report]) and (key -> device) into (device -> list[report]) + device_reports = defaultdict(list) + for key, reports in key_reports.items(): + device_reports[key_devs[key]].extend(reports) + for dev in device_reports: + device_reports[dev] = sorted(device_reports[dev]) - res: dict[HasHashedPublicKey, list[LocationReport]] = {key: [] for key in keys} - for report in reports: - for key, reports in res.items(): - if key.hashed_adv_key_bytes == report.hashed_adv_key_bytes: - reports.append(report) - break - return res + # result + if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)): + # single key or generator + return device_reports[device] + # multiple static keys or key generators + return device_reports async def _fetch_reports( self, date_from: datetime, date_to: datetime, keys: Sequence[HasHashedPublicKey], - ) -> list[LocationReport]: + ) -> dict[HasHashedPublicKey, list[LocationReport]]: logging.debug("Fetching reports for %s keys", len(keys)) # lock requested time range to the past 7 days, +- 12 hours, then filter the response. @@ -327,7 +369,7 @@ class LocationReportsFetcher: data = await self._account.fetch_raw_reports(start_date, end_date, ids) id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys} - reports: list[LocationReport] = [] + reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list) for report in data.get("results", []): payload = base64.b64decode(report["payload"]) hashed_adv_key = base64.b64decode(report["id"]) @@ -347,6 +389,6 @@ class LocationReportsFetcher: if loc_report.timestamp < date_from or loc_report.timestamp > date_to: continue - reports.append(loc_report) + reports[key].append(loc_report) return reports diff --git a/findmy/util/http.py b/findmy/util/http.py index a6da8ac..f954e49 100644 --- a/findmy/util/http.py +++ b/findmy/util/http.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import json import logging from typing import Any, TypedDict, cast +import aiohttp from aiohttp import BasicAuth, ClientSession, ClientTimeout from typing_extensions import Unpack, override @@ -18,6 +20,7 @@ logging.getLogger(__name__) class _RequestOptions(TypedDict, total=False): json: dict[str, Any] | None headers: dict[str, str] + auto_retry: bool data: bytes @@ -108,13 +111,32 @@ class HttpSession(Closable): kwargs["auth"] = BasicAuth(auth[0], auth[1]) options = cast(_AiohttpRequestOptions, kwargs) - async with await session.request( - method, - url, - ssl=False, - **options, - ) as r: - return HttpResponse(r.status, await r.content.read()) + auto_retry = kwargs.pop("auto_retry", False) + + retry_count = 1 + while True: # if auto_retry is set, raise for status and retry on error + try: + async with await session.request( + method, + url, + ssl=False, + raise_for_status=auto_retry, + **options, + ) as r: + return HttpResponse(r.status, await r.content.read()) + except aiohttp.ClientError as e: # noqa: PERF203 + if not auto_retry or retry_count > 3: + raise e from None + + retry_after = 5 * retry_count + logging.warning( + "Error while making HTTP request; retrying after %i seconds. %s", + retry_after, + e, + ) + await asyncio.sleep(retry_after) + + retry_count += 1 async def get(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse: """Alias for `HttpSession.request("GET", ...)`."""