Enforce strict typing

This commit is contained in:
Mike A
2024-01-01 21:47:06 +01:00
parent d989af9243
commit 1b18e0dfef
10 changed files with 604 additions and 204 deletions

View File

@@ -2,8 +2,13 @@ import json
import logging
import os
from findmy import AppleAccount, LoginState, SmsSecondFactor, RemoteAnisetteProvider
from findmy import keys
from findmy import (
AppleAccount,
LoginState,
RemoteAnisetteProvider,
SmsSecondFactor,
keys,
)
# URL to (public or local) anisette server
ANISETTE_SERVER = "http://localhost:6969"
@@ -50,7 +55,7 @@ def fetch_reports(lookup_key):
# Save / restore account logic
if os.path.isfile("account.json"):
with open("account.json", "r") as f:
with open("account.json") as f:
acc.restore(json.load(f))
else:
login(acc)

View File

@@ -6,10 +6,10 @@ import os
from findmy import (
AsyncAppleAccount,
LoginState,
SmsSecondFactor,
RemoteAnisetteProvider,
SmsSecondFactor,
keys,
)
from findmy import keys
# URL to (public or local) anisette server
ANISETTE_SERVER = "http://localhost:6969"
@@ -57,7 +57,7 @@ async def fetch_reports(lookup_key):
try:
# Save / restore account logic
if os.path.isfile("account.json"):
with open("account.json", "r") as f:
with open("account.json") as f:
acc.restore(json.load(f))
else:
await login(acc)

View File

@@ -1,10 +1,11 @@
"""A package providing everything you need to query Apple's FindMy network."""
from .account import AppleAccount, AsyncAppleAccount, LoginState, SmsSecondFactor
from .anisette import RemoteAnisetteProvider
__all__ = (
AppleAccount,
AsyncAppleAccount,
LoginState,
SmsSecondFactor,
RemoteAnisetteProvider,
"AppleAccount",
"AsyncAppleAccount",
"LoginState",
"SmsSecondFactor",
"RemoteAnisetteProvider",
)

View File

