mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

View File

@@ -20,7 +20,7 @@ func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Snapshot [5, 10).
@@ -44,7 +44,7 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -57,8 +57,8 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
if !c2.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
}
// Verify State() returns arrays with correct sequence dimension.
@@ -84,7 +84,7 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -98,8 +98,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
if !c2.Restore(parent, 5) {
t.Fatal("Restore(parent) failed")
}
if c2.Offset() != 5 {
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
if int(c2.Offsets()[0]) != 5 {
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
}
state := c2.State()
if state[0].Dim(2) != 5 {
@@ -110,8 +110,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
if !c2.Restore(child, 10) {
t.Fatal("Restore(child) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset after child = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
}
state = c2.State()
if state[0].Dim(2) != 10 {
@@ -128,7 +128,7 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -142,8 +142,8 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
if !c2.Restore(merged, 10) {
t.Fatal("Restore(merged) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
}
state := c2.State()
@@ -163,7 +163,7 @@ func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Offset 3 is outside the window.
@@ -182,7 +182,7 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -194,15 +194,15 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
for range 5 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Restore to offset 10.
if !c.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
if int(c.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
}
state := c.State()
@@ -228,10 +228,10 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
for range 6 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
if c.Offset() != 6 {
t.Fatalf("offset = %d, want 6", c.Offset())
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
}
snap := c.Snapshot(0)
@@ -240,25 +240,25 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
for range 3 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Restore to snapshot state.
if !c.Restore(snap, 6) {
t.Fatal("Restore failed")
}
if c.Offset() != 6 {
t.Fatalf("offset after restore = %d, want 6", c.Offset())
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
}
// Feed one more token. If idx was restored correctly, this should
// produce a valid window of size 4 at offset 7.
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
if c.Offset() != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
if int(c.Offsets()[0]) != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
}
state := c.State()
if len(state) != 2 {