diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 59dbe3d..264d01c 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -15,7 +15,6 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Sequence, TypedDict, TypeVar, cast, @@ -49,6 +48,8 @@ from .twofactor import ( ) if TYPE_CHECKING: + from collections.abc import Sequence + from findmy.accessory import RollingKeyPairSource from findmy.keys import HasHashedPublicKey from findmy.util.types import MaybeCoro @@ -248,13 +249,28 @@ class BaseAppleAccount(Closable, ABC): date_to: datetime | None, ) -> MaybeCoro[list[LocationReport]]: ... + @overload + def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ... + @abstractmethod def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]: + ) -> MaybeCoro[ + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ]: """ Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`. @@ -286,12 +302,27 @@ class BaseAppleAccount(Closable, ABC): hours: int = 7 * 24, ) -> MaybeCoro[list[LocationReport]]: ... + @overload @abstractmethod def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]: + ) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ... + + @abstractmethod + def fetch_last_reports( + self, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> MaybeCoro[ + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ]: """ Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours. @@ -641,14 +672,29 @@ class AsyncAppleAccount(BaseAppleAccount): date_to: datetime | None, ) -> list[LocationReport]: ... + @overload + async def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @require_login_state(LoginState.LOGGED_IN) @override async def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `BaseAppleAccount.fetch_reports`.""" date_to = date_to or datetime.now().astimezone() @@ -679,13 +725,27 @@ class AsyncAppleAccount(BaseAppleAccount): hours: int = 7 * 24, ) -> list[LocationReport]: ... + @overload + async def fetch_last_reports( + self, + keys: Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @require_login_state(LoginState.LOGGED_IN) @override async def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `BaseAppleAccount.fetch_last_reports`.""" end = datetime.now(tz=timezone.utc) start = end - timedelta(hours=hours) @@ -1041,13 +1101,28 @@ class AppleAccount(BaseAppleAccount): date_to: datetime | None, ) -> list[LocationReport]: ... + @overload + def fetch_reports( + self, + keys: Sequence[RollingKeyPairSource], + date_from: datetime, + date_to: datetime | None, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @override def fetch_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], date_from: datetime, date_to: datetime | None, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `AsyncAppleAccount.fetch_reports`.""" coro = self._asyncacc.fetch_reports(keys, date_from, date_to) return self._evt_loop.run_until_complete(coro) @@ -1073,12 +1148,26 @@ class AppleAccount(BaseAppleAccount): hours: int = 7 * 24, ) -> list[LocationReport]: ... + @overload + def fetch_last_reports( + self, + keys: Sequence[RollingKeyPairSource], + hours: int = 7 * 24, + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + @override def fetch_last_reports( self, - keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, + keys: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], hours: int = 7 * 24, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """See `AsyncAppleAccount.fetch_last_reports`.""" coro = self._asyncacc.fetch_last_reports(keys, hours) return self._evt_loop.run_until_complete(coro) diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index 7124e5b..e910545 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -6,8 +6,9 @@ import base64 import hashlib import logging import struct +from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, cast, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -260,12 +261,27 @@ class LocationReportsFetcher: device: RollingKeyPairSource, ) -> list[LocationReport]: ... + @overload async def fetch_reports( self, date_from: datetime, date_to: datetime, - device: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource, - ) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]: + device: Sequence[RollingKeyPairSource], + ) -> dict[RollingKeyPairSource, list[LocationReport]]: ... + + async def fetch_reports( + self, + date_from: datetime, + date_to: datetime, + device: HasHashedPublicKey + | Sequence[HasHashedPublicKey] + | RollingKeyPairSource + | Sequence[RollingKeyPairSource], + ) -> ( + list[LocationReport] + | dict[HasHashedPublicKey, list[LocationReport]] + | dict[RollingKeyPairSource, list[LocationReport]] + ): """ Fetch location reports for a certain device. @@ -276,45 +292,71 @@ class LocationReportsFetcher: When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of location reports corresponding to that source. """ - # single key + key_devs: ( + dict[HasHashedPublicKey, HasHashedPublicKey] + | dict[HasHashedPublicKey, RollingKeyPairSource] + ) = {} if isinstance(device, HasHashedPublicKey): - return await self._fetch_reports(date_from, date_to, [device]) - - # key generator - # add 12h margin to the generator - if isinstance(device, RollingKeyPairSource): - keys = list( - device.keys_between( + # single key + key_devs = {device: device} + elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device): + # multiple static keys + device = cast(list[HasHashedPublicKey], device) + key_devs = {key: key for key in device} + elif isinstance(device, RollingKeyPairSource): + # key generator + # add 12h margin to the generator + key_devs = { + key: device + for key in device.keys_between( date_from - timedelta(hours=12), date_to + timedelta(hours=12), - ), - ) + ) + } + elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device): + # multiple key generators + # add 12h margin to each generator + device = cast(list[RollingKeyPairSource], device) + key_devs = { + key: dev + for dev in device + for key in dev.keys_between( + date_from - timedelta(hours=12), + date_to + timedelta(hours=12), + ) + } else: - keys = device + msg = "Unknown device type: %s" + raise ValueError(msg, type(device)) # sequence of keys (fetch 256 max at a time) - reports: list[LocationReport] = [] + key_reports: dict[HasHashedPublicKey, list[LocationReport]] = {} + keys = list(key_devs.keys()) for key_offset in range(0, len(keys), 256): - chunk = keys[key_offset : key_offset + 256] - reports.extend(await self._fetch_reports(date_from, date_to, chunk)) + chunk_keys = keys[key_offset : key_offset + 256] + chunk_reports = await self._fetch_reports(date_from, date_to, chunk_keys) + key_reports |= chunk_reports - if isinstance(device, RollingKeyPairSource): - return reports + # combine (key -> list[report]) and (key -> device) into (device -> list[report]) + device_reports = defaultdict(list) + for key, reports in key_reports.items(): + device_reports[key_devs[key]].extend(reports) + for dev in device_reports: + device_reports[dev] = sorted(device_reports[dev]) - res: dict[HasHashedPublicKey, list[LocationReport]] = {key: [] for key in keys} - for report in reports: - for key in res: - if key.hashed_adv_key_bytes == report.hashed_adv_key_bytes: - res[key].append(report) - break - return res + # result + if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)): + # single key or generator + return device_reports[device] + # multiple static keys or key generators + return device_reports async def _fetch_reports( self, date_from: datetime, date_to: datetime, keys: Sequence[HasHashedPublicKey], - ) -> list[LocationReport]: + ) -> dict[HasHashedPublicKey, list[LocationReport]]: logging.debug("Fetching reports for %s keys", len(keys)) # lock requested time range to the past 7 days, +- 12 hours, then filter the response. @@ -327,7 +369,7 @@ class LocationReportsFetcher: data = await self._account.fetch_raw_reports(start_date, end_date, ids) id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys} - reports: list[LocationReport] = [] + reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list) for report in data.get("results", []): payload = base64.b64decode(report["payload"]) hashed_adv_key = base64.b64decode(report["id"]) @@ -347,6 +389,6 @@ class LocationReportsFetcher: if loc_report.timestamp < date_from or loc_report.timestamp > date_to: continue - reports.append(loc_report) + reports[key].append(loc_report) return reports