diff --git a/examples/fetch_reports.py b/examples/fetch_reports.py index a6ea8ab..6270cf8 100644 --- a/examples/fetch_reports.py +++ b/examples/fetch_reports.py @@ -2,8 +2,13 @@ import json import logging import os -from findmy import AppleAccount, LoginState, SmsSecondFactor, RemoteAnisetteProvider -from findmy import keys +from findmy import ( + AppleAccount, + LoginState, + RemoteAnisetteProvider, + SmsSecondFactor, + keys, +) # URL to (public or local) anisette server ANISETTE_SERVER = "http://localhost:6969" @@ -50,7 +55,7 @@ def fetch_reports(lookup_key): # Save / restore account logic if os.path.isfile("account.json"): - with open("account.json", "r") as f: + with open("account.json") as f: acc.restore(json.load(f)) else: login(acc) diff --git a/examples/fetch_reports_async.py b/examples/fetch_reports_async.py index 8c1ec8e..6cdd69f 100644 --- a/examples/fetch_reports_async.py +++ b/examples/fetch_reports_async.py @@ -6,10 +6,10 @@ import os from findmy import ( AsyncAppleAccount, LoginState, - SmsSecondFactor, RemoteAnisetteProvider, + SmsSecondFactor, + keys, ) -from findmy import keys # URL to (public or local) anisette server ANISETTE_SERVER = "http://localhost:6969" @@ -57,7 +57,7 @@ async def fetch_reports(lookup_key): try: # Save / restore account logic if os.path.isfile("account.json"): - with open("account.json", "r") as f: + with open("account.json") as f: acc.restore(json.load(f)) else: await login(acc) diff --git a/findmy/__init__.py b/findmy/__init__.py index d779ac9..64d6efa 100644 --- a/findmy/__init__.py +++ b/findmy/__init__.py @@ -1,10 +1,11 @@ +"""A package providing everything you need to query Apple's FindMy network.""" from .account import AppleAccount, AsyncAppleAccount, LoginState, SmsSecondFactor from .anisette import RemoteAnisetteProvider __all__ = ( - AppleAccount, - AsyncAppleAccount, - LoginState, - SmsSecondFactor, - RemoteAnisetteProvider, + "AppleAccount", + "AsyncAppleAccount", + "LoginState", + "SmsSecondFactor", + "RemoteAnisetteProvider", ) diff --git a/findmy/account.py b/findmy/account.py index d5a49cb..88a3033 100644 --- a/findmy/account.py +++ b/findmy/account.py @@ -1,3 +1,6 @@ +"""Module containing most of the code necessary to interact with an Apple account.""" +from __future__ import annotations + import asyncio import base64 import hashlib @@ -6,22 +9,23 @@ import json import logging import plistlib import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from functools import wraps -from typing import Optional, TypedDict, Any -from typing import Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar import bs4 import srp._pysrp as srp -from cryptography.hazmat.primitives import padding, hashes +from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from .anisette import AnisetteProvider from .base import BaseAppleAccount, BaseSecondFactorMethod, LoginState from .http import HttpSession -from .keys import KeyPair -from .reports import fetch_reports +from .reports import KeyReport, fetch_reports + +if TYPE_CHECKING: + from .anisette import BaseAnisetteProvider + from .keys import KeyPair logging.getLogger(__name__) @@ -29,25 +33,28 @@ srp.rfc5054_enable() srp.no_username_in_x() -class AccountInfo(TypedDict): +class _AccountInfo(TypedDict): account_name: str first_name: str last_name: str -class LoginException(Exception): - pass +class LoginError(Exception): + """Raised when an error occurs during login, such as when the password is incorrect.""" -class InvalidStateException(RuntimeError): - pass +class InvalidStateError(RuntimeError): + """Raised when a method is used that is in conflict with the internal account state. + + For example: calling `BaseAppleAccount.login` while already logged in. + """ class ExportRestoreError(ValueError): - pass + """Raised when an error occurs while exporting or restoring the account's current state.""" -def _load_plist(data: bytes) -> Any: +def _load_plist(data: bytes) -> Any: # noqa: ANN401 plist_header = ( b"" b"" @@ -89,20 +96,26 @@ def _extract_phone_numbers(html: str) -> list[dict]: soup = bs4.BeautifulSoup(html, features="html.parser") data_elem = soup.find("script", **{"class": "boot_args"}) if not data_elem: - raise RuntimeError("Could not find HTML element containing phone numbers") + msg = "Could not find HTML element containing phone numbers" + raise RuntimeError(msg) data = json.loads(data_elem.text) return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", []) -def _require_login_state(*states: LoginState): - def decorator(func): +F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any]) + + +def _require_login_state(*states: LoginState) -> Callable[[F], F]: + def decorator(func: F) -> F: @wraps(func) - def wrapper(acc: "BaseAppleAccount", *args, **kwargs): + def wrapper(acc: BaseAppleAccount, *args, **kwargs): if acc.login_state not in states: - raise InvalidStateException( - f"Invalid login state! Currently: {acc.login_state} but should be one of: {states}" + msg = ( + f"Invalid login state! Currently: {acc.login_state}" + f" but should be one of: {states}" ) + raise InvalidStateError(msg) return func(acc, *args, **kwargs) @@ -112,93 +125,164 @@ def _require_login_state(*states: LoginState): class AsyncSmsSecondFactor(BaseSecondFactorMethod): - def __init__(self, account: "AsyncAppleAccount", number_id: int, phone_number: str): + """An async implementation of a second-factor method.""" + + def __init__( + self, + account: AsyncAppleAccount, + number_id: int, + phone_number: str, + ) -> None: + """Initialize the second factor method. + + Should not be done manually; use `BaseAppleAccount.get_2fa_methods` instead. + """ super().__init__(account) self._phone_number_id: int = number_id self._phone_number: str = phone_number @property - def phone_number_id(self): + def phone_number_id(self) -> int: + """The phone number's ID. You most likely don't need this.""" return self._phone_number_id @property - def phone_number(self): + def phone_number(self) -> str: + """The 2FA method's phone number. + + May be masked using unicode characters; should only be used for identification purposes. + """ return self._phone_number - async def request(self): + async def request(self) -> None: + """Request an SMS to the corresponding phone number containing a 2FA code.""" return await self.account.sms_2fa_request(self._phone_number_id) async def submit(self, code: str) -> LoginState: + """See `BaseSecondFactorMethod.submit`.""" return await self.account.sms_2fa_submit(self._phone_number_id, code) class SmsSecondFactor(BaseSecondFactorMethod): - def __init__(self, account: "AppleAccount", number_id: int, phone_number: str): + """A sync implementation of `BaseSecondFactorMethod`. + + Uses `AsyncSmsSecondFactor` internally. + """ + + def __init__( + self, + account: AppleAccount, + number_id: int, + phone_number: str, + ) -> None: + """See `AsyncSmsSecondFactor.__init__`.""" super().__init__(account) self._phone_number_id: int = number_id self._phone_number: str = phone_number @property - def phone_number(self): + def phone_number_id(self) -> int: + """See `AsyncSmsSecondFactor.phone_number_id`.""" + return self._phone_number_id + + @property + def phone_number(self) -> str: + """See `AsyncSmsSecondFactor.phone_number`.""" return self._phone_number def request(self) -> None: + """See `AsyncSmsSecondFactor.request`.""" return self.account.sms_2fa_request(self._phone_number_id) def submit(self, code: str) -> LoginState: + """See `AsyncSmsSecondFactor.submit`.""" return self.account.sms_2fa_submit(self._phone_number_id, code) class AsyncAppleAccount(BaseAppleAccount): - def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None): - self._anisette: AnisetteProvider = anisette + """An async implementation of `BaseAppleAccount`.""" + + def __init__( + self, + anisette: BaseAnisetteProvider, + user_id: str | None = None, + device_id: str | None = None, + ) -> None: + """Initialize the apple account. + + :param anisette: An instance of `AsyncAnisetteProvider`. + :param user_id: An optional user ID to use. Will be auto-generated if missing. + :param device_id: An optional device ID to use. Will be auto-generated if missing. + """ + self._anisette: BaseAnisetteProvider = anisette self._uid: str = user_id or str(uuid.uuid4()) self._devid: str = device_id or str(uuid.uuid4()) - self._username: Optional[str] = None - self._password: Optional[str] = None + self._username: str | None = None + self._password: str | None = None self._login_state: LoginState = LoginState.LOGGED_OUT self._login_state_data: dict = {} - self._account_info: Optional[AccountInfo] = None + self._account_info: _AccountInfo | None = None self._http = HttpSession() - def _set_login_state(self, state: LoginState, data: Optional[dict] = None) -> LoginState: + def _set_login_state( + self, + state: LoginState, + data: dict | None = None, + ) -> LoginState: # clear account info if downgrading state (e.g. LOGGED_IN -> LOGGED_OUT) if state < self._login_state: logging.debug("Clearing cached account information") self._account_info = None - logging.info(f"Transitioning login state: {self._login_state} -> {state}") + logging.info("Transitioning login state: %s -> %s", self._login_state, state) self._login_state = state self._login_state_data = data or {} return state @property - def login_state(self): + def login_state(self) -> LoginState: + """See `BaseAppleAccount.login_state`.""" return self._login_state @property - @_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA) - def account_name(self): + @_require_login_state( + LoginState.LOGGED_IN, + LoginState.AUTHENTICATED, + LoginState.REQUIRE_2FA, + ) + def account_name(self) -> str | None: + """See `BaseAppleAccount.account_name`.""" return self._account_info["account_name"] if self._account_info else None @property - @_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA) - def first_name(self): + @_require_login_state( + LoginState.LOGGED_IN, + LoginState.AUTHENTICATED, + LoginState.REQUIRE_2FA, + ) + def first_name(self) -> str | None: + """See `BaseAppleAccount.first_name`.""" return self._account_info["first_name"] if self._account_info else None @property - @_require_login_state(LoginState.LOGGED_IN, LoginState.AUTHENTICATED, LoginState.REQUIRE_2FA) - def last_name(self): + @_require_login_state( + LoginState.LOGGED_IN, + LoginState.AUTHENTICATED, + LoginState.REQUIRE_2FA, + ) + def last_name(self) -> str | None: + """See `BaseAppleAccount.last_name`.""" return self._account_info["last_name"] if self._account_info else None def export(self) -> dict: + """See `BaseAppleAccount.export`.""" return { "ids": {"uid": self._uid, "devid": self._devid}, "account": { @@ -212,7 +296,8 @@ class AsyncAppleAccount(BaseAppleAccount): }, } - def restore(self, data: dict): + def restore(self, data: dict) -> None: + """See `BaseAppleAccount.restore`.""" try: self._uid = data["ids"]["uid"] self._devid = data["ids"]["devid"] @@ -224,14 +309,20 @@ class AsyncAppleAccount(BaseAppleAccount): self._login_state = LoginState(data["login_state"]["state"]) self._login_state_data = data["login_state"]["data"] except KeyError as e: - raise ExportRestoreError(f"Failed to restore account data: {e}") + msg = f"Failed to restore account data: {e}" + raise ExportRestoreError(msg) from None - async def close(self): + async def close(self) -> None: + """Close any sessions or other resources in use by this object. + + Should be called when the object will no longer be used. + """ await self._anisette.close() await self._http.close() @_require_login_state(LoginState.LOGGED_OUT) async def login(self, username: str, password: str) -> LoginState: + """See `BaseAppleAccount.login`.""" # LOGGED_OUT -> (REQUIRE_2FA or AUTHENTICATED) new_state = await self._gsa_authenticate(username, password) if new_state == LoginState.REQUIRE_2FA: # pass control back to handle 2FA @@ -242,30 +333,40 @@ class AsyncAppleAccount(BaseAppleAccount): @_require_login_state(LoginState.REQUIRE_2FA) async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + """See `BaseAppleAccount.get_2fa_methods`.""" methods: list[BaseSecondFactorMethod] = [] # sms auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth") try: phone_numbers = _extract_phone_numbers(auth_page) + methods.extend( + AsyncSmsSecondFactor( + self, + number.get("id"), + number.get("numberWithDialCode"), + ) + for number in phone_numbers + ) except RuntimeError: logging.warning("Unable to extract phone numbers from login page") - methods.extend( - AsyncSmsSecondFactor(self, number.get("id"), number.get("numberWithDialCode")) - for number in phone_numbers - ) - return methods @_require_login_state(LoginState.REQUIRE_2FA) - async def sms_2fa_request(self, phone_number_id: int): + async def sms_2fa_request(self, phone_number_id: int) -> None: + """See `BaseAppleAccount.sms_2fa_request`.""" data = {"phoneNumber": {"id": phone_number_id}, "mode": "sms"} - await self._sms_2fa_request("PUT", "https://gsa.apple.com/auth/verify/phone", data) + await self._sms_2fa_request( + "PUT", + "https://gsa.apple.com/auth/verify/phone", + data, + ) @_require_login_state(LoginState.REQUIRE_2FA) async def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: + """See `BaseAppleAccount.sms_2fa_submit`.""" data = { "phoneNumber": {"id": phone_number_id}, "securityCode": {"code": str(code)}, @@ -273,19 +374,28 @@ class AsyncAppleAccount(BaseAppleAccount): } await self._sms_2fa_request( - "POST", "https://gsa.apple.com/auth/verify/phone/securitycode", data + "POST", + "https://gsa.apple.com/auth/verify/phone/securitycode", + data, ) # REQUIRE_2FA -> AUTHENTICATED new_state = await self._gsa_authenticate() if new_state != LoginState.AUTHENTICATED: - raise LoginException(f"Unexpected state after submitting 2FA: {new_state}") + msg = f"Unexpected state after submitting 2FA: {new_state}" + raise LoginError(msg) # AUTHENTICATED -> LOGGED_IN return await self._login_mobileme() @_require_login_state(LoginState.LOGGED_IN) - async def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime): + async def fetch_reports( + self, + keys: Sequence[KeyPair], + date_from: datetime, + date_to: datetime, + ) -> dict[KeyPair, list[KeyReport]]: + """See `BaseAppleAccount.fetch_reports`.""" anisette_headers = await self.get_anisette_headers() return await fetch_reports( @@ -302,57 +412,67 @@ class AsyncAppleAccount(BaseAppleAccount): self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ): - end = datetime.now() + ) -> dict[KeyPair, list[KeyReport]]: + """See `BaseAppleAccount.fetch_last_reports`.""" + end = datetime.now(tz=timezone.utc) start = end - timedelta(hours=hours) return await self.fetch_reports(keys, start, end) @_require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA) async def _gsa_authenticate( - self, username: Optional[str] = None, password: Optional[str] = None + self, + username: str | None = None, + password: str | None = None, ) -> LoginState: + # use stored values for re-authentication self._username = username or self._username self._password = password or self._password - logging.info(f"Attempting authentication for user {self._username}") + logging.info("Attempting authentication for user %s", self._username) if not self._username or not self._password: - raise ValueError("No username or password to log in") + msg = "No username or password specified" + raise ValueError(msg) logging.debug("Starting authentication with username") usr = srp.User(self._username, b"", hash_alg=srp.SHA256, ng_type=srp.NG_2048) _, a2k = usr.start_authentication() r = await self._gsa_request( - {"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"} + {"A2k": a2k, "u": self._username, "ps": ["s2k", "s2k_fo"], "o": "init"}, ) logging.debug("Verifying response to auth request") if r["Status"].get("ec") != 0: - message = r["Status"].get("em") - raise LoginException(f"Email verify failed: {message}") + msg = "Email verify failed: " + r["Status"].get("em") + raise LoginError(msg) sp = r.get("sp") if sp != "s2k": - raise LoginException(f"This implementation only supports s2k. Server returned {sp}") + msg = f"This implementation only supports s2k. Server returned {sp}" + raise LoginError(msg) logging.debug("Attempting password challenge") usr.p = _encrypt_password(self._password, r["s"], r["i"]) m1 = usr.process_challenge(r["s"], r["B"]) if m1 is None: - raise LoginException("Failed to process challenge") - r = await self._gsa_request({"c": r["c"], "M1": m1, "u": self._username, "o": "complete"}) + msg = "Failed to process challenge" + raise LoginError(msg) + r = await self._gsa_request( + {"c": r["c"], "M1": m1, "u": self._username, "o": "complete"}, + ) logging.debug("Verifying password challenge response") if r["Status"].get("ec") != 0: - message = r["Status"].get("em") - raise LoginException(f"Password authentication failed: {message}") + msg = "Password authentication failed: " + r["Status"].get("em") + raise LoginError(msg) usr.verify_session(r.get("M2")) if not usr.authenticated(): - raise LoginException("Failed to verify session") + msg = "Failed to verify session" + raise LoginError(msg) logging.debug("Decrypting SPD data in response") @@ -360,33 +480,36 @@ class AsyncAppleAccount(BaseAppleAccount): spd = _load_plist(spd) logging.debug("Received account information") - self._account_info: AccountInfo = { + self._account_info: _AccountInfo = { "account_name": spd.get("acname"), "first_name": spd.get("fn"), "last_name": spd.get("ln"), } - # TODO: support trusted device auth (need account to test) + # TODO(malmeloo): support trusted device auth (need account to test) + # https://github.com/malmeloo/FindMy.py/issues/1 au = r["Status"].get("au") if au in ("secondaryAuth",): - logging.info(f"Detected 2FA requirement: {au}") + logging.info("Detected 2FA requirement: %s", au) return self._set_login_state( LoginState.REQUIRE_2FA, {"adsid": spd["adsid"], "idms_token": spd["GsIdmsToken"]}, ) - elif au is not None: - raise LoginException(f"Unknown auth value: {au}") + if au is not None: + msg = f"Unknown auth value: {au}" + raise LoginError(msg) logging.info("GSA authentication successful") idms_pet = spd.get("t", {}).get("com.apple.gs.idms.pet", {}).get("token", "") return self._set_login_state( - LoginState.AUTHENTICATED, {"idms_pet": idms_pet, "adsid": spd["adsid"]} + LoginState.AUTHENTICATED, + {"idms_pet": idms_pet, "adsid": spd["adsid"]}, ) @_require_login_state(LoginState.AUTHENTICATED) - async def _login_mobileme(self): + async def _login_mobileme(self) -> LoginState: logging.info("Logging into com.apple.mobileme") data = plistlib.dumps( { @@ -394,7 +517,7 @@ class AsyncAppleAccount(BaseAppleAccount): "delegates": {"com.apple.mobileme": {}}, "password": self._login_state_data["idms_pet"], "client-id": self._uid, - } + }, ) headers = { @@ -417,15 +540,21 @@ class AsyncAppleAccount(BaseAppleAccount): mobileme_data = resp.get("delegates", {}).get("com.apple.mobileme", {}) status = mobileme_data.get("status") if status != 0: - message = mobileme_data.get("status-message") - raise LoginException(f"com.apple.mobileme login failed with status {status}: {message}") + status_message = mobileme_data.get("status-message") + msg = f"com.apple.mobileme login failed with status {status}: {status_message}" + raise LoginError(msg) return self._set_login_state( LoginState.LOGGED_IN, {"dsid": resp["dsid"], "mobileme_data": mobileme_data["service-data"]}, ) - async def _sms_2fa_request(self, method: str, url: str, data: Optional[dict] = None) -> str: + async def _sms_2fa_request( + self, + method: str, + url: str, + data: dict | None = None, + ) -> str: adsid = self._login_state_data["adsid"] idms_token = self._login_state_data["idms_token"] identity_token = base64.b64encode((adsid + ":" + idms_token).encode()).decode() @@ -441,13 +570,19 @@ class AsyncAppleAccount(BaseAppleAccount): } headers.update(await self.get_anisette_headers()) - async with await self._http.request(method, url, json=data, headers=headers) as r: + async with await self._http.request( + method, + url, + json=data, + headers=headers, + ) as r: if not r.ok: - raise LoginException(f"HTTP request failed: {r.status_code}") + msg = f"HTTP request failed: {r.status_code}" + raise LoginError(msg) return await r.text() - async def _gsa_request(self, params): + async def _gsa_request(self, params: dict[str, Any]) -> Any: request_data = { "cpd": { "bootstrap": True, @@ -455,7 +590,7 @@ class AsyncAppleAccount(BaseAppleAccount): "pbe": False, "prkgen": True, "svct": "iCloud", - } + }, } request_data["cpd"].update(await self.get_anisette_headers()) request_data.update(params) @@ -482,11 +617,23 @@ class AsyncAppleAccount(BaseAppleAccount): return _load_plist(content)["Response"] async def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: + """See `BaseAppleAccount.get_anisette_headers`.""" return await self._anisette.get_headers(self._uid, self._devid, serial) class AppleAccount(BaseAppleAccount): - def __init__(self, anisette: AnisetteProvider, user_id: str = None, device_id: str = None): + """A sync implementation of `BaseappleAccount`. + + Uses `AsyncappleAccount` internally. + """ + + def __init__( + self, + anisette: BaseAnisetteProvider, + user_id: str | None = None, + device_id: str | None = None, + ) -> None: + """See `AsyncAppleAccount.__init__`.""" self._asyncacc = AsyncAppleAccount(anisette, user_id, device_id) try: @@ -496,36 +643,45 @@ class AppleAccount(BaseAppleAccount): asyncio.set_event_loop(self._loop) def __del__(self) -> None: + """Gracefully close the async instance's session when garbage collected.""" coro = self._asyncacc.close() return self._loop.run_until_complete(coro) @property - def login_state(self): + def login_state(self) -> LoginState: + """See `AsyncAppleAccount.login_state`.""" return self._asyncacc.login_state @property - def account_name(self): + def account_name(self) -> str: + """See `AsyncAppleAccount.login_state`.""" return self._asyncacc.account_name @property - def first_name(self): + def first_name(self) -> str | None: + """See `AsyncAppleAccount.first_name`.""" return self._asyncacc.first_name @property - def last_name(self): + def last_name(self) -> str | None: + """See `AsyncAppleAccount.last_name`.""" return self._asyncacc.last_name def export(self) -> dict: + """See `AsyncAppleAccount.export`.""" return self._asyncacc.export() - def restore(self, data: dict): + def restore(self, data: dict) -> None: + """See `AsyncAppleAccount.restore`.""" return self._asyncacc.restore(data) def login(self, username: str, password: str) -> LoginState: + """See `AsyncAppleAccount.login`.""" coro = self._asyncacc.login(username, password) return self._loop.run_until_complete(coro) def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + """See `AsyncAppleAccount.get_2fa_methods`.""" coro = self._asyncacc.get_2fa_methods() methods = self._loop.run_until_complete(coro) @@ -534,28 +690,44 @@ class AppleAccount(BaseAppleAccount): if isinstance(m, AsyncSmsSecondFactor): res.append(SmsSecondFactor(self, m.phone_number_id, m.phone_number)) else: - raise RuntimeError( - f"Failed to cast 2FA object to sync alternative: {m}. This is a bug, please report it." + msg = ( + f"Failed to cast 2FA object to sync alternative: {m}." + f" This is a bug, please report it." ) + raise TypeError(msg) return res - def sms_2fa_request(self, phone_number_id: int): + def sms_2fa_request(self, phone_number_id: int) -> None: + """See `AsyncAppleAccount.sms_2fa_request`.""" coro = self._asyncacc.sms_2fa_request(phone_number_id) return self._loop.run_until_complete(coro) def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: + """See `AsyncAppleAccount.sms_2fa_submit`.""" coro = self._asyncacc.sms_2fa_submit(phone_number_id, code) return self._loop.run_until_complete(coro) - def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime): + def fetch_reports( + self, + keys: Sequence[KeyPair], + date_from: datetime, + date_to: datetime, + ) -> dict[KeyPair, list[KeyReport]]: + """See `AsyncAppleAccount.fetch_reports`.""" coro = self._asyncacc.fetch_reports(keys, date_from, date_to) return self._loop.run_until_complete(coro) - def fetch_last_reports(self, keys: Sequence[KeyPair], hours: int = 7 * 24): + def fetch_last_reports( + self, + keys: Sequence[KeyPair], + hours: int = 7 * 24, + ) -> dict[KeyPair, list[KeyReport]]: + """See `AsyncAppleAccount.fetch_last_reports`.""" coro = self._asyncacc.fetch_last_reports(keys, hours) return self._loop.run_until_complete(coro) def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: - coro = self._asyncacc.get_anisette_headers() + """See `AsyncAppleAccount.get_anisette_headers`.""" + coro = self._asyncacc.get_anisette_headers(serial) return self._loop.run_until_complete(coro) diff --git a/findmy/anisette.py b/findmy/anisette.py index bac2975..a05f84a 100644 --- a/findmy/anisette.py +++ b/findmy/anisette.py @@ -1,14 +1,21 @@ +"""Module for Anisette header providers.""" +from __future__ import annotations + import base64 import locale import logging from abc import ABC, abstractmethod -from datetime import datetime +from datetime import datetime, timezone from .http import HttpSession -def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"): - now = datetime.utcnow() +def _gen_meta_headers( + user_id: str, + device_id: str, + serial: str = "0", +) -> dict[str, str]: + now = datetime.now(tz=timezone.utc) locale_str = locale.getdefaultlocale()[0] or "en_US" return { @@ -23,29 +30,44 @@ def _gen_meta_headers(user_id: str, device_id: str, serial: str = "0"): } -class AnisetteProvider(ABC): +class BaseAnisetteProvider(ABC): + """Abstract base class for Anisette providers.""" + @abstractmethod async def _get_base_headers(self) -> dict[str, str]: - return NotImplemented + raise NotImplementedError @abstractmethod - async def close(self): - return NotImplemented + async def close(self) -> None: + """Close any underlying sessions. Call when the provider will no longer be used.""" + raise NotImplementedError - async def get_headers(self, user_id: str, device_id: str, serial: str = "0") -> dict[str, str]: + async def get_headers( + self, + user_id: str, + device_id: str, + serial: str = "0", + ) -> dict[str, str]: + """Retrieve a complete dictionary of Anisette headers. + + Consider using `BaseAppleAccount.get_anisette_headers` instead. + """ base_headers = await self._get_base_headers() base_headers.update(_gen_meta_headers(user_id, device_id, serial)) return base_headers -class RemoteAnisetteProvider(AnisetteProvider): - def __init__(self, server_url: str): +class RemoteAnisetteProvider(BaseAnisetteProvider): + """Anisette provider. Fetches headers from a remote Anisette server.""" + + def __init__(self, server_url: str) -> None: + """Initialize the provider with URL to te remote server.""" self._server_url = server_url self._http = HttpSession() - logging.info(f"Using remote anisette server: {self._server_url}") + logging.info("Using remote anisette server: %s", self._server_url) async def _get_base_headers(self) -> dict[str, str]: async with await self._http.get(self._server_url) as r: @@ -56,17 +78,21 @@ class RemoteAnisetteProvider(AnisetteProvider): "X-Apple-I-MD-M": headers["X-Apple-I-MD-M"], } - async def close(self): + async def close(self) -> None: + """See `AnisetteProvider.close`.""" await self._http.close() -# TODO: implement using pyprovision -class LocalAnisetteProvider(AnisetteProvider): - def __init__(self): - pass +# TODO(malmeloo): implement using pyprovision +# https://github.com/malmeloo/FindMy.py/issues/2 +class LocalAnisetteProvider(BaseAnisetteProvider): + """Anisette provider. Generates headers without a remote server using pyprovision.""" + + def __init__(self) -> None: + """Initialize the provider.""" async def _get_base_headers(self) -> dict[str, str]: return NotImplemented - async def close(self): - pass + async def close(self) -> None: + """See `AnisetteProvider.close`.""" diff --git a/findmy/base.py b/findmy/base.py index 5a762d0..50224d4 100644 --- a/findmy/base.py +++ b/findmy/base.py @@ -1,101 +1,185 @@ -from abc import ABC, abstractmethod -from datetime import datetime -from enum import Enum -from typing import Sequence +"""Module that contains base classes for various other modules. For internal use only.""" +from __future__ import annotations -from .keys import KeyPair +from abc import ABC, abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Sequence, TypeVar + +if TYPE_CHECKING: + from datetime import datetime + + from .keys import KeyPair + from .reports import KeyReport class LoginState(Enum): + """Enum of possible login states. Used for `AppleAccount`'s internal state machine.""" + LOGGED_OUT = 0 REQUIRE_2FA = 1 AUTHENTICATED = 2 LOGGED_IN = 3 - def __lt__(self, other): + def __lt__(self, other: LoginState) -> bool: + """Compare against another `LoginState`. + + A `LoginState` is said to be "less than" another `LoginState` iff it is in + an "earlier" stage of the login process, going from LOGGED_OUT to LOGGED_IN. + """ if isinstance(other, LoginState): return self.value < other.value return NotImplemented - def __repr__(self): + def __repr__(self) -> str: + """Human-readable string representation of the state.""" return self.__str__() +T = TypeVar("T", bound="BaseAppleAccount") + + class BaseSecondFactorMethod(ABC): - def __init__(self, account: "BaseAppleAccount"): - self._account = account + """Base class for a second-factor authentication method for an Apple account.""" + + def __init__(self, account: T) -> None: + """Initialize the second-factor method.""" + self._account: T = account @property - def account(self): + def account(self) -> T: + """The account associated with the second-factor method.""" return self._account @abstractmethod def request(self) -> None: - raise NotImplementedError() + """Put in a request for the second-factor challenge. + + Exact meaning is up to the implementing class. + """ + raise NotImplementedError @abstractmethod def submit(self, code: str) -> LoginState: - raise NotImplementedError() + """Submit a code to complete the second-factor challenge.""" + raise NotImplementedError class BaseAppleAccount(ABC): - @property - @abstractmethod - def login_state(self): - return NotImplemented + """Base class for an Apple account.""" @property @abstractmethod - def account_name(self): - return NotImplemented + def login_state(self) -> LoginState: + """The current login state of the account.""" + raise NotImplementedError @property @abstractmethod - def first_name(self): - return NotImplemented + def account_name(self) -> str: + """The name of the account as reported by Apple. + + This is usually an e-mail address. + May be None in some cases, such as when not logged in. + """ + raise NotImplementedError @property @abstractmethod - def last_name(self): - return NotImplemented + def first_name(self) -> str | None: + """First name of the account holder as reported by Apple. + + May be None in some cases, such as when not logged in. + """ + raise NotImplementedError + + @property + @abstractmethod + def last_name(self) -> str | None: + """Last name of the account holder as reported by Apple. + + May be None in some cases, such as when not logged in. + """ + raise NotImplementedError @abstractmethod def export(self) -> dict: - return NotImplemented + """Export a representation of the current state of the account as a dictionary. + + The output of this method is guaranteed to be JSON-serializable, and passing + the return value of this function as an argument to `BaseAppleAccount.restore` + will always result in an exact copy of the internal state as it was when exported. + + This method is especially useful to avoid having to keep going through the login flow. + """ + raise NotImplementedError @abstractmethod - def restore(self, data: dict): - return NotImplemented + def restore(self, data: dict) -> None: + """Restore a previous export of the internal state of the account. + + See `BaseAppleAccount.export` for more information. + """ + raise NotImplementedError @abstractmethod def login(self, username: str, password: str) -> LoginState: - return NotImplemented + """Log in to an Apple account using a username and password.""" + raise NotImplementedError @abstractmethod def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: - return NotImplemented + """Get a list of 2FA methods that can be used as a secondary challenge. + + Currently, only SMS-based 2FA methods are supported. + """ + raise NotImplementedError @abstractmethod - def sms_2fa_request(self, phone_number_id: int): - return NotImplemented + def sms_2fa_request(self, phone_number_id: int) -> None: + """Request a 2FA code to be sent to a specific phone number ID. + + Consider using `BaseSecondFactorMethod.request` instead. + """ + raise NotImplementedError @abstractmethod def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: - return NotImplemented + """Submit a 2FA code that was sent to a specific phone number ID. + + Consider using `BaseSecondFactorMethod.submit` instead. + """ + raise NotImplementedError @abstractmethod - def fetch_reports(self, keys: Sequence[KeyPair], date_from: datetime, date_to: datetime): - return NotImplemented + def fetch_reports( + self, + keys: Sequence[KeyPair], + date_from: datetime, + date_to: datetime, + ) -> dict[KeyPair, list[KeyReport]]: + """Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`. + + Returns a dictionary mapping `KeyPair`s to a list of their location reports. + """ + raise NotImplementedError @abstractmethod def fetch_last_reports( self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ): - return NotImplemented + ) -> dict[KeyPair, list[KeyReport]]: + """Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours. + + Utility method as an alternative to using `BaseAppleAccount.fetch_reports` directly. + """ + raise NotImplementedError @abstractmethod def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: - return NotImplemented + """Retrieve a complete dictionary of Anisette headers. + + Utility method for `AnisetteProvider.get_headers` using this account's user and device ID. + """ + raise NotImplementedError diff --git a/findmy/http.py b/findmy/http.py index bdaeb1f..72a7f5d 100644 --- a/findmy/http.py +++ b/findmy/http.py @@ -1,45 +1,78 @@ -import logging -from typing import Optional -import asyncio +"""Module to simplify asynchronous HTTP calls. For internal use only.""" +from __future__ import annotations -from aiohttp import ClientSession, BasicAuth, ClientTimeout +import asyncio +import logging +from typing import Any + +from aiohttp import BasicAuth, ClientResponse, ClientSession, ClientTimeout logging.getLogger(__name__) -class HttpSession: - def __init__(self): - self._session: Optional[ClientSession] = None +class HttpResponse: + """Response of a request made by `HttpSession`.""" - async def _ensure_session(self): + def __init__(self, resp: ClientResponse) -> None: + """Initialize the response.""" + self._resp: ClientResponse = resp + + +class HttpSession: + """Asynchronous HTTP session manager. For internal use only.""" + + def __init__(self) -> None: # noqa: D107 + self._session: ClientSession | None = None + + async def _ensure_session(self) -> None: if self._session is None: logging.debug("Creating aiohttp session") self._session = ClientSession(timeout=ClientTimeout(total=5)) - async def close(self): + async def close(self) -> None: + """Close the underlying session. Should be called when session will no longer be used.""" if self._session is not None: logging.debug("Closing aiohttp session") await self._session.close() self._session = None def __del__(self) -> None: + """Attempt to gracefully close the session. + + Ideally this should be done by manually calling close(). + """ + if self._session is None: + return + try: loop = asyncio.get_running_loop() loop.call_soon_threadsafe(loop.create_task, self.close()) except RuntimeError: # cannot await closure pass - async def request(self, method: str, url: str, auth: tuple[str] = None, **kwargs): + async def request( + self, + method: str, + url: str, + auth: tuple[str] | None = None, + **kwargs: Any, + ) -> ClientResponse: + """Make an HTTP request. + + Keyword arguments will directly be passed to `aiohttp.ClientSession.request`. + """ await self._ensure_session() basic_auth = None if auth is not None: basic_auth = BasicAuth(auth[0], auth[1]) - return self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs) + return await self._session.request(method, url, auth=basic_auth, ssl=False, **kwargs) - async def get(self, url: str, **kwargs): + async def get(self, url: str, **kwargs: Any) -> ClientResponse: + """Alias for `HttpSession.request("GET", ...)`.""" return await self.request("GET", url, **kwargs) - async def post(self, url: str, **kwargs): + async def post(self, url: str, **kwargs: Any) -> ClientResponse: + """Alias for `HttpSession.request("POST", ...)`.""" return await self.request("POST", url, **kwargs) diff --git a/findmy/keys.py b/findmy/keys.py index 1831170..d489122 100644 --- a/findmy/keys.py +++ b/findmy/keys.py @@ -1,49 +1,73 @@ +"""Module to work with private and public keys as used in FindMy accessories.""" + import base64 +import hashlib import secrets from cryptography.hazmat.backends import default_backend -import hashlib from cryptography.hazmat.primitives.asymmetric import ec class KeyPair: - def __init__(self, private_key: bytes): + """A private-public keypair for a trackable FindMy accessory.""" + + def __init__(self, private_key: bytes) -> None: + """Initialize the `KeyPair` with the private key bytes.""" priv_int = int.from_bytes(private_key, "big") - self._priv_key = ec.derive_private_key(priv_int, ec.SECP224R1(), default_backend()) + self._priv_key = ec.derive_private_key( + priv_int, + ec.SECP224R1(), + default_backend(), + ) @classmethod def generate(cls) -> "KeyPair": + """Generate a new random `KeyPair`.""" return cls(secrets.token_bytes(28)) @classmethod def from_b64(cls, key_b64: str) -> "KeyPair": + """Import an existing `KeyPair` from its base64-encoded representation. + + Same format as returned by `KeyPair.private_key_b64`. + """ return cls(base64.b64decode(key_b64)) @property def private_key_bytes(self) -> bytes: + """Return the private key as bytes.""" key_bytes = self._priv_key.private_numbers().private_value return int.to_bytes(key_bytes, 28, "big") @property def private_key_b64(self) -> str: + """Return the private key as a base64-encoded string. + + Can be re-imported using `KeyPair.from_b64`. + """ return base64.b64encode(self.private_key_bytes).decode("ascii") @property def adv_key_bytes(self) -> bytes: + """Return the advertised (public) key as bytes.""" key_bytes = self._priv_key.public_key().public_numbers().x return int.to_bytes(key_bytes, 28, "big") @property def adv_key_b64(self) -> str: + """Return the advertised (public) key as a base64-encoded string.""" return base64.b64encode(self.adv_key_bytes).decode("ascii") @property def hashed_adv_key_bytes(self) -> bytes: + """Return the hashed advertised (public) key as bytes.""" return hashlib.sha256(self.adv_key_bytes).digest() @property def hashed_adv_key_b64(self) -> str: + """Return the hashed advertised (public) key as a base64-encoded string.""" return base64.b64encode(self.hashed_adv_key_bytes).decode("ascii") def dh_exchange(self, other_pub_key: ec.EllipticCurvePublicKey) -> bytes: + """Do a Diffie-Hellman key exchange using another EC public key.""" return self._priv_key.exchange(ec.ECDH(), other_pub_key) diff --git a/findmy/reports.py b/findmy/reports.py index a778d08..f7765ca 100644 --- a/findmy/reports.py +++ b/findmy/reports.py @@ -1,27 +1,37 @@ +"""Module providing functionality to look up location reports.""" +from __future__ import annotations + import base64 import hashlib import struct -from datetime import datetime -from typing import Sequence +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Sequence from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .keys import KeyPair from .http import HttpSession +if TYPE_CHECKING: + from .keys import KeyPair + _session = HttpSession() class ReportsError(RuntimeError): - pass + """Raised when an error occurs while looking up reports.""" def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes: - eph_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP224R1(), payload[5:62]) + eph_key = ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP224R1(), + payload[5:62], + ) shared_key = key.dh_exchange(eph_key) - symmetric_key = hashlib.sha256(shared_key + b"\x00\x00\x00\x01" + payload[5:62]).digest() + symmetric_key = hashlib.sha256( + shared_key + b"\x00\x00\x00\x01" + payload[5:62], + ).digest() decryption_key = symmetric_key[:16] iv = symmetric_key[16:] @@ -29,13 +39,17 @@ def _decrypt_payload(payload: bytes, key: KeyPair) -> bytes: tag = payload[72:] decryptor = Cipher( - algorithms.AES(decryption_key), modes.GCM(iv, tag), default_backend() + algorithms.AES(decryption_key), + modes.GCM(iv, tag), + default_backend(), ).decryptor() return decryptor.update(enc_data) + decryptor.finalize() class KeyReport: - def __init__( + """Location report corresponding to a certain `KeyPair`.""" + + def __init__( # noqa: PLR0913 self, key: KeyPair, publish_date: datetime, @@ -45,7 +59,8 @@ class KeyReport: lng: float, confidence: int, status: int, - ): + ) -> None: + """Initialize a `KeyReport`. You should probably use `KeyReport.from_payload` instead.""" self._key = key self._publish_date = publish_date self._timestamp = timestamp @@ -58,43 +73,59 @@ class KeyReport: self._status = status @property - def key(self): + def key(self) -> KeyPair: + """The `KeyPair` corresponding to this location report.""" return self._key @property - def published_at(self): + def published_at(self) -> datetime: + """The `datetime` when this report was published by a device.""" return self._publish_date @property - def timestamp(self): + def timestamp(self) -> datetime: + """The `datetime` when this report was recorded by a device.""" return self._timestamp @property - def description(self): + def description(self) -> str: + """Description of the location report as published by Apple.""" return self._description @property - def latitude(self): + def latitude(self) -> float: + """Latitude of the location of this report.""" return self._lat @property - def longitude(self): + def longitude(self) -> float: + """Longitude of the location of this report.""" return self._lng @property - def confidence(self): + def confidence(self) -> int: + """Confidence of the location of this report.""" return self._confidence @property - def status(self): + def status(self) -> int: + """Status byte of the accessory as recorded by a device, as an integer.""" return self._status @classmethod def from_payload( - cls, key: KeyPair, publish_date: datetime, description: str, payload: bytes - ) -> "KeyReport": + cls, + key: KeyPair, + publish_date: datetime, + description: str, + payload: bytes, + ) -> KeyReport: + """Create a `KeyReport` from fields and a payload as reported by Apple. + + Requires a `KeyPair` to decrypt the report's payload. + """ timestamp_int = int.from_bytes(payload[0:4], "big") + (60 * 60 * 24 * 11323) - timestamp = datetime.utcfromtimestamp(timestamp_int) + timestamp = datetime.fromtimestamp(timestamp_int, tz=timezone.utc) data = _decrypt_payload(payload, key) latitude = struct.unpack(">i", data[0:4])[0] / 10000000 @@ -113,33 +144,40 @@ class KeyReport: status, ) - def __lt__(self, other): + def __lt__(self, other: KeyReport) -> bool: + """Compare against another `KeyReport`. + + A `KeyReport` is said to be "less than" another `KeyReport` iff its recorded + timestamp is strictly less than the other report. + """ if isinstance(other, KeyReport): return self.timestamp < other.timestamp return NotImplemented - def __repr__(self): + def __repr__(self) -> str: + """Human-readable string representation of the location report.""" return ( f"" ) -async def fetch_reports( +async def fetch_reports( # noqa: PLR0913 dsid: str, search_party_token: str, anisette_headers: dict[str, str], date_from: datetime, date_to: datetime, keys: Sequence[KeyPair], -): +) -> dict[KeyPair, list[KeyReport]]: + """Look up reports for given `KeyPair`s.""" start_date = date_from.timestamp() * 1000 end_date = date_to.timestamp() * 1000 ids = [key.hashed_adv_key_b64 for key in keys] data = {"search": [{"startDate": start_date, "endDate": end_date, "ids": ids}]} - # TODO: do not create a new session every time - # probably needs a wrapper class to allow closing the connections + # TODO(malmeloo): do not create a new session every time + # https://github.com/malmeloo/FindMy.py/issues/3 async with await _session.post( "https://gateway.icloud.com/acsnservice/fetch", auth=(dsid, search_party_token), @@ -148,7 +186,8 @@ async def fetch_reports( ) as r: resp = await r.json() if not r.ok or resp["statusCode"] != "200": - raise ReportsError(f"Failed to fetch reports: {resp['statusCode']}") + msg = f"Failed to fetch reports: {resp['statusCode']}" + raise ReportsError(msg) await _session.close() reports: dict[KeyPair, list[KeyReport]] = {key: [] for key in keys} @@ -156,11 +195,14 @@ async def fetch_reports( for report in resp.get("results", []): key = id_to_key[report["id"]] - date_published = datetime.utcfromtimestamp(report.get("datePublished", 0) / 1000) + date_published = datetime.fromtimestamp( + report.get("datePublished", 0) / 1000, + tz=timezone.utc, + ) description = report.get("description", "") payload = base64.b64decode(report["payload"]) - report = KeyReport.from_payload(key, date_published, description, payload) - reports[key].append(report) + r = KeyReport.from_payload(key, date_published, description, payload) + reports[key].append(r) return {key: sorted(reps) for key, reps in reports.items()} diff --git a/pyproject.toml b/pyproject.toml index cbbdb70..918551a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,19 @@ aiohttp = "^3.9.1" pre-commit = "^3.6.0" [tool.ruff] +exclude = [ + "examples/" +] + +select = [ + "ALL", +] +ignore = [ + "ANN101", # annotations on `self` + "ANN102", # annotations on `cls` + "FIX002", # resolving TODOs +] + line-length = 100 [build-system]