mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 21:53:57 +02:00
Enforce strict typing
This commit is contained in:
@@ -2,8 +2,13 @@ import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from findmy import AppleAccount, LoginState, SmsSecondFactor, RemoteAnisetteProvider
|
||||
from findmy import keys
|
||||
from findmy import (
|
||||
AppleAccount,
|
||||
LoginState,
|
||||
RemoteAnisetteProvider,
|
||||
SmsSecondFactor,
|
||||
keys,
|
||||
)
|
||||
|
||||
# URL to (public or local) anisette server
|
||||
ANISETTE_SERVER = "http://localhost:6969"
|
||||
@@ -50,7 +55,7 @@ def fetch_reports(lookup_key):
|
||||
|
||||
# Save / restore account logic
|
||||
if os.path.isfile("account.json"):
|
||||
with open("account.json", "r") as f:
|
||||
with open("account.json") as f:
|
||||
acc.restore(json.load(f))
|
||||
else:
|
||||
login(acc)
|
||||
|
||||
@@ -6,10 +6,10 @@ import os
|
||||
from findmy import (
|
||||
AsyncAppleAccount,
|
||||
LoginState,
|
||||
SmsSecondFactor,
|
||||
RemoteAnisetteProvider,
|
||||
SmsSecondFactor,
|
||||
keys,
|
||||
)
|
||||
from findmy import keys
|
||||
|
||||
# URL to (public or local) anisette server
|
||||
ANISETTE_SERVER = "http://localhost:6969"
|
||||
@@ -57,7 +57,7 @@ async def fetch_reports(lookup_key):
|
||||
try:
|
||||
# Save / restore account logic
|
||||
if os.path.isfile("account.json"):
|
||||
with open("account.json", "r") as f:
|
||||
with open("account.json") as f:
|
||||
acc.restore(json.load(f))
|
||||
else:
|
||||
await login(acc)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""A package providing everything you need to query Apple's FindMy network."""
|
||||
from .account import AppleAccount, AsyncAppleAccount, LoginState, SmsSecondFactor
|
||||
from .anisette import RemoteAnisetteProvider
|
||||
|
||||
__all__ = (
|
||||
AppleAccount,
|
||||
AsyncAppleAccount,
|
||||
LoginState,
|
||||
SmsSecondFactor,
|
||||
RemoteAnisetteProvider,
|
||||
"AppleAccount",
|
||||
"AsyncAppleAccount",
|
||||
"LoginState",
|
||||
"SmsSecondFactor",
|
||||
"RemoteAnisetteProvider",
|
||||
)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""Module containing most of the code necessary to interact with an Apple account."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -6,22 +9,23 @@ import json
|
||||
import logging
|
||||
import plistlib
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import wraps
|
||||
from typing import Optional, TypedDict, Any
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar
|
||||
|
||||
import bs4
|
||||
import srp._pysrp as srp
|
||||
from cryptography.hazmat.primitives import padding, hashes
|
||||
from cryptography.hazmat.primitives import hashes, padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from .anisette import AnisetteProvider
|
||||
from .base import BaseAppleAccount, BaseSecondFactorMethod, LoginState
|
||||
from .http import HttpSession
|
||||
from .keys import KeyPair
|
||||
from .reports import fetch_reports
|
||||
from .reports import KeyReport, fetch_reports
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .anisette import BaseAnisetteProvider
|
||||
from .keys import KeyPair
|
||||
|
||||
logging.getLogger(__name__)
|
||||
|
||||
@@ -29,25 +33,28 @@ srp.rfc5054_enable()
|
||||
srp.no_username_in_x()
|
||||
|
||||
|
||||
class AccountInfo(TypedDict):
|
||||
class _AccountInfo(TypedDict):
|
||||
account_name: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
|
||||
class LoginException(Exception):
|
||||
pass
|
||||
class LoginError(Exception):
|
||||
"""Raised when an error occurs during login, such as when the password is incorrect."""
|
||||
|
||||
|
||||
class InvalidStateException(RuntimeError):
|
||||
pass
|
||||
class InvalidStateError(RuntimeError):
|
||||
"""Raised when a method is used that is in conflict with the internal account state.
|
||||
|
||||
For example: calling `BaseAppleAccount.login` while already logged in.
|
||||
"""
|
||||
|
||||
|
||||
class ExportRestoreError(ValueError):
|
||||
pass
|
||||
"""Raised when an error occurs while exporting or restoring the account's current state."""
|
||||
|
||||
|
||||
def _load_plist(data: bytes) -> Any:
|
||||
def _load_plist(data: bytes) -> Any: # noqa: ANN401
|
||||
plist_header = (
|
||||
b"<?xml version='1.0' encoding='UTF-8'?>"
|
||||
b"<!DOCTYPE plist PUBLIC '-//Apple//DTD PLIST 1.0//EN' 'http://www.apple.com/DTDs/PropertyList-1.0.dtd'>"
|
||||
@@ -89,20 +96,26 @@ def _extract_phone_numbers(html: str) -> list[dict]:
|
||||
soup = bs4.BeautifulSoup(html, features="html.parser")
|
||||
data_elem = soup.find("script", **{"class": "boot_args"})
|
||||
if not data_elem:
|
||||
raise RuntimeError("Could not find HTML element containing phone numbers")
|
||||
msg = "Could not find HTML element containing phone numbers"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = json.loads(data_elem.text)
|
||||
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
|
||||
|
||||
|
||||
def _require_login_state(*states: LoginState):
|
||||
def decorator(func):
|
||||
F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any])
|
||||
|
||||
|
||||
def _require_login_state(*states: LoginState) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapper(acc: "BaseAppleAccount", *args, **kwargs):
|
||||
def wrapper(acc: BaseAppleAccount, *args, **kwargs):
|
||||
if acc.login_state not in states:
|
||||
raise InvalidStateException(
|
||||
f"Invalid login state! Currently: {acc.login_state} but should be one of: {states}"
|
||||
msg = (
|
||||
f"Invalid login state! Currently: {acc.login_state}"
|
||||
f" but should be one of: {states}"
|
||||
)
|
||||
raise InvalidStateError(msg)
|
||||
|
||||
return func(acc, *args, **kwargs)
|
||||
|
||||
@@ -112,93 +125,164 @@ def _require_login_state(*states: LoginState):
|
||||
|
||||
|
||||
class AsyncSmsSecondFactor(BaseSecondFactorMethod):
|
||||
def __init__(self, account: "AsyncAppleAccount", number_id: int, phone_number: str):
|
||||
"""An async implementation of a second-factor method."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: AsyncAppleAccount,
|
||||
number_id: int,
|
||||
phone_number: str,
|
||||
) -> None:
|
||||
"""Initialize the second factor method.
|
||||
|
||||
Should not be done manually; use `BaseAppleAccount.get_2fa_methods` instead.
|
||||
"""
|
||||
super().__init__(account)
|
||||
|
||||
self._phone_number_id: int = number_id
|
||||
self._phone_number: str = phone_number
|
||||
|
||||
@property
|
||||
def phone_number_id(self):
|
||||
def phone_number_id(self) -> int:
|
||||
"""The phone number's ID. You most likely don't need this."""
|
||||
return self._phone_number_id
|
||||
|
||||
@property
|
||||
def phone_number(self):
|
||||
def phone_number(self) -> str:
|
||||
"""The 2FA method's phone number.
|
||||
|
||||
May be masked using unicode characters; should only be used for identification purposes.
|
||||
"""
|
||||
return self._phone_number
|
||||
|
||||
async def request(self):
|
||||
async def request(self) -> None:
|
||||
"""Request an SMS to the corresponding phone number containing a 2FA code."""
|
||||
return await self.account.sms_2fa_request(self._phone_number_id)
|
||||
|
||||
async def submit(self, code: str) -> LoginState:
|
||||
"""See `BaseSecondFactorMethod.submit`."""
|
||||
return await self.account.sms_2fa_submit(self._phone_number_id, code)
|
||||
|
||||
|
||||
class SmsSecondFactor(BaseSecondFactorMethod):
|
||||
def __init__(self, account: "AppleAccount", number_id: int, phone_number: str):
|
||||
"""A sync implementation of `BaseSecondFactorMethod`.
|
||||
|
||||
Uses `AsyncSmsSecondFactor` internally.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: AppleAccount,
|
||||
number_id: int,
|
||||
phone_number: str,
|
||||
) -> None:
|
||||
"""See `AsyncSmsSecondFactor.__init__`."""
|
||||
super().__init__(account)
|
||||
|
||||
self._phone_number_id: int = number_id
|
||||
self._phone_number: str = phone_number
|
||||
|
||||
@property
|
||||
def phone_number(self):
|
||||
def phone_number_id(self) -> int:
|
||||
"""See `AsyncSmsSecondFactor.phone_number_id`."""
|
||||
return self._phone_number_id
|
||||
|
||||
@property
|
||||
def phone_number(self) -> str:
|
||||
"""See `AsyncSmsSecondFactor.phone_number`."""
|
||||
return self._phone_number
|
||||
|
||||
def request(self) -> None:
|
||||
"""See `AsyncSmsSecondFactor.request`."""
|
||||
return self.account.sms_2fa_request(self._phone_number_id)
|
||||
|
||||
def submit(self, code: str) -> LoginState:
|
||||
"""See `AsyncSmsSecondFactor.submit`."""
|
||||
return self.account.sms_2fa_submit(self._phone_number_id, code)
|
||||
|
||||
|
||||
class AsyncAppleAccount(BaseAppleAccount):
|
||||
def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None):
|
||||
self._anisette: AnisetteProvider = anisette
|
||||
"""An async implementation of `BaseAppleAccount`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anisette: BaseAnisetteProvider,
|
||||
user_id: str | None = None,
|
||||
device_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the apple account.
|
||||
|
||||
:param anisette: An instance of `AsyncAnisetteProvider`.
|
||||
:param user_id: An optional user ID to use. Will be auto-generated if missing.
|
||||
:param device_id: An optional device ID to use. Will be auto-generated if missing.
|
||||
"""
|
||||
self._anisette: BaseAnisetteProvider = anisette
|
||||
self._uid: str = user_id or str(uuid.uuid4())
|
||||
self._devid: str = device_id or str(uuid.uuid4())
|
||||
|
||||
self._username: Optional[str] = None
|
||||
self._password: Optional[str] = None
|
||||
self._username: str | None = None
|
||||
self._password: str | None = None
|
||||
|
||||
self._login_state: LoginState = LoginState.LOGGED_OUT
|
||||
self._login_state_data: dict = {}
|
||||
|
||||
self._account_info: Optional[AccountInfo] = None
|
||||
self._account_info: _AccountInfo | None = None
|
||||
|
||||
self._http = HttpSession()
|
||||
|
||||
def _set_login_state(self, state: LoginState, data: Optional[dict] = None) -> LoginState:
|
||||
def _set_login_state(
|
||||
self,
|
||||
state: LoginState,
|
||||
data: dict | None = None,
|
||||
) -> LoginState:
|
||||
# clear account info if downgrading state (e.g. LOGGED_IN -> LOGGED_OUT)
|
||||
if state < self._login_state:
|
||||
logging.debug("Clearing cached account information")
|
||||
self._account_info = None
|
||||
|
||||
logging.info(f"Transitioning login state: {self._login_state} -> {state}")
|
||||
logging.info("Transitioning login state: %s -> %s", self._login_state, state)
|
||||
self._login_state = state
|
||||
self._login_state_data = data or {}
|
||||
|
||||
return state
|
||||
|
||||
@property
|
||||
def login_state(self):
|
||||
def login_state(self) -> LoginState:
|
||||
"""See `BaseAppleAccount.login_state`."""
|
||||
return self._login_state
|
||||
|
||||
@property
|
||||
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
|
||||
def account_name(self):
|
||||
@_require_login_state(
|
||||
LoginState.LOGGED_IN,
|
||||
LoginState.AUTHENTICATED,
|
||||
LoginState.REQUIRE_2FA,
|
||||
)
|
||||
def account_name(self) -> str | None:
|
||||
"""See `BaseAppleAccount.account_name`."""
|
||||
return self._account_info["account_name"] if self._account_info else None
|
||||
|
||||
@property
|
||||
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
|
||||
def first_name(self):
|
||||
@_require_login_state(
|
||||
LoginState.LOGGED_IN,
|
||||
LoginState.AUTHENTICATED,
|
||||
LoginState.REQUIRE_2FA,
|
||||
)
|
||||
def first_name(self) -> str | None:
|
||||
"""See `BaseAppleAccount.first_name`."""
|
||||
return self._account_info["first_name"] if self._account_info else None
|
||||
|
||||
@property
|
||||
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
|
||||
def last_name(self):
|
||||
@_require_login_state(
|
||||
LoginState.LOGGED_IN,
|
||||
LoginState.AUTHENTICATED,
|
||||
LoginState.REQUIRE_2FA,
|
||||
)
|
||||
def last_name(self) -> str | None:
|
||||
"""See `BaseAppleAccount.last_name`."""
|
||||
return self._account_info["last_name"] if self._account_info else None
|
||||
|
||||
def export(self) -> dict:
|
||||
"""See `BaseAppleAccount.export`."""
|
||||
return {
|
||||
"ids": {"uid": self._uid, "devid": self._devid},
|
||||
"account": {
|
||||
@@ -212,7 +296,8 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
},
|
||||
}
|
||||
|
||||
def restore(self, data: dict):
|
||||
def restore(self, data: dict) -> None:
|
||||
"""See `BaseAppleAccount.restore`."""
|
||||
try:
|
||||
self._uid = data["ids"]["uid"]
|
||||
self._devid = data["ids"]["devid"]
|
||||
@@ -224,14 +309,20 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
self._login_state = LoginState(data["login_state"]["state"])
|
||||
self._login_state_data = data["login_state"]["data"]
|
||||
except KeyError as e:
|
||||
raise ExportRestoreError(f"Failed to restore account data: {e}")
|
||||
msg = f"Failed to restore account data: {e}"
|
||||
raise ExportRestoreError(msg) from None
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Close any sessions or other resources in use by this object.
|
||||
|
||||
Should be called when the object will no longer be used.
|
||||
"""
|
||||
await self._anisette.close()
|
||||
await self._http.close()
|
||||
|
||||
@_require_login_state(LoginState.LOGGED_OUT)
|
||||
async def login(self, username: str, password: str) -> LoginState:
|
||||
"""See `BaseAppleAccount.login`."""
|
||||
# LOGGED_OUT -> (REQUIRE_2FA or AUTHENTICATED)
|
||||
new_state = await self._gsa_authenticate(username, password)
|
||||
if new_state == LoginState.REQUIRE_2FA: # pass control back to handle 2FA
|
||||
@@ -242,30 +333,40 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
|
||||
@_require_login_state(LoginState.REQUIRE_2FA)
|
||||
async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
|
||||
"""See `BaseAppleAccount.get_2fa_methods`."""
|
||||
methods: list[BaseSecondFactorMethod] = []
|
||||
|
||||
# sms
|
||||
auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth")
|
||||
try:
|
||||
phone_numbers = _extract_phone_numbers(auth_page)
|
||||
methods.extend(
|
||||
AsyncSmsSecondFactor(
|
||||
self,
|
||||
number.get("id"),
|
||||
number.get("numberWithDialCode"),
|
||||
)
|
||||
for number in phone_numbers
|
||||
)
|
||||
except RuntimeError:
|
||||
logging.warning("Unable to extract phone numbers from login page")
|
||||
|
||||
methods.extend(
|
||||
AsyncSmsSecondFactor(self, number.get("id"), number.get("numberWithDialCode"))
|
||||
for number in phone_numbers
|
||||
)
|
||||
|
||||
return methods
|
||||
|
||||
@_require_login_state(LoginState.REQUIRE_2FA)
|
||||
async def sms_2fa_request(self, phone_number_id: int):
|
||||
async def sms_2fa_request(self, phone_number_id: int) -> None:
|
||||
"""See `BaseAppleAccount.sms_2fa_request`."""
|
||||
data = {"phoneNumber": {"id": phone_number_id}, "mode": "sms"}
|
||||
|
||||
await self._sms_2fa_request("PUT", "https://gsa.apple.com/auth/verify/phone", data)
|
||||
await self._sms_2fa_request(
|
||||
"PUT",
|
||||
"https://gsa.apple.com/auth/verify/phone",
|
||||
data,
|
||||
)
|
||||
|
||||
@_require_login_state(LoginState.REQUIRE_2FA)
|
||||
async def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
|
||||
"""See `BaseAppleAccount.sms_2fa_submit`."""
|
||||
data = {
|
||||
"phoneNumber": {"id": phone_number_id},
|
||||
"securityCode": {"code": str(code)},
|
||||
@@ -273,19 +374,28 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
}
|
||||
|
||||
await self._sms_2fa_request(
|
||||
"POST", "https://gsa.apple.com/auth/verify/phone/securitycode", data
|
||||
"POST",
|
||||
"https://gsa.apple.com/auth/verify/phone/securitycode",
|
||||
data,
|
||||
)
|
||||
|
||||
# REQUIRE_2FA -> AUTHENTICATED
|
||||
new_state = await self._gsa_authenticate()
|
||||
if new_state != LoginState.AUTHENTICATED:
|
||||
raise LoginException(f"Unexpected state after submitting 2FA: {new_state}")
|
||||
msg = f"Unexpected state after submitting 2FA: {new_state}"
|
||||
raise LoginError(msg)
|
||||
|
||||
# AUTHENTICATED -> LOGGED_IN
|
||||
return await self._login_mobileme()
|
||||
|
||||
@_require_login_state(LoginState.LOGGED_IN)
|
||||
async def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
|
||||
async def fetch_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""See `BaseAppleAccount.fetch_reports`."""
|
||||
anisette_headers = await self.get_anisette_headers()
|
||||
|
||||
return await fetch_reports(
|
||||
@@ -302,57 +412,67 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
):
|
||||
end = datetime.now()
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""See `BaseAppleAccount.fetch_last_reports`."""
|
||||
end = datetime.now(tz=timezone.utc)
|
||||
start = end - timedelta(hours=hours)
|
||||
|
||||
return await self.fetch_reports(keys, start, end)
|
||||
|
||||
@_require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA)
|
||||
async def _gsa_authenticate(
|
||||
self, username: Optional[str] = None, password: Optional[str] = None
|
||||
self,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> LoginState:
|
||||
# use stored values for re-authentication
|
||||
self._username = username or self._username
|
||||
self._password = password or self._password
|
||||
|
||||
logging.info(f"Attempting authentication for user {self._username}")
|
||||
logging.info("Attempting authentication for user %s", self._username)
|
||||
|
||||
if not self._username or not self._password:
|
||||
raise ValueError("No username or password to log in")
|
||||
msg = "No username or password specified"
|
||||
raise ValueError(msg)
|
||||
|
||||
logging.debug("Starting authentication with username")
|
||||
|
||||
usr = srp.User(self._username, b"", hash_alg=srp.SHA256, ng_type=srp.NG_2048)
|
||||
_, a2k = usr.start_authentication()
|
||||
r = await self._gsa_request(
|
||||
{"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"}
|
||||
{"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"},
|
||||
)
|
||||
|
||||
logging.debug("Verifying response to auth request")
|
||||
|
||||
if r["Status"].get("ec") != 0:
|
||||
message = r["Status"].get("em")
|
||||
raise LoginException(f"Email verify failed: {message}")
|
||||
msg = "Email verify failed: " + r["Status"].get("em")
|
||||
raise LoginError(msg)
|
||||
sp = r.get("sp")
|
||||
if sp != "s2k":
|
||||
raise LoginException(f"This implementation only supports s2k. Server returned {sp}")
|
||||
msg = f"This implementation only supports s2k. Server returned {sp}"
|
||||
raise LoginError(msg)
|
||||
|
||||
logging.debug("Attempting password challenge")
|
||||
|
||||
usr.p = _encrypt_password(self._password, r["s"], r["i"])
|
||||
m1 = usr.process_challenge(r["s"], r["B"])
|
||||
if m1 is None:
|
||||
raise LoginException("Failed to process challenge")
|
||||
r = await self._gsa_request({"c": r["c"], "M1": m1, "u": self._username, "o": "complete"})
|
||||
msg = "Failed to process challenge"
|
||||
raise LoginError(msg)
|
||||
r = await self._gsa_request(
|
||||
{"c": r["c"], "M1": m1, "u": self._username, "o": "complete"},
|
||||
)
|
||||
|
||||
logging.debug("Verifying password challenge response")
|
||||
|
||||
if r["Status"].get("ec") != 0:
|
||||
message = r["Status"].get("em")
|
||||
raise LoginException(f"Password authentication failed: {message}")
|
||||
msg = "Password authentication failed: " + r["Status"].get("em")
|
||||
raise LoginError(msg)
|
||||
usr.verify_session(r.get("M2"))
|
||||
if not usr.authenticated():
|
||||
raise LoginException("Failed to verify session")
|
||||
msg = "Failed to verify session"
|
||||
raise LoginError(msg)
|
||||
|
||||
logging.debug("Decrypting SPD data in response")
|
||||
|
||||
@@ -360,33 +480,36 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
spd = _load_plist(spd)
|
||||
|
||||
logging.debug("Received account information")
|
||||
self._account_info: AccountInfo = {
|
||||
self._account_info: _AccountInfo = {
|
||||
"account_name": spd.get("acname"),
|
||||
"first_name": spd.get("fn"),
|
||||
"last_name": spd.get("ln"),
|
||||
}
|
||||
|
||||
# TODO: support trusted device auth (need account to test)
|
||||
# TODO(malmeloo): support trusted device auth (need account to test)
|
||||
# https://github.com/malmeloo/FindMy.py/issues/1
|
||||
au = r["Status"].get("au")
|
||||
if au in ("secondaryAuth",):
|
||||
logging.info(f"Detected 2FA requirement: {au}")
|
||||
logging.info("Detected 2FA requirement: %s", au)
|
||||
|
||||
return self._set_login_state(
|
||||
LoginState.REQUIRE_2FA,
|
||||
{"adsid": spd["adsid"], "idms_token": spd["GsIdmsToken"]},
|
||||
)
|
||||
elif au is not None:
|
||||
raise LoginException(f"Unknown auth value: {au}")
|
||||
if au is not None:
|
||||
msg = f"Unknown auth value: {au}"
|
||||
raise LoginError(msg)
|
||||
|
||||
logging.info("GSA authentication successful")
|
||||
|
||||
idms_pet = spd.get("t", {}).get("com.apple.gs.idms.pet", {}).get("token", "")
|
||||
return self._set_login_state(
|
||||
LoginState.AUTHENTICATED, {"idms_pet": idms_pet, "adsid": spd["adsid"]}
|
||||
LoginState.AUTHENTICATED,
|
||||
{"idms_pet": idms_pet, "adsid": spd["adsid"]},
|
||||
)
|
||||
|
||||
@_require_login_state(LoginState.AUTHENTICATED)
|
||||
async def _login_mobileme(self):
|
||||
async def _login_mobileme(self) -> LoginState:
|
||||
logging.info("Logging into com.apple.mobileme")
|
||||
data = plistlib.dumps(
|
||||
{
|
||||
@@ -394,7 +517,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
"delegates": {"com.apple.mobileme": {}},
|
||||
"password": self._login_state_data["idms_pet"],
|
||||
"client-id": self._uid,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
headers = {
|
||||
@@ -417,15 +540,21 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
mobileme_data = resp.get("delegates", {}).get("com.apple.mobileme", {})
|
||||
status = mobileme_data.get("status")
|
||||
if status != 0:
|
||||
message = mobileme_data.get("status-message")
|
||||
raise LoginException(f"com.apple.mobileme login failed with status {status}: {message}")
|
||||
status_message = mobileme_data.get("status-message")
|
||||
msg = f"com.apple.mobileme login failed with status {status}: {status_message}"
|
||||
raise LoginError(msg)
|
||||
|
||||
return self._set_login_state(
|
||||
LoginState.LOGGED_IN,
|
||||
{"dsid": resp["dsid"], "mobileme_data": mobileme_data["service-data"]},
|
||||
)
|
||||
|
||||
async def _sms_2fa_request(self, method: str, url: str, data: Optional[dict] = None) -> str:
|
||||
async def _sms_2fa_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict | None = None,
|
||||
) -> str:
|
||||
adsid = self._login_state_data["adsid"]
|
||||
idms_token = self._login_state_data["idms_token"]
|
||||
identity_token = base64.b64encode((adsid + ":" + idms_token).encode()).decode()
|
||||
@@ -441,13 +570,19 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
}
|
||||
headers.update(await self.get_anisette_headers())
|
||||
|
||||
async with await self._http.request(method, url, json=data, headers=headers) as r:
|
||||
async with await self._http.request(
|
||||
method,
|
||||
url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
) as r:
|
||||
if not r.ok:
|
||||
raise LoginException(f"HTTP request failed: {r.status_code}")
|
||||
msg = f"HTTP request failed: {r.status_code}"
|
||||
raise LoginError(msg)
|
||||
|
||||
return await r.text()
|
||||
|
||||
async def _gsa_request(self, params):
|
||||
async def _gsa_request(self, params: dict[str, Any]) -> Any:
|
||||
request_data = {
|
||||
"cpd": {
|
||||
"bootstrap": True,
|
||||
@@ -455,7 +590,7 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
"pbe": False,
|
||||
"prkgen": True,
|
||||
"svct": "iCloud",
|
||||
}
|
||||
},
|
||||
}
|
||||
request_data["cpd"].update(await self.get_anisette_headers())
|
||||
request_data.update(params)
|
||||
@@ -482,11 +617,23 @@ class AsyncAppleAccount(BaseAppleAccount):
|
||||
return _load_plist(content)["Response"]
|
||||
|
||||
async def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
|
||||
"""See `BaseAppleAccount.get_anisette_headers`."""
|
||||
return await self._anisette.get_headers(self._uid, self._devid, serial)
|
||||
|
||||
|
||||
class AppleAccount(BaseAppleAccount):
|
||||
def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None):
|
||||
"""A sync implementation of `BaseappleAccount`.
|
||||
|
||||
Uses `AsyncappleAccount` internally.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anisette: BaseAnisetteProvider,
|
||||
user_id: str | None = None,
|
||||
device_id: str | None = None,
|
||||
) -> None:
|
||||
"""See `AsyncAppleAccount.__init__`."""
|
||||
self._asyncacc = AsyncAppleAccount(anisette, user_id, device_id)
|
||||
|
||||
try:
|
||||
@@ -496,36 +643,45 @@ class AppleAccount(BaseAppleAccount):
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Gracefully close the async instance's session when garbage collected."""
|
||||
coro = self._asyncacc.close()
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
@property
|
||||
def login_state(self):
|
||||
def login_state(self) -> LoginState:
|
||||
"""See `AsyncAppleAccount.login_state`."""
|
||||
return self._asyncacc.login_state
|
||||
|
||||
@property
|
||||
def account_name(self):
|
||||
def account_name(self) -> str:
|
||||
"""See `AsyncAppleAccount.login_state`."""
|
||||
return self._asyncacc.account_name
|
||||
|
||||
@property
|
||||
def first_name(self):
|
||||
def first_name(self) -> str | None:
|
||||
"""See `AsyncAppleAccount.first_name`."""
|
||||
return self._asyncacc.first_name
|
||||
|
||||
@property
|
||||
def last_name(self):
|
||||
def last_name(self) -> str | None:
|
||||
"""See `AsyncAppleAccount.last_name`."""
|
||||
return self._asyncacc.last_name
|
||||
|
||||
def export(self) -> dict:
|
||||
"""See `AsyncAppleAccount.export`."""
|
||||
return self._asyncacc.export()
|
||||
|
||||
def restore(self, data: dict):
|
||||
def restore(self, data: dict) -> None:
|
||||
"""See `AsyncAppleAccount.restore`."""
|
||||
return self._asyncacc.restore(data)
|
||||
|
||||
def login(self, username: str, password: str) -> LoginState:
|
||||
"""See `AsyncAppleAccount.login`."""
|
||||
coro = self._asyncacc.login(username, password)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
|
||||
"""See `AsyncAppleAccount.get_2fa_methods`."""
|
||||
coro = self._asyncacc.get_2fa_methods()
|
||||
methods = self._loop.run_until_complete(coro)
|
||||
|
||||
@@ -534,28 +690,44 @@ class AppleAccount(BaseAppleAccount):
|
||||
if isinstance(m, AsyncSmsSecondFactor):
|
||||
res.append(SmsSecondFactor(self, m.phone_number_id, m.phone_number))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed to cast 2FA object to sync alternative: {m}. This is a bug, please report it."
|
||||
msg = (
|
||||
f"Failed to cast 2FA object to sync alternative: {m}."
|
||||
f" This is a bug, please report it."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return res
|
||||
|
||||
def sms_2fa_request(self, phone_number_id: int):
|
||||
def sms_2fa_request(self, phone_number_id: int) -> None:
|
||||
"""See `AsyncAppleAccount.sms_2fa_request`."""
|
||||
coro = self._asyncacc.sms_2fa_request(phone_number_id)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
|
||||
"""See `AsyncAppleAccount.sms_2fa_submit`."""
|
||||
coro = self._asyncacc.sms_2fa_submit(phone_number_id, code)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
|
||||
def fetch_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""See `AsyncAppleAccount.fetch_reports`."""
|
||||
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
def fetch_last_reports(self, keys: Sequence[KeyPair], hours: int = 7 * 24):
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""See `AsyncAppleAccount.fetch_last_reports`."""
|
||||
coro = self._asyncacc.fetch_last_reports(keys, hours)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
|
||||
coro = self._asyncacc.get_anisette_headers()
|
||||
"""See `AsyncAppleAccount.get_anisette_headers`."""
|
||||
coro = self._asyncacc.get_anisette_headers(serial)
|
||||
return self._loop.run_until_complete(coro)
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
"""Module for Anisette header providers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import locale
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .http import HttpSession
|
||||
|
||||
|
||||
def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"):
|
||||
now = datetime.utcnow()
|
||||
def _gen_meta_headers(
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
serial: str = "0",
|
||||
) -> dict[str, str]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
locale_str = locale.getdefaultlocale()[0] or "en_US"
|
||||
|
||||
return {
|
||||
@@ -23,29 +30,44 @@ def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"):
|
||||
}
|
||||
|
||||
|
||||
class AnisetteProvider(ABC):
|
||||
class BaseAnisetteProvider(ABC):
|
||||
"""Abstract base class for Anisette providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def _get_base_headers(self) -> dict[str, str]:
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
return NotImplemented
|
||||
async def close(self) -> None:
|
||||
"""Close any underlying sessions. Call when the provider will no longer be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_headers(self, user_id: str, device_id: str, serial: str = "0") -> dict[str, str]:
|
||||
async def get_headers(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
serial: str = "0",
|
||||
) -> dict[str, str]:
|
||||
"""Retrieve a complete dictionary of Anisette headers.
|
||||
|
||||
Consider using `BaseAppleAccount.get_anisette_headers` instead.
|
||||
"""
|
||||
base_headers = await self._get_base_headers()
|
||||
base_headers.update(_gen_meta_headers(user_id, device_id, serial))
|
||||
|
||||
return base_headers
|
||||
|
||||
|
||||
class RemoteAnisetteProvider(AnisetteProvider):
|
||||
def __init__(self, server_url: str):
|
||||
class RemoteAnisetteProvider(BaseAnisetteProvider):
|
||||
"""Anisette provider. Fetches headers from a remote Anisette server."""
|
||||
|
||||
def __init__(self, server_url: str) -> None:
|
||||
"""Initialize the provider with URL to te remote server."""
|
||||
self._server_url = server_url
|
||||
|
||||
self._http = HttpSession()
|
||||
|
||||
logging.info(f"Using remote anisette server: {self._server_url}")
|
||||
logging.info("Using remote anisette server: %s", self._server_url)
|
||||
|
||||
async def _get_base_headers(self) -> dict[str, str]:
|
||||
async with await self._http.get(self._server_url) as r:
|
||||
@@ -56,17 +78,21 @@ class RemoteAnisetteProvider(AnisetteProvider):
|
||||
"X-Apple-I-MD-M": headers["X-Apple-I-MD-M"],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""See `AnisetteProvider.close`."""
|
||||
await self._http.close()
|
||||
|
||||
|
||||
# TODO: implement using pyprovision
|
||||
class LocalAnisetteProvider(AnisetteProvider):
|
||||
def __init__(self):
|
||||
pass
|
||||
# TODO(malmeloo): implement using pyprovision
|
||||
# https://github.com/malmeloo/FindMy.py/issues/2
|
||||
class LocalAnisetteProvider(BaseAnisetteProvider):
|
||||
"""Anisette provider. Generates headers without a remote server using pyprovision."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the provider."""
|
||||
|
||||
async def _get_base_headers(self) -> dict[str, str]:
|
||||
return NotImplemented
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
async def close(self) -> None:
|
||||
"""See `AnisetteProvider.close`."""
|
||||
|
||||
154
findmy/base.py
154
findmy/base.py
@@ -1,101 +1,185 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence
|
||||
"""Module that contains base classes for various other modules. For internal use only."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .keys import KeyPair
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Sequence, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from .keys import KeyPair
|
||||
from .reports import KeyReport
|
||||
|
||||
|
||||
class LoginState(Enum):
|
||||
"""Enum of possible login states. Used for `AppleAccount`'s internal state machine."""
|
||||
|
||||
LOGGED_OUT = 0
|
||||
REQUIRE_2FA = 1
|
||||
AUTHENTICATED = 2
|
||||
LOGGED_IN = 3
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: LoginState) -> bool:
|
||||
"""Compare against another `LoginState`.
|
||||
|
||||
A `LoginState` is said to be "less than" another `LoginState` iff it is in
|
||||
an "earlier" stage of the login process, going from LOGGED_OUT to LOGGED_IN.
|
||||
"""
|
||||
if isinstance(other, LoginState):
|
||||
return self.value < other.value
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Human-readable string representation of the state."""
|
||||
return self.__str__()
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseAppleAccount")
|
||||
|
||||
|
||||
class BaseSecondFactorMethod(ABC):
|
||||
def __init__(self, account: "BaseAppleAccount"):
|
||||
self._account = account
|
||||
"""Base class for a second-factor authentication method for an Apple account."""
|
||||
|
||||
def __init__(self, account: T) -> None:
|
||||
"""Initialize the second-factor method."""
|
||||
self._account: T = account
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
def account(self) -> T:
|
||||
"""The account associated with the second-factor method."""
|
||||
return self._account
|
||||
|
||||
@abstractmethod
|
||||
def request(self) -> None:
|
||||
raise NotImplementedError()
|
||||
"""Put in a request for the second-factor challenge.
|
||||
|
||||
Exact meaning is up to the implementing class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def submit(self, code: str) -> LoginState:
|
||||
raise NotImplementedError()
|
||||
"""Submit a code to complete the second-factor challenge."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseAppleAccount(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def login_state(self):
|
||||
return NotImplemented
|
||||
"""Base class for an Apple account."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def account_name(self):
|
||||
return NotImplemented
|
||||
def login_state(self) -> LoginState:
|
||||
"""The current login state of the account."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def first_name(self):
|
||||
return NotImplemented
|
||||
def account_name(self) -> str:
|
||||
"""The name of the account as reported by Apple.
|
||||
|
||||
This is usually an e-mail address.
|
||||
May be None in some cases, such as when not logged in.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_name(self):
|
||||
return NotImplemented
|
||||
def first_name(self) -> str | None:
|
||||
"""First name of the account holder as reported by Apple.
|
||||
|
||||
May be None in some cases, such as when not logged in.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_name(self) -> str | None:
|
||||
"""Last name of the account holder as reported by Apple.
|
||||
|
||||
May be None in some cases, such as when not logged in.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def export(self) -> dict:
|
||||
return NotImplemented
|
||||
"""Export a representation of the current state of the account as a dictionary.
|
||||
|
||||
The output of this method is guaranteed to be JSON-serializable, and passing
|
||||
the return value of this function as an argument to `BaseAppleAccount.restore`
|
||||
will always result in an exact copy of the internal state as it was when exported.
|
||||
|
||||
This method is especially useful to avoid having to keep going through the login flow.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def restore(self, data: dict):
|
||||
return NotImplemented
|
||||
def restore(self, data: dict) -> None:
|
||||
"""Restore a previous export of the internal state of the account.
|
||||
|
||||
See `BaseAppleAccount.export` for more information.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def login(self, username: str, password: str) -> LoginState:
|
||||
return NotImplemented
|
||||
"""Log in to an Apple account using a username and password."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
|
||||
return NotImplemented
|
||||
"""Get a list of 2FA methods that can be used as a secondary challenge.
|
||||
|
||||
Currently, only SMS-based 2FA methods are supported.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sms_2fa_request(self, phone_number_id: int):
|
||||
return NotImplemented
|
||||
def sms_2fa_request(self, phone_number_id: int) -> None:
|
||||
"""Request a 2FA code to be sent to a specific phone number ID.
|
||||
|
||||
Consider using `BaseSecondFactorMethod.request` instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
|
||||
return NotImplemented
|
||||
"""Submit a 2FA code that was sent to a specific phone number ID.
|
||||
|
||||
Consider using `BaseSecondFactorMethod.submit` instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
|
||||
return NotImplemented
|
||||
def fetch_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
|
||||
|
||||
Returns a dictionary mapping `KeyPair`s to a list of their location reports.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fetch_last_reports(
|
||||
self,
|
||||
keys: Sequence[KeyPair],
|
||||
hours: int = 7 * 24,
|
||||
):
|
||||
return NotImplemented
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
|
||||
|
||||
Utility method as an alternative to using `BaseAppleAccount.fetch_reports` directly.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
|
||||
return NotImplemented
|
||||
"""Retrieve a complete dictionary of Anisette headers.
|
||||
|
||||
Utility method for `AnisetteProvider.get_headers` using this account's user and device ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,45 +1,78 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
"""Module to simplify asynchronous HTTP calls. For internal use only."""
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import ClientSession, BasicAuth, ClientTimeout
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import BasicAuth, ClientResponse, ClientSession, ClientTimeout
|
||||
|
||||
logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpSession:
|
||||
def __init__(self):
|
||||
self._session: Optional[ClientSession] = None
|
||||
class HttpResponse:
|
||||
"""Response of a request made by `HttpSession`."""
|
||||
|
||||
async def _ensure_session(self):
|
||||
def __init__(self, resp: ClientResponse) -> None:
|
||||
"""Initialize the response."""
|
||||
self._resp: ClientResponse = resp
|
||||
|
||||
|
||||
class HttpSession:
|
||||
"""Asynchronous HTTP session manager. For internal use only."""
|
||||
|
||||
def __init__(self) -> None: # noqa: D107
|
||||
self._session: ClientSession | None = None
|
||||
|
||||
async def _ensure_session(self) -> None:
|
||||
if self._session is None:
|
||||
logging.debug("Creating aiohttp session")
|
||||
self._session = ClientSession(timeout=ClientTimeout(total=5))
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying session. Should be called when session will no longer be used."""
|
||||
if self._session is not None:
|
||||
logging.debug("Closing aiohttp session")
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Attempt to gracefully close the session.
|
||||
|
||||
Ideally this should be done by manually calling close().
|
||||
"""
|
||||
if self._session is None:
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_soon_threadsafe(loop.create_task, self.close())
|
||||
except RuntimeError: # cannot await closure
|
||||
pass
|
||||
|
||||
async def request(self, method: str, url: str, auth: tuple[str] = None, **kwargs):
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
auth: tuple[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ClientResponse:
|
||||
"""Make an HTTP request.
|
||||
|
||||
Keyword arguments will directly be passed to `aiohttp.ClientSession.request`.
|
||||
"""
|
||||
await self._ensure_session()
|
||||
|
||||
basic_auth = None
|
||||
if auth is not None:
|
||||
basic_auth = BasicAuth(auth[0], auth[1])
|
||||
|
||||
return self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs)
|
||||
return await self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs)
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
async def get(self, url: str, **kwargs: Any) -> ClientResponse:
|
||||
"""Alias for `HttpSession.request("GET", ...)`."""
|
||||
return await self.request("GET", url, **kwargs)
|
||||
|
||||
async def post(self, url: str, **kwargs):
|
||||
async def post(self, url: str, **kwargs: Any) -> ClientResponse:
|
||||
"""Alias for `HttpSession.request("POST", ...)`."""
|
||||
return await self.request("POST", url, **kwargs)
|
||||
|
||||
@@ -1,49 +1,73 @@
|
||||
"""Module to work with private and public keys as used in FindMy accessories."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import hashlib
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
|
||||
|
||||
class KeyPair:
|
||||
def __init__(self, private_key: bytes):
|
||||
"""A private-public keypair for a trackable FindMy accessory."""
|
||||
|
||||
def __init__(self, private_key: bytes) -> None:
|
||||
"""Initialize the `KeyPair` with the private key bytes."""
|
||||
priv_int = int.from_bytes(private_key, "big")
|
||||
self._priv_key = ec.derive_private_key(priv_int, ec.SECP224R1(), default_backend())
|
||||
self._priv_key = ec.derive_private_key(
|
||||
priv_int,
|
||||
ec.SECP224R1(),
|
||||
default_backend(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "KeyPair":
|
||||
"""Generate a new random `KeyPair`."""
|
||||
return cls(secrets.token_bytes(28))
|
||||
|
||||
@classmethod
|
||||
def from_b64(cls, key_b64: str) -> "KeyPair":
|
||||
"""Import an existing `KeyPair` from its base64-encoded representation.
|
||||
|
||||
Same format as returned by `KeyPair.private_key_b64`.
|
||||
"""
|
||||
return cls(base64.b64decode(key_b64))
|
||||
|
||||
@property
|
||||
def private_key_bytes(self) -> bytes:
|
||||
"""Return the private key as bytes."""
|
||||
key_bytes = self._priv_key.private_numbers().private_value
|
||||
return int.to_bytes(key_bytes, 28, "big")
|
||||
|
||||
@property
|
||||
def private_key_b64(self) -> str:
|
||||
"""Return the private key as a base64-encoded string.
|
||||
|
||||
Can be re-imported using `KeyPair.from_b64`.
|
||||
"""
|
||||
return base64.b64encode(self.private_key_bytes).decode("ascii")
|
||||
|
||||
@property
|
||||
def adv_key_bytes(self) -> bytes:
|
||||
"""Return the advertised (public) key as bytes."""
|
||||
key_bytes = self._priv_key.public_key().public_numbers().x
|
||||
return int.to_bytes(key_bytes, 28, "big")
|
||||
|
||||
@property
|
||||
def adv_key_b64(self) -> str:
|
||||
"""Return the advertised (public) key as a base64-encoded string."""
|
||||
return base64.b64encode(self.adv_key_bytes).decode("ascii")
|
||||
|
||||
@property
|
||||
def hashed_adv_key_bytes(self) -> bytes:
|
||||
"""Return the hashed advertised (public) key as bytes."""
|
||||
return hashlib.sha256(self.adv_key_bytes).digest()
|
||||
|
||||
@property
|
||||
def hashed_adv_key_b64(self) -> str:
|
||||
"""Return the hashed advertised (public) key as a base64-encoded string."""
|
||||
return base64.b64encode(self.hashed_adv_key_bytes).decode("ascii")
|
||||
|
||||
def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes:
|
||||
"""Do a Diffie-Hellman key exchange using another EC public key."""
|
||||
return self._priv_key.exchange(ec.ECDH(), other_pub_key)
|
||||
|
||||
@@ -1,27 +1,37 @@
|
||||
"""Module providing functionality to look up location reports."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from .keys import KeyPair
|
||||
from .http import HttpSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .keys import KeyPair
|
||||
|
||||
_session = HttpSession()
|
||||
|
||||
|
||||
class ReportsError(RuntimeError):
|
||||
pass
|
||||
"""Raised when an error occurs while looking up reports."""
|
||||
|
||||
|
||||
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
||||
eph_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP224R1(), payload[5:62])
|
||||
eph_key = ec.EllipticCurvePublicKey.from_encoded_point(
|
||||
ec.SECP224R1(),
|
||||
payload[5:62],
|
||||
)
|
||||
shared_key = key.dh_exchange(eph_key)
|
||||
symmetric_key = hashlib.sha256(shared_key + b"\x00\x00\x00\x01" + payload[5:62]).digest()
|
||||
symmetric_key = hashlib.sha256(
|
||||
shared_key + b"\x00\x00\x00\x01" + payload[5:62],
|
||||
).digest()
|
||||
|
||||
decryption_key = symmetric_key[:16]
|
||||
iv = symmetric_key[16:]
|
||||
@@ -29,13 +39,17 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
|
||||
tag = payload[72:]
|
||||
|
||||
decryptor = Cipher(
|
||||
algorithms.AES(decryption_key), modes.GCM(iv, tag), default_backend()
|
||||
algorithms.AES(decryption_key),
|
||||
modes.GCM(iv, tag),
|
||||
default_backend(),
|
||||
).decryptor()
|
||||
return decryptor.update(enc_data) + decryptor.finalize()
|
||||
|
||||
|
||||
class KeyReport:
|
||||
def __init__(
|
||||
"""Location report corresponding to a certain `KeyPair`."""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
key: KeyPair,
|
||||
publish_date: datetime,
|
||||
@@ -45,7 +59,8 @@ class KeyReport:
|
||||
lng: float,
|
||||
confidence: int,
|
||||
status: int,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize a `KeyReport`. You should probably use `KeyReport.from_payload` instead."""
|
||||
self._key = key
|
||||
self._publish_date = publish_date
|
||||
self._timestamp = timestamp
|
||||
@@ -58,43 +73,59 @@ class KeyReport:
|
||||
self._status = status
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
def key(self) -> KeyPair:
|
||||
"""The `KeyPair` corresponding to this location report."""
|
||||
return self._key
|
||||
|
||||
@property
|
||||
def published_at(self):
|
||||
def published_at(self) -> datetime:
|
||||
"""The `datetime` when this report was published by a device."""
|
||||
return self._publish_date
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
def timestamp(self) -> datetime:
|
||||
"""The `datetime` when this report was recorded by a device."""
|
||||
return self._timestamp
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> str:
|
||||
"""Description of the location report as published by Apple."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def latitude(self):
|
||||
def latitude(self) -> float:
|
||||
"""Latitude of the location of this report."""
|
||||
return self._lat
|
||||
|
||||
@property
|
||||
def longitude(self):
|
||||
def longitude(self) -> float:
|
||||
"""Longitude of the location of this report."""
|
||||
return self._lng
|
||||
|
||||
@property
|
||||
def confidence(self):
|
||||
def confidence(self) -> int:
|
||||
"""Confidence of the location of this report."""
|
||||
return self._confidence
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
def status(self) -> int:
|
||||
"""Status byte of the accessory as recorded by a device, as an integer."""
|
||||
return self._status
|
||||
|
||||
@classmethod
|
||||
def from_payload(
|
||||
cls, key: KeyPair, publish_date: datetime, description: str, payload: bytes
|
||||
) -> "KeyReport":
|
||||
cls,
|
||||
key: KeyPair,
|
||||
publish_date: datetime,
|
||||
description: str,
|
||||
payload: bytes,
|
||||
) -> KeyReport:
|
||||
"""Create a `KeyReport` from fields and a payload as reported by Apple.
|
||||
|
||||
Requires a `KeyPair` to decrypt the report's payload.
|
||||
"""
|
||||
timestamp_int = int.from_bytes(payload[0:4], "big") + (60 * 60 * 24 * 11323)
|
||||
timestamp = datetime.utcfromtimestamp(timestamp_int)
|
||||
timestamp = datetime.fromtimestamp(timestamp_int, tz=timezone.utc)
|
||||
|
||||
data = _decrypt_payload(payload, key)
|
||||
latitude = struct.unpack(">i", data[0:4])[0] / 10000000
|
||||
@@ -113,33 +144,40 @@ class KeyReport:
|
||||
status,
|
||||
)
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: KeyReport) -> bool:
|
||||
"""Compare against another `KeyReport`.
|
||||
|
||||
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
|
||||
timestamp is strictly less than the other report.
|
||||
"""
|
||||
if isinstance(other, KeyReport):
|
||||
return self.timestamp < other.timestamp
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Human-readable string representation of the location report."""
|
||||
return (
|
||||
f"<KeyReport(key={self._key.hashed_adv_key_b64}, timestamp={self._timestamp},"
|
||||
f" lat={self._lat}, lng={self._lng})>"
|
||||
)
|
||||
|
||||
|
||||
async def fetch_reports(
|
||||
async def fetch_reports( # noqa: PLR0913
|
||||
dsid: str,
|
||||
search_party_token: str,
|
||||
anisette_headers: dict[str, str],
|
||||
date_from: datetime,
|
||||
date_to: datetime,
|
||||
keys: Sequence[KeyPair],
|
||||
):
|
||||
) -> dict[KeyPair, list[KeyReport]]:
|
||||
"""Look up reports for given `KeyPair`s."""
|
||||
start_date = date_from.timestamp() * 1000
|
||||
end_date = date_to.timestamp() * 1000
|
||||
ids = [key.hashed_adv_key_b64 for key in keys]
|
||||
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
|
||||
|
||||
# TODO: do not create a new session every time
|
||||
# probably needs a wrapper class to allow closing the connections
|
||||
# TODO(malmeloo): do not create a new session every time
|
||||
# https://github.com/malmeloo/FindMy.py/issues/3
|
||||
async with await _session.post(
|
||||
"https://gateway.icloud.com/acsnservice/fetch",
|
||||
auth=(dsid, search_party_token),
|
||||
@@ -148,7 +186,8 @@ async def fetch_reports(
|
||||
) as r:
|
||||
resp = await r.json()
|
||||
if not r.ok or resp["statusCode"] != "200":
|
||||
raise ReportsError(f"Failed to fetch reports: {resp['statusCode']}")
|
||||
msg = f"Failed to fetch reports: {resp['statusCode']}"
|
||||
raise ReportsError(msg)
|
||||
await _session.close()
|
||||
|
||||
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys}
|
||||
@@ -156,11 +195,14 @@ async def fetch_reports(
|
||||
|
||||
for report in resp.get("results", []):
|
||||
key = id_to_key[report["id"]]
|
||||
date_published = datetime.utcfromtimestamp(report.get("datePublished", 0) / 1000)
|
||||
date_published = datetime.fromtimestamp(
|
||||
report.get("datePublished", 0) / 1000,
|
||||
tz=timezone.utc,
|
||||
)
|
||||
description = report.get("description", "")
|
||||
payload = base64.b64decode(report["payload"])
|
||||
|
||||
report = KeyReport.from_payload(key, date_published, description, payload)
|
||||
reports[key].append(report)
|
||||
r = KeyReport.from_payload(key, date_published, description, payload)
|
||||
reports[key].append(r)
|
||||
|
||||
return {key: sorted(reps) for key, reps in reports.items()}
|
||||
|
||||
@@ -17,6 +17,19 @@ aiohttp = "^3.9.1"
|
||||
pre-commit = "^3.6.0"
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
"examples/"
|
||||
]
|
||||
|
||||
select = [
|
||||
"ALL",
|
||||
]
|
||||
ignore = [
|
||||
"ANN101", # annotations on `self`
|
||||
"ANN102", # annotations on `cls`
|
||||
"FIX002", # resolving TODOs
|
||||
]
|
||||
|
||||
line-length = 100
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user