Fix all typing issues

This commit is contained in:
Mike A
2024-01-02 20:44:47 +01:00
parent cf4213c8bb
commit 58d7dabd21
2 changed files with 26 additions and 11 deletions

View File

@@ -11,7 +11,16 @@ import plistlib
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypedDict, TypeVar from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
ParamSpec,
Sequence,
TypedDict,
TypeVar,
)
import bs4 import bs4
import srp._pysrp as srp import srp._pysrp as srp
@@ -91,13 +100,16 @@ def _extract_phone_numbers(html: str) -> list[dict]:
return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", []) return data.get("direct", {}).get("phoneNumberVerification", {}).get("trustedPhoneNumbers", [])
F = TypeVar("F", bound=Callable[[BaseAppleAccount, ...], Any]) P = ParamSpec("P")
R = TypeVar("R")
A = TypeVar("A", bound=BaseAppleAccount)
F = Callable[Concatenate[A, P], R]
def _require_login_state(*states: LoginState) -> Callable[[F], F]: def _require_login_state(*states: LoginState) -> Callable[[F], F]:
def decorator(func: F) -> F: def decorator(func: F) -> F:
@wraps(func) @wraps(func)
def wrapper(acc: BaseAppleAccount, *args, **kwargs): def wrapper(acc: A, *args: P.args, **kwargs: P.kwargs) -> R:
if acc.login_state not in states: if acc.login_state not in states:
msg = ( msg = (
f"Invalid login state! Currently: {acc.login_state}" f"Invalid login state! Currently: {acc.login_state}"
@@ -320,9 +332,9 @@ class AsyncAppleAccount(BaseAppleAccount):
return await self._login_mobileme() return await self._login_mobileme()
@_require_login_state(LoginState.REQUIRE_2FA) @_require_login_state(LoginState.REQUIRE_2FA)
async def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: async def get_2fa_methods(self) -> list[AsyncSmsSecondFactor]:
"""See `BaseAppleAccount.get_2fa_methods`.""" """See `BaseAppleAccount.get_2fa_methods`."""
methods: list[BaseSecondFactorMethod] = [] methods: list[AsyncSmsSecondFactor] = []
# sms # sms
auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth") auth_page = await self._sms_2fa_request("GET", "https://gsa.apple.com/auth")
@@ -569,7 +581,7 @@ class AsyncAppleAccount(BaseAppleAccount):
return r.text() return r.text()
async def _gsa_request(self, params: dict[str, Any]) -> Any: async def _gsa_request(self, params: dict[str, Any]) -> dict[Any, Any]:
request_data = { request_data = {
"cpd": { "cpd": {
"bootstrap": True, "bootstrap": True,
@@ -666,7 +678,7 @@ class AppleAccount(BaseAppleAccount):
coro = self._asyncacc.login(username, password) coro = self._asyncacc.login(username, password)
return self._loop.run_until_complete(coro) return self._loop.run_until_complete(coro)
def get_2fa_methods(self) -> list[BaseSecondFactorMethod]: def get_2fa_methods(self) -> list[SmsSecondFactor]:
"""See `AsyncAppleAccount.get_2fa_methods`.""" """See `AsyncAppleAccount.get_2fa_methods`."""
coro = self._asyncacc.get_2fa_methods() coro = self._asyncacc.get_2fa_methods()
methods = self._loop.run_until_complete(coro) methods = self._loop.run_until_complete(coro)

View File

@@ -5,7 +5,7 @@ import asyncio
import json import json
import logging import logging
import plistlib import plistlib
from typing import Any from typing import Any, ParamSpec
from aiohttp import BasicAuth, ClientSession, ClientTimeout from aiohttp import BasicAuth, ClientSession, ClientTimeout
@@ -61,6 +61,9 @@ class HttpResponse:
return data return data
P = ParamSpec("P")
class HttpSession: class HttpSession:
"""Asynchronous HTTP session manager. For internal use only.""" """Asynchronous HTTP session manager. For internal use only."""
@@ -98,7 +101,7 @@ class HttpSession:
method: str, method: str,
url: str, url: str,
auth: tuple[str] | None = None, auth: tuple[str] | None = None,
**kwargs: Any, **kwargs: P.kwargs,
) -> HttpResponse: ) -> HttpResponse:
"""Make an HTTP request. """Make an HTTP request.
@@ -119,10 +122,10 @@ class HttpSession:
) as r: ) as r:
return HttpResponse(r.status, await r.content.read()) return HttpResponse(r.status, await r.content.read())
async def get(self, url: str, **kwargs: Any) -> HttpResponse: async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
"""Alias for `HttpSession.request("GET", ...)`.""" """Alias for `HttpSession.request("GET", ...)`."""
return await self.request("GET", url, **kwargs) return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs: Any) -> HttpResponse: async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
"""Alias for `HttpSession.request("POST", ...)`.""" """Alias for `HttpSession.request("POST", ...)`."""
return await self.request("POST", url, **kwargs) return await self.request("POST", url, **kwargs)