Add Pyright as static type checker

This commit is contained in:
Mike A
2024-02-10 17:25:54 +01:00
parent 34ac209f0d
commit 7d09296565
10 changed files with 148 additions and 86 deletions

View File

@@ -4,15 +4,23 @@ from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, ParamSpec
from typing import Any, TypedDict
from aiohttp import BasicAuth, ClientSession, ClientTimeout
from typing_extensions import Unpack
from .parsers import decode_plist
logging.getLogger(__name__)
class _HttpRequestOptions(TypedDict, total=False):
json: dict[str, Any]
headers: dict[str, str]
auth: tuple[str, str] | BasicAuth
data: bytes
class HttpResponse:
"""Response of a request made by `HttpSession`."""
@@ -49,19 +57,19 @@ class HttpResponse:
return data
P = ParamSpec("P")
class HttpSession:
"""Asynchronous HTTP session manager. For internal use only."""
def __init__(self) -> None: # noqa: D107
self._session: ClientSession | None = None
async def _ensure_session(self) -> None:
if self._session is None:
logging.debug("Creating aiohttp session")
self._session = ClientSession(timeout=ClientTimeout(total=5))
async def _get_session(self) -> ClientSession:
if self._session is not None:
return self._session
logging.debug("Creating aiohttp session")
self._session = ClientSession(timeout=ClientTimeout(total=5))
return self._session
async def close(self) -> None:
"""Close the underlying session. Should be called when session will no longer be used."""
@@ -89,33 +97,31 @@ class HttpSession:
self,
method: str,
url: str,
auth: tuple[str] | None = None,
**kwargs: P.kwargs,
**kwargs: Unpack[_HttpRequestOptions],
) -> HttpResponse:
"""
Make an HTTP request.
Keyword arguments will directly be passed to `aiohttp.ClientSession.request`.
"""
await self._ensure_session()
session = await self._get_session()
basic_auth = None
if auth is not None:
basic_auth = BasicAuth(auth[0], auth[1])
auth = kwargs.get("auth")
if isinstance(auth, tuple):
kwargs["auth"] = BasicAuth(auth[0], auth[1])
async with await self._session.request(
async with await session.request(
method,
url,
auth=basic_auth,
ssl=False,
**kwargs,
) as r:
return HttpResponse(r.status, await r.content.read())
async def get(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
async def get(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse:
"""Alias for `HttpSession.request("GET", ...)`."""
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs: P.kwargs) -> HttpResponse:
async def post(self, url: str, **kwargs: Unpack[_HttpRequestOptions]) -> HttpResponse:
"""Alias for `HttpSession.request("POST", ...)`."""
return await self.request("POST", url, **kwargs)