@@ -1,3 +1,6 @@
"""Module containing most of the code necessary to interact with an Apple account."""
from __future__ import annotations
import asyncio
import base64
import hashlib
@@ -6,22 +9,23 @@ import json
import logging
import plistlib
import uuid
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import Optional, TypedDict, Any
from typing import Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar
import bs4
import srp._pysrp as srp
from cryptography.hazmat.primitives import padding, hashes
from cryptography.hazmat.primitives import hashes, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .anisette import AnisetteProvider
from .base import BaseAppleAccount, BaseSecondFactorMethod, LoginState
from .http import HttpSession
from .keys import KeyPair
from .reports import fetch_reports
from .reports import KeyReport, fetch_reports
if TYPE_CHECKING:
from .anisette import BaseAnisetteProvider
from .keys import KeyPair
logging.getLogger(__name__)
@@ -29,25 +33,28 @@ srp.rfc5054_enable()
srp.no_username_in_x()
class AccountInfo(TypedDict):
class _AccountInfo(TypedDict):
account_name: str
first_name: str
last_name: str
class LoginException(Exception):
pass
class LoginError(Exception):
"""Raised when an error occurs during login, such as when the password is incorrect."""
class InvalidStateException(RuntimeError):
pass
class InvalidStateError(RuntimeError):
"""Raised when a method is used that is in conflict with the internal account state.
For example: calling `BaseAppleAccount.login` while already logged in.
"""
class ExportRestoreError(ValueError):
pass
"""Raised when an error occurs while exporting or restoring the account's current state."""
def _load_plist(data: bytes) -> Any:
def _load_plist(data: bytes) -> Any: # noqa: ANN401
plist_header = (
b"<?xml version='1.0' encoding='UTF-8'?>"
b"<!DOCTYPE plist PUBLIC '-//Apple//DTD PLIST 1.0//EN' 'http://www.apple.com/DTDs/PropertyList-1.0.dtd'>"
@@ -89,20 +96,26 @@ def _extract_phone_numbers(html: str) -> list[dict]:
soup = bs4.BeautifulSoup(html, features="html.parser")
data_elem = soup.find("script", **{"class": "boot_args"})
if not data_elem:
raise RuntimeError("Could not find HTML element containing phone numbers")
msg = "Could not find HTML element containing phone numbers"
raise RuntimeError(msg)
data = json.loads(data_elem.text)
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
def _require_login_state(*states: LoginState):
def decorator(func):
F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any])
def _require_login_state(*states: LoginState) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapper(acc: "BaseAppleAccount", *args, **kwargs):
def wrapper(acc: BaseAppleAccount, *args, **kwargs):
if acc.login_state not in states:
raise InvalidStateException(
f"Invalid login state! Currently: {acc.login_state} but should be one of: {states}"
msg = (
f"Invalid login state! Currently: {acc.login_state}"
f" but should be one of: {states}"
)
raise InvalidStateError(msg)
return func(acc, *args, **kwargs)
@@ -112,93 +125,164 @@ def _require_login_state(*states: LoginState):
class AsyncSmsSecondFactor(BaseSecondFactorMethod):
def __init__(self, account: "AsyncAppleAccount", number_id: int, phone_number: str):
"""An async implementation of a second-factor method."""
def __init__(
self,
account: AsyncAppleAccount,
number_id: int,
phone_number: str,
) -> None:
"""Initialize the second factor method.
Should not be done manually; use `BaseAppleAccount.get_2fa_methods` instead.
"""
super().__init__(account)
self._phone_number_id: int = number_id
self._phone_number: str = phone_number
@property
def phone_number_id(self):
def phone_number_id(self) -> int:
"""The phone number's ID. You most likely don't need this."""
return self._phone_number_id
@property
def phone_number(self):
def phone_number(self) -> str:
"""The 2FA method's phone number.
May be masked using unicode characters; should only be used for identification purposes.
"""
return self._phone_number
async def request(self):
async def request(self) -> None:
"""Request an SMS to the corresponding phone number containing a 2FA code."""
return await self.account.sms_2fa_request(self._phone_number_id)
async def submit(self, code: str) -> LoginState:
"""See `BaseSecondFactorMethod.submit`."""
return await self.account.sms_2fa_submit(self._phone_number_id, code)
class SmsSecondFactor(BaseSecondFactorMethod):
def __init__(self, account: "AppleAccount", number_id: int, phone_number: str):
"""A sync implementation of `BaseSecondFactorMethod`.
Uses `AsyncSmsSecondFactor` internally.
"""
def __init__(
self,
account: AppleAccount,
number_id: int,
phone_number: str,
) -> None:
"""See `AsyncSmsSecondFactor.__init__`."""
super().__init__(account)
self._phone_number_id: int = number_id
self._phone_number: str = phone_number
@property
def phone_number(self):
def phone_number_id(self) -> int:
"""See `AsyncSmsSecondFactor.phone_number_id`."""
return self._phone_number_id
@property
def phone_number(self) -> str:
"""See `AsyncSmsSecondFactor.phone_number`."""
return self._phone_number
def request(self) -> None:
"""See `AsyncSmsSecondFactor.request`."""
return self.account.sms_2fa_request(self._phone_number_id)
def submit(self, code: str) -> LoginState:
"""See `AsyncSmsSecondFactor.submit`."""
return self.account.sms_2fa_submit(self._phone_number_id, code)
class AsyncAppleAccount(BaseAppleAccount):
def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None):
self._anisette: AnisetteProvider = anisette
"""An async implementation of `BaseAppleAccount`."""
def __init__(
self,
anisette: BaseAnisetteProvider,
user_id: str | None = None,
device_id: str | None = None,
) -> None:
"""Initialize the apple account.
:param anisette: An instance of `AsyncAnisetteProvider`.
:param user_id: An optional user ID to use. Will be auto-generated if missing.
:param device_id: An optional device ID to use. Will be auto-generated if missing.
"""
self._anisette: BaseAnisetteProvider = anisette
self._uid: str = user_id or str(uuid.uuid4())
self._devid: str = device_id or str(uuid.uuid4())
self._username: Optional[str] = None
self._password: Optional[str] = None
self._username: str | None = None
self._password: str | None = None
self._login_state: LoginState = LoginState.LOGGED_OUT
self._login_state_data: dict = {}
self._account_info: Optional[AccountInfo] = None
self._account_info: _AccountInfo | None = None
self._http = HttpSession()
def _set_login_state(self, state: LoginState, data: Optional[dict] = None) -> LoginState:
def _set_login_state(
self,
state: LoginState,
data: dict | None = None,
) -> LoginState:
# clear account info if downgrading state (e.g. LOGGED_IN -> LOGGED_OUT)
if state < self._login_state:
logging.debug("Clearing cached account information")
self._account_info = None
logging.info(f"Transitioning login state: {self._login_state} -> {state}")
logging.info("Transitioning login state: %s -> %s", self._login_state, state)
self._login_state = state
self._login_state_data = data or {}
return state
@property
def login_state(self):
def login_state(self) -> LoginState:
"""See `BaseAppleAccount.login_state`."""
return self._login_state
@property
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
def account_name(self):
@_require_login_state(
LoginState.LOGGED_IN,
LoginState.AUTHENTICATED,
LoginState.REQUIRE_2FA,
)
def account_name(self) -> str | None:
"""See `BaseAppleAccount.account_name`."""
return self._account_info["account_name"] if self._account_info else None
@property
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
def first_name(self):
@_require_login_state(
LoginState.LOGGED_IN,
LoginState.AUTHENTICATED,
LoginState.REQUIRE_2FA,
)
def first_name(self) -> str | None:
"""See `BaseAppleAccount.first_name`."""
return self._account_info["first_name"] if self._account_info else None
@property
@_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA)
def last_name(self):
@_require_login_state(
LoginState.LOGGED_IN,
LoginState.AUTHENTICATED,
LoginState.REQUIRE_2FA,
)
def last_name(self) -> str | None:
"""See `BaseAppleAccount.last_name`."""
return self._account_info["last_name"] if self._account_info else None
def export(self) -> dict:
"""See `BaseAppleAccount.export`."""
return {
"ids": {"uid": self._uid, "devid": self._devid},
"account": {
@@ -212,7 +296,8 @@ class AsyncAppleAccount(BaseAppleAccount):
},
}
def restore(self, data: dict):
def restore(self, data: dict) -> None:
"""See `BaseAppleAccount.restore`."""
try:
self._uid = data["ids"]["uid"]
self._devid = data["ids"]["devid"]
@@ -224,14 +309,20 @@ class AsyncAppleAccount(BaseAppleAccount):
self._login_state = LoginState(data["login_state"]["state"])
self._login_state_data = data["login_state"]["data"]
except KeyError as e:
raise ExportRestoreError(f"Failed to restore account data: {e}")
msg = f"Failed to restore account data: {e}"
raise ExportRestoreError(msg) from None
async def close(self):
async def close(self) -> None:
"""Close any sessions or other resources in use by this object.
Should be called when the object will no longer be used.
"""
await self._anisette.close()
await self._http.close()
@_require_login_state(LoginState.LOGGED_OUT)
async def login(self, username: str, password: str) -> LoginState:
"""See `BaseAppleAccount.login`."""
# LOGGED_OUT -> (REQUIRE_2FA or AUTHENTICATED)
new_state = await self._gsa_authenticate(username, password)
if new_state == LoginState.REQUIRE_2FA: # pass control back to handle 2FA
@@ -242,30 +333,40 @@ class AsyncAppleAccount(BaseAppleAccount):
@_require_login_state(LoginState.REQUIRE_2FA)
async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
"""See `BaseAppleAccount.get_2fa_methods`."""
methods: list[BaseSecondFactorMethod] = []
# sms
auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth")
try:
phone_numbers = _extract_phone_numbers(auth_page)
methods.extend(
AsyncSmsSecondFactor(
self,
number.get("id"),
number.get("numberWithDialCode"),
)
for number in phone_numbers
)
except RuntimeError:
logging.warning("Unable to extract phone numbers from login page")
methods.extend(
AsyncSmsSecondFactor(self, number.get("id"), number.get("numberWithDialCode"))
for number in phone_numbers
)
return methods
@_require_login_state(LoginState.REQUIRE_2FA)
async def sms_2fa_request(self, phone_number_id: int):
async def sms_2fa_request(self, phone_number_id: int) -> None:
"""See `BaseAppleAccount.sms_2fa_request`."""
data = {"phoneNumber": {"id": phone_number_id}, "mode": "sms"}
await self._sms_2fa_request("PUT", "https://gsa.apple.com/auth/verify/phone", data)
await self._sms_2fa_request(
"PUT",
"https://gsa.apple.com/auth/verify/phone",
data,
)
@_require_login_state(LoginState.REQUIRE_2FA)
async def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
"""See `BaseAppleAccount.sms_2fa_submit`."""
data = {
"phoneNumber": {"id": phone_number_id},
"securityCode": {"code": str(code)},
@@ -273,19 +374,28 @@ class AsyncAppleAccount(BaseAppleAccount):
}
await self._sms_2fa_request(
"POST", "https://gsa.apple.com/auth/verify/phone/securitycode", data
"POST",
"https://gsa.apple.com/auth/verify/phone/securitycode",
data,
)
# REQUIRE_2FA -> AUTHENTICATED
new_state = await self._gsa_authenticate()
if new_state != LoginState.AUTHENTICATED:
raise LoginException(f"Unexpected state after submitting 2FA: {new_state}")
msg = f"Unexpected state after submitting 2FA: {new_state}"
raise LoginError(msg)
# AUTHENTICATED -> LOGGED_IN
return await self._login_mobileme()
@_require_login_state(LoginState.LOGGED_IN)
async def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
async def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
"""See `BaseAppleAccount.fetch_reports`."""
anisette_headers = await self.get_anisette_headers()
return await fetch_reports(
@@ -302,57 +412,67 @@ class AsyncAppleAccount(BaseAppleAccount):
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
):
end = datetime.now()
) -> dict[KeyPair, list[KeyReport]]:
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours)
return await self.fetch_reports(keys, start, end)
@_require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA)
async def _gsa_authenticate(
self, username: Optional[str] = None, password: Optional[str] = None
self,
username: str | None = None,
password: str | None = None,
) -> LoginState:
# use stored values for re-authentication
self._username = username or self._username
self._password = password or self._password
logging.info(f"Attempting authentication for user {self._username}")
logging.info("Attempting authentication for user %s", self._username)
if not self._username or not self._password:
raise ValueError("No username or password to log in")
msg = "No username or password specified"
raise ValueError(msg)
logging.debug("Starting authentication with username")
usr = srp.User(self._username, b"", hash_alg=srp.SHA256, ng_type=srp.NG_2048)
_, a2k = usr.start_authentication()
r = await self._gsa_request(
{"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"}
{"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"},
)
logging.debug("Verifying response to auth request")
if r["Status"].get("ec") != 0:
message = r["Status"].get("em")
raise LoginException(f"Email verify failed: {message}")
msg = "Email verify failed: " + r["Status"].get("em")
raise LoginError(msg)
sp = r.get("sp")
if sp != "s2k":
raise LoginException(f"This implementation only supports s2k. Server returned {sp}")
msg = f"This implementation only supports s2k. Server returned {sp}"
raise LoginError(msg)
logging.debug("Attempting password challenge")
usr.p = _encrypt_password(self._password, r["s"], r["i"])
m1 = usr.process_challenge(r["s"], r["B"])
if m1 is None:
raise LoginException("Failed to process challenge")
r = await self._gsa_request({"c": r["c"], "M1": m1, "u": self._username, "o": "complete"})
msg = "Failed to process challenge"
raise LoginError(msg)
r = await self._gsa_request(
{"c": r["c"], "M1": m1, "u": self._username, "o": "complete"},
)
logging.debug("Verifying password challenge response")
if r["Status"].get("ec") != 0:
message = r["Status"].get("em")
raise LoginException(f"Password authentication failed: {message}")
msg = "Password authentication failed: " + r["Status"].get("em")
raise LoginError(msg)
usr.verify_session(r.get("M2"))
if not usr.authenticated():
raise LoginException("Failed to verify session")
msg = "Failed to verify session"
raise LoginError(msg)
logging.debug("Decrypting SPD data in response")
@@ -360,33 +480,36 @@ class AsyncAppleAccount(BaseAppleAccount):
spd = _load_plist(spd)
logging.debug("Received account information")
self._account_info: AccountInfo = {
self._account_info: _AccountInfo = {
"account_name": spd.get("acname"),
"first_name": spd.get("fn"),
"last_name": spd.get("ln"),
}
# TODO: support trusted device auth (need account to test)
# TODO(malmeloo): support trusted device auth (need account to test)
# https://github.com/malmeloo/FindMy.py/issues/1
au = r["Status"].get("au")
if au in ("secondaryAuth",):
logging.info(f"Detected 2FA requirement: {au}")
logging.info("Detected 2FA requirement: %s", au)
return self._set_login_state(
LoginState.REQUIRE_2FA,
{"adsid": spd["adsid"], "idms_token": spd["GsIdmsToken"]},
)
elif au is not None:
raise LoginException(f"Unknown auth value: {au}")
if au is not None:
msg = f"Unknown auth value: {au}"
raise LoginError(msg)
logging.info("GSA authentication successful")
idms_pet = spd.get("t", {}).get("com.apple.gs.idms.pet", {}).get("token", "")
return self._set_login_state(
LoginState.AUTHENTICATED, {"idms_pet": idms_pet, "adsid": spd["adsid"]}
LoginState.AUTHENTICATED,
{"idms_pet": idms_pet, "adsid": spd["adsid"]},
)
@_require_login_state(LoginState.AUTHENTICATED)
async def _login_mobileme(self):
async def _login_mobileme(self) -> LoginState:
logging.info("Logging into com.apple.mobileme")
data = plistlib.dumps(
{
@@ -394,7 +517,7 @@ class AsyncAppleAccount(BaseAppleAccount):
"delegates": {"com.apple.mobileme": {}},
"password": self._login_state_data["idms_pet"],
"client-id": self._uid,
}
},
)
headers = {
@@ -417,15 +540,21 @@ class AsyncAppleAccount(BaseAppleAccount):
mobileme_data = resp.get("delegates", {}).get("com.apple.mobileme", {})
status = mobileme_data.get("status")
if status != 0:
message = mobileme_data.get("status-message")
raise LoginException(f"com.apple.mobileme login failed with status {status}: {message}")
status_message = mobileme_data.get("status-message")
msg = f"com.apple.mobileme login failed with status {status}: {status_message}"
raise LoginError(msg)
return self._set_login_state(
LoginState.LOGGED_IN,
{"dsid": resp["dsid"], "mobileme_data": mobileme_data["service-data"]},
)
async def _sms_2fa_request(self, method: str, url: str, data: Optional[dict] = None) -> str:
async def _sms_2fa_request(
self,
method: str,
url: str,
data: dict | None = None,
) -> str:
adsid = self._login_state_data["adsid"]
idms_token = self._login_state_data["idms_token"]
identity_token = base64.b64encode((adsid + ":" + idms_token).encode()).decode()
@@ -441,13 +570,19 @@ class AsyncAppleAccount(BaseAppleAccount):
}
headers.update(await self.get_anisette_headers())
async with await self._http.request(method, url, json=data, headers=headers) as r:
async with await self._http.request(
method,
url,
json=data,
headers=headers,
) as r:
if not r.ok:
raise LoginException(f"HTTP request failed: {r.status_code}")
msg = f"HTTP request failed: {r.status_code}"
raise LoginError(msg)
return await r.text()
async def _gsa_request(self, params):
async def _gsa_request(self, params: dict[str, Any]) -> Any:
request_data = {
"cpd": {
"bootstrap": True,
@@ -455,7 +590,7 @@ class AsyncAppleAccount(BaseAppleAccount):
"pbe": False,
"prkgen": True,
"svct": "iCloud",
}
},
}
request_data["cpd"].update(await self.get_anisette_headers())
request_data.update(params)
@@ -482,11 +617,23 @@ class AsyncAppleAccount(BaseAppleAccount):
return _load_plist(content)["Response"]
async def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
"""See `BaseAppleAccount.get_anisette_headers`."""
return await self._anisette.get_headers(self._uid, self._devid, serial)
class AppleAccount(BaseAppleAccount):
def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None):
"""A sync implementation of `BaseappleAccount`.
Uses `AsyncappleAccount` internally.
"""
def __init__(
self,
anisette: BaseAnisetteProvider,
user_id: str | None = None,
device_id: str | None = None,
) -> None:
"""See `AsyncAppleAccount.__init__`."""
self._asyncacc = AsyncAppleAccount(anisette, user_id, device_id)
try:
@@ -496,36 +643,45 @@ class AppleAccount(BaseAppleAccount):
asyncio.set_event_loop(self._loop)
def __del__(self) -> None:
"""Gracefully close the async instance's session when garbage collected."""
coro = self._asyncacc.close()
return self._loop.run_until_complete(coro)
@property
def login_state(self):
def login_state(self) -> LoginState:
"""See `AsyncAppleAccount.login_state`."""
return self._asyncacc.login_state
@property
def account_name(self):
def account_name(self) -> str:
"""See `AsyncAppleAccount.login_state`."""
return self._asyncacc.account_name
@property
def first_name(self):
def first_name(self) -> str | None:
"""See `AsyncAppleAccount.first_name`."""
return self._asyncacc.first_name
@property
def last_name(self):
def last_name(self) -> str | None:
"""See `AsyncAppleAccount.last_name`."""
return self._asyncacc.last_name
def export(self) -> dict:
"""See `AsyncAppleAccount.export`."""
return self._asyncacc.export()
def restore(self, data: dict):
def restore(self, data: dict) -> None:
"""See `AsyncAppleAccount.restore`."""
return self._asyncacc.restore(data)
def login(self, username: str, password: str) -> LoginState:
"""See `AsyncAppleAccount.login`."""
coro = self._asyncacc.login(username, password)
return self._loop.run_until_complete(coro)
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
"""See `AsyncAppleAccount.get_2fa_methods`."""
coro = self._asyncacc.get_2fa_methods()
methods = self._loop.run_until_complete(coro)
@@ -534,28 +690,44 @@ class AppleAccount(BaseAppleAccount):
if isinstance(m, AsyncSmsSecondFactor):
res.append(SmsSecondFactor(self, m.phone_number_id, m.phone_number))
else:
raise RuntimeError(
f"Failed to cast 2FA object to sync alternative: {m}. This is a bug, please report it."
msg = (
f"Failed to cast 2FA object to sync alternative: {m}."
f" This is a bug, please report it."
)
raise TypeError(msg)
return res
def sms_2fa_request(self, phone_number_id: int):
def sms_2fa_request(self, phone_number_id: int) -> None:
"""See `AsyncAppleAccount.sms_2fa_request`."""
coro = self._asyncacc.sms_2fa_request(phone_number_id)
return self._loop.run_until_complete(coro)
def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
"""See `AsyncAppleAccount.sms_2fa_submit`."""
coro = self._asyncacc.sms_2fa_submit(phone_number_id, code)
return self._loop.run_until_complete(coro)
def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._loop.run_until_complete(coro)
def fetch_last_reports(self, keys: Sequence[KeyPair], hours: int = 7 * 24):
def fetch_last_reports(
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
) -> dict[KeyPair, list[KeyReport]]:
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._loop.run_until_complete(coro)
def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
coro = self._asyncacc.get_anisette_headers()
"""See `AsyncAppleAccount.get_anisette_headers`."""
coro = self._asyncacc.get_anisette_headers(serial)
return self._loop.run_until_complete(coro)

View File

@@ -1,14 +1,21 @@
"""Module for Anisette header providers."""
from __future__ import annotations
import base64
import locale
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from datetime import datetime, timezone
from .http import HttpSession
def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"):
now = datetime.utcnow()
def _gen_meta_headers(
user_id: str,
device_id: str,
serial: str = "0",
) -> dict[str, str]:
now = datetime.now(tz=timezone.utc)
locale_str = locale.getdefaultlocale()[0] or "en_US"
return {
@@ -23,29 +30,44 @@ def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"):
}
class AnisetteProvider(ABC):
class BaseAnisetteProvider(ABC):
"""Abstract base class for Anisette providers."""
@abstractmethod
async def _get_base_headers(self) -> dict[str, str]:
return NotImplemented
raise NotImplementedError
@abstractmethod
async def close(self):
return NotImplemented
async def close(self) -> None:
"""Close any underlying sessions. Call when the provider will no longer be used."""
raise NotImplementedError
async def get_headers(self, user_id: str, device_id: str, serial: str = "0") -> dict[str, str]:
async def get_headers(
self,
user_id: str,
device_id: str,
serial: str = "0",
) -> dict[str, str]:
"""Retrieve a complete dictionary of Anisette headers.
Consider using `BaseAppleAccount.get_anisette_headers` instead.
"""
base_headers = await self._get_base_headers()
base_headers.update(_gen_meta_headers(user_id, device_id, serial))
return base_headers
class RemoteAnisetteProvider(AnisetteProvider):
def __init__(self, server_url: str):
class RemoteAnisetteProvider(BaseAnisetteProvider):
"""Anisette provider. Fetches headers from a remote Anisette server."""
def __init__(self, server_url: str) -> None:
"""Initialize the provider with URL to te remote server."""
self._server_url = server_url
self._http = HttpSession()
logging.info(f"Using remote anisette server: {self._server_url}")
logging.info("Using remote anisette server: %s", self._server_url)
async def _get_base_headers(self) -> dict[str, str]:
async with await self._http.get(self._server_url) as r:
@@ -56,17 +78,21 @@ class RemoteAnisetteProvider(AnisetteProvider):
"X-Apple-I-MD-M": headers["X-Apple-I-MD-M"],
}
async def close(self):
async def close(self) -> None:
"""See `AnisetteProvider.close`."""
await self._http.close()
# TODO: implement using pyprovision
class LocalAnisetteProvider(AnisetteProvider):
def __init__(self):
pass
# TODO(malmeloo): implement using pyprovision
# https://github.com/malmeloo/FindMy.py/issues/2
class LocalAnisetteProvider(BaseAnisetteProvider):
"""Anisette provider. Generates headers without a remote server using pyprovision."""
def __init__(self) -> None:
"""Initialize the provider."""
async def _get_base_headers(self) -> dict[str, str]:
return NotImplemented
async def close(self):
pass
async def close(self) -> None:
"""See `AnisetteProvider.close`."""

