mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-23 01:05:41 +02:00
Make key fetcher stateful
Prevents unnecessary HTTP session closing
This commit is contained in:
@@ -27,7 +27,7 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from findmy.util import HttpSession, decode_plist
|
||||
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
|
||||
|
||||
from .reports import KeyReport, fetch_reports
|
||||
from .reports import LocationReport, LocationReportsFetcher
|
||||
from .state import LoginState, require_login_state
|
||||
from .twofactor import (
|
||||
AsyncSecondFactorMethod,
|
||||
@@ -190,8 +190,8 @@ class BaseAppleAccount(ABC):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
date_to: datetime | None,
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""
|
||||
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
|
||||
|
||||
@@ -204,7 +204,7 @@ class BaseAppleAccount(ABC):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""
|
||||
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
|
||||
|
||||
@@ -251,6 +251,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
self._account_info: _AccountInfo | None = None
|
||||
|
||||
self._http: HttpSession = HttpSession()
|
||||
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
|
||||
|
||||
def _set_login_state(
|
||||
self,
|
||||
@@ -411,20 +412,38 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
# AUTHENTICATED -> LOGGED_IN
|
||||
return await self._login_mobileme()
|
||||
|
||||
@require_login_state(LoginState.LOGGED_IN)
|
||||
async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]:
|
||||
"""Make a request for location reports, returning raw data."""
|
||||
auth = (
|
||||
self._login_state_data["dsid"],
|
||||
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
|
||||
)
|
||||
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
|
||||
r = await self._http.post(
|
||||
"https://gateway.icloud.com/acsnservice/fetch",
|
||||
auth=auth,
|
||||
headers=await self.get_anisette_headers(),
|
||||
json=data,
|
||||
)
|
||||
resp = r.json()
|
||||
if not r.ok or resp["statusCode"] != "200":
|
||||
msg = f"Failed to fetch reports: {resp['statusCode']}"
|
||||
raise UnhandledProtocolError(msg)
|
||||
|
||||
return resp
|
||||
|
||||
@require_login_state(LoginState.LOGGED_IN)
|
||||
async def fetch_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
date_to: datetime | None,
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""See `BaseAppleAccount.fetch_reports`."""
|
||||
anisette_headers = await self.get_anisette_headers()
|
||||
date_to = date_to or datetime.now().astimezone()
|
||||
|
||||
return await fetch_reports(
|
||||
self._login_state_data["dsid"],
|
||||
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
|
||||
anisette_headers,
|
||||
return await self._reports.fetch_reports(
|
||||
date_from,
|
||||
date_to,
|
||||
keys,
|
||||
@@ -435,7 +454,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""See `BaseAppleAccount.fetch_last_reports`."""
|
||||
end = datetime.now(tz=timezone.utc)
|
||||
start = end - timedelta(hours=hours)
|
||||
@@ -737,8 +756,8 @@ class AppleAccount(BaseAppleAccount):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
date_to: datetime | None,
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""See `AsyncAppleAccount.fetch_reports`."""
|
||||
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
|
||||
return self._loop.run_until_complete(coro)
|
||||
@@ -747,7 +766,7 @@ class AppleAccount(BaseAppleAccount):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
"""See `AsyncAppleAccount.fetch_last_reports`."""
|
||||
coro = self._asyncacc.fetch_last_reports(keys, hours)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
@@ -5,22 +5,27 @@ import base64
|
||||
import hashlib
|
||||
import struct
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
from typing import TYPE_CHECKING, Sequence, TypedDict, overload
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from findmy.util import HttpSession
|
||||
from findmy.keys import KeyPair
|
||||
from findmy.util.http import HttpSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from findmy.keys import KeyPair
|
||||
from .account import AsyncAppleAccount
|
||||
|
||||
_session = HttpSession()
|
||||
|
||||
|
||||
class ReportsError(RuntimeError):
|
||||
"""Raised when an error occurs while looking up reports."""
|
||||
class _FetcherConfig(TypedDict):
|
||||
user_id: str
|
||||
device_id: str
|
||||
dsid: str
|
||||
search_party_token: str
|
||||
|
||||
|
||||
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
||||
@@ -46,7 +51,7 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
||||
return decryptor.update(enc_data) + decryptor.finalize()
|
||||
|
||||
|
||||
class KeyReport:
|
||||
class LocationReport:
|
||||
"""Location report corresponding to a certain `KeyPair`."""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
@@ -119,7 +124,7 @@ class KeyReport:
|
||||
publish_date: datetime,
|
||||
description: str,
|
||||
payload: bytes,
|
||||
) -> KeyReport:
|
||||
) -> LocationReport:
|
||||
"""
|
||||
Create a `KeyReport` from fields and a payload as reported by Apple.
|
||||
|
||||
@@ -134,7 +139,7 @@ class KeyReport:
|
||||
confidence = int.from_bytes(data[8:9], "big")
|
||||
status = int.from_bytes(data[9:10], "big")
|
||||
|
||||
return KeyReport(
|
||||
return cls(
|
||||
key,
|
||||
publish_date,
|
||||
timestamp,
|
||||
@@ -145,14 +150,14 @@ class KeyReport:
|
||||
status,
|
||||
)
|
||||
|
||||
def __lt__(self, other: KeyReport) -> bool:
|
||||
def __lt__(self, other: LocationReport) -> bool:
|
||||
"""
|
||||
Compare against another `KeyReport`.
|
||||
|
||||
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
|
||||
timestamp is strictly less than the other report.
|
||||
"""
|
||||
if isinstance(other, KeyReport):
|
||||
if isinstance(other, LocationReport):
|
||||
return self.timestamp < other.timestamp
|
||||
return NotImplemented
|
||||
|
||||
@@ -164,47 +169,90 @@ class KeyReport:
|
||||
)
|
||||
|
||||
|
||||
async def fetch_reports( # noqa: PLR0913
|
||||
dsid: str,
|
||||
search_party_token: str,
|
||||
anisette_headers: dict[str, str],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
keys: Sequence[KeyPair],
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""Look up reports for given `KeyPair`s."""
|
||||
start_date = date_from.timestamp() * 1000
|
||||
end_date = date_to.timestamp() * 1000
|
||||
ids = [key.hashed_adv_key_b64 for key in keys]
|
||||
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
|
||||
class LocationReportsFetcher:
|
||||
"""Fetcher class to retrieve location reports."""
|
||||
|
||||
# TODO(malmeloo): do not create a new session every time
|
||||
# https://github.com/malmeloo/FindMy.py/issues/3
|
||||
r = await _session.post(
|
||||
"https://gateway.icloud.com/acsnservice/fetch",
|
||||
auth=(dsid, search_party_token),
|
||||
headers=anisette_headers,
|
||||
json=data,
|
||||
)
|
||||
resp = r.json()
|
||||
if not r.ok or resp["statusCode"] != "200":
|
||||
msg = f"Failed to fetch reports: {resp['statusCode']}"
|
||||
raise ReportsError(msg)
|
||||
await _session.close()
|
||||
def __init__(self, account: AsyncAppleAccount) -> None:
|
||||
"""
|
||||
Initialize the fetcher.
|
||||
|
||||
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys}
|
||||
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
|
||||
:param account: Apple account.
|
||||
"""
|
||||
self._account: AsyncAppleAccount = account
|
||||
|
||||
for report in resp.get("results", []):
|
||||
key = id_to_key[report["id"]]
|
||||
date_published = datetime.fromtimestamp(
|
||||
report.get("datePublished", 0) / 1000,
|
||||
tz=timezone.utc,
|
||||
)
|
||||
description = report.get("description", "")
|
||||
payload = base64.b64decode(report["payload"])
|
||||
self._http: HttpSession = HttpSession()
|
||||
|
||||
r = KeyReport.from_payload(key, date_published, description, payload)
|
||||
reports[key].append(r)
|
||||
self._config: _FetcherConfig | None = None
|
||||
|
||||
return {key: sorted(reps) for key, reps in reports.items()}
|
||||
def apply_config(self, **conf: Unpack[_FetcherConfig]) -> None:
|
||||
"""Configure internal variables necessary to make reports fetching calls."""
|
||||
self._config = conf
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: KeyPair,
|
||||
) -> list[LocationReport]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def fetch_reports(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: Sequence[KeyPair],
|
||||
) -> dict[KeyPair, list[LocationReport]]:
|
||||
...
|
||||
|
||||
async def fetch_reports(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
device: KeyPair | Sequence[KeyPair],
|
||||
) -> list[LocationReport] | dict[KeyPair, list[LocationReport]]:
|
||||
"""
|
||||
Fetch location reports for a certain device.
|
||||
|
||||
When ``device`` is a single :class:`.KeyPair`, this method will return
|
||||
a list of location reports corresponding to that pair.
|
||||
When ``device`` is a sequence of :class:`.KeyPair`s, it will return a dictionary
|
||||
with the :class:`.KeyPair` as key, and a list of location reports as value.
|
||||
"""
|
||||
# single KeyPair
|
||||
if isinstance(device, KeyPair):
|
||||
return await self._fetch_reports(date_from, date_to, [device])
|
||||
|
||||
# sequence of KeyPairs
|
||||
reports = await self._fetch_reports(date_from, date_to, device)
|
||||
res: dict[KeyPair, list[LocationReport]] = {key: [] for key in device}
|
||||
for report in reports:
|
||||
res[report.key].append(report)
|
||||
return res
|
||||
|
||||
async def _fetch_reports(
|
||||
self,
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
keys: Sequence[KeyPair],
|
||||
) -> list[LocationReport]:
|
||||
start_date = int(date_from.timestamp() * 1000)
|
||||
end_date = int(date_to.timestamp() * 1000)
|
||||
ids = [key.hashed_adv_key_b64 for key in keys]
|
||||
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
|
||||
|
||||
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
|
||||
reports: list[LocationReport] = []
|
||||
for report in data.get("results", []):
|
||||
key = id_to_key[report["id"]]
|
||||
date_published = datetime.fromtimestamp(
|
||||
report.get("datePublished", 0) / 1000,
|
||||
tz=timezone.utc,
|
||||
)
|
||||
description = report.get("description", "")
|
||||
payload = base64.b64decode(report["payload"])
|
||||
|
||||
reports.append(LocationReport.from_payload(key, date_published, description, payload))
|
||||
|
||||
return reports
|
||||
|
||||
Reference in New Issue
Block a user