mirror of
https://github.com/malmeloo/FindMy.py.git
synced 2026-04-17 19:53:53 +02:00
Merge pull request #196 from malmeloo/fix/improve-keygen-performance
feat: cache more intermediate accessory keys
This commit is contained in:
@@ -377,6 +377,11 @@ class FindMyAccessory(RollingKeyPairSource, util.abc.Serializable[FindMyAccessor
|
||||
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
|
||||
|
||||
# cache enough keys for an entire week.
|
||||
# every interval'th key is cached.
|
||||
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
|
||||
_CACHE_INTERVAL = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
master_key: bytes,
|
||||
@@ -401,8 +406,7 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
self._initial_sk = initial_sk
|
||||
self._key_type = key_type
|
||||
|
||||
self._cur_sk = initial_sk
|
||||
self._cur_sk_ind = 0
|
||||
self._sk_cache: dict[int, bytes] = {}
|
||||
|
||||
self._iter_ind = 0
|
||||
|
||||
@@ -426,14 +430,33 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
msg = "The key index must be non-negative"
|
||||
raise ValueError(msg)
|
||||
|
||||
if ind < self._cur_sk_ind: # behind us; need to reset :(
|
||||
self._cur_sk = self._initial_sk
|
||||
self._cur_sk_ind = 0
|
||||
# retrieve from cache
|
||||
cached_sk = self._sk_cache.get(ind)
|
||||
if cached_sk is not None:
|
||||
return cached_sk
|
||||
|
||||
for _ in range(self._cur_sk_ind, ind):
|
||||
self._cur_sk = crypto.x963_kdf(self._cur_sk, b"update", 32)
|
||||
self._cur_sk_ind += 1
|
||||
return self._cur_sk
|
||||
# not in cache: find largest cached index smaller than ind (if exists)
|
||||
start_ind: int = 0
|
||||
cur_sk: bytes = self._initial_sk
|
||||
for cached_ind in self._sk_cache:
|
||||
if cached_ind < ind and cached_ind > start_ind:
|
||||
start_ind = cached_ind
|
||||
cur_sk = self._sk_cache[cached_ind]
|
||||
|
||||
# compute and update cache
|
||||
for cur_ind in range(start_ind, ind):
|
||||
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
|
||||
|
||||
# insert intermediate result into cache and evict oldest entry if necessary
|
||||
if cur_ind % self._CACHE_INTERVAL == 0:
|
||||
self._sk_cache[cur_ind] = cur_sk
|
||||
|
||||
if len(self._sk_cache) > self._CACHE_SIZE:
|
||||
# evict oldest entry
|
||||
oldest_ind = min(self._sk_cache.keys())
|
||||
del self._sk_cache[oldest_ind]
|
||||
|
||||
return cur_sk
|
||||
|
||||
def _get_keypair(self, ind: int) -> KeyPair:
|
||||
sk = self._get_sk(ind)
|
||||
@@ -449,14 +472,14 @@ class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
|
||||
|
||||
@override
|
||||
def __iter__(self) -> KeyGenerator:
|
||||
self._iter_ind = -1
|
||||
return self
|
||||
|
||||
@override
|
||||
def __next__(self) -> KeyPair:
|
||||
key = self._get_keypair(self._iter_ind)
|
||||
self._iter_ind += 1
|
||||
|
||||
return self._get_keypair(self._iter_ind)
|
||||
return key
|
||||
|
||||
@overload
|
||||
def __getitem__(self, val: int) -> KeyPair: ...
|
||||
|
||||
Reference in New Issue
Block a user