fix: more efficient key caching system

This commit is contained in:
Mike A.
2026-01-06 00:23:15 +01:00
parent 485b984b6b
commit 2afba5ed51

View File

@@ -6,8 +6,10 @@ Accessories could be anything ranging from AirTags to iPhones.
from __future__ import annotations
import bisect
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Literal, TypedDict, overload
@@ -390,13 +392,24 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
)
@dataclass(frozen=True)
class _CacheTier:
"""Configuration for a cache tier."""
interval: int # Cache every n'th key
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
# cache enough keys for an entire week.
# every interval'th key is cached.
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
_CACHE_INTERVAL = 1 # cache every key
# Define cache tiers: (interval, max_size)
# Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
# Tier 2: Cache every 672nd key (1 week), unlimited
_CACHE_TIERS = (
_CacheTier(interval=4, max_size=672),
_CacheTier(interval=672, max_size=None),
)
def __init__(
self,
@@ -422,7 +435,9 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
self._initial_sk = initial_sk
self._key_type = key_type
self._sk_cache: dict[int, bytes] = {}
# Multi-tier cache: dict + sorted indices per tier
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]
self._iter_ind = 0
@@ -441,36 +456,66 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""The type of key this generator produces."""
return self._key_type
def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
"""Find the largest cached index smaller than ind across all tiers."""
best_ind = 0
best_sk = self._initial_sk
for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
if not indices:
continue
# Use bisect to find the largest index < ind in O(log n)
pos = bisect.bisect_left(indices, ind)
if pos > 0:
cached_ind = indices[pos - 1]
if cached_ind > best_ind:
best_ind = cached_ind
best_sk = cache[cached_ind]
return best_ind, best_sk
def _update_caches(self, ind: int, sk: bytes) -> None:
"""Update all applicable cache tiers with the computed key."""
for tier_idx, tier in enumerate(self._CACHE_TIERS):
if ind % tier.interval == 0:
cache = self._sk_caches[tier_idx]
indices = self._cache_indices[tier_idx]
# Add to cache if not already present
if ind not in cache:
cache[ind] = sk
bisect.insort(indices, ind)
# Evict if cache exceeds size limit
if tier.max_size is not None and len(cache) > tier.max_size:
# If adding a historical key, evict smallest (oldest)
# If adding a future key, evict largest (newest old key)
if indices and ind > indices[0]:
evict_ind = indices.pop(0)
else:
evict_ind = indices.pop(-1)
del cache[evict_ind]
def _get_sk(self, ind: int) -> bytes:
if ind < 0:
msg = "The key index must be non-negative"
raise ValueError(msg)
# retrieve from cache
cached_sk = self._sk_cache.get(ind)
if cached_sk is not None:
return cached_sk
# Check all caches for exact match
for cache in self._sk_caches:
cached_sk = cache.get(ind)
if cached_sk is not None:
return cached_sk
# not in cache: find largest cached index smaller than ind (if exists)
start_ind: int = 0
cur_sk: bytes = self._initial_sk
for cached_ind in self._sk_cache:
if cached_ind < ind and cached_ind > start_ind:
start_ind = cached_ind
cur_sk = self._sk_cache[cached_ind]
# Find best starting point across all tiers
start_ind, cur_sk = self._find_best_cached_sk(ind)
# compute and update cache
# Compute from best cached position to target
for cur_ind in range(start_ind + 1, ind + 1):
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
# insert intermediate result into cache and evict oldest entry if necessary
if cur_ind % self._CACHE_INTERVAL == 0:
self._sk_cache[cur_ind] = cur_sk
if len(self._sk_cache) > self._CACHE_SIZE:
# evict oldest entry
oldest_ind = min(self._sk_cache.keys())
del self._sk_cache[oldest_ind]
self._update_caches(cur_ind, cur_sk)
return cur_sk