diff --git a/findmy/account.py b/findmy/account.py index 9342b82..43215bc 100644 --- a/findmy/account.py +++ b/findmy/account.py @@ -11,7 +11,16 @@ import plistlib import uuid from datetime import datetime, timedelta, timezone from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Concatenate, + ParamSpec, + Sequence, + TypedDict, + TypeVar, +) import bs4 import srp._pysrp as srp @@ -91,13 +100,16 @@ def _extract_phone_numbers(html: str) -> list[dict]: return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", []) -F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any]) +P = ParamSpec("P") +R = TypeVar("R") +A = TypeVar("A", bound=BaseAppleAccount) +F = Callable[Concatenate[A, P], R] def _require_login_state(*states: LoginState) -> Callable[[F], F]: def decorator(func: F) -> F: @wraps(func) - def wrapper(acc: BaseAppleAccount, *args, **kwargs): + def wrapper(acc: A, *args: P.args, **kwargs: P.kwargs) -> R: if acc.login_state not in states: msg = ( f"Invalid login state! Currently: {acc.login_state}" @@ -320,9 +332,9 @@ class AsyncAppleAccount(BaseAppleAccount): return await self._login_mobileme() @_require_login_state(LoginState.REQUIRE_2FA) - async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + async def get_2fa_methods(self) -> list[AsyncSmsSecondFactor]: """See `BaseAppleAccount.get_2fa_methods`.""" - methods: list[BaseSecondFactorMethod] = [] + methods: list[AsyncSmsSecondFactor] = [] # sms auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth") @@ -569,7 +581,7 @@ class AsyncAppleAccount(BaseAppleAccount): return r.text() - async def _gsa_request(self, params: dict[str, Any]) -> Any: + async def _gsa_request(self, params: dict[str, Any]) -> dict[Any, Any]: request_data = { "cpd": { "bootstrap": True, @@ -666,7 +678,7 @@ class AppleAccount(BaseAppleAccount): coro = self._asyncacc.login(username, password) return self._loop.run_until_complete(coro) - def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: + def get_2fa_methods(self) -> list[SmsSecondFactor]: """See `AsyncAppleAccount.get_2fa_methods`.""" coro = self._asyncacc.get_2fa_methods() methods = self._loop.run_until_complete(coro) diff --git a/findmy/http.py b/findmy/http.py index d31d2d9..06d274e 100644 --- a/findmy/http.py +++ b/findmy/http.py @@ -5,7 +5,7 @@ import asyncio import json import logging import plistlib -from typing import Any +from typing import Any, ParamSpec from aiohttp import BasicAuth, ClientSession, ClientTimeout @@ -61,6 +61,9 @@ class HttpResponse: return data +P = ParamSpec("P") + + class HttpSession: """Asynchronous HTTP session manager. For internal use only.""" @@ -98,7 +101,7 @@ class HttpSession: method: str, url: str, auth: tuple[str] | None = None, - **kwargs: Any, + **kwargs: P.kwargs, ) -> HttpResponse: """Make an HTTP request. @@ -119,10 +122,10 @@ class HttpSession: ) as r: return HttpResponse(r.status, await r.content.read()) - async def get(self, url: str, **kwargs: Any) -> HttpResponse: + async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse: """Alias for `HttpSession.request("GET", ...)`.""" return await self.request("GET", url, **kwargs) - async def post(self, url: str, **kwargs: Any) -> HttpResponse: + async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse: """Alias for `HttpSession.request("POST", ...)`.""" return await self.request("POST", url, **kwargs)