Merge pull request #216 from malmeloo/feat/better-key-cache

fix: more efficient key caching system
This commit is contained in:
Mike Almeloo
2026-01-08 22:32:46 +01:00
committed by GitHub

View File

@@ -6,8 +6,10 @@ Accessories could be anything ranging from AirTags to iPhones.
from __future__ import annotations from __future__ import annotations
import bisect
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Literal, TypedDict, overload from typing import TYPE_CHECKING, Literal, TypedDict, overload
@@ -390,13 +392,24 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
) )
@dataclass(frozen=True)
class _CacheTier:
"""Configuration for a cache tier."""
interval: int # Cache every n'th key
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]): class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do.""" """KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
# cache enough keys for an entire week. # Define cache tiers: (interval, max_size)
# every interval'th key is cached. # Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour # Tier 2: Cache every 672nd key (1 week), unlimited
_CACHE_INTERVAL = 1 # cache every key _CACHE_TIERS = (
_CacheTier(interval=4, max_size=672),
_CacheTier(interval=672, max_size=None),
)
def __init__( def __init__(
self, self,
@@ -422,7 +435,9 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
self._initial_sk = initial_sk self._initial_sk = initial_sk
self._key_type = key_type self._key_type = key_type
self._sk_cache: dict[int, bytes] = {} # Multi-tier cache: dict + sorted indices per tier
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]
self._iter_ind = 0 self._iter_ind = 0
@@ -441,36 +456,68 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""The type of key this generator produces.""" """The type of key this generator produces."""
return self._key_type return self._key_type
def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
"""Find the largest cached index smaller than ind across all tiers."""
best_ind = 0
best_sk = self._initial_sk
for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
if not indices:
continue
# Use bisect to find the largest index < ind in O(log n)
pos = bisect.bisect_left(indices, ind)
if pos == 0: # No cached index less than ind
continue
cached_ind = indices[pos - 1]
if cached_ind > best_ind:
best_ind = cached_ind
best_sk = cache[cached_ind]
return best_ind, best_sk
def _update_caches(self, ind: int, sk: bytes) -> None:
"""Update all applicable cache tiers with the computed key."""
for tier_idx, tier in enumerate(self._CACHE_TIERS):
if ind % tier.interval != 0:
continue
cache = self._sk_caches[tier_idx]
indices = self._cache_indices[tier_idx]
# Add to cache if not already present
if ind in cache:
continue
cache[ind] = sk
bisect.insort(indices, ind)
# Evict if cache exceeds size limit
if tier.max_size is not None and len(cache) > tier.max_size:
# If adding a historical key, evict smallest index
# If adding a future key, evict largest
evict_ind = indices.pop(0 if indices and ind > indices[0] else -1)
del cache[evict_ind]
def _get_sk(self, ind: int) -> bytes: def _get_sk(self, ind: int) -> bytes:
if ind < 0: if ind < 0:
msg = "The key index must be non-negative" msg = "The key index must be non-negative"
raise ValueError(msg) raise ValueError(msg)
# retrieve from cache # Check all caches for exact match
cached_sk = self._sk_cache.get(ind) for cache in self._sk_caches:
if cached_sk is not None: cached_sk = cache.get(ind)
return cached_sk if cached_sk is not None:
return cached_sk
# not in cache: find largest cached index smaller than ind (if exists) # Find best starting point across all tiers
start_ind: int = 0 start_ind, cur_sk = self._find_best_cached_sk(ind)
cur_sk: bytes = self._initial_sk
for cached_ind in self._sk_cache:
if cached_ind < ind and cached_ind > start_ind:
start_ind = cached_ind
cur_sk = self._sk_cache[cached_ind]
# compute and update cache # Compute from best cached position to target
for cur_ind in range(start_ind + 1, ind + 1): for cur_ind in range(start_ind + 1, ind + 1):
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32) cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
self._update_caches(cur_ind, cur_sk)
# insert intermediate result into cache and evict oldest entry if necessary
if cur_ind % self._CACHE_INTERVAL == 0:
self._sk_cache[cur_ind] = cur_sk
if len(self._sk_cache) > self._CACHE_SIZE:
# evict oldest entry
oldest_ind = min(self._sk_cache.keys())
del self._sk_cache[oldest_ind]
return cur_sk return cur_sk