Fix all typing issues

This commit is contained in:
Mike A
2024-01-02 20:44:47 +01:00
parent cf4213c8bb
commit 58d7dabd21
2 changed files with 26 additions and 11 deletions

View File

@@ -11,7 +11,16 @@ import plistlib
import uuid
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
ParamSpec,
Sequence,
TypedDict,
TypeVar,
)
import bs4
import srp._pysrp as srp
@@ -91,13 +100,16 @@ def _extract_phone_numbers(html: str) -> list[dict]:
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any])
P = ParamSpec("P")
R = TypeVar("R")
A = TypeVar("A", bound=BaseAppleAccount)
F = Callable[Concatenate[A, P], R]
def _require_login_state(*states: LoginState) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapper(acc: BaseAppleAccount, *args, **kwargs):
def wrapper(acc: A, *args: P.args, **kwargs: P.kwargs) -> R:
if acc.login_state not in states:
msg = (
f"Invalid login state! Currently: {acc.login_state}"
@@ -320,9 +332,9 @@ class AsyncAppleAccount(BaseAppleAccount):
return await self._login_mobileme()
@_require_login_state(LoginState.REQUIRE_2FA)
async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
async def get_2fa_methods(self) -> list[AsyncSmsSecondFactor]:
"""See `BaseAppleAccount.get_2fa_methods`."""
methods: list[BaseSecondFactorMethod] = []
methods: list[AsyncSmsSecondFactor] = []
# sms
auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth")
@@ -569,7 +581,7 @@ class AsyncAppleAccount(BaseAppleAccount):
return r.text()
async def _gsa_request(self, params: dict[str, Any]) -> Any:
async def _gsa_request(self, params: dict[str, Any]) -> dict[Any, Any]:
request_data = {
"cpd": {
"bootstrap": True,
@@ -666,7 +678,7 @@ class AppleAccount(BaseAppleAccount):
coro = self._asyncacc.login(username, password)
return self._loop.run_until_complete(coro)
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]:
def get_2fa_methods(self) -> list[SmsSecondFactor]:
"""See `AsyncAppleAccount.get_2fa_methods`."""
coro = self._asyncacc.get_2fa_methods()
methods = self._loop.run_until_complete(coro)

View File

@@ -5,7 +5,7 @@ import asyncio
import json
import logging
import plistlib
from typing import Any
from typing import Any, ParamSpec
from aiohttp import BasicAuth, ClientSession, ClientTimeout
@@ -61,6 +61,9 @@ class HttpResponse:
return data
P = ParamSpec("P")
class HttpSession:
"""Asynchronous HTTP session manager. For internal use only."""
@@ -98,7 +101,7 @@ class HttpSession:
method: str,
url: str,
auth: tuple[str] | None = None,
**kwargs: Any,
**kwargs: P.kwargs,
) -> HttpResponse:
"""Make an HTTP request.
@@ -119,10 +122,10 @@ class HttpSession:
) as r:
return HttpResponse(r.status, await r.content.read())
async def get(self, url: str, **kwargs: Any) -> HttpResponse:
async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
"""Alias for `HttpSession.request("GET", ...)`."""
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs: Any) -> HttpResponse:
async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
"""Alias for `HttpSession.request("POST", ...)`."""
return await self.request("POST", url, **kwargs)