Make key fetcher stateful

Prevents unnecessary HTTP session closing
This commit is contained in:
Mike A
2024-02-10 21:26:55 +01:00
parent 55d6c9f8ff
commit 128ce84749
2 changed files with 131 additions and 64 deletions

View File

@@ -27,7 +27,7 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from findmy.util import HttpSession, decode_plist from findmy.util import HttpSession, decode_plist
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
from .reports import KeyReport, fetch_reports from .reports import LocationReport, LocationReportsFetcher
from .state import LoginState, require_login_state from .state import LoginState, require_login_state
from .twofactor import ( from .twofactor import (
AsyncSecondFactorMethod, AsyncSecondFactorMethod,
@@ -190,8 +190,8 @@ class BaseAppleAccount(ABC):
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
date_from: datetime, date_from: datetime,
date_to: datetime, date_to: datetime | None,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
""" """
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`. Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
@@ -204,7 +204,7 @@ class BaseAppleAccount(ABC):
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
hours: int = 7 * 24, hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
""" """
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours. Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
@@ -251,6 +251,7 @@ class AsyncAppleAccount(BaseAppleAccount):
self._account_info: _AccountInfo | None = None self._account_info: _AccountInfo | None = None
self._http: HttpSession = HttpSession() self._http: HttpSession = HttpSession()
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
def _set_login_state( def _set_login_state(
self, self,
@@ -411,20 +412,38 @@ class AsyncAppleAccount(BaseAppleAccount):
# AUTHENTICATED -> LOGGED_IN # AUTHENTICATED -> LOGGED_IN
return await self._login_mobileme() return await self._login_mobileme()
@require_login_state(LoginState.LOGGED_IN)
async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]:
"""Make a request for location reports, returning raw data."""
auth = (
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
)
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
r = await self._http.post(
"https://gateway.icloud.com/acsnservice/fetch",
auth=auth,
headers=await self.get_anisette_headers(),
json=data,
)
resp = r.json()
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise UnhandledProtocolError(msg)
return resp
@require_login_state(LoginState.LOGGED_IN) @require_login_state(LoginState.LOGGED_IN)
async def fetch_reports( async def fetch_reports(
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
date_from: datetime, date_from: datetime,
date_to: datetime, date_to: datetime | None,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_reports`.""" """See `BaseAppleAccount.fetch_reports`."""
anisette_headers = await self.get_anisette_headers() date_to = date_to or datetime.now().astimezone()
return await fetch_reports( return await self._reports.fetch_reports(
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
anisette_headers,
date_from, date_from,
date_to, date_to,
keys, keys,
@@ -435,7 +454,7 @@ class AsyncAppleAccount(BaseAppleAccount):
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
hours: int = 7 * 24, hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_last_reports`.""" """See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc) end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours) start = end - timedelta(hours=hours)
@@ -737,8 +756,8 @@ class AppleAccount(BaseAppleAccount):
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
date_from: datetime, date_from: datetime,
date_to: datetime, date_to: datetime | None,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_reports`.""" """See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to) coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._loop.run_until_complete(coro) return self._loop.run_until_complete(coro)
@@ -747,7 +766,7 @@ class AppleAccount(BaseAppleAccount):
self, self,
keys: Sequence[KeyPair], keys: Sequence[KeyPair],
hours: int = 7 * 24, hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]: ) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_last_reports`.""" """See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours) coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._loop.run_until_complete(coro) return self._loop.run_until_complete(coro)

View File

@@ -5,22 +5,27 @@ import base64
import hashlib import hashlib
import struct import struct
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence, TypedDict, overload
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from typing_extensions import Unpack
from findmy.util import HttpSession from findmy.keys import KeyPair
from findmy.util.http import HttpSession
if TYPE_CHECKING: if TYPE_CHECKING:
from findmy.keys import KeyPair from .account import AsyncAppleAccount
_session = HttpSession() _session = HttpSession()
class ReportsError(RuntimeError): class _FetcherConfig(TypedDict):
"""Raised when an error occurs while looking up reports.""" user_id: str
device_id: str
dsid: str
search_party_token: str
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes: def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
@@ -46,7 +51,7 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
return decryptor.update(enc_data) + decryptor.finalize() return decryptor.update(enc_data) + decryptor.finalize()
class KeyReport: class LocationReport:
"""Location report corresponding to a certain `KeyPair`.""" """Location report corresponding to a certain `KeyPair`."""
def __init__( # noqa: PLR0913 def __init__( # noqa: PLR0913
@@ -119,7 +124,7 @@ class KeyReport:
publish_date: datetime, publish_date: datetime,
description: str, description: str,
payload: bytes, payload: bytes,
) -> KeyReport: ) -> LocationReport:
""" """
Create a `KeyReport` from fields and a payload as reported by Apple. Create a `KeyReport` from fields and a payload as reported by Apple.
@@ -134,7 +139,7 @@ class KeyReport:
confidence = int.from_bytes(data[8:9], "big") confidence = int.from_bytes(data[8:9], "big")
status = int.from_bytes(data[9:10], "big") status = int.from_bytes(data[9:10], "big")
return KeyReport( return cls(
key, key,
publish_date, publish_date,
timestamp, timestamp,
@@ -145,14 +150,14 @@ class KeyReport:
status, status,
) )
def __lt__(self, other: KeyReport) -> bool: def __lt__(self, other: LocationReport) -> bool:
""" """
Compare against another `KeyReport`. Compare against another `KeyReport`.
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
timestamp is strictly less than the other report. timestamp is strictly less than the other report.
""" """
if isinstance(other, KeyReport): if isinstance(other, LocationReport):
return self.timestamp < other.timestamp return self.timestamp < other.timestamp
return NotImplemented return NotImplemented
@@ -164,47 +169,90 @@ class KeyReport:
) )
async def fetch_reports( # noqa: PLR0913 class LocationReportsFetcher:
dsid: str, """Fetcher class to retrieve location reports."""
search_party_token: str,
anisette_headers: dict[str, str],
date_from: datetime,
date_to: datetime,
keys: Sequence[KeyPair],
) -> dict[KeyPair, list[KeyReport]]:
"""Look up reports for given `KeyPair`s."""
start_date = date_from.timestamp() * 1000
end_date = date_to.timestamp() * 1000
ids = [key.hashed_adv_key_b64 for key in keys]
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
# TODO(malmeloo): do not create a new session every time def __init__(self, account: AsyncAppleAccount) -> None:
# https://github.com/malmeloo/FindMy.py/issues/3 """
r = await _session.post( Initialize the fetcher.
"https://gateway.icloud.com/acsnservice/fetch",
auth=(dsid, search_party_token),
headers=anisette_headers,
json=data,
)
resp = r.json()
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise ReportsError(msg)
await _session.close()
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys} :param account: Apple account.
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys} """
self._account: AsyncAppleAccount = account
for report in resp.get("results", []): self._http: HttpSession = HttpSession()
key = id_to_key[report["id"]]
date_published = datetime.fromtimestamp(
report.get("datePublished", 0) / 1000,
tz=timezone.utc,
)
description = report.get("description", "")
payload = base64.b64decode(report["payload"])
r = KeyReport.from_payload(key, date_published, description, payload) self._config: _FetcherConfig | None = None
reports[key].append(r)
return {key: sorted(reps) for key, reps in reports.items()} def apply_config(self, **conf: Unpack[_FetcherConfig]) -> None:
"""Configure internal variables necessary to make reports fetching calls."""
self._config = conf
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: KeyPair,
) -> list[LocationReport]:
...
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: Sequence[KeyPair],
) -> dict[KeyPair, list[LocationReport]]:
...
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: KeyPair | Sequence[KeyPair],
) -> list[LocationReport] | dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a certain device.
When ``device`` is a single :class:`.KeyPair`, this method will return
a list of location reports corresponding to that pair.
When ``device`` is a sequence of :class:`.KeyPair`s, it will return a dictionary
with the :class:`.KeyPair` as key, and a list of location reports as value.
"""
# single KeyPair
if isinstance(device, KeyPair):
return await self._fetch_reports(date_from, date_to, [device])
# sequence of KeyPairs
reports = await self._fetch_reports(date_from, date_to, device)
res: dict[KeyPair, list[LocationReport]] = {key: [] for key in device}
for report in reports:
res[report.key].append(report)
return res
async def _fetch_reports(
self,
date_from: datetime,
date_to: datetime,
keys: Sequence[KeyPair],
) -> list[LocationReport]:
start_date = int(date_from.timestamp() * 1000)
end_date = int(date_to.timestamp() * 1000)
ids = [key.hashed_adv_key_b64 for key in keys]
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
reports: list[LocationReport] = []
for report in data.get("results", []):
key = id_to_key[report["id"]]
date_published = datetime.fromtimestamp(
report.get("datePublished", 0) / 1000,
tz=timezone.utc,
)
description = report.get("description", "")
payload = base64.b64decode(report["payload"])
reports.append(LocationReport.from_payload(key, date_published, description, payload))
return reports