View File

@@ -1,101 +1,185 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence
"""Module that contains base classes for various other modules. For internal use only."""
from __future__ import annotations
from .keys import KeyPair
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Sequence, TypeVar
if TYPE_CHECKING:
from datetime import datetime
from .keys import KeyPair
from .reports import KeyReport
class LoginState(Enum):
"""Enum of possible login states. Used for `AppleAccount`'s internal state machine."""
LOGGED_OUT = 0
REQUIRE_2FA = 1
AUTHENTICATED = 2
LOGGED_IN = 3
def __lt__(self, other):
def __lt__(self, other: LoginState) -> bool:
"""Compare against another `LoginState`.
A `LoginState` is said to be "less than" another `LoginState` iff it is in
an "earlier" stage of the login process, going from LOGGED_OUT to LOGGED_IN.
"""
if isinstance(other, LoginState):
return self.value < other.value
return NotImplemented
def __repr__(self):
def __repr__(self) -> str:
"""Human-readable string representation of the state."""
return self.__str__()
T = TypeVar("T", bound="BaseAppleAccount")
class BaseSecondFactorMethod(ABC):
def __init__(self, account: "BaseAppleAccount"):
self._account = account
"""Base class for a second-factor authentication method for an Apple account."""
def __init__(self, account: T) -> None:
"""Initialize the second-factor method."""
self._account: T = account
@property
def account(self):
def account(self) -> T:
"""The account associated with the second-factor method."""
return self._account
@abstractmethod
def request(self) -> None:
raise NotImplementedError()
"""Put in a request for the second-factor challenge.
Exact meaning is up to the implementing class.
"""
raise NotImplementedError
@abstractmethod
def submit(self, code: str) -> LoginState:
raise NotImplementedError()
"""Submit a code to complete the second-factor challenge."""
raise NotImplementedError
class BaseAppleAccount(ABC):
@property
@abstractmethod
def login_state(self):
return NotImplemented
"""Base class for an Apple account."""
@property
@abstractmethod
def account_name(self):
return NotImplemented
def login_state(self) -> LoginState:
"""The current login state of the account."""
raise NotImplementedError
@property
@abstractmethod
def first_name(self):
return NotImplemented
def account_name(self) -> str:
"""The name of the account as reported by Apple.
This is usually an e-mail address.
May be None in some cases, such as when not logged in.
"""
raise NotImplementedError
@property
@abstractmethod
def last_name(self):
return NotImplemented
def first_name(self) -> str | None:
"""First name of the account holder as reported by Apple.
May be None in some cases, such as when not logged in.
"""
raise NotImplementedError
@property
@abstractmethod
def last_name(self) -> str | None:
"""Last name of the account holder as reported by Apple.
May be None in some cases, such as when not logged in.
"""
raise NotImplementedError
@abstractmethod
def export(self) -> dict:
return NotImplemented
"""Export a representation of the current state of the account as a dictionary.
The output of this method is guaranteed to be JSON-serializable, and passing
the return value of this function as an argument to `BaseAppleAccount.restore`
will always result in an exact copy of the internal state as it was when exported.
This method is especially useful to avoid having to keep going through the login flow.
"""
raise NotImplementedError
@abstractmethod
def restore(self, data: dict):
return NotImplemented
def restore(self, data: dict) -> None:
"""Restore a previous export of the internal state of the account.
See `BaseAppleAccount.export` for more information.
"""
raise NotImplementedError
@abstractmethod
def login(self, username: str, password: str) -> LoginState:
return NotImplemented
"""Log in to an Apple account using a username and password."""
raise NotImplementedError
@abstractmethod
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
return NotImplemented
"""Get a list of 2FA methods that can be used as a secondary challenge.
Currently, only SMS-based 2FA methods are supported.
"""
raise NotImplementedError
@abstractmethod
def sms_2fa_request(self, phone_number_id: int):
return NotImplemented
def sms_2fa_request(self, phone_number_id: int) -> None:
"""Request a 2FA code to be sent to a specific phone number ID.
Consider using `BaseSecondFactorMethod.request` instead.
"""
raise NotImplementedError
@abstractmethod
def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState:
return NotImplemented
"""Submit a 2FA code that was sent to a specific phone number ID.
Consider using `BaseSecondFactorMethod.submit` instead.
"""
raise NotImplementedError
@abstractmethod
def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime):
return NotImplemented
def fetch_reports(
self,
keys: Sequence[KeyPair],
date_from: datetime,
date_to: datetime,
) -> dict[KeyPair, list[KeyReport]]:
"""Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`.
Returns a dictionary mapping `KeyPair`s to a list of their location reports.
"""
raise NotImplementedError
@abstractmethod
def fetch_last_reports(
self,
keys: Sequence[KeyPair],
hours: int = 7 * 24,
):
return NotImplemented
) -> dict[KeyPair, list[KeyReport]]:
"""Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours.
Utility method as an alternative to using `BaseAppleAccount.fetch_reports` directly.
"""
raise NotImplementedError
@abstractmethod
def get_anisette_headers(self, serial: str = "0") -> dict[str, str]:
return NotImplemented
"""Retrieve a complete dictionary of Anisette headers.
Utility method for `AnisetteProvider.get_headers` using this account's user and device ID.
"""
raise NotImplementedError

