feat: cache more intermediate accessory keys

should improve performance for accessories that were paired a long time ago
This commit is contained in:
Mike A.
2025-10-21 00:16:26 +02:00
parent 2f4b969577
commit 997cf8233c

View File

@@ -377,6 +377,11 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
# cache enough keys for an entire week.
# every interval'th key is cached.
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
_CACHE_INTERVAL = 10
def __init__(
self,
master_key: bytes,
@@ -401,8 +406,7 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
self._initial_sk = initial_sk
self._key_type = key_type
self._cur_sk = initial_sk
self._cur_sk_ind = 0
self._sk_cache: dict[int, bytes] = {}
self._iter_ind = 0
@@ -426,14 +430,33 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
msg = "The key index must be non-negative"
raise ValueError(msg)
if ind < self._cur_sk_ind: # behind us; need to reset :(
self._cur_sk = self._initial_sk
self._cur_sk_ind = 0
# retrieve from cache
cached_sk = self._sk_cache.get(ind)
if cached_sk is not None:
return cached_sk
for _ in range(self._cur_sk_ind, ind):
self._cur_sk = crypto.x963_kdf(self._cur_sk, b"update", 32)
self._cur_sk_ind += 1
return self._cur_sk
# not in cache: find largest cached index smaller than ind (if exists)
start_ind: int = 0
cur_sk: bytes = self._initial_sk
for cached_ind in self._sk_cache:
if cached_ind < ind and cached_ind > start_ind:
start_ind = cached_ind
cur_sk = self._sk_cache[cached_ind]
# compute and update cache
for cur_ind in range(start_ind, ind):
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
# insert intermediate result into cache and evict oldest entry if necessary
if cur_ind % self._CACHE_INTERVAL == 0:
self._sk_cache[cur_ind] = cur_sk
if len(self._sk_cache) > self._CACHE_SIZE:
# evict oldest entry
oldest_ind = min(self._sk_cache.keys())
del self._sk_cache[oldest_ind]
return cur_sk
def _get_keypair(self, ind: int) -> KeyPair:
sk = self._get_sk(ind)
@@ -449,14 +472,14 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
@override
def __iter__(self) -> KeyGenerator:
self._iter_ind = -1
return self
@override
def __next__(self) -> KeyPair:
key = self._get_keypair(self._iter_ind)
self._iter_ind += 1
return self._get_keypair(self._iter_ind)
return key
@overload
def __getitem__(self, val: int) -> KeyPair: ...