diff --git a/findmy/accessory.py b/findmy/accessory.py index dd20902..29c54f2 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from collections.abc import Generator from pathlib import Path + from findmy.reports.reports import LocationReport + logger = logging.getLogger(__name__) @@ -48,37 +50,38 @@ class RollingKeyPairSource(ABC): @abstractmethod def interval(self) -> timedelta: """KeyPair rollover interval.""" - - @abstractmethod - def keys_at(self, ind: int | datetime) -> set[KeyPair]: - """Generate potential key(s) occurring at a certain index or timestamp.""" raise NotImplementedError - @overload + @abstractmethod + def get_min_index(self, dt: datetime) -> int: + """Get the minimum key index that the accessory could be broadcasting at a specific time.""" + raise NotImplementedError + + @abstractmethod + def get_max_index(self, dt: datetime) -> int: + """Get the maximum key index that the accessory could be broadcasting at a specific time.""" + raise NotImplementedError + + @abstractmethod + def update_alignment(self, report: LocationReport, index: int) -> None: + """ + Update alignment of the accessory. + + Alignment can be updated based on a LocationReport that was observed at a specific index. + """ + raise NotImplementedError + + @abstractmethod + def keys_at(self, ind: int) -> set[KeyPair]: + """Generate potential key(s) occurring at a certain index.""" + raise NotImplementedError + def keys_between(self, start: int, end: int) -> set[KeyPair]: - pass - - @overload - def keys_between(self, start: datetime, end: datetime) -> set[KeyPair]: - pass - - def keys_between(self, start: int | datetime, end: int | datetime) -> set[KeyPair]: - """Generate potential key(s) occurring between two indices or timestamps.""" + """Generate potential key(s) occurring between two indices.""" keys: set[KeyPair] = set() - if isinstance(start, int) and isinstance(end, int): - while start < end: - keys.update(self.keys_at(start)) - - start += 1 - elif isinstance(start, datetime) and isinstance(end, datetime): - while start < end: - keys.update(self.keys_at(start)) - - start += self.interval - else: - msg = "Invalid start/end type" - raise TypeError(msg) + for ind in range(start, end + 1): + keys.update(self.keys_at(ind)) return keys @@ -174,53 +177,82 @@ class FindMyAccessory(RollingKeyPairSource, Serializable[FindMyAccessoryMapping] return timedelta(minutes=15) @override - def keys_at(self, ind: int | datetime) -> set[KeyPair]: - """Get the potential primary and secondary keys active at a certain time or index.""" - if isinstance(ind, datetime) and ind < self._paired_at: - return set() - if isinstance(ind, int) and ind < 0: - return set() - - secondary_offset = 0 - - if isinstance(ind, datetime): - # number of 15-minute slots since alignment - slots_since_alignment = ( - int( - (ind - self._alignment_date).total_seconds() / (15 * 60), - ) - + 1 - ) - ind = self._alignment_index + slots_since_alignment - - # number of slots until first 4 am - first_rollover = self._alignment_date.astimezone().replace( - hour=4, - minute=0, - second=0, - microsecond=0, - ) - if first_rollover < self._alignment_date: # we rolled backwards, so increment the day - first_rollover += timedelta(days=1) - secondary_offset = ( - int( - (first_rollover - self._alignment_date).total_seconds() / (15 * 60), - ) - + 1 + def get_min_index(self, dt: datetime) -> int: + if dt.tzinfo is None: + end = dt.astimezone() + logger.warning( + "Datetime is timezone-naive. Assuming system tz: %s.", + end.tzname(), ) - possible_keys = set() - # primary key can always be determined - possible_keys.add(self._primary_gen[ind]) + if dt >= self._alignment_date: + # in the worst case, the accessory has not rolled over at all since alignment + return self._alignment_index + # the accessory key will rollover AT MOST once every 15 minutes, so + # this is the minimum index for which we will need to generate keys. + # it's possible that rollover has progressed slower or not at all. + ind_before_alignment = (self._alignment_date - dt) // self.interval + return self._alignment_index - ind_before_alignment + + @override + def get_max_index(self, dt: datetime) -> int: + if dt.tzinfo is None: + end = dt.astimezone() + logger.warning( + "Datetime is timezone-naive. Assuming system tz: %s.", + end.tzname(), + ) + + if dt <= self._alignment_date: + # in the worst case, the accessory has not rolled over at all since `dt`, + # in which case it was at the alignment index. We can't go lower than that. + return self._alignment_index + + # the accessory key will rollover AT MOST once every 15 minutes, so + # this is the maximum index for which we will need to generate keys. + # it's possible that rollover has progressed slower or not at all. + ind_since_alignment = (dt - self._alignment_date) // self.interval + return self._alignment_index + ind_since_alignment + + @override + def update_alignment(self, report: LocationReport, index: int) -> None: + if report.timestamp < self._alignment_date: + # we only care about the most recent report + return + + logger.info("Updating alignment based on report observed at index %i", index) + + self._alignment_date = report.timestamp + self._alignment_index = index + + def _primary_key_at(self, ind: int) -> KeyPair: + """Get the primary key at a certain index.""" + return self._primary_gen[ind] + + def _secondary_keys_at(self, ind: int) -> tuple[KeyPair, KeyPair]: + """Get possible secondary keys at a certain primary index.""" # when the accessory has been rebooted, it will use the following secondary key - possible_keys.add(self._secondary_gen[ind // 96 + 1]) + key_1 = self._secondary_gen[ind // 96 + 1] - if ind > secondary_offset: - # after the first 4 am after pairing, we need to account for the first day - possible_keys.add(self._secondary_gen[(ind - secondary_offset) // 96 + 2]) + # in some cases, the secondary index may not be at primary_ind // 96 + 1, but at +2 instead. + # example: if we paired at 3:00 am, the first secondary key will be used until 4:00 am, + # at which point the second secondary key will be used. The primary index at 4:00 am is 4, + # but the 'second' secondary key is used. + # however, since we don't know the exact index rollover pattern, we just take a guess here + # and return both keys. for alignment, it's better to underestimate progression of the index + # than to overestimate it. + key_2 = self._secondary_gen[ind // 96 + 2] - return possible_keys + return key_1, key_2 + + @override + def keys_at(self, ind: int) -> set[KeyPair]: + """Get the primary and secondary keys that might be active at a certain index.""" + if ind < 0: + return set() + + return {self._primary_key_at(ind), *self._secondary_keys_at(ind)} @classmethod def from_plist( @@ -377,6 +409,10 @@ class AccessoryKeyGenerator(KeyGenerator[KeyPair]): return self._key_type def _get_sk(self, ind: int) -> bytes: + if ind < 0: + msg = "The key index must be non-negative" + raise ValueError(msg) + if ind < self._cur_sk_ind: # behind us; need to reset :( self._cur_sk = self._initial_sk self._cur_sk_ind = 0 diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 53f1ab1..05f9ed5 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -233,90 +233,45 @@ class BaseAppleAccount(Closable, Serializable[AccountStateMapping], ABC): @overload @abstractmethod - def fetch_reports( + def fetch_location( self, keys: HasHashedPublicKey, - date_from: datetime, - date_to: datetime | None, - ) -> MaybeCoro[list[LocationReport]]: ... + ) -> MaybeCoro[LocationReport | None]: ... @overload @abstractmethod - def fetch_reports( + def fetch_location( self, keys: RollingKeyPairSource, - date_from: datetime, - date_to: datetime | None, - ) -> MaybeCoro[list[LocationReport]]: ... + ) -> MaybeCoro[LocationReport | None]: ... @overload @abstractmethod - def fetch_reports( + def fetch_location( self, keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - date_from: datetime, - date_to: datetime | None, - ) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ... + ) -> MaybeCoro[ + dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] | None + ]: ... @abstractmethod - def fetch_reports( + def fetch_location( self, keys: HasHashedPublicKey | Sequence[HasHashedPublicKey | RollingKeyPairSource] | RollingKeyPairSource, - date_from: datetime, - date_to: datetime | None, ) -> MaybeCoro[ - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] + LocationReport + | dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] + | None ]: """ - Fetch location reports for :class:`HasHashedPublicKey`s between `date_from` and `date_end`. + Fetch location for :class:`HasHashedPublicKey`s. Returns a dictionary mapping :class:`HasHashedPublicKey`s to their location reports. """ raise NotImplementedError - @overload - @abstractmethod - def fetch_last_reports( - self, - keys: HasHashedPublicKey, - hours: int = 7 * 24, - ) -> MaybeCoro[list[LocationReport]]: ... - - @overload - @abstractmethod - def fetch_last_reports( - self, - keys: RollingKeyPairSource, - hours: int = 7 * 24, - ) -> MaybeCoro[list[LocationReport]]: ... - - @overload - @abstractmethod - def fetch_last_reports( - self, - keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ... - - @abstractmethod - def fetch_last_reports( - self, - keys: HasHashedPublicKey - | RollingKeyPairSource - | Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> MaybeCoro[ - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] - ]: - """ - Fetch location reports for :class:`HasHashedPublicKey`s for the last `hours` hours. - - Utility method as an alternative to using :meth:`BaseAppleAccount.fetch_reports` directly. - """ - raise NotImplementedError - @abstractmethod def get_anisette_headers( self, @@ -617,17 +572,19 @@ class AsyncAppleAccount(BaseAppleAccount): @require_login_state(LoginState.LOGGED_IN) async def fetch_raw_reports( self, - start: datetime, - end: datetime, - devices: list[list[str]], - ) -> dict[str, Any]: + devices: list[tuple[list[str], list[str]]], + ) -> list[LocationReport]: """Make a request for location reports, returning raw data.""" + logger.debug("Fetching raw reports for %d device(s)", len(devices)) + + now = datetime.now(tz=timezone.utc) + start_ts = int((now - timedelta(days=7)).timestamp()) * 1000 + end_ts = int(now.timestamp()) * 1000 + auth = ( self._login_state_data["dsid"], self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"], ) - start_ts = int(start.timestamp() * 1000) - end_ts = int(end.timestamp() * 1000) data = { "clientContext": { "clientBundleIdentifier": "com.apple.icloud.searchpartyuseragent", @@ -640,8 +597,8 @@ class AsyncAppleAccount(BaseAppleAccount): "startDate": start_ts, "startDateSecondary": start_ts, "endDate": end_ts, - # passing all keys as primary seems to work fine - "primaryIds": device_keys, + "primaryIds": device_keys[0], + "secondaryIds": device_keys[1], } for device_keys in devices ], @@ -679,90 +636,51 @@ class AsyncAppleAccount(BaseAppleAccount): msg = f"Failed to fetch reports: {resp.get('statusCode')}" raise UnhandledProtocolError(msg) - return resp["acsnLocations"] + # parse reports + reports: list[LocationReport] = [] + for key_reports in resp.get("acsnLocations", {}).get("locationPayload", []): + hashed_adv_key_bytes = base64.b64decode(key_reports["id"]) + + for report in key_reports.get("locationInfo", []): + payload = base64.b64decode(report) + loc_report = LocationReport(payload, hashed_adv_key_bytes) + + reports.append(loc_report) + + return reports @overload - async def fetch_reports( + async def fetch_location( self, keys: HasHashedPublicKey, - date_from: datetime, - date_to: datetime | None, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - async def fetch_reports( + async def fetch_location( self, keys: RollingKeyPairSource, - date_from: datetime, - date_to: datetime | None, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - async def fetch_reports( + async def fetch_location( self, keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - date_from: datetime, - date_to: datetime | None, - ) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ... + ) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ... @require_login_state(LoginState.LOGGED_IN) @override - async def fetch_reports( + async def fetch_location( self, keys: HasHashedPublicKey | RollingKeyPairSource | Sequence[HasHashedPublicKey | RollingKeyPairSource], - date_from: datetime, - date_to: datetime | None, ) -> ( - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] + LocationReport + | dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] + | None ): """See :meth:`BaseAppleAccount.fetch_reports`.""" - date_to = date_to or datetime.now().astimezone() - - return await self._reports.fetch_reports( - date_from, - date_to, - keys, - ) - - @overload - async def fetch_last_reports( - self, - keys: HasHashedPublicKey, - hours: int = 7 * 24, - ) -> list[LocationReport]: ... - - @overload - async def fetch_last_reports( - self, - keys: RollingKeyPairSource, - hours: int = 7 * 24, - ) -> list[LocationReport]: ... - - @overload - async def fetch_last_reports( - self, - keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ... - - @require_login_state(LoginState.LOGGED_IN) - @override - async def fetch_last_reports( - self, - keys: HasHashedPublicKey - | RollingKeyPairSource - | Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> ( - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] - ): - """See :meth:`BaseAppleAccount.fetch_last_reports`.""" - end = datetime.now(tz=timezone.utc) - start = end - timedelta(hours=hours) - - return await self.fetch_reports(keys, start, end) + return await self._reports.fetch_location(keys) @require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA, LoginState.LOGGED_IN) async def _gsa_authenticate( @@ -1101,77 +1019,36 @@ class AppleAccount(BaseAppleAccount): return self._evt_loop.run_until_complete(coro) @overload - def fetch_reports( + def fetch_location( self, keys: HasHashedPublicKey, - date_from: datetime, - date_to: datetime | None, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - def fetch_reports( + def fetch_location( self, keys: RollingKeyPairSource, - date_from: datetime, - date_to: datetime | None, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - def fetch_reports( + def fetch_location( self, keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - date_from: datetime, - date_to: datetime | None, - ) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ... + ) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ... @override - def fetch_reports( + def fetch_location( self, keys: HasHashedPublicKey | Sequence[HasHashedPublicKey | RollingKeyPairSource] | RollingKeyPairSource, - date_from: datetime, - date_to: datetime | None, ) -> ( - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] + LocationReport + | dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] + | None ): - """See :meth:`AsyncAppleAccount.fetch_reports`.""" - coro = self._asyncacc.fetch_reports(keys, date_from, date_to) - return self._evt_loop.run_until_complete(coro) - - @overload - def fetch_last_reports( - self, - keys: HasHashedPublicKey, - hours: int = 7 * 24, - ) -> list[LocationReport]: ... - - @overload - def fetch_last_reports( - self, - keys: RollingKeyPairSource, - hours: int = 7 * 24, - ) -> list[LocationReport]: ... - - @overload - def fetch_last_reports( - self, - keys: Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ... - - @override - def fetch_last_reports( - self, - keys: HasHashedPublicKey - | RollingKeyPairSource - | Sequence[HasHashedPublicKey | RollingKeyPairSource], - hours: int = 7 * 24, - ) -> ( - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] - ): - """See :meth:`AsyncAppleAccount.fetch_last_reports`.""" - coro = self._asyncacc.fetch_last_reports(keys, hours) + """See :meth:`AsyncAppleAccount.fetch_location`.""" + coro = self._asyncacc.fetch_location(keys) return self._evt_loop.run_until_complete(coro) @override diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index a8062b8..a8d051d 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -8,7 +8,7 @@ import logging import struct from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload +from typing import TYPE_CHECKING, Literal, TypedDict, Union, overload from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from typing_extensions import override from findmy.accessory import RollingKeyPairSource -from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping +from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping, KeyType from findmy.util.abc import Serializable from findmy.util.files import read_data_json, save_and_return_json @@ -337,144 +337,187 @@ class LocationReportsFetcher: self._account: AsyncAppleAccount = account @overload - async def fetch_reports( + async def fetch_location( self, - date_from: datetime, - date_to: datetime, device: HasHashedPublicKey, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - async def fetch_reports( + async def fetch_location( self, - date_from: datetime, - date_to: datetime, device: RollingKeyPairSource, - ) -> list[LocationReport]: ... + ) -> LocationReport | None: ... @overload - async def fetch_reports( + async def fetch_location( self, - date_from: datetime, - date_to: datetime, device: Sequence[HasHashedPublicKey | RollingKeyPairSource], - ) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ... + ) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ... - async def fetch_reports( # noqa: C901 + async def fetch_location( self, - date_from: datetime, - date_to: datetime, device: HasHashedPublicKey | RollingKeyPairSource | Sequence[HasHashedPublicKey | RollingKeyPairSource], ) -> ( - list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]] + LocationReport + | dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] + | None ): """ - Fetch location reports for a certain device. + Fetch location for a certain device or multipel devices. When `device` is a single :class:`HasHashedPublicKey`, this method will return - a list of location reports corresponding to that key. - When `device` is a :class:`RollingKeyPairSource`, it will return a list of - location reports corresponding to that source. + a location report corresponding to that key, or None if unavailable. + When `device` is a :class:`RollingKeyPairSource`, it will return a location + report corresponding to that source, or None if unavailable. When `device` is a sequence of :class:`HasHashedPublicKey`s or RollingKeyPairSource's, - it will return a dictionary with the provided object - as key, and a list of location reports as value. + it will return a dictionary with the provided objects + as keys, and a location report (or None) as value. """ - key_devs: dict[HasHashedPublicKey, HasHashedPublicKey | RollingKeyPairSource] = {} - key_batches: list[list[HasHashedPublicKey]] = [] if isinstance(device, HasHashedPublicKey): # single key - key_devs = {device: device} - key_batches.append([device]) - elif isinstance(device, RollingKeyPairSource): + key_reports = await self._fetch_key_reports([device]) + return key_reports.get(device, None) + + if isinstance(device, RollingKeyPairSource): # key generator - # add 12h margin to the generator - keys = device.keys_between( - date_from - timedelta(hours=12), - date_to + timedelta(hours=12), - ) - key_devs = dict.fromkeys(keys, device) - key_batches.append(list(keys)) - elif isinstance(device, list) and all( + return await self._fetch_accessory_report(device) + + if not isinstance(device, list) or not all( isinstance(x, HasHashedPublicKey | RollingKeyPairSource) for x in device ): - # multiple key generators - # add 12h margin to each generator - device = cast("list[HasHashedPublicKey | RollingKeyPairSource]", device) - for dev in device: - if isinstance(dev, HasHashedPublicKey): - key_devs[dev] = dev - key_batches.append([dev]) - elif isinstance(dev, RollingKeyPairSource): - keys = dev.keys_between( - date_from - timedelta(hours=12), - date_to + timedelta(hours=12), - ) - for key in keys: - key_devs[key] = dev - key_batches.append(list(keys)) - else: - msg = "Unknown device type: %s" - raise ValueError(msg, type(device)) + # unsupported type + msg = "Device must be a HasHashedPublicKey, RollingKeyPairSource, or list thereof." + raise ValueError(msg) - # sequence of keys (fetch 256 max at a time) - key_reports: dict[HasHashedPublicKey, list[LocationReport]] = await self._fetch_reports( - date_from, - date_to, - key_batches, - ) + # multiple key generators + # we can batch static keys in a single request, + # but key generators need to be queried separately + static_keys: list[HasHashedPublicKey] = [] + reports: dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] = {} + for dev in device: + if isinstance(dev, HasHashedPublicKey): + # save for later batch request + static_keys.append(dev) + elif isinstance(dev, RollingKeyPairSource): + # query immediately + reports[dev] = await self._fetch_accessory_report(dev) - # combine (key -> list[report]) and (key -> device) into (device -> list[report]) - device_reports = defaultdict(list) - for key, reports in key_reports.items(): - device_reports[key_devs[key]].extend(reports) - for dev in device_reports: - device_reports[dev] = sorted(device_reports[dev]) - - # result - if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)): - # single key or generator - return device_reports[device] - # multiple static keys or key generators - return device_reports - - async def _fetch_reports( - self, - date_from: datetime, - date_to: datetime, - device_keys: Sequence[Sequence[HasHashedPublicKey]], - ) -> dict[HasHashedPublicKey, list[LocationReport]]: - logger.debug("Fetching reports for %s device(s)", len(device_keys)) - - # lock requested time range to the past 7 days, +- 12 hours, then filter the response. - # this is due to an Apple backend bug where the time range is not respected. - # More info: https://github.com/biemster/FindMy/issues/7 - now = datetime.now().astimezone() - start_date = now - timedelta(days=7, hours=12) - end_date = now + timedelta(hours=12) - ids = [[key.hashed_adv_key_b64 for key in keys] for keys in device_keys] - data = await self._account.fetch_raw_reports(start_date, end_date, ids) - - id_to_key: dict[bytes, HasHashedPublicKey] = { - key.hashed_adv_key_bytes: key for keys in device_keys for key in keys - } - reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list) - for key_reports in data.get("locationPayload", []): - hashed_adv_key_bytes = base64.b64decode(key_reports["id"]) - key = id_to_key[hashed_adv_key_bytes] - - for report in key_reports.get("locationInfo", []): - payload = base64.b64decode(report) - loc_report = LocationReport(payload, hashed_adv_key_bytes) - - if loc_report.timestamp < date_from or loc_report.timestamp > date_to: - continue - - # pre-decrypt if possible - if isinstance(key, KeyPair): - loc_report.decrypt(key) - - reports[key].append(loc_report) + if static_keys: # batch request for static keys + key_reports = await self._fetch_key_reports(static_keys) + reports.update(dict(key_reports.items())) + + return reports + + async def _fetch_accessory_report( + self, + accessory: RollingKeyPairSource, + ) -> LocationReport | None: + logger.debug("Fetching location report for accessory") + + now = datetime.now().astimezone() + start_date = now - timedelta(days=7) + end_date = now + + # mappings + key_to_ind: dict[KeyPair, set[int]] = defaultdict(set) + id_to_key: dict[bytes, KeyPair] = {} + + # state variables + cur_keys_primary: set[str] = set() + cur_keys_secondary: set[str] = set() + cur_index = accessory.get_min_index(start_date) + ret: LocationReport | None = None + + async def _fetch() -> LocationReport | None: + """Fetch current keys and add them to final reports.""" + new_reports: list[LocationReport] = await self._account.fetch_raw_reports( + [(list(cur_keys_primary), (list(cur_keys_secondary)))] + ) + logger.info("Fetched %d new reports (index %i)", len(new_reports), cur_index) + + if new_reports: + report = sorted(new_reports)[-1] + + key = id_to_key[report.hashed_adv_key_bytes] + report.decrypt(key) + + # update alignment data on every report + # if a key maps to multiple indices, only feed it the maximum index, + # since apple only returns the latest reports per request. + # This makes the value more likely to be stable. + accessory.update_alignment(report, max(key_to_ind[key])) + else: + report = None + + cur_keys_primary.clear() + cur_keys_secondary.clear() + + return report + + while cur_index <= accessory.get_max_index(end_date): + key_batch = accessory.keys_at(cur_index) + + # split into primary and secondary keys + # (UNKNOWN keys are filed as primary) + new_keys_primary: set[str] = { + key.hashed_adv_key_b64 for key in key_batch if key.key_type == KeyType.PRIMARY + } + new_keys_secondary: set[str] = { + key.hashed_adv_key_b64 for key in key_batch if key.key_type != KeyType.PRIMARY + } + + # 290 seems to be the maximum number of keys that Apple accepts in a single request, + # so if adding the new keys would exceed that, fire a request first + if ( + len(cur_keys_primary | new_keys_primary) > 290 + or len(cur_keys_secondary | new_keys_secondary) > 290 + ): + report = await _fetch() + if ret is None or (report is not None and report.timestamp > ret.timestamp): + ret = report + + # build mappings before adding to current keys + for key in key_batch: + key_to_ind[key].add(cur_index) + id_to_key[key.hashed_adv_key_bytes] = key + cur_keys_primary |= new_keys_primary + cur_keys_secondary |= new_keys_secondary + + cur_index += 1 + + if cur_keys_primary or cur_keys_secondary: + # fetch remaining keys + report = await _fetch() + if ret is None or (report is not None and report.timestamp > ret.timestamp): + ret = report + + # filter duplicate reports (can happen since key batches may overlap) + return ret + + async def _fetch_key_reports( + self, + keys: Sequence[HasHashedPublicKey], + ) -> dict[HasHashedPublicKey, LocationReport | None]: + logger.debug("Fetching reports for %s key(s)", len(keys)) + + # fetch all as primary keys + ids = [([key.hashed_adv_key_b64], []) for key in keys] + encrypted_reports: list[LocationReport] = await self._account.fetch_raw_reports(ids) + + id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys} + reports: dict[HasHashedPublicKey, LocationReport | None] = dict.fromkeys(keys) + for report in encrypted_reports: + key = id_to_key[report.hashed_adv_key_bytes] + + cur_report = reports[key] + if cur_report is None or report.timestamp > cur_report.timestamp: + # more recent report, replace + reports[key] = report + + # pre-decrypt report if possible + if isinstance(key, KeyPair): + report.decrypt(key) return reports