View File

@@ -1,45 +1,78 @@
import logging
from typing import Optional
import asyncio
"""Module to simplify asynchronous HTTP calls. For internal use only."""
from __future__ import annotations
from aiohttp import ClientSession, BasicAuth, ClientTimeout
import asyncio
import logging
from typing import Any
from aiohttp import BasicAuth, ClientResponse, ClientSession, ClientTimeout
logging.getLogger(__name__)
class HttpSession:
def __init__(self):
self._session: Optional[ClientSession] = None
class HttpResponse:
"""Response of a request made by `HttpSession`."""
async def _ensure_session(self):
def __init__(self, resp: ClientResponse) -> None:
"""Initialize the response."""
self._resp: ClientResponse = resp
class HttpSession:
"""Asynchronous HTTP session manager. For internal use only."""
def __init__(self) -> None: # noqa: D107
self._session: ClientSession | None = None
async def _ensure_session(self) -> None:
if self._session is None:
logging.debug("Creating aiohttp session")
self._session = ClientSession(timeout=ClientTimeout(total=5))
async def close(self):
async def close(self) -> None:
"""Close the underlying session. Should be called when session will no longer be used."""
if self._session is not None:
logging.debug("Closing aiohttp session")
await self._session.close()
self._session = None
def __del__(self) -> None:
"""Attempt to gracefully close the session.
Ideally this should be done by manually calling close().
"""
if self._session is None:
return
try:
loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(loop.create_task, self.close())
except RuntimeError: # cannot await closure
pass
async def request(self, method: str, url: str, auth: tuple[str] = None, **kwargs):
async def request(
self,
method: str,
url: str,
auth: tuple[str] | None = None,
**kwargs: Any,
) -> ClientResponse:
"""Make an HTTP request.
Keyword arguments will directly be passed to `aiohttp.ClientSession.request`.
"""
await self._ensure_session()
basic_auth = None
if auth is not None:
basic_auth = BasicAuth(auth[0], auth[1])
return self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs)
return await self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs)
async def get(self, url: str, **kwargs):
async def get(self, url: str, **kwargs: Any) -> ClientResponse:
"""Alias for `HttpSession.request("GET", ...)`."""
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs):
async def post(self, url: str, **kwargs: Any) -> ClientResponse:
"""Alias for `HttpSession.request("POST", ...)`."""
return await self.request("POST", url, **kwargs)

