mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 21:53:57 +02:00
feat!: implement key alignment algorithm for accessories
BREAKING: due to fundamental issues with Apple's API, this commit also DEPRECATES the `fetch_[last_]reports` methods on Apple account instances. It has been replaced by a method named `fetch_location`, which only returns a single location report (the latest one) and does not support setting a date range.
This commit is contained in:
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from findmy.reports.reports import LocationReport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,37 +50,38 @@ class RollingKeyPairSource(ABC):
|
||||
@abstractmethod
|
||||
def interval(self) -> timedelta:
|
||||
"""KeyPair rollover interval."""
|
||||
|
||||
@abstractmethod
|
||||
def keys_at(self, ind: int | datetime) -> set[KeyPair]:
|
||||
"""Generate potential key(s) occurring at a certain index or timestamp."""
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def get_min_index(self, dt: datetime) -> int:
|
||||
"""Get the minimum key index that the accessory could be broadcasting at a specific time."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_max_index(self, dt: datetime) -> int:
|
||||
"""Get the maximum key index that the accessory could be broadcasting at a specific time."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_alignment(self, report: LocationReport, index: int) -> None:
|
||||
"""
|
||||
Update alignment of the accessory.
|
||||
|
||||
Alignment can be updated based on a LocationReport that was observed at a specific index.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def keys_at(self, ind: int) -> set[KeyPair]:
|
||||
"""Generate potential key(s) occurring at a certain index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def keys_between(self, start: int, end: int) -> set[KeyPair]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def keys_between(self, start: datetime, end: datetime) -> set[KeyPair]:
|
||||
pass
|
||||
|
||||
def keys_between(self, start: int | datetime, end: int | datetime) -> set[KeyPair]:
|
||||
"""Generate potential key(s) occurring between two indices or timestamps."""
|
||||
"""Generate potential key(s) occurring between two indices."""
|
||||
keys: set[KeyPair] = set()
|
||||
|
||||
if isinstance(start, int) and isinstance(end, int):
|
||||
while start < end:
|
||||
keys.update(self.keys_at(start))
|
||||
|
||||
start += 1
|
||||
elif isinstance(start, datetime) and isinstance(end, datetime):
|
||||
while start < end:
|
||||
keys.update(self.keys_at(start))
|
||||
|
||||
start += self.interval
|
||||
else:
|
||||
msg = "Invalid start/end type"
|
||||
raise TypeError(msg)
|
||||
for ind in range(start, end + 1):
|
||||
keys.update(self.keys_at(ind))
|
||||
|
||||
return keys
|
||||
|
||||
@@ -174,53 +177,82 @@ class FindMyAccessory(RollingKeyPairSource, Serializable[FindMyAccessoryMapping]
|
||||
return timedelta(minutes=15)
|
||||
|
||||
@override
|
||||
def keys_at(self, ind: int | datetime) -> set[KeyPair]:
|
||||
"""Get the potential primary and secondary keys active at a certain time or index."""
|
||||
if isinstance(ind, datetime) and ind < self._paired_at:
|
||||
return set()
|
||||
if isinstance(ind, int) and ind < 0:
|
||||
return set()
|
||||
|
||||
secondary_offset = 0
|
||||
|
||||
if isinstance(ind, datetime):
|
||||
# number of 15-minute slots since alignment
|
||||
slots_since_alignment = (
|
||||
int(
|
||||
(ind - self._alignment_date).total_seconds() / (15 * 60),
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
ind = self._alignment_index + slots_since_alignment
|
||||
|
||||
# number of slots until first 4 am
|
||||
first_rollover = self._alignment_date.astimezone().replace(
|
||||
hour=4,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
if first_rollover < self._alignment_date: # we rolled backwards, so increment the day
|
||||
first_rollover += timedelta(days=1)
|
||||
secondary_offset = (
|
||||
int(
|
||||
(first_rollover - self._alignment_date).total_seconds() / (15 * 60),
|
||||
)
|
||||
+ 1
|
||||
def get_min_index(self, dt: datetime) -> int:
|
||||
if dt.tzinfo is None:
|
||||
end = dt.astimezone()
|
||||
logger.warning(
|
||||
"Datetime is timezone-naive. Assuming system tz: %s.",
|
||||
end.tzname(),
|
||||
)
|
||||
|
||||
possible_keys = set()
|
||||
# primary key can always be determined
|
||||
possible_keys.add(self._primary_gen[ind])
|
||||
if dt >= self._alignment_date:
|
||||
# in the worst case, the accessory has not rolled over at all since alignment
|
||||
return self._alignment_index
|
||||
|
||||
# the accessory key will rollover AT MOST once every 15 minutes, so
|
||||
# this is the minimum index for which we will need to generate keys.
|
||||
# it's possible that rollover has progressed slower or not at all.
|
||||
ind_before_alignment = (self._alignment_date - dt) // self.interval
|
||||
return self._alignment_index - ind_before_alignment
|
||||
|
||||
@override
|
||||
def get_max_index(self, dt: datetime) -> int:
|
||||
if dt.tzinfo is None:
|
||||
end = dt.astimezone()
|
||||
logger.warning(
|
||||
"Datetime is timezone-naive. Assuming system tz: %s.",
|
||||
end.tzname(),
|
||||
)
|
||||
|
||||
if dt <= self._alignment_date:
|
||||
# in the worst case, the accessory has not rolled over at all since `dt`,
|
||||
# in which case it was at the alignment index. We can't go lower than that.
|
||||
return self._alignment_index
|
||||
|
||||
# the accessory key will rollover AT MOST once every 15 minutes, so
|
||||
# this is the maximum index for which we will need to generate keys.
|
||||
# it's possible that rollover has progressed slower or not at all.
|
||||
ind_since_alignment = (dt - self._alignment_date) // self.interval
|
||||
return self._alignment_index + ind_since_alignment
|
||||
|
||||
@override
|
||||
def update_alignment(self, report: LocationReport, index: int) -> None:
|
||||
if report.timestamp < self._alignment_date:
|
||||
# we only care about the most recent report
|
||||
return
|
||||
|
||||
logger.info("Updating alignment based on report observed at index %i", index)
|
||||
|
||||
self._alignment_date = report.timestamp
|
||||
self._alignment_index = index
|
||||
|
||||
def _primary_key_at(self, ind: int) -> KeyPair:
|
||||
"""Get the primary key at a certain index."""
|
||||
return self._primary_gen[ind]
|
||||
|
||||
def _secondary_keys_at(self, ind: int) -> tuple[KeyPair, KeyPair]:
|
||||
"""Get possible secondary keys at a certain primary index."""
|
||||
# when the accessory has been rebooted, it will use the following secondary key
|
||||
possible_keys.add(self._secondary_gen[ind // 96 + 1])
|
||||
key_1 = self._secondary_gen[ind // 96 + 1]
|
||||
|
||||
if ind > secondary_offset:
|
||||
# after the first 4 am after pairing, we need to account for the first day
|
||||
possible_keys.add(self._secondary_gen[(ind - secondary_offset) // 96 + 2])
|
||||
# in some cases, the secondary index may not be at primary_ind // 96 + 1, but at +2 instead.
|
||||
# example: if we paired at 3:00 am, the first secondary key will be used until 4:00 am,
|
||||
# at which point the second secondary key will be used. The primary index at 4:00 am is 4,
|
||||
# but the 'second' secondary key is used.
|
||||
# however, since we don't know the exact index rollover pattern, we just take a guess here
|
||||
# and return both keys. for alignment, it's better to underestimate progression of the index
|
||||
# than to overestimate it.
|
||||
key_2 = self._secondary_gen[ind // 96 + 2]
|
||||
|
||||
return possible_keys
|
||||
return key_1, key_2
|
||||
|
||||
@override
|
||||
def keys_at(self, ind: int) -> set[KeyPair]:
|
||||
"""Get the primary and secondary keys that might be active at a certain index."""
|
||||
if ind < 0:
|
||||
return set()
|
||||
|
||||
return {self._primary_key_at(ind), *self._secondary_keys_at(ind)}
|
||||
|
||||
@classmethod
|
||||
def from_plist(
|
||||
@@ -377,6 +409,10 @@ class AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
return self._key_type
|
||||
|
||||
def _get_sk(self, ind: int) -> bytes:
|
||||
if ind < 0:
|
||||
msg = "The key index must be non-negative"
|
||||
raise ValueError(msg)
|
||||
|
||||
if ind < self._cur_sk_ind: # behind us; need to reset :(
|
||||
self._cur_sk = self._initial_sk
|
||||
self._cur_sk_ind = 0
|
||||
|
||||
@@ -233,90 +233,45 @@ class BaseAppleAccount(Closable, Serializable[AccountStateMapping], ABC):
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> MaybeCoro[list[LocationReport]]: ...
|
||||
) -> MaybeCoro[LocationReport | None]: ...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> MaybeCoro[list[LocationReport]]: ...
|
||||
) -> MaybeCoro[LocationReport | None]: ...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ...
|
||||
) -> MaybeCoro[
|
||||
dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] | None
|
||||
]: ...
|
||||
|
||||
@abstractmethod
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource]
|
||||
| RollingKeyPairSource,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> MaybeCoro[
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
LocationReport
|
||||
| dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]
|
||||
| None
|
||||
]:
|
||||
"""
|
||||
Fetch location reports for :class:`HasHashedPublicKey`s between `date_from` and `date_end`.
|
||||
Fetch location for :class:`HasHashedPublicKey`s.
|
||||
|
||||
Returns a dictionary mapping :class:`HasHashedPublicKey`s to their location reports.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
hours: int = 7 * 24,
|
||||
) -> MaybeCoro[list[LocationReport]]: ...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
hours: int = 7 * 24,
|
||||
) -> MaybeCoro[list[LocationReport]]: ...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| RollingKeyPairSource
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> MaybeCoro[
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
]:
|
||||
"""
|
||||
Fetch location reports for :class:`HasHashedPublicKey`s for the last `hours` hours.
|
||||
|
||||
Utility method as an alternative to using :meth:`BaseAppleAccount.fetch_reports` directly.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_anisette_headers(
|
||||
self,
|
||||
@@ -617,17 +572,19 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
@require_login_state(LoginState.LOGGED_IN)
|
||||
async def fetch_raw_reports(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
devices: list[list[str]],
|
||||
) -> dict[str, Any]:
|
||||
devices: list[tuple[list[str], list[str]]],
|
||||
) -> list[LocationReport]:
|
||||
"""Make a request for location reports, returning raw data."""
|
||||
logger.debug("Fetching raw reports for %d device(s)", len(devices))
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ts = int((now - timedelta(days=7)).timestamp()) * 1000
|
||||
end_ts = int(now.timestamp()) * 1000
|
||||
|
||||
auth = (
|
||||
self._login_state_data["dsid"],
|
||||
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
|
||||
)
|
||||
start_ts = int(start.timestamp() * 1000)
|
||||
end_ts = int(end.timestamp() * 1000)
|
||||
data = {
|
||||
"clientContext": {
|
||||
"clientBundleIdentifier": "com.apple.icloud.searchpartyuseragent",
|
||||
@@ -640,8 +597,8 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
"startDate": start_ts,
|
||||
"startDateSecondary": start_ts,
|
||||
"endDate": end_ts,
|
||||
# passing all keys as primary seems to work fine
|
||||
"primaryIds": device_keys,
|
||||
"primaryIds": device_keys[0],
|
||||
"secondaryIds": device_keys[1],
|
||||
}
|
||||
for device_keys in devices
|
||||
],
|
||||
@@ -679,90 +636,51 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
msg = f"Failed to fetch reports: {resp.get('statusCode')}"
|
||||
raise UnhandledProtocolError(msg)
|
||||
|
||||
return resp["acsnLocations"]
|
||||
# parse reports
|
||||
reports: list[LocationReport] = []
|
||||
for key_reports in resp.get("acsnLocations", {}).get("locationPayload", []):
|
||||
hashed_adv_key_bytes = base64.b64decode(key_reports["id"])
|
||||
|
||||
for report in key_reports.get("locationInfo", []):
|
||||
payload = base64.b64decode(report)
|
||||
loc_report = LocationReport(payload, hashed_adv_key_bytes)
|
||||
|
||||
reports.append(loc_report)
|
||||
|
||||
return reports
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ...
|
||||
|
||||
@require_login_state(LoginState.LOGGED_IN)
|
||||
@override
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| RollingKeyPairSource
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> (
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
LocationReport
|
||||
| dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]
|
||||
| None
|
||||
):
|
||||
"""See :meth:`BaseAppleAccount.fetch_reports`."""
|
||||
date_to = date_to or datetime.now().astimezone()
|
||||
|
||||
return await self._reports.fetch_reports(
|
||||
date_from,
|
||||
date_to,
|
||||
keys,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
hours: int = 7 * 24,
|
||||
) -> list[LocationReport]: ...
|
||||
|
||||
@overload
|
||||
async def fetch_last_reports(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
hours: int = 7 * 24,
|
||||
) -> list[LocationReport]: ...
|
||||
|
||||
@overload
|
||||
async def fetch_last_reports(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
|
||||
|
||||
@require_login_state(LoginState.LOGGED_IN)
|
||||
@override
|
||||
async def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| RollingKeyPairSource
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> (
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
):
|
||||
"""See :meth:`BaseAppleAccount.fetch_last_reports`."""
|
||||
end = datetime.now(tz=timezone.utc)
|
||||
start = end - timedelta(hours=hours)
|
||||
|
||||
return await self.fetch_reports(keys, start, end)
|
||||
return await self._reports.fetch_location(keys)
|
||||
|
||||
@require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA, LoginState.LOGGED_IN)
|
||||
async def _gsa_authenticate(
|
||||
@@ -1101,77 +1019,36 @@ class AppleAccount(BaseAppleAccount):
|
||||
return self._evt_loop.run_until_complete(coro)
|
||||
|
||||
@overload
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ...
|
||||
|
||||
@override
|
||||
def fetch_reports(
|
||||
def fetch_location(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource]
|
||||
| RollingKeyPairSource,
|
||||
date_from: datetime,
|
||||
date_to: datetime | None,
|
||||
) -> (
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
LocationReport
|
||||
| dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]
|
||||
| None
|
||||
):
|
||||
"""See :meth:`AsyncAppleAccount.fetch_reports`."""
|
||||
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
|
||||
return self._evt_loop.run_until_complete(coro)
|
||||
|
||||
@overload
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey,
|
||||
hours: int = 7 * 24,
|
||||
) -> list[LocationReport]: ...
|
||||
|
||||
@overload
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: RollingKeyPairSource,
|
||||
hours: int = 7 * 24,
|
||||
) -> list[LocationReport]: ...
|
||||
|
||||
@overload
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
|
||||
|
||||
@override
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: HasHashedPublicKey
|
||||
| RollingKeyPairSource
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
hours: int = 7 * 24,
|
||||
) -> (
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
):
|
||||
"""See :meth:`AsyncAppleAccount.fetch_last_reports`."""
|
||||
coro = self._asyncacc.fetch_last_reports(keys, hours)
|
||||
"""See :meth:`AsyncAppleAccount.fetch_location`."""
|
||||
coro = self._asyncacc.fetch_location(keys)
|
||||
return self._evt_loop.run_until_complete(coro)
|
||||
|
||||
@override
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union, overload
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
@@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from typing_extensions import override
|
||||
|
||||
from findmy.accessory import RollingKeyPairSource
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping
|
||||
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping, KeyType
|
||||
from findmy.util.abc import Serializable
|
||||
from findmy.util.files import read_data_json, save_and_return_json
|
||||
|
||||
@@ -337,144 +337,187 @@ class LocationReportsFetcher:
|
||||
self._account: AsyncAppleAccount = account
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: HasHashedPublicKey,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: RollingKeyPairSource,
|
||||
) -> list[LocationReport]: ...
|
||||
) -> LocationReport | None: ...
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
async def fetch_location(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
|
||||
) -> dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]: ...
|
||||
|
||||
async def fetch_reports( # noqa: C901
|
||||
async def fetch_location(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: HasHashedPublicKey
|
||||
| RollingKeyPairSource
|
||||
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
|
||||
) -> (
|
||||
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
|
||||
LocationReport
|
||||
| dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None]
|
||||
| None
|
||||
):
|
||||
"""
|
||||
Fetch location reports for a certain device.
|
||||
Fetch location for a certain device or multipel devices.
|
||||
|
||||
When `device` is a single :class:`HasHashedPublicKey`, this method will return
|
||||
a list of location reports corresponding to that key.
|
||||
When `device` is a :class:`RollingKeyPairSource`, it will return a list of
|
||||
location reports corresponding to that source.
|
||||
a location report corresponding to that key, or None if unavailable.
|
||||
When `device` is a :class:`RollingKeyPairSource`, it will return a location
|
||||
report corresponding to that source, or None if unavailable.
|
||||
When `device` is a sequence of :class:`HasHashedPublicKey`s or RollingKeyPairSource's,
|
||||
it will return a dictionary with the provided object
|
||||
as key, and a list of location reports as value.
|
||||
it will return a dictionary with the provided objects
|
||||
as keys, and a location report (or None) as value.
|
||||
"""
|
||||
key_devs: dict[HasHashedPublicKey, HasHashedPublicKey | RollingKeyPairSource] = {}
|
||||
key_batches: list[list[HasHashedPublicKey]] = []
|
||||
if isinstance(device, HasHashedPublicKey):
|
||||
# single key
|
||||
key_devs = {device: device}
|
||||
key_batches.append([device])
|
||||
elif isinstance(device, RollingKeyPairSource):
|
||||
key_reports = await self._fetch_key_reports([device])
|
||||
return key_reports.get(device, None)
|
||||
|
||||
if isinstance(device, RollingKeyPairSource):
|
||||
# key generator
|
||||
# add 12h margin to the generator
|
||||
keys = device.keys_between(
|
||||
date_from - timedelta(hours=12),
|
||||
date_to + timedelta(hours=12),
|
||||
)
|
||||
key_devs = dict.fromkeys(keys, device)
|
||||
key_batches.append(list(keys))
|
||||
elif isinstance(device, list) and all(
|
||||
return await self._fetch_accessory_report(device)
|
||||
|
||||
if not isinstance(device, list) or not all(
|
||||
isinstance(x, HasHashedPublicKey | RollingKeyPairSource) for x in device
|
||||
):
|
||||
# multiple key generators
|
||||
# add 12h margin to each generator
|
||||
device = cast("list[HasHashedPublicKey | RollingKeyPairSource]", device)
|
||||
for dev in device:
|
||||
if isinstance(dev, HasHashedPublicKey):
|
||||
key_devs[dev] = dev
|
||||
key_batches.append([dev])
|
||||
elif isinstance(dev, RollingKeyPairSource):
|
||||
keys = dev.keys_between(
|
||||
date_from - timedelta(hours=12),
|
||||
date_to + timedelta(hours=12),
|
||||
)
|
||||
for key in keys:
|
||||
key_devs[key] = dev
|
||||
key_batches.append(list(keys))
|
||||
else:
|
||||
msg = "Unknown device type: %s"
|
||||
raise ValueError(msg, type(device))
|
||||
# unsupported type
|
||||
msg = "Device must be a HasHashedPublicKey, RollingKeyPairSource, or list thereof."
|
||||
raise ValueError(msg)
|
||||
|
||||
# sequence of keys (fetch 256 max at a time)
|
||||
key_reports: dict[HasHashedPublicKey, list[LocationReport]] = await self._fetch_reports(
|
||||
date_from,
|
||||
date_to,
|
||||
key_batches,
|
||||
)
|
||||
# multiple key generators
|
||||
# we can batch static keys in a single request,
|
||||
# but key generators need to be queried separately
|
||||
static_keys: list[HasHashedPublicKey] = []
|
||||
reports: dict[HasHashedPublicKey | RollingKeyPairSource, LocationReport | None] = {}
|
||||
for dev in device:
|
||||
if isinstance(dev, HasHashedPublicKey):
|
||||
# save for later batch request
|
||||
static_keys.append(dev)
|
||||
elif isinstance(dev, RollingKeyPairSource):
|
||||
# query immediately
|
||||
reports[dev] = await self._fetch_accessory_report(dev)
|
||||
|
||||
# combine (key -> list[report]) and (key -> device) into (device -> list[report])
|
||||
device_reports = defaultdict(list)
|
||||
for key, reports in key_reports.items():
|
||||
device_reports[key_devs[key]].extend(reports)
|
||||
for dev in device_reports:
|
||||
device_reports[dev] = sorted(device_reports[dev])
|
||||
|
||||
# result
|
||||
if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)):
|
||||
# single key or generator
|
||||
return device_reports[device]
|
||||
# multiple static keys or key generators
|
||||
return device_reports
|
||||
|
||||
async def _fetch_reports(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device_keys: Sequence[Sequence[HasHashedPublicKey]],
|
||||
) -> dict[HasHashedPublicKey, list[LocationReport]]:
|
||||
logger.debug("Fetching reports for %s device(s)", len(device_keys))
|
||||
|
||||
# lock requested time range to the past 7 days, +- 12 hours, then filter the response.
|
||||
# this is due to an Apple backend bug where the time range is not respected.
|
||||
# More info: https://github.com/biemster/FindMy/issues/7
|
||||
now = datetime.now().astimezone()
|
||||
start_date = now - timedelta(days=7, hours=12)
|
||||
end_date = now + timedelta(hours=12)
|
||||
ids = [[key.hashed_adv_key_b64 for key in keys] for keys in device_keys]
|
||||
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
|
||||
|
||||
id_to_key: dict[bytes, HasHashedPublicKey] = {
|
||||
key.hashed_adv_key_bytes: key for keys in device_keys for key in keys
|
||||
}
|
||||
reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list)
|
||||
for key_reports in data.get("locationPayload", []):
|
||||
hashed_adv_key_bytes = base64.b64decode(key_reports["id"])
|
||||
key = id_to_key[hashed_adv_key_bytes]
|
||||
|
||||
for report in key_reports.get("locationInfo", []):
|
||||
payload = base64.b64decode(report)
|
||||
loc_report = LocationReport(payload, hashed_adv_key_bytes)
|
||||
|
||||
if loc_report.timestamp < date_from or loc_report.timestamp > date_to:
|
||||
continue
|
||||
|
||||
# pre-decrypt if possible
|
||||
if isinstance(key, KeyPair):
|
||||
loc_report.decrypt(key)
|
||||
|
||||
reports[key].append(loc_report)
|
||||
if static_keys: # batch request for static keys
|
||||
key_reports = await self._fetch_key_reports(static_keys)
|
||||
reports.update(dict(key_reports.items()))
|
||||
|
||||
return reports
|
||||
|
||||
async def _fetch_accessory_report(
|
||||
self,
|
||||
accessory: RollingKeyPairSource,
|
||||
) -> LocationReport | None:
|
||||
logger.debug("Fetching location report for accessory")
|
||||
|
||||
now = datetime.now().astimezone()
|
||||
start_date = now - timedelta(days=7)
|
||||
end_date = now
|
||||
|
||||
# mappings
|
||||
key_to_ind: dict[KeyPair, set[int]] = defaultdict(set)
|
||||
id_to_key: dict[bytes, KeyPair] = {}
|
||||
|
||||
# state variables
|
||||
cur_keys_primary: set[str] = set()
|
||||
cur_keys_secondary: set[str] = set()
|
||||
cur_index = accessory.get_min_index(start_date)
|
||||
ret: LocationReport | None = None
|
||||
|
||||
async def _fetch() -> LocationReport | None:
|
||||
"""Fetch current keys and add them to final reports."""
|
||||
new_reports: list[LocationReport] = await self._account.fetch_raw_reports(
|
||||
[(list(cur_keys_primary), (list(cur_keys_secondary)))]
|
||||
)
|
||||
logger.info("Fetched %d new reports (index %i)", len(new_reports), cur_index)
|
||||
|
||||
if new_reports:
|
||||
report = sorted(new_reports)[-1]
|
||||
|
||||
key = id_to_key[report.hashed_adv_key_bytes]
|
||||
report.decrypt(key)
|
||||
|
||||
# update alignment data on every report
|
||||
# if a key maps to multiple indices, only feed it the maximum index,
|
||||
# since apple only returns the latest reports per request.
|
||||
# This makes the value more likely to be stable.
|
||||
accessory.update_alignment(report, max(key_to_ind[key]))
|
||||
else:
|
||||
report = None
|
||||
|
||||
cur_keys_primary.clear()
|
||||
cur_keys_secondary.clear()
|
||||
|
||||
return report
|
||||
|
||||
while cur_index <= accessory.get_max_index(end_date):
|
||||
key_batch = accessory.keys_at(cur_index)
|
||||
|
||||
# split into primary and secondary keys
|
||||
# (UNKNOWN keys are filed as primary)
|
||||
new_keys_primary: set[str] = {
|
||||
key.hashed_adv_key_b64 for key in key_batch if key.key_type == KeyType.PRIMARY
|
||||
}
|
||||
new_keys_secondary: set[str] = {
|
||||
key.hashed_adv_key_b64 for key in key_batch if key.key_type != KeyType.PRIMARY
|
||||
}
|
||||
|
||||
# 290 seems to be the maximum number of keys that Apple accepts in a single request,
|
||||
# so if adding the new keys would exceed that, fire a request first
|
||||
if (
|
||||
len(cur_keys_primary | new_keys_primary) > 290
|
||||
or len(cur_keys_secondary | new_keys_secondary) > 290
|
||||
):
|
||||
report = await _fetch()
|
||||
if ret is None or (report is not None and report.timestamp > ret.timestamp):
|
||||
ret = report
|
||||
|
||||
# build mappings before adding to current keys
|
||||
for key in key_batch:
|
||||
key_to_ind[key].add(cur_index)
|
||||
id_to_key[key.hashed_adv_key_bytes] = key
|
||||
cur_keys_primary |= new_keys_primary
|
||||
cur_keys_secondary |= new_keys_secondary
|
||||
|
||||
cur_index += 1
|
||||
|
||||
if cur_keys_primary or cur_keys_secondary:
|
||||
# fetch remaining keys
|
||||
report = await _fetch()
|
||||
if ret is None or (report is not None and report.timestamp > ret.timestamp):
|
||||
ret = report
|
||||
|
||||
# filter duplicate reports (can happen since key batches may overlap)
|
||||
return ret
|
||||
|
||||
async def _fetch_key_reports(
|
||||
self,
|
||||
keys: Sequence[HasHashedPublicKey],
|
||||
) -> dict[HasHashedPublicKey, LocationReport | None]:
|
||||
logger.debug("Fetching reports for %s key(s)", len(keys))
|
||||
|
||||
# fetch all as primary keys
|
||||
ids = [([key.hashed_adv_key_b64], []) for key in keys]
|
||||
encrypted_reports: list[LocationReport] = await self._account.fetch_raw_reports(ids)
|
||||
|
||||
id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys}
|
||||
reports: dict[HasHashedPublicKey, LocationReport | None] = dict.fromkeys(keys)
|
||||
for report in encrypted_reports:
|
||||
key = id_to_key[report.hashed_adv_key_bytes]
|
||||
|
||||
cur_report = reports[key]
|
||||
if cur_report is None or report.timestamp > cur_report.timestamp:
|
||||
# more recent report, replace
|
||||
reports[key] = report
|
||||
|
||||
# pre-decrypt report if possible
|
||||
if isinstance(key, KeyPair):
|
||||
report.decrypt(key)
|
||||
|
||||
return reports
|
||||
|
||||
Reference in New Issue
Block a user