mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 23:53:57 +02:00
Merge pull request #216 from malmeloo/feat/better-key-cache
fix: more efficient key caching system
This commit is contained in:
@@ -6,8 +6,10 @@ Accessories could be anything ranging from AirTags to iPhones.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, overload
|
from typing import TYPE_CHECKING, Literal, TypedDict, overload
|
||||||
|
|
||||||
@@ -390,13 +392,24 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _CacheTier:
|
||||||
|
"""Configuration for a cache tier."""
|
||||||
|
|
||||||
|
interval: int # Cache every n'th key
|
||||||
|
max_size: int | None # Maximum number of keys to cache in this tier (None = unlimited)
|
||||||
|
|
||||||
|
|
||||||
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||||
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
|
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
|
||||||
|
|
||||||
# cache enough keys for an entire week.
|
# Define cache tiers: (interval, max_size)
|
||||||
# every interval'th key is cached.
|
# Tier 1: Cache every 4th key (1 hour), keep up to 672 keys (2 weeks at 15min intervals)
|
||||||
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
|
# Tier 2: Cache every 672nd key (1 week), unlimited
|
||||||
_CACHE_INTERVAL = 1 # cache every key
|
_CACHE_TIERS = (
|
||||||
|
_CacheTier(interval=4, max_size=672),
|
||||||
|
_CacheTier(interval=672, max_size=None),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -422,7 +435,9 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
|||||||
self._initial_sk = initial_sk
|
self._initial_sk = initial_sk
|
||||||
self._key_type = key_type
|
self._key_type = key_type
|
||||||
|
|
||||||
self._sk_cache: dict[int, bytes] = {}
|
# Multi-tier cache: dict + sorted indices per tier
|
||||||
|
self._sk_caches: list[dict[int, bytes]] = [{} for _ in self._CACHE_TIERS]
|
||||||
|
self._cache_indices: list[list[int]] = [[] for _ in self._CACHE_TIERS]
|
||||||
|
|
||||||
self._iter_ind = 0
|
self._iter_ind = 0
|
||||||
|
|
||||||
@@ -441,36 +456,68 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
|||||||
"""The type of key this generator produces."""
|
"""The type of key this generator produces."""
|
||||||
return self._key_type
|
return self._key_type
|
||||||
|
|
||||||
|
def _find_best_cached_sk(self, ind: int) -> tuple[int, bytes]:
|
||||||
|
"""Find the largest cached index smaller than ind across all tiers."""
|
||||||
|
best_ind = 0
|
||||||
|
best_sk = self._initial_sk
|
||||||
|
|
||||||
|
for indices, cache in zip(self._cache_indices, self._sk_caches, strict=True):
|
||||||
|
if not indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use bisect to find the largest index < ind in O(log n)
|
||||||
|
pos = bisect.bisect_left(indices, ind)
|
||||||
|
if pos == 0: # No cached index less than ind
|
||||||
|
continue
|
||||||
|
|
||||||
|
cached_ind = indices[pos - 1]
|
||||||
|
if cached_ind > best_ind:
|
||||||
|
best_ind = cached_ind
|
||||||
|
best_sk = cache[cached_ind]
|
||||||
|
|
||||||
|
return best_ind, best_sk
|
||||||
|
|
||||||
|
def _update_caches(self, ind: int, sk: bytes) -> None:
|
||||||
|
"""Update all applicable cache tiers with the computed key."""
|
||||||
|
for tier_idx, tier in enumerate(self._CACHE_TIERS):
|
||||||
|
if ind % tier.interval != 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cache = self._sk_caches[tier_idx]
|
||||||
|
indices = self._cache_indices[tier_idx]
|
||||||
|
|
||||||
|
# Add to cache if not already present
|
||||||
|
if ind in cache:
|
||||||
|
continue
|
||||||
|
cache[ind] = sk
|
||||||
|
bisect.insort(indices, ind)
|
||||||
|
|
||||||
|
# Evict if cache exceeds size limit
|
||||||
|
if tier.max_size is not None and len(cache) > tier.max_size:
|
||||||
|
# If adding a historical key, evict smallest index
|
||||||
|
# If adding a future key, evict largest
|
||||||
|
evict_ind = indices.pop(0 if indices and ind > indices[0] else -1)
|
||||||
|
|
||||||
|
del cache[evict_ind]
|
||||||
|
|
||||||
def _get_sk(self, ind: int) -> bytes:
|
def _get_sk(self, ind: int) -> bytes:
|
||||||
if ind < 0:
|
if ind < 0:
|
||||||
msg = "The key index must be non-negative"
|
msg = "The key index must be non-negative"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# retrieve from cache
|
# Check all caches for exact match
|
||||||
cached_sk = self._sk_cache.get(ind)
|
for cache in self._sk_caches:
|
||||||
|
cached_sk = cache.get(ind)
|
||||||
if cached_sk is not None:
|
if cached_sk is not None:
|
||||||
return cached_sk
|
return cached_sk
|
||||||
|
|
||||||
# not in cache: find largest cached index smaller than ind (if exists)
|
# Find best starting point across all tiers
|
||||||
start_ind: int = 0
|
start_ind, cur_sk = self._find_best_cached_sk(ind)
|
||||||
cur_sk: bytes = self._initial_sk
|
|
||||||
for cached_ind in self._sk_cache:
|
|
||||||
if cached_ind < ind and cached_ind > start_ind:
|
|
||||||
start_ind = cached_ind
|
|
||||||
cur_sk = self._sk_cache[cached_ind]
|
|
||||||
|
|
||||||
# compute and update cache
|
# Compute from best cached position to target
|
||||||
for cur_ind in range(start_ind + 1, ind + 1):
|
for cur_ind in range(start_ind + 1, ind + 1):
|
||||||
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
|
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
|
||||||
|
self._update_caches(cur_ind, cur_sk)
|
||||||
# insert intermediate result into cache and evict oldest entry if necessary
|
|
||||||
if cur_ind % self._CACHE_INTERVAL == 0:
|
|
||||||
self._sk_cache[cur_ind] = cur_sk
|
|
||||||
|
|
||||||
if len(self._sk_cache) > self._CACHE_SIZE:
|
|
||||||
# evict oldest entry
|
|
||||||
oldest_ind = min(self._sk_cache.keys())
|
|
||||||
del self._sk_cache[oldest_ind]
|
|
||||||
|
|
||||||
return cur_sk
|
return cur_sk
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user