View File

@@ -1,49 +1,73 @@
"""Module to work with private and public keys as used in FindMy accessories."""
import base64
import hashlib
import secrets
from cryptography.hazmat.backends import default_backend
import hashlib
from cryptography.hazmat.primitives.asymmetric import ec
class KeyPair:
def __init__(self, private_key: bytes):
"""A private-public keypair for a trackable FindMy accessory."""
def __init__(self, private_key: bytes) -> None:
"""Initialize the `KeyPair` with the private key bytes."""
priv_int = int.from_bytes(private_key, "big")
self._priv_key = ec.derive_private_key(priv_int, ec.SECP224R1(), default_backend())
self._priv_key = ec.derive_private_key(
priv_int,
ec.SECP224R1(),
default_backend(),
)
@classmethod
def generate(cls) -> "KeyPair":
"""Generate a new random `KeyPair`."""
return cls(secrets.token_bytes(28))
@classmethod
def from_b64(cls, key_b64: str) -> "KeyPair":
"""Import an existing `KeyPair` from its base64-encoded representation.
Same format as returned by `KeyPair.private_key_b64`.
"""
return cls(base64.b64decode(key_b64))
@property
def private_key_bytes(self) -> bytes:
"""Return the private key as bytes."""
key_bytes = self._priv_key.private_numbers().private_value
return int.to_bytes(key_bytes, 28, "big")
@property
def private_key_b64(self) -> str:
"""Return the private key as a base64-encoded string.
Can be re-imported using `KeyPair.from_b64`.
"""
return base64.b64encode(self.private_key_bytes).decode("ascii")
@property
def adv_key_bytes(self) -> bytes:
"""Return the advertised (public) key as bytes."""
key_bytes = self._priv_key.public_key().public_numbers().x
return int.to_bytes(key_bytes, 28, "big")
@property
def adv_key_b64(self) -> str:
"""Return the advertised (public) key as a base64-encoded string."""
return base64.b64encode(self.adv_key_bytes).decode("ascii")
@property
def hashed_adv_key_bytes(self) -> bytes:
"""Return the hashed advertised (public) key as bytes."""
return hashlib.sha256(self.adv_key_bytes).digest()
@property
def hashed_adv_key_b64(self) -> str:
"""Return the hashed advertised (public) key as a base64-encoded string."""
return base64.b64encode(self.hashed_adv_key_bytes).decode("ascii")
def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes:
"""Do a Diffie-Hellman key exchange using another EC public key."""
return self._priv_key.exchange(ec.ECDH(), other_pub_key)

