mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

View File

@@ -205,7 +205,7 @@ pageIn:
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= nodeTarget {
if int(kv.Offsets()[0]) >= nodeTarget {
continue
}
if !kv.Restore(node.snapshots[j], nodeTarget) {
@@ -224,7 +224,7 @@ pageIn:
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && kv.Offset() != minOff {
if kv != nil && int(kv.Offsets()[0]) != minOff {
if !kv.Restore(nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
@@ -390,8 +390,8 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if kv.Offset() != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
if int(kv.Offsets()[0]) != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets()[0])))
}
snaps[i] = kv.Snapshot(node.startOffset())
}
@@ -418,7 +418,7 @@ func (c *kvCache) minCacheOffset() int {
if kv == nil {
continue
}
if off := kv.Offset(); !found || off < offset {
if off := int(kv.Offsets()[0]); !found || off < offset {
offset = off
found = true
}