feat: Support fetching of multiple accessories at once

This commit is contained in:
Mike A.
2025-01-23 16:27:28 +01:00
parent f01ad6b15b
commit 10aee19e93
2 changed files with 173 additions and 42 deletions

View File

@@ -15,7 +15,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Sequence,
TypedDict,
TypeVar,
cast,
@@ -49,6 +48,8 @@ from .twofactor import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
from findmy.accessory import RollingKeyPairSource
from findmy.keys import HasHashedPublicKey
from findmy.util.types import MaybeCoro
@@ -248,13 +249,28 @@ class BaseAppleAccount(Closable, ABC):
date_to: datetime | None,
) -> MaybeCoro[list[LocationReport]]: ...
@overload
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
@abstractmethod
def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`.
@@ -286,12 +302,27 @@ class BaseAppleAccount(Closable, ABC):
hours: int = 7 * 24,
) -> MaybeCoro[list[LocationReport]]: ...
@overload
@abstractmethod
def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
@abstractmethod
def fetch_last_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours.
@@ -641,14 +672,29 @@ class AsyncAppleAccount(BaseAppleAccount):
date_to: datetime | None,
) -> list[LocationReport]: ...
@overload
async def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_reports`."""
date_to = date_to or datetime.now().astimezone()
@@ -679,13 +725,27 @@ class AsyncAppleAccount(BaseAppleAccount):
hours: int = 7 * 24,
) -> list[LocationReport]: ...
@overload
async def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours)
@@ -1041,13 +1101,28 @@ class AppleAccount(BaseAppleAccount):
date_to: datetime | None,
) -> list[LocationReport]: ...
@overload
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
@override
def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._evt_loop.run_until_complete(coro)
@@ -1073,12 +1148,26 @@ class AppleAccount(BaseAppleAccount):
hours: int = 7 * 24,
) -> list[LocationReport]: ...
@overload
def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
@override
def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._evt_loop.run_until_complete(coro)

View File

@@ -6,8 +6,9 @@ import base64
import hashlib
import logging
import struct
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, cast, overload
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
@@ -260,12 +261,27 @@ class LocationReportsFetcher:
device: RollingKeyPairSource,
) -> list[LocationReport]: ...
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
device: Sequence[RollingKeyPairSource],
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""
Fetch location reports for a certain device.
@@ -276,45 +292,71 @@ class LocationReportsFetcher:
When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of
location reports corresponding to that source.
"""
# single key
key_devs: (
dict[HasHashedPublicKey, HasHashedPublicKey]
| dict[HasHashedPublicKey, RollingKeyPairSource]
) = {}
if isinstance(device, HasHashedPublicKey):
return await self._fetch_reports(date_from, date_to, [device])
# key generator
# add 12h margin to the generator
if isinstance(device, RollingKeyPairSource):
keys = list(
device.keys_between(
# single key
key_devs = {device: device}
elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device):
# multiple static keys
device = cast(list[HasHashedPublicKey], device)
key_devs = {key: key for key in device}
elif isinstance(device, RollingKeyPairSource):
# key generator
# add 12h margin to the generator
key_devs = {
key: device
for key in device.keys_between(
date_from - timedelta(hours=12),
date_to + timedelta(hours=12),
),
)
)
}
elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device):
# multiple key generators
# add 12h margin to each generator
device = cast(list[RollingKeyPairSource], device)
key_devs = {
key: dev
for dev in device
for key in dev.keys_between(
date_from - timedelta(hours=12),
date_to + timedelta(hours=12),
)
}
else:
keys = device
msg = "Unknown device type: %s"
raise ValueError(msg, type(device))
# sequence of keys (fetch 256 max at a time)
reports: list[LocationReport] = []
key_reports: dict[HasHashedPublicKey, list[LocationReport]] = {}
keys = list(key_devs.keys())
for key_offset in range(0, len(keys), 256):
chunk = keys[key_offset : key_offset + 256]
reports.extend(await self._fetch_reports(date_from, date_to, chunk))
chunk_keys = keys[key_offset : key_offset + 256]
chunk_reports = await self._fetch_reports(date_from, date_to, chunk_keys)
key_reports |= chunk_reports
if isinstance(device, RollingKeyPairSource):
return reports
# combine (key -> list[report]) and (key -> device) into (device -> list[report])
device_reports = defaultdict(list)
for key, reports in key_reports.items():
device_reports[key_devs[key]].extend(reports)
for dev in device_reports:
device_reports[dev] = sorted(device_reports[dev])
res: dict[HasHashedPublicKey, list[LocationReport]] = {key: [] for key in keys}
for report in reports:
for key in res:
if key.hashed_adv_key_bytes == report.hashed_adv_key_bytes:
res[key].append(report)
break
return res
# result
if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)):
# single key or generator
return device_reports[device]
# multiple static keys or key generators
return device_reports
async def _fetch_reports(
self,
date_from: datetime,
date_to: datetime,
keys: Sequence[HasHashedPublicKey],
) -> list[LocationReport]:
) -> dict[HasHashedPublicKey, list[LocationReport]]:
logging.debug("Fetching reports for %s keys", len(keys))
# lock requested time range to the past 7 days, +- 12 hours, then filter the response.
@@ -327,7 +369,7 @@ class LocationReportsFetcher:
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys}
reports: list[LocationReport] = []
reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list)
for report in data.get("results", []):
payload = base64.b64decode(report["payload"])
hashed_adv_key = base64.b64decode(report["id"])
@@ -347,6 +389,6 @@ class LocationReportsFetcher:
if loc_report.timestamp < date_from or loc_report.timestamp > date_to:
continue
reports.append(loc_report)
reports[key].append(loc_report)
return reports