mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 19:53:53 +02:00
Merge pull request #216 from malmeloo/feat/better-key-cache
fix: more efficient key caching system
This commit is contained in:
@@ -6,8 +6,10 @@ Accessories could be anything ranging from AirTags to iPhones.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, overload
|
||||
|
||||
@@ -390,13 +392,24 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CacheTier:
|
||||
"""Configuration for a cache tier."""
|
||||
|
||||
interval: int # Cache every n'th key
|
||||
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)
|
||||
|
||||
|
||||
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
|
||||
|
||||
# cache enough keys for an entire week.
|
||||
# every interval'th key is cached.
|
||||
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
|
||||
_CACHE_INTERVAL = 1 # cache every key
|
||||
# Define cache tiers: (interval, max_size)
|
||||
# Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
|
||||
# Tier 2: Cache every 672nd key (1 week), unlimited
|
||||
_CACHE_TIERS = (
|
||||
_CacheTier(interval=4, max_size=672),
|
||||
_CacheTier(interval=672, max_size=None),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -422,7 +435,9 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
self._initial_sk = initial_sk
|
||||
self._key_type = key_type
|
||||
|
||||
self._sk_cache: dict[int, bytes] = {}
|
||||
# Multi-tier cache: dict + sorted indices per tier
|
||||
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
|
||||
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]
|
||||
|
||||
self._iter_ind = 0
|
||||
|
||||
@@ -441,36 +456,68 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
"""The type of key this generator produces."""
|
||||
return self._key_type
|
||||
|
||||
def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
|
||||
"""Find the largest cached index smaller than ind across all tiers."""
|
||||
best_ind = 0
|
||||
best_sk = self._initial_sk
|
||||
|
||||
for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
|
||||
if not indices:
|
||||
continue
|
||||
|
||||
# Use bisect to find the largest index < ind in O(log n)
|
||||
pos = bisect.bisect_left(indices, ind)
|
||||
if pos == 0: # No cached index less than ind
|
||||
continue
|
||||
|
||||
cached_ind = indices[pos - 1]
|
||||
if cached_ind > best_ind:
|
||||
best_ind = cached_ind
|
||||
best_sk = cache[cached_ind]
|
||||
|
||||
return best_ind, best_sk
|
||||
|
||||
def _update_caches(self, ind: int, sk: bytes) -> None:
|
||||
"""Update all applicable cache tiers with the computed key."""
|
||||
for tier_idx, tier in enumerate(self._CACHE_TIERS):
|
||||
if ind % tier.interval != 0:
|
||||
continue
|
||||
|
||||
cache = self._sk_caches[tier_idx]
|
||||
indices = self._cache_indices[tier_idx]
|
||||
|
||||
# Add to cache if not already present
|
||||
if ind in cache:
|
||||
continue
|
||||
cache[ind] = sk
|
||||
bisect.insort(indices, ind)
|
||||
|
||||
# Evict if cache exceeds size limit
|
||||
if tier.max_size is not None and len(cache) > tier.max_size:
|
||||
# If adding a historical key, evict smallest index
|
||||
# If adding a future key, evict largest
|
||||
evict_ind = indices.pop(0 if indices and ind > indices[0] else -1)
|
||||
|
||||
del cache[evict_ind]
|
||||
|
||||
def _get_sk(self, ind: int) -> bytes:
|
||||
if ind < 0:
|
||||
msg = "The key index must be non-negative"
|
||||
raise ValueError(msg)
|
||||
|
||||
# retrieve from cache
|
||||
cached_sk = self._sk_cache.get(ind)
|
||||
if cached_sk is not None:
|
||||
return cached_sk
|
||||
# Check all caches for exact match
|
||||
for cache in self._sk_caches:
|
||||
cached_sk = cache.get(ind)
|
||||
if cached_sk is not None:
|
||||
return cached_sk
|
||||
|
||||
# not in cache: find largest cached index smaller than ind (if exists)
|
||||
start_ind: int = 0
|
||||
cur_sk: bytes = self._initial_sk
|
||||
for cached_ind in self._sk_cache:
|
||||
if cached_ind < ind and cached_ind > start_ind:
|
||||
start_ind = cached_ind
|
||||
cur_sk = self._sk_cache[cached_ind]
|
||||
# Find best starting point across all tiers
|
||||
start_ind, cur_sk = self._find_best_cached_sk(ind)
|
||||
|
||||
# compute and update cache
|
||||
# Compute from best cached position to target
|
||||
for cur_ind in range(start_ind + 1, ind + 1):
|
||||
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
|
||||
|
||||
# insert intermediate result into cache and evict oldest entry if necessary
|
||||
if cur_ind % self._CACHE_INTERVAL == 0:
|
||||
self._sk_cache[cur_ind] = cur_sk
|
||||
|
||||
if len(self._sk_cache) > self._CACHE_SIZE:
|
||||
# evict oldest entry
|
||||
oldest_ind = min(self._sk_cache.keys())
|
||||
del self._sk_cache[oldest_ind]
|
||||
self._update_caches(cur_ind, cur_sk)
|
||||
|
||||
return cur_sk
|
||||
|
||||
|
||||
Reference in New Issue
Block a user