View File

@@ -1,27 +1,37 @@
"""Module providing functionality to look up location reports."""
from __future__ import annotations
import base64
import hashlib
import struct
from datetime import datetime
from typing import Sequence
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Sequence
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .keys import KeyPair
from .http import HttpSession
if TYPE_CHECKING:
from .keys import KeyPair
_session = HttpSession()
class ReportsError(RuntimeError):
pass
"""Raised when an error occurs while looking up reports."""
def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
eph_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP224R1(), payload[5:62])
eph_key = ec.EllipticCurvePublicKey.from_encoded_point(
ec.SECP224R1(),
payload[5:62],
)
shared_key = key.dh_exchange(eph_key)
symmetric_key = hashlib.sha256(shared_key + b"\x00\x00\x00\x01" + payload[5:62]).digest()
symmetric_key = hashlib.sha256(
shared_key + b"\x00\x00\x00\x01" + payload[5:62],
).digest()
decryption_key = symmetric_key[:16]
iv = symmetric_key[16:]
@@ -29,13 +39,17 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes:
tag = payload[72:]
decryptor = Cipher(
algorithms.AES(decryption_key), modes.GCM(iv, tag), default_backend()
algorithms.AES(decryption_key),
modes.GCM(iv, tag),
default_backend(),
).decryptor()
return decryptor.update(enc_data) + decryptor.finalize()
class KeyReport:
def __init__(
"""Location report corresponding to a certain `KeyPair`."""
def __init__( # noqa: PLR0913
self,
key: KeyPair,
publish_date: datetime,
@@ -45,7 +59,8 @@ class KeyReport:
lng: float,
confidence: int,
status: int,
):
) -> None:
"""Initialize a `KeyReport`. You should probably use `KeyReport.from_payload` instead."""
self._key = key
self._publish_date = publish_date
self._timestamp = timestamp
@@ -58,43 +73,59 @@ class KeyReport:
self._status = status
@property
def key(self):
def key(self) -> KeyPair:
"""The `KeyPair` corresponding to this location report."""
return self._key
@property
def published_at(self):
def published_at(self) -> datetime:
"""The `datetime` when this report was published by a device."""
return self._publish_date
@property
def timestamp(self):
def timestamp(self) -> datetime:
"""The `datetime` when this report was recorded by a device."""
return self._timestamp
@property
def description(self):
def description(self) -> str:
"""Description of the location report as published by Apple."""
return self._description
@property
def latitude(self):
def latitude(self) -> float:
"""Latitude of the location of this report."""
return self._lat
@property
def longitude(self):
def longitude(self) -> float:
"""Longitude of the location of this report."""
return self._lng
@property
def confidence(self):
def confidence(self) -> int:
"""Confidence of the location of this report."""
return self._confidence
@property
def status(self):
def status(self) -> int:
"""Status byte of the accessory as recorded by a device, as an integer."""
return self._status
@classmethod
def from_payload(
cls, key: KeyPair, publish_date: datetime, description: str, payload: bytes
) -> "KeyReport":
cls,
key: KeyPair,
publish_date: datetime,
description: str,
payload: bytes,
) -> KeyReport:
"""Create a `KeyReport` from fields and a payload as reported by Apple.
Requires a `KeyPair` to decrypt the report's payload.
"""
timestamp_int = int.from_bytes(payload[0:4], "big") + (60 * 60 * 24 * 11323)
timestamp = datetime.utcfromtimestamp(timestamp_int)
timestamp = datetime.fromtimestamp(timestamp_int, tz=timezone.utc)
data = _decrypt_payload(payload, key)
latitude = struct.unpack(">i", data[0:4])[0] / 10000000
@@ -113,33 +144,40 @@ class KeyReport:
status,
)
def __lt__(self, other):
def __lt__(self, other: KeyReport) -> bool:
"""Compare against another `KeyReport`.
A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded
timestamp is strictly less than the other report.
"""
if isinstance(other, KeyReport):
return self.timestamp < other.timestamp
return NotImplemented
def __repr__(self):
def __repr__(self) -> str:
"""Human-readable string representation of the location report."""
return (
f"<KeyReport(key={self._key.hashed_adv_key_b64}, timestamp={self._timestamp},"
f" lat={self._lat}, lng={self._lng})>"
)
async def fetch_reports(
async def fetch_reports( # noqa: PLR0913
dsid: str,
search_party_token: str,
anisette_headers: dict[str, str],
date_from: datetime,
date_to: datetime,
keys: Sequence[KeyPair],
):
) -> dict[KeyPair, list[KeyReport]]:
"""Look up reports for given `KeyPair`s."""
start_date = date_from.timestamp() * 1000
end_date = date_to.timestamp() * 1000
ids = [key.hashed_adv_key_b64 for key in keys]
data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]}
# TODO: do not create a new session every time
# probably needs a wrapper class to allow closing the connections
# TODO(malmeloo): do not create a new session every time
# https://github.com/malmeloo/FindMy.py/issues/3
async with await _session.post(
"https://gateway.icloud.com/acsnservice/fetch",
auth=(dsid, search_party_token),
@@ -148,7 +186,8 @@ async def fetch_reports(
) as r:
resp = await r.json()
if not r.ok or resp["statusCode"] != "200":
raise ReportsError(f"Failed to fetch reports: {resp['statusCode']}")
msg = f"Failed to fetch reports: {resp['statusCode']}"
raise ReportsError(msg)
await _session.close()
reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys}
@@ -156,11 +195,14 @@ async def fetch_reports(
for report in resp.get("results", []):
key = id_to_key[report["id"]]
date_published = datetime.utcfromtimestamp(report.get("datePublished", 0) / 1000)
date_published = datetime.fromtimestamp(
report.get("datePublished", 0) / 1000,
tz=timezone.utc,
)
description = report.get("description", "")
payload = base64.b64decode(report["payload"])
report = KeyReport.from_payload(key, date_published, description, payload)
reports[key].append(report)
r = KeyReport.from_payload(key, date_published, description, payload)
reports[key].append(r)
return {key: sorted(reps) for key, reps in reports.items()}

View File

@@ -17,6 +17,19 @@ aiohttp = "^3.9.1"
pre-commit = "^3.6.0"
[tool.ruff]
exclude = [
"examples/"
]
select = [
"ALL",
]
ignore = [
"ANN101", # annotations on `self`
"ANN102", # annotations on `cls`
"FIX002", # resolving TODOs
]
line-length = 100
[build-system]