Merge pull request #135 from ivanik7/fetch-reports-list-type

Update fetch_reports to accept lists with HasHashedPublicKey and RolingKeyPairSource at the same time
This commit is contained in:
Mike Almeloo
2025-06-14 16:22:46 +02:00
committed by GitHub
2 changed files with 42 additions and 119 deletions

View File

@@ -237,15 +237,6 @@ class BaseAppleAccount(Closable, ABC):
date_to: datetime | None,
) -> MaybeCoro[list[LocationReport]]: ...
@overload
@abstractmethod
def fetch_reports(
self,
keys: Sequence[HasHashedPublicKey],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[dict[HasHashedPublicKey, list[LocationReport]]]: ...
@overload
@abstractmethod
def fetch_reports(
@@ -256,26 +247,24 @@ class BaseAppleAccount(Closable, ABC):
) -> MaybeCoro[list[LocationReport]]: ...
@overload
@abstractmethod
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ...
@abstractmethod
def fetch_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource]
| RollingKeyPairSource,
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`.
@@ -292,14 +281,6 @@ class BaseAppleAccount(Closable, ABC):
hours: int = 7 * 24,
) -> MaybeCoro[list[LocationReport]]: ...
@overload
@abstractmethod
def fetch_last_reports(
self,
keys: Sequence[HasHashedPublicKey],
hours: int = 7 * 24,
) -> MaybeCoro[dict[HasHashedPublicKey, list[LocationReport]]]: ...
@overload
@abstractmethod
def fetch_last_reports(
@@ -312,22 +293,19 @@ class BaseAppleAccount(Closable, ABC):
@abstractmethod
def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
) -> MaybeCoro[dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]]: ...
@abstractmethod
def fetch_last_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours.
@@ -665,14 +643,6 @@ class AsyncAppleAccount(BaseAppleAccount):
date_to: datetime | None,
) -> list[LocationReport]: ...
@overload
async def fetch_reports(
self,
keys: Sequence[HasHashedPublicKey],
date_from: datetime,
date_to: datetime | None,
) -> dict[HasHashedPublicKey, list[LocationReport]]: ...
@overload
async def fetch_reports(
self,
@@ -684,25 +654,22 @@ class AsyncAppleAccount(BaseAppleAccount):
@overload
async def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_reports`."""
date_to = date_to or datetime.now().astimezone()
@@ -720,13 +687,6 @@ class AsyncAppleAccount(BaseAppleAccount):
hours: int = 7 * 24,
) -> list[LocationReport]: ...
@overload
async def fetch_last_reports(
self,
keys: Sequence[HasHashedPublicKey],
hours: int = 7 * 24,
) -> dict[HasHashedPublicKey, list[LocationReport]]: ...
@overload
async def fetch_last_reports(
self,
@@ -737,23 +697,20 @@ class AsyncAppleAccount(BaseAppleAccount):
@overload
async def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_last_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
@@ -1093,14 +1050,6 @@ class AppleAccount(BaseAppleAccount):
date_to: datetime | None,
) -> list[LocationReport]: ...
@overload
def fetch_reports(
self,
keys: Sequence[HasHashedPublicKey],
date_from: datetime,
date_to: datetime | None,
) -> dict[HasHashedPublicKey, list[LocationReport]]: ...
@overload
def fetch_reports(
self,
@@ -1112,24 +1061,21 @@ class AppleAccount(BaseAppleAccount):
@overload
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
@override
def fetch_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource]
| RollingKeyPairSource,
date_from: datetime,
date_to: datetime | None,
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
@@ -1142,13 +1088,6 @@ class AppleAccount(BaseAppleAccount):
hours: int = 7 * 24,
) -> list[LocationReport]: ...
@overload
def fetch_last_reports(
self,
keys: Sequence[HasHashedPublicKey],
hours: int = 7 * 24,
) -> dict[HasHashedPublicKey, list[LocationReport]]: ...
@overload
def fetch_last_reports(
self,
@@ -1159,22 +1098,19 @@ class AppleAccount(BaseAppleAccount):
@overload
def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
keys: Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
@override
def fetch_last_reports(
self,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
hours: int = 7 * 24,
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)

View File

@@ -245,14 +245,6 @@ class LocationReportsFetcher:
device: HasHashedPublicKey,
) -> list[LocationReport]: ...
@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: Sequence[HasHashedPublicKey],
) -> dict[HasHashedPublicKey, list[LocationReport]]: ...
@overload
async def fetch_reports(
self,
@@ -266,43 +258,34 @@ class LocationReportsFetcher:
self,
date_from: datetime,
date_to: datetime,
device: Sequence[RollingKeyPairSource],
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
device: Sequence[HasHashedPublicKey | RollingKeyPairSource],
) -> dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]: ...
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
| Sequence[HasHashedPublicKey | RollingKeyPairSource],
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
list[LocationReport] | dict[HasHashedPublicKey | RollingKeyPairSource, list[LocationReport]]
):
"""
Fetch location reports for a certain device.
When ``device`` is a single :class:`.HasHashedPublicKey`, this method will return
a list of location reports corresponding to that key.
When ``device`` is a sequence of :class:`.HasHashedPublicKey`s, it will return a dictionary
with the :class:`.HasHashedPublicKey` as key, and a list of location reports as value.
When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of
location reports corresponding to that source.
When ``device`` is a sequence of :class:`.HasHashedPublicKey`s or RollingKeyPairSource's,
it will return a dictionary with the :class:`.HasHashedPublicKey` or `.RollingKeyPairSource`
as key, and a list of location reports as value.
"""
key_devs: (
dict[HasHashedPublicKey, HasHashedPublicKey]
| dict[HasHashedPublicKey, RollingKeyPairSource]
) = {}
key_devs: dict[HasHashedPublicKey, HasHashedPublicKey | RollingKeyPairSource] = {}
if isinstance(device, HasHashedPublicKey):
# single key
key_devs = {device: device}
elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device):
# multiple static keys
device = cast("list[HasHashedPublicKey]", device)
key_devs = {key: key for key in device}
elif isinstance(device, RollingKeyPairSource):
# key generator
# add 12h margin to the generator
@@ -313,13 +296,17 @@ class LocationReportsFetcher:
date_to + timedelta(hours=12),
)
}
elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device):
elif isinstance(device, list) and all(
isinstance(x, HasHashedPublicKey | RollingKeyPairSource) for x in device
):
# multiple key generators
# add 12h margin to each generator
device = cast("list[RollingKeyPairSource]", device)
key_devs = {
device = cast("list[HasHashedPublicKey | RollingKeyPairSource]", device)
key_devs = {key: key for key in device if isinstance(key, HasHashedPublicKey)} | {
key: dev
for dev in device
if isinstance(dev, RollingKeyPairSource)
for key in dev.keys_between(
date_from - timedelta(hours=12),
date_to + timedelta(hours=12),