Make key fetcher stateful

Prevents unnecessary HTTP session closing
This commit is contained in:
Mike A
2024-02-10 21:26:55 +01:00
parent 55d6c9f8ff
commit 128ce84749
2 changed files with 131 additions and 64 deletions

View File

@@ -27,7 +27,7 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from findmy.util import HttpSession, decode_plist
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
from .reports import KeyReport, fetch_reports
from .reports import LocationReport, LocationReportsFetcher
from .state import LoginState, require_login_state
from .twofactor import (
AsyncSecondFactorMethod,
@@ -190,8 +190,8 @@ class BaseAppleAccount(ABC):
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
@@ -204,7 +204,7 @@ class BaseAppleAccount(ABC):
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
@@ -251,6 +251,7 @@ class AsyncAppleAccount(BaseAppleAccount):
self._account_info: _AccountInfo | None = None
self._http: HttpSession = HttpSession()
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
def _set_login_state(
self,
@@ -411,20 +412,38 @@ class AsyncAppleAccount(BaseAppleAccount):
# AUTHENTICATED -> LOGGED_IN
return await self._login_mobileme()
@require_login_state(LoginState.LOGGED_IN)
async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]:
"""Make a request for location reports, returning raw data."""
auth = (
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
)
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
r = await self._http.post(
"https://gateway.icloud.com/acsnservice/fetch",
auth=auth,
headers=await self.get_anisette_headers(),
json=data,
)
resp = r.json()
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise UnhandledProtocolError(msg)
return resp
@require_login_state(LoginState.LOGGED_IN)
async def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_reports`."""
anisette_headers = await self.get_anisette_headers()
date_to = date_to or datetime.now().astimezone()
return await fetch_reports(
self._login_state_data["dsid"],
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
anisette_headers,
return await self._reports.fetch_reports(
date_from,
date_to,
keys,
@@ -435,7 +454,7 @@ class AsyncAppleAccount(BaseAppleAccount):
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours)
@@ -737,8 +756,8 @@ class AppleAccount(BaseAppleAccount):
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
date_to: datetime | None,
) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._loop.run_until_complete(coro)
@@ -747,7 +766,7 @@ class AppleAccount(BaseAppleAccount):
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
) -> dict[KeyPair, list[LocationReport]]:
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._loop.run_until_complete(coro)

View File

@@ -5,22 +5,27 @@ import base64
import hashlib
import struct
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Sequence, TypedDict, overload
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from typing_extensions import Unpack
from findmy.util import HttpSession
from findmy.keys import KeyPair
from findmy.util.http import HttpSession
if TYPE_CHECKING:
from findmy.keys import KeyPair
from .account import AsyncAppleAccount
_session = HttpSession()
class ReportsError(RuntimeError):
"""Raised when an error occurs while looking up reports."""
class _FetcherConfig(TypedDict):
user_id: str
device_id: str
dsid: str
search_party_token: str
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
@@ -46,7 +51,7 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
return decryptor.update(enc_data) + decryptor.finalize()
class KeyReport:
class LocationReport:
"""Location report corresponding to a certain `KeyPair`."""
def __init__( # noqa: PLR0913
@@ -119,7 +124,7 @@ class KeyReport:
publish_date: datetime,
description: str,
payload: bytes,
) -> KeyReport:
) -> LocationReport:
"""
Create a `KeyReport` from fields and a payload as reported by Apple.
@@ -134,7 +139,7 @@ class KeyReport:
confidence = int.from_bytes(data[8:9], "big")
status = int.from_bytes(data[9:10], "big")
return KeyReport(
return cls(
key,
publish_date,
timestamp,
@@ -145,14 +150,14 @@ class KeyReport:
status,
)
def __lt__(self, other: KeyReport) -> bool:
def __lt__(self, other: LocationReport) -> bool:
"""
Compare against another `KeyReport`.
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
timestamp is strictly less than the other report.
"""
if isinstance(other, KeyReport):
if isinstance(other, LocationReport):
return self.timestamp < other.timestamp
return NotImplemented
@@ -164,47 +169,90 @@ class KeyReport:
)
async def fetch_reports( # noqa: PLR0913
dsid: str,
search_party_token: str,
anisette_headers: dict[str, str],
date_from: datetime,
date_to: datetime,
keys: Sequence[KeyPair],
) -> dict[KeyPair, list[KeyReport]]:
"""Look up reports for given `KeyPair`s."""
start_date = date_from.timestamp() * 1000
end_date = date_to.timestamp() * 1000
ids = [key.hashed_adv_key_b64 for key in keys]
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
class LocationReportsFetcher:
"""Fetcher class to retrieve location reports."""
# TODO(malmeloo): do not create a new session every time
# https://github.com/malmeloo/FindMy.py/issues/3
r = await _session.post(
"https://gateway.icloud.com/acsnservice/fetch",
auth=(dsid, search_party_token),
headers=anisette_headers,
json=data,
)
resp = r.json()
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise ReportsError(msg)
await _session.close()
def __init__(self, account: AsyncAppleAccount) -> None:
"""
Initialize the fetcher.
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys}
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
:param account: Apple account.
"""
self._account: AsyncAppleAccount = account
for report in resp.get("results", []):
key = id_to_key[report["id"]]
date_published = datetime.fromtimestamp(
report.get("datePublished", 0) / 1000,
tz=timezone.utc,
)
description = report.get("description", "")
payload = base64.b64decode(report["payload"])
self._http: HttpSession = HttpSession()
r = KeyReport.from_payload(key, date_published, description, payload)
reports[key].append(r)
self._config: _FetcherConfig | None = None
return {key: sorted(reps) for key, reps in reports.items()}
def apply_config(self, **conf: Unpack[_FetcherConfig]) -> None:
"""Configure internal variables necessary to make reports fetching calls."""
self._config = conf
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: KeyPair,
) -> list[LocationReport]:
...
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: Sequence[KeyPair],
) -> dict[KeyPair, list[LocationReport]]:
...
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: KeyPair | Sequence[KeyPair],
) -> list[LocationReport] | dict[KeyPair, list[LocationReport]]:
"""
Fetch location reports for a certain device.
When ``device`` is a single :class:`.KeyPair`, this method will return
a list of location reports corresponding to that pair.
When ``device`` is a sequence of :class:`.KeyPair`s, it will return a dictionary
with the :class:`.KeyPair` as key, and a list of location reports as value.
"""
# single KeyPair
if isinstance(device, KeyPair):
return await self._fetch_reports(date_from, date_to, [device])
# sequence of KeyPairs
reports = await self._fetch_reports(date_from, date_to, device)
res: dict[KeyPair, list[LocationReport]] = {key: [] for key in device}
for report in reports:
res[report.key].append(report)
return res
async def _fetch_reports(
self,
date_from: datetime,
date_to: datetime,
keys: Sequence[KeyPair],
) -> list[LocationReport]:
start_date = int(date_from.timestamp() * 1000)
end_date = int(date_to.timestamp() * 1000)
ids = [key.hashed_adv_key_b64 for key in keys]
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
reports: list[LocationReport] = []
for report in data.get("results", []):
key = id_to_key[report["id"]]
date_published = datetime.fromtimestamp(
report.get("datePublished", 0) / 1000,
tz=timezone.utc,
)
description = report.get("description", "")
payload = base64.b64decode(report["payload"])
reports.append(LocationReport.from_payload(key, date_published, description, payload))
return reports