diff --git a/findmy/accessory.py b/findmy/accessory.py index 2864243..4da1145 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -7,7 +7,7 @@ from __future__ import annotations import logging from datetime import datetime, timedelta -from typing import Generator +from typing import Generator, overload from .keys import KeyGenerator, KeyPair, KeyType from .util import crypto @@ -150,6 +150,14 @@ class AccessoryKeyGenerator(KeyGenerator[KeyPair]): return self._get_keypair(self._iter_ind) + @overload + def __getitem__(self, val: int) -> KeyPair: + ... + + @overload + def __getitem__(self, val: slice) -> Generator[KeyPair, None, None]: + ... + def __getitem__(self, val: int | slice) -> KeyPair | Generator[KeyPair, None, None]: if isinstance(val, int): if val < 0: diff --git a/findmy/keys.py b/findmy/keys.py index 42d5ce5..72739af 100644 --- a/findmy/keys.py +++ b/findmy/keys.py @@ -52,7 +52,7 @@ class HasPublicKey(ABC): def __hash__(self) -> int: return crypto.bytes_to_int(self.adv_key_bytes) - def __eq__(self, other: HasPublicKey) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, HasPublicKey): return NotImplemented @@ -135,11 +135,13 @@ class KeyGenerator(ABC, Generic[K]): return NotImplemented @overload + @abstractmethod def __getitem__(self, val: int) -> K: ... @overload - def __getitem__(self, slc: slice) -> Generator[K, None, None]: + @abstractmethod + def __getitem__(self, val: slice) -> Generator[K, None, None]: ... @abstractmethod diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 168a494..f06737d 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -11,12 +11,7 @@ import plistlib import uuid from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone -from typing import ( - TYPE_CHECKING, - Any, - Sequence, - TypedDict, -) +from typing import TYPE_CHECKING, Any, Sequence, TypedDict import bs4 import srp._pysrp as srp @@ -39,6 +34,7 @@ from .twofactor import ( if TYPE_CHECKING: from findmy.keys import KeyPair + from findmy.util.types import MaybeCoro from .anisette import BaseAnisetteProvider @@ -82,7 +78,7 @@ def _decrypt_cbc(session_key: bytes, data: bytes) -> bytes: def _extract_phone_numbers(html: str) -> list[dict]: soup = bs4.BeautifulSoup(html, features="html.parser") - data_elem = soup.find("script", **{"class": "boot_args"}) + data_elem = soup.find("script", {"class": "boot_args"}) if not data_elem: msg = "Could not find HTML element containing phone numbers" raise RuntimeError(msg) @@ -102,7 +98,7 @@ class BaseAppleAccount(ABC): @property @abstractmethod - def account_name(self) -> str: + def account_name(self) -> str | None: """ The name of the account as reported by Apple. @@ -154,12 +150,12 @@ class BaseAppleAccount(ABC): raise NotImplementedError @abstractmethod - def login(self, username: str, password: str) -> LoginState: + def login(self, username: str, password: str) -> MaybeCoro[LoginState]: """Log in to an Apple account using a username and password.""" raise NotImplementedError @abstractmethod - def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + def get_2fa_methods(self) -> MaybeCoro[Sequence[BaseSecondFactorMethod]]: """ Get a list of 2FA methods that can be used as a secondary challenge. @@ -168,7 +164,7 @@ class BaseAppleAccount(ABC): raise NotImplementedError @abstractmethod - def sms_2fa_request(self, phone_number_id: int) -> None: + def sms_2fa_request(self, phone_number_id: int) -> MaybeCoro[None]: """ Request a 2FA code to be sent to a specific phone number ID. @@ -177,7 +173,7 @@ class BaseAppleAccount(ABC): raise NotImplementedError @abstractmethod - def sms_2fa_submit(self, phone_number_id: int, code: str) -> LoginState: + def sms_2fa_submit(self, phone_number_id: int, code: str) -> MaybeCoro[LoginState]: """ Submit a 2FA code that was sent to a specific phone number ID. @@ -191,7 +187,7 @@ class BaseAppleAccount(ABC): keys: Sequence[KeyPair], date_from: datetime, date_to: datetime, - ) -> dict[KeyPair, list[KeyReport]]: + ) -> MaybeCoro[dict[KeyPair, list[KeyReport]]]: """ Fetch location reports for a sequence of `KeyPair`s between `date_from` and `date_end`. @@ -204,7 +200,7 @@ class BaseAppleAccount(ABC): self, keys: Sequence[KeyPair], hours: int = 7 * 24, - ) -> dict[KeyPair, list[KeyReport]]: + ) -> MaybeCoro[dict[KeyPair, list[KeyReport]]]: """ Fetch location reports for a sequence of `KeyPair`s for the last `hours` hours. @@ -213,7 +209,7 @@ class BaseAppleAccount(ABC): raise NotImplementedError @abstractmethod - def get_anisette_headers(self, serial: str = "0") -> dict[str, str]: + def get_anisette_headers(self, serial: str = "0") -> MaybeCoro[dict[str, str]]: """ Retrieve a complete dictionary of Anisette headers. @@ -355,7 +351,7 @@ class AsyncAppleAccount(BaseAppleAccount): return await self._login_mobileme() @require_login_state(LoginState.REQUIRE_2FA) - async def get_2fa_methods(self) -> list[AsyncSecondFactorMethod]: + async def get_2fa_methods(self) -> Sequence[AsyncSecondFactorMethod]: """See `BaseAppleAccount.get_2fa_methods`.""" methods: list[AsyncSecondFactorMethod] = [] @@ -366,8 +362,8 @@ class AsyncAppleAccount(BaseAppleAccount): methods.extend( AsyncSmsSecondFactor( self, - number.get("id"), - number.get("numberWithDialCode"), + number.get("id") or -1, + number.get("numberWithDialCode") or "-", ) for number in phone_numbers ) @@ -499,11 +495,11 @@ class AsyncAppleAccount(BaseAppleAccount): logging.debug("Decrypting SPD data in response") - spd = _decrypt_cbc(usr.get_session_key(), r["spd"]) + spd = _decrypt_cbc(usr.get_session_key() or b"", r["spd"]) spd = decode_plist(spd) logging.debug("Received account information") - self._account_info: _AccountInfo = { + self._account_info = { "account_name": spd.get("acname"), "first_name": spd.get("fn"), "last_name": spd.get("ln"), @@ -553,7 +549,7 @@ class AsyncAppleAccount(BaseAppleAccount): resp = await self._http.post( "https://setup.icloud.com/setup/iosbuddy/loginDelegates", - auth=(self._username, self._login_state_data["idms_pet"]), + auth=(self._username or "", self._login_state_data["idms_pet"]), data=data, headers=headers, ) @@ -575,7 +571,7 @@ class AsyncAppleAccount(BaseAppleAccount): self, method: str, url: str, - data: dict | None = None, + data: dict[str, Any] | None = None, ) -> str: adsid = self._login_state_data["adsid"] idms_token = self._login_state_data["idms_token"] @@ -595,7 +591,7 @@ class AsyncAppleAccount(BaseAppleAccount): r = await self._http.request( method, url, - json=data, + json=data or {}, headers=headers, ) if not r.ok: @@ -678,7 +674,7 @@ class AppleAccount(BaseAppleAccount): return self._asyncacc.login_state @property - def account_name(self) -> str: + def account_name(self) -> str | None: """See `AsyncAppleAccount.login_state`.""" return self._asyncacc.account_name @@ -705,7 +701,7 @@ class AppleAccount(BaseAppleAccount): coro = self._asyncacc.login(username, password) return self._loop.run_until_complete(coro) - def get_2fa_methods(self) -> list[SyncSecondFactorMethod]: + def get_2fa_methods(self) -> Sequence[SyncSecondFactorMethod]: """See `AsyncAppleAccount.get_2fa_methods`.""" coro = self._asyncacc.get_2fa_methods() methods = self._loop.run_until_complete(coro) diff --git a/findmy/reports/state.py b/findmy/reports/state.py index 5cf751e..96b9556 100644 --- a/findmy/reports/state.py +++ b/findmy/reports/state.py @@ -1,18 +1,11 @@ """Code related to internal account state handling.""" from enum import Enum from functools import wraps -from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar +from typing import Callable, Concatenate, ParamSpec, TypeVar from findmy.util.errors import InvalidStateError -if TYPE_CHECKING: - # noinspection PyUnresolvedReferences - from .account import BaseAppleAccount - -P = ParamSpec("P") -R = TypeVar("R") -A = TypeVar("A", bound="BaseAppleAccount") -F = Callable[Concatenate[A, P], R] +from .account import BaseAppleAccount class LoginState(Enum): @@ -40,12 +33,22 @@ class LoginState(Enum): return self.__str__() -def require_login_state(*states: LoginState) -> Callable[[F], F]: +_P = ParamSpec("_P") +_R = TypeVar("_R") +_A = TypeVar("_A", bound="BaseAppleAccount") +_F = Callable[Concatenate[_A, _P], _R] + + +def require_login_state(*states: LoginState) -> Callable[[_F], _F]: """Enforce a login state as precondition for a method.""" - def decorator(func: F) -> F: + def decorator(func: _F) -> _F: @wraps(func) - def wrapper(acc: A, *args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(acc: _A, *args: _P.args, **kwargs: _P.kwargs) -> _R: # pyright: ignore [reportInvalidTypeVarUse] + if not isinstance(args[0], BaseAppleAccount): + msg = "This decorator can only be used on instances of BaseAppleAccount." + raise TypeError(msg) + if acc.login_state not in states: msg = ( f"Invalid login state! Currently: {acc.login_state}" diff --git a/findmy/reports/twofactor.py b/findmy/reports/twofactor.py index c2ff81e..0441099 100644 --- a/findmy/reports/twofactor.py +++ b/findmy/reports/twofactor.py @@ -1,6 +1,8 @@ """Public classes related to handling two-factor authentication.""" -from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, TypeVar +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +from findmy.util.types import MaybeCoro from .state import LoginState @@ -8,23 +10,23 @@ if TYPE_CHECKING: # noinspection PyUnresolvedReferences from .account import AppleAccount, AsyncAppleAccount, BaseAppleAccount -T = TypeVar("T", bound="BaseAppleAccount") +_AccType = TypeVar("_AccType", bound="BaseAppleAccount") -class BaseSecondFactorMethod(metaclass=ABCMeta): +class BaseSecondFactorMethod(ABC, Generic[_AccType]): """Base class for a second-factor authentication method for an Apple account.""" - def __init__(self, account: T) -> None: + def __init__(self, account: _AccType) -> None: """Initialize the second-factor method.""" - self._account: T = account + self._account: _AccType = account @property - def account(self) -> T: + def account(self) -> _AccType: """The account associated with the second-factor method.""" return self._account @abstractmethod - def request(self) -> None: + def request(self) -> MaybeCoro[None]: """ Put in a request for the second-factor challenge. @@ -33,12 +35,12 @@ class BaseSecondFactorMethod(metaclass=ABCMeta): raise NotImplementedError @abstractmethod - def submit(self, code: str) -> LoginState: + def submit(self, code: str) -> MaybeCoro[LoginState]: """Submit a code to complete the second-factor challenge.""" raise NotImplementedError -class AsyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): +class AsyncSecondFactorMethod(BaseSecondFactorMethod, ABC): """ An asynchronous implementation of a second-factor authentication method. @@ -54,8 +56,18 @@ class AsyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): """The account associated with the second-factor method.""" return self._account + @abstractmethod + async def request(self) -> None: + """See `BaseSecondFactorMethod.request`.""" + raise NotImplementedError -class SyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): + @abstractmethod + async def submit(self, code: str) -> LoginState: + """See `BaseSecondFactorMethod.submit`.""" + raise NotImplementedError + + +class SyncSecondFactorMethod(BaseSecondFactorMethod, ABC): """ A synchronous implementation of a second-factor authentication method. @@ -71,8 +83,18 @@ class SyncSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): """The account associated with the second-factor method.""" return self._account + @abstractmethod + def request(self) -> None: + """See `BaseSecondFactorMethod.request`.""" + raise NotImplementedError -class SmsSecondFactorMethod(BaseSecondFactorMethod, metaclass=ABCMeta): + @abstractmethod + def submit(self, code: str) -> LoginState: + """See `BaseSecondFactorMethod.submit`.""" + raise NotImplementedError + + +class SmsSecondFactorMethod(BaseSecondFactorMethod, ABC): """Base class for SMS-based two-factor authentication.""" @property diff --git a/findmy/scanner/scanner.py b/findmy/scanner/scanner.py index 044a01e..8d83fe7 100644 --- a/findmy/scanner/scanner.py +++ b/findmy/scanner/scanner.py @@ -4,12 +4,16 @@ from __future__ import annotations import asyncio import logging import time -from typing import Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator -import bleak +from bleak import BleakScanner from findmy.keys import HasPublicKey +if TYPE_CHECKING: + from bleak.backends.device import BLEDevice + from bleak.backends.scanner import AdvertisementData + logging.getLogger(__name__) @@ -109,12 +113,6 @@ class OfflineFindingDevice(HasPublicKey): f" status={self.status}, hint={self.hint})" ) - def __eq__(self, other: OfflineFindingDevice) -> bool: - """Check if two OfflineFindingDevices are equal by comparing their MAC addresses.""" - if not isinstance(other, OfflineFindingDevice): - return False - return other.mac_address == self.mac_address - def __hash__(self) -> int: """Hash an OfflineFindingDevice. This is simply the MAC address as an integer.""" return int.from_bytes(self._mac_bytes, "big") @@ -134,12 +132,10 @@ class OfflineFindingScanner: You most likely do not want to use this yourself; check out `OfflineFindingScanner.create` instead. """ - self._scanner: bleak.BleakScanner = bleak.BleakScanner(self._scan_callback) + self._scanner: BleakScanner = BleakScanner(self._scan_callback) self._loop = loop - self._device_fut: asyncio.Future[ - (bleak.BLEDevice, bleak.AdvertisementData) - ] = loop.create_future() + self._device_fut: asyncio.Future[tuple[BLEDevice, AdvertisementData]] = loop.create_future() self._scanner_count: int = 0 @@ -165,8 +161,8 @@ class OfflineFindingScanner: async def _scan_callback( self, - device: bleak.BLEDevice, - data: bleak.AdvertisementData, + device: BLEDevice, + data: AdvertisementData, ) -> None: self._device_fut.set_result((device, data)) self._device_fut = self._loop.create_future() @@ -186,7 +182,7 @@ class OfflineFindingScanner: timeout: float = 10, *, extend_timeout: bool = False, - ) -> AsyncGenerator[OfflineFindingDevice]: + ) -> AsyncGenerator[OfflineFindingDevice, None]: """ Scan for `OfflineFindingDevice`s for up to `timeout` seconds. diff --git a/findmy/util/http.py b/findmy/util/http.py index 20c9d37..20f11a7 100644 --- a/findmy/util/http.py +++ b/findmy/util/http.py @@ -4,15 +4,23 @@ from __future__ import annotations import asyncio import json import logging -from typing import Any, ParamSpec +from typing import Any, TypedDict from aiohttp import BasicAuth, ClientSession, ClientTimeout +from typing_extensions import Unpack from .parsers import decode_plist logging.getLogger(__name__) +class _HttpRequestOptions(TypedDict, total=False): + json: dict[str, Any] + headers: dict[str, str] + auth: tuple[str, str] | BasicAuth + data: bytes + + class HttpResponse: """Response of a request made by `HttpSession`.""" @@ -49,19 +57,19 @@ class HttpResponse: return data -P = ParamSpec("P") - - class HttpSession: """Asynchronous HTTP session manager. For internal use only.""" def __init__(self) -> None: # noqa: D107 self._session: ClientSession | None = None - async def _ensure_session(self) -> None: - if self._session is None: - logging.debug("Creating aiohttp session") - self._session = ClientSession(timeout=ClientTimeout(total=5)) + async def _get_session(self) -> ClientSession: + if self._session is not None: + return self._session + + logging.debug("Creating aiohttp session") + self._session = ClientSession(timeout=ClientTimeout(total=5)) + return self._session async def close(self) -> None: """Close the underlying session. Should be called when session will no longer be used.""" @@ -89,33 +97,31 @@ class HttpSession: self, method: str, url: str, - auth: tuple[str] | None = None, - **kwargs: P.kwargs, + **kwargs: Unpack[_HttpRequestOptions], ) -> HttpResponse: """ Make an HTTP request. Keyword arguments will directly be passed to `aiohttp.ClientSession.request`. """ - await self._ensure_session() + session = await self._get_session() - basic_auth = None - if auth is not None: - basic_auth = BasicAuth(auth[0], auth[1]) + auth = kwargs.get("auth") + if isinstance(auth, tuple): + kwargs["auth"] = BasicAuth(auth[0], auth[1]) - async with await self._session.request( + async with await session.request( method, url, - auth=basic_auth, ssl=False, **kwargs, ) as r: return HttpResponse(r.status, await r.content.read()) - async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse: + async def get(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse: """Alias for `HttpSession.request("GET", ...)`.""" return await self.request("GET", url, **kwargs) - async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse: + async def post(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse: """Alias for `HttpSession.request("POST", ...)`.""" return await self.request("POST", url, **kwargs) diff --git a/findmy/util/types.py b/findmy/util/types.py new file mode 100644 index 0000000..8c754aa --- /dev/null +++ b/findmy/util/types.py @@ -0,0 +1,7 @@ +"""Utility types.""" + +from typing import Coroutine, TypeVar + +T = TypeVar("T") + +MaybeCoro = T | Coroutine[None, None, T] diff --git a/poetry.lock b/poetry.lock index 4af3779..c78d880 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1073,6 +1073,24 @@ files = [ [package.dependencies] pyobjc-core = ">=9.2" +[[package]] +name = "pyright" +version = "1.1.350" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.350-py3-none-any.whl", hash = "sha256:f1dde6bcefd3c90aedbe9dd1c573e4c1ddbca8c74bf4fa664dd3b1a599ac9a66"}, + {file = "pyright-1.1.350.tar.gz", hash = "sha256:a8ba676de3a3737ea4d8590604da548d4498cc5ee9ee00b1a403c6db987916c6"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" + +[package.extras] +all = ["twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1760,4 +1778,4 @@ scan = ["bleak"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "f0a6463477183b86d152b71aa14404fea80a3b28df20c43782eb56de00db8d91" +content-hash = "696a56ccbba231e3ec702aaee911977819b996d21074b37807c42a45d107c7ab" diff --git a/pyproject.toml b/pyproject.toml index c9e277a..474cd9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,10 @@ scan = ["bleak"] pre-commit = "^3.6.0" sphinx = "^7.2.6" sphinx-autoapi = "^3.0.0" +pyright = "^1.1.350" + +[tool.pyright] +typeCheckingMode = "standard" [tool.ruff] exclude = [