reports: attempt to reauthenticate on 401

This commit is contained in:
Mike A.
2024-07-11 21:02:34 +02:00
parent b631d0b8bd
commit 740bbf059c
2 changed files with 41 additions and 15 deletions

View File

@@ -5,6 +5,10 @@ class InvalidCredentialsError(Exception):
"""Raised when credentials are incorrect.""" """Raised when credentials are incorrect."""
class UnauthorizedError(Exception):
"""Raised when an authorization error occurs."""
class UnhandledProtocolError(RuntimeError): class UnhandledProtocolError(RuntimeError):
""" """
Raised when an unexpected error occurs while communicating with Apple servers. Raised when an unexpected error occurs while communicating with Apple servers.

View File

@@ -27,10 +27,15 @@ import bs4
import srp._pysrp as srp import srp._pysrp as srp
from typing_extensions import override from typing_extensions import override
from findmy.errors import InvalidCredentialsError, InvalidStateError, UnhandledProtocolError from findmy.errors import (
InvalidCredentialsError,
InvalidStateError,
UnauthorizedError,
UnhandledProtocolError,
)
from findmy.util import crypto from findmy.util import crypto
from findmy.util.closable import Closable from findmy.util.closable import Closable
from findmy.util.http import HttpSession, decode_plist from findmy.util.http import HttpResponse, HttpSession, decode_plist
from .reports import LocationReport, LocationReportsFetcher from .reports import LocationReport, LocationReportsFetcher
from .state import LoginState from .state import LoginState
@@ -585,18 +590,35 @@ class AsyncAppleAccount(BaseAppleAccount):
) )
data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]} data = {"search": [{"startDate": start, "endDate": end, "ids": ids}]}
r = await self._http.post( async def _do_request() -> HttpResponse:
self._ENDPOINT_REPORTS_FETCH, return await self._http.post(
auth=auth, self._ENDPOINT_REPORTS_FETCH,
headers=await self.get_anisette_headers(), auth=auth,
json=data, headers=await self.get_anisette_headers(),
) json=data,
resp = r.json() )
if not r.ok or resp["statusCode"] != "200":
msg = f"Failed to fetch reports: {resp['statusCode']}" r = await _do_request()
if r.status_code == 401:
logging.info("Got 401 while fetching reports, redoing login")
new_state = await self._gsa_authenticate()
if new_state != LoginState.AUTHENTICATED:
msg = f"Unexpected login state after reauth: {new_state}. Please log in again."
raise UnauthorizedError(msg)
await self._login_mobileme()
r = await _do_request()
if r.status_code == 401:
msg = "Not authorized to fetch reports."
raise UnauthorizedError(msg)
if not r.ok or r.json()["statusCode"] != "200":
msg = f"Failed to fetch reports: {r.json()['statusCode']}"
raise UnhandledProtocolError(msg) raise UnhandledProtocolError(msg)
return resp return r.json()
@overload @overload
async def fetch_reports( async def fetch_reports(
@@ -679,7 +701,7 @@ class AsyncAppleAccount(BaseAppleAccount):
return await self.fetch_reports(keys, start, end) return await self.fetch_reports(keys, start, end)
@require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA) @require_login_state(LoginState.LOGGED_OUT, LoginState.REQUIRE_2FA, LoginState.LOGGED_IN)
async def _gsa_authenticate( async def _gsa_authenticate(
self, self,
username: str | None = None, username: str | None = None,
@@ -805,9 +827,9 @@ class AsyncAppleAccount(BaseAppleAccount):
data = resp.plist() data = resp.plist()
mobileme_data = data.get("delegates", {}).get("com.apple.mobileme", {}) mobileme_data = data.get("delegates", {}).get("com.apple.mobileme", {})
status = mobileme_data.get("status") status = mobileme_data.get("status") or data.get("status")
if status != 0: if status != 0:
status_message = mobileme_data.get("status-message") status_message = mobileme_data.get("status-message") or data.get("status-message")
msg = f"com.apple.mobileme login failed with status {status}: {status_message}" msg = f"com.apple.mobileme login failed with status {status}: {status_message}"
raise UnhandledProtocolError(msg) raise UnhandledProtocolError(msg)