mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-24 01:35:38 +02:00
Make key fetcher stateful
Prevents unnecessary HTTP session closing
This commit is contained in:
@@ -27,7 +27,7 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|||||||
from findmy.util import HttpSession, decode_plist
|
from findmy.util import HttpSession, decode_plist
|
||||||
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
|
from findmy.util.errors import InvalidCredentialsError, UnhandledProtocolError
|
||||||
|
|
||||||
from .reports import KeyReport, fetch_reports
|
from .reports import LocationReport, LocationReportsFetcher
|
||||||
from .state import LoginState, require_login_state
|
from .state import LoginState, require_login_state
|
||||||
from .twofactor import (
|
from .twofactor import (
|
||||||
AsyncSecondFactorMethod,
|
AsyncSecondFactorMethod,
|
||||||
@@ -190,8 +190,8 @@ class BaseAppleAccount(ABC):
|
|||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
date_from: datetime,
|
date_from: datetime,
|
||||||
date_to: datetime,
|
date_to: datetime | None,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""
|
"""
|
||||||
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
|
Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
|
||||||
|
|
||||||
@@ -204,7 +204,7 @@ class BaseAppleAccount(ABC):
|
|||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
hours: int = 7 * 24,
|
hours: int = 7 * 24,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""
|
"""
|
||||||
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
|
Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
|
||||||
|
|
||||||
@@ -251,6 +251,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
|||||||
self._account_info: _AccountInfo | None = None
|
self._account_info: _AccountInfo | None = None
|
||||||
|
|
||||||
self._http: HttpSession = HttpSession()
|
self._http: HttpSession = HttpSession()
|
||||||
|
self._reports: LocationReportsFetcher = LocationReportsFetcher(self)
|
||||||
|
|
||||||
def _set_login_state(
|
def _set_login_state(
|
||||||
self,
|
self,
|
||||||
@@ -411,20 +412,38 @@ class AsyncAppleAccount(BaseAppleAccount):
|
|||||||
# AUTHENTICATED -> LOGGED_IN
|
# AUTHENTICATED -> LOGGED_IN
|
||||||
return await self._login_mobileme()
|
return await self._login_mobileme()
|
||||||
|
|
||||||
|
@require_login_state(LoginState.LOGGED_IN)
|
||||||
|
async def fetch_raw_reports(self, start: int, end: int, ids: list[str]) -> dict[str, Any]:
|
||||||
|
"""Make a request for location reports, returning raw data."""
|
||||||
|
auth = (
|
||||||
|
self._login_state_data["dsid"],
|
||||||
|
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
|
||||||
|
)
|
||||||
|
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
|
||||||
|
r = await self._http.post(
|
||||||
|
"https://gateway.icloud.com/acsnservice/fetch",
|
||||||
|
auth=auth,
|
||||||
|
headers=await self.get_anisette_headers(),
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
resp = r.json()
|
||||||
|
if not r.ok or resp["statusCode"] != "200":
|
||||||
|
msg = f"Failed to fetch reports: {resp['statusCode']}"
|
||||||
|
raise UnhandledProtocolError(msg)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
@require_login_state(LoginState.LOGGED_IN)
|
@require_login_state(LoginState.LOGGED_IN)
|
||||||
async def fetch_reports(
|
async def fetch_reports(
|
||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
date_from: datetime,
|
date_from: datetime,
|
||||||
date_to: datetime,
|
date_to: datetime | None,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""See `BaseAppleAccount.fetch_reports`."""
|
"""See `BaseAppleAccount.fetch_reports`."""
|
||||||
anisette_headers = await self.get_anisette_headers()
|
date_to = date_to or datetime.now().astimezone()
|
||||||
|
|
||||||
return await fetch_reports(
|
return await self._reports.fetch_reports(
|
||||||
self._login_state_data["dsid"],
|
|
||||||
self._login_state_data["mobileme_data"]["tokens"]["searchPartyToken"],
|
|
||||||
anisette_headers,
|
|
||||||
date_from,
|
date_from,
|
||||||
date_to,
|
date_to,
|
||||||
keys,
|
keys,
|
||||||
@@ -435,7 +454,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
|||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
hours: int = 7 * 24,
|
hours: int = 7 * 24,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""See `BaseAppleAccount.fetch_last_reports`."""
|
"""See `BaseAppleAccount.fetch_last_reports`."""
|
||||||
end = datetime.now(tz=timezone.utc)
|
end = datetime.now(tz=timezone.utc)
|
||||||
start = end - timedelta(hours=hours)
|
start = end - timedelta(hours=hours)
|
||||||
@@ -737,8 +756,8 @@ class AppleAccount(BaseAppleAccount):
|
|||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
date_from: datetime,
|
date_from: datetime,
|
||||||
date_to: datetime,
|
date_to: datetime | None,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""See `AsyncAppleAccount.fetch_reports`."""
|
"""See `AsyncAppleAccount.fetch_reports`."""
|
||||||
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
|
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
|
||||||
return self._loop.run_until_complete(coro)
|
return self._loop.run_until_complete(coro)
|
||||||
@@ -747,7 +766,7 @@ class AppleAccount(BaseAppleAccount):
|
|||||||
self,
|
self,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
hours: int = 7 * 24,
|
hours: int = 7 * 24,
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
"""See `AsyncAppleAccount.fetch_last_reports`."""
|
"""See `AsyncAppleAccount.fetch_last_reports`."""
|
||||||
coro = self._asyncacc.fetch_last_reports(keys, hours)
|
coro = self._asyncacc.fetch_last_reports(keys, hours)
|
||||||
return self._loop.run_until_complete(coro)
|
return self._loop.run_until_complete(coro)
|
||||||
|
|||||||
@@ -5,22 +5,27 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import struct
|
import struct
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Sequence
|
from typing import TYPE_CHECKING, Sequence, TypedDict, overload
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec
|
from cryptography.hazmat.primitives.asymmetric import ec
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from findmy.util import HttpSession
|
from findmy.keys import KeyPair
|
||||||
|
from findmy.util.http import HttpSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from findmy.keys import KeyPair
|
from .account import AsyncAppleAccount
|
||||||
|
|
||||||
_session = HttpSession()
|
_session = HttpSession()
|
||||||
|
|
||||||
|
|
||||||
class ReportsError(RuntimeError):
|
class _FetcherConfig(TypedDict):
|
||||||
"""Raised when an error occurs while looking up reports."""
|
user_id: str
|
||||||
|
device_id: str
|
||||||
|
dsid: str
|
||||||
|
search_party_token: str
|
||||||
|
|
||||||
|
|
||||||
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
||||||
@@ -46,7 +51,7 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
|||||||
return decryptor.update(enc_data) + decryptor.finalize()
|
return decryptor.update(enc_data) + decryptor.finalize()
|
||||||
|
|
||||||
|
|
||||||
class KeyReport:
|
class LocationReport:
|
||||||
"""Location report corresponding to a certain `KeyPair`."""
|
"""Location report corresponding to a certain `KeyPair`."""
|
||||||
|
|
||||||
def __init__( # noqa: PLR0913
|
def __init__( # noqa: PLR0913
|
||||||
@@ -119,7 +124,7 @@ class KeyReport:
|
|||||||
publish_date: datetime,
|
publish_date: datetime,
|
||||||
description: str,
|
description: str,
|
||||||
payload: bytes,
|
payload: bytes,
|
||||||
) -> KeyReport:
|
) -> LocationReport:
|
||||||
"""
|
"""
|
||||||
Create a `KeyReport` from fields and a payload as reported by Apple.
|
Create a `KeyReport` from fields and a payload as reported by Apple.
|
||||||
|
|
||||||
@@ -134,7 +139,7 @@ class KeyReport:
|
|||||||
confidence = int.from_bytes(data[8:9], "big")
|
confidence = int.from_bytes(data[8:9], "big")
|
||||||
status = int.from_bytes(data[9:10], "big")
|
status = int.from_bytes(data[9:10], "big")
|
||||||
|
|
||||||
return KeyReport(
|
return cls(
|
||||||
key,
|
key,
|
||||||
publish_date,
|
publish_date,
|
||||||
timestamp,
|
timestamp,
|
||||||
@@ -145,14 +150,14 @@ class KeyReport:
|
|||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __lt__(self, other: KeyReport) -> bool:
|
def __lt__(self, other: LocationReport) -> bool:
|
||||||
"""
|
"""
|
||||||
Compare against another `KeyReport`.
|
Compare against another `KeyReport`.
|
||||||
|
|
||||||
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
|
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
|
||||||
timestamp is strictly less than the other report.
|
timestamp is strictly less than the other report.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, KeyReport):
|
if isinstance(other, LocationReport):
|
||||||
return self.timestamp < other.timestamp
|
return self.timestamp < other.timestamp
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
@@ -164,38 +169,82 @@ class KeyReport:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def fetch_reports( # noqa: PLR0913
|
class LocationReportsFetcher:
|
||||||
dsid: str,
|
"""Fetcher class to retrieve location reports."""
|
||||||
search_party_token: str,
|
|
||||||
anisette_headers: dict[str, str],
|
def __init__(self, account: AsyncAppleAccount) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the fetcher.
|
||||||
|
|
||||||
|
:param account: Apple account.
|
||||||
|
"""
|
||||||
|
self._account: AsyncAppleAccount = account
|
||||||
|
|
||||||
|
self._http: HttpSession = HttpSession()
|
||||||
|
|
||||||
|
self._config: _FetcherConfig | None = None
|
||||||
|
|
||||||
|
def apply_config(self, **conf: Unpack[_FetcherConfig]) -> None:
|
||||||
|
"""Configure internal variables necessary to make reports fetching calls."""
|
||||||
|
self._config = conf
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def fetch_reports(
|
||||||
|
self,
|
||||||
|
date_from: datetime,
|
||||||
|
date_to: datetime,
|
||||||
|
device: KeyPair,
|
||||||
|
) -> list[LocationReport]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def fetch_reports(
|
||||||
|
self,
|
||||||
|
date_from: datetime,
|
||||||
|
date_to: datetime,
|
||||||
|
device: Sequence[KeyPair],
|
||||||
|
) -> dict[KeyPair, list[LocationReport]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def fetch_reports(
|
||||||
|
self,
|
||||||
|
date_from: datetime,
|
||||||
|
date_to: datetime,
|
||||||
|
device: KeyPair | Sequence[KeyPair],
|
||||||
|
) -> list[LocationReport] | dict[KeyPair, list[LocationReport]]:
|
||||||
|
"""
|
||||||
|
Fetch location reports for a certain device.
|
||||||
|
|
||||||
|
When ``device`` is a single :class:`.KeyPair`, this method will return
|
||||||
|
a list of location reports corresponding to that pair.
|
||||||
|
When ``device`` is a sequence of :class:`.KeyPair`s, it will return a dictionary
|
||||||
|
with the :class:`.KeyPair` as key, and a list of location reports as value.
|
||||||
|
"""
|
||||||
|
# single KeyPair
|
||||||
|
if isinstance(device, KeyPair):
|
||||||
|
return await self._fetch_reports(date_from, date_to, [device])
|
||||||
|
|
||||||
|
# sequence of KeyPairs
|
||||||
|
reports = await self._fetch_reports(date_from, date_to, device)
|
||||||
|
res: dict[KeyPair, list[LocationReport]] = {key: [] for key in device}
|
||||||
|
for report in reports:
|
||||||
|
res[report.key].append(report)
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def _fetch_reports(
|
||||||
|
self,
|
||||||
date_from: datetime,
|
date_from: datetime,
|
||||||
date_to: datetime,
|
date_to: datetime,
|
||||||
keys: Sequence[KeyPair],
|
keys: Sequence[KeyPair],
|
||||||
) -> dict[KeyPair, list[KeyReport]]:
|
) -> list[LocationReport]:
|
||||||
"""Look up reports for given `KeyPair`s."""
|
start_date = int(date_from.timestamp() * 1000)
|
||||||
start_date = date_from.timestamp() * 1000
|
end_date = int(date_to.timestamp() * 1000)
|
||||||
end_date = date_to.timestamp() * 1000
|
|
||||||
ids = [key.hashed_adv_key_b64 for key in keys]
|
ids = [key.hashed_adv_key_b64 for key in keys]
|
||||||
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
|
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
|
||||||
|
|
||||||
# TODO(malmeloo): do not create a new session every time
|
|
||||||
# https://github.com/malmeloo/FindMy.py/issues/3
|
|
||||||
r = await _session.post(
|
|
||||||
"https://gateway.icloud.com/acsnservice/fetch",
|
|
||||||
auth=(dsid, search_party_token),
|
|
||||||
headers=anisette_headers,
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
resp = r.json()
|
|
||||||
if not r.ok or resp["statusCode"] != "200":
|
|
||||||
msg = f"Failed to fetch reports: {resp['statusCode']}"
|
|
||||||
raise ReportsError(msg)
|
|
||||||
await _session.close()
|
|
||||||
|
|
||||||
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys}
|
|
||||||
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
|
id_to_key: dict[str, KeyPair] = {key.hashed_adv_key_b64: key for key in keys}
|
||||||
|
reports: list[LocationReport] = []
|
||||||
for report in resp.get("results", []):
|
for report in data.get("results", []):
|
||||||
key = id_to_key[report["id"]]
|
key = id_to_key[report["id"]]
|
||||||
date_published = datetime.fromtimestamp(
|
date_published = datetime.fromtimestamp(
|
||||||
report.get("datePublished", 0) / 1000,
|
report.get("datePublished", 0) / 1000,
|
||||||
@@ -204,7 +253,6 @@ async def fetch_reports( # noqa: PLR0913
|
|||||||
description = report.get("description", "")
|
description = report.get("description", "")
|
||||||
payload = base64.b64decode(report["payload"])
|
payload = base64.b64decode(report["payload"])
|
||||||
|
|
||||||
r = KeyReport.from_payload(key, date_published, description, payload)
|
reports.append(LocationReport.from_payload(key, date_published, description, payload))
|
||||||
reports[key].append(r)
|
|
||||||
|
|
||||||
return {key: sorted(reps) for key, reps in reports.items()}
|
return reports
|
||||||
|
|||||||
Reference in New Issue
Block a user