mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 09:46:01 +02:00
mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory
Signature changes from Update(k, v) to Update(batch, k, v) returning (k, v, KVHistory). KVCache returns a real page table mapping positions to buffer slots. RecurrentCache returns empty KVHistory from Update. Replace Cache.Offset() with Offsets() returning per-sequence offsets. Add KVHistory type to mlx package.
This commit is contained in:
52
x/mlxrunner/cache/cache_test.go
vendored
52
x/mlxrunner/cache/cache_test.go
vendored
@@ -20,7 +20,7 @@ func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Snapshot [5, 10).
|
||||
@@ -44,7 +44,7 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -57,8 +57,8 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
if !c2.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
|
||||
// Verify State() returns arrays with correct sequence dimension.
|
||||
@@ -84,7 +84,7 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -98,8 +98,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
if !c2.Restore(parent, 5) {
|
||||
t.Fatal("Restore(parent) failed")
|
||||
}
|
||||
if c2.Offset() != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
|
||||
}
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 5 {
|
||||
@@ -110,8 +110,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
if !c2.Restore(child, 10) {
|
||||
t.Fatal("Restore(child) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
state = c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
@@ -128,7 +128,7 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -142,8 +142,8 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
if !c2.Restore(merged, 10) {
|
||||
t.Fatal("Restore(merged) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c2.State()
|
||||
@@ -163,7 +163,7 @@ func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Offset 3 is outside the window.
|
||||
@@ -182,7 +182,7 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -194,15 +194,15 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
for range 5 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Restore to offset 10.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
if int(c.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
@@ -228,10 +228,10 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
for range 6 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset = %d, want 6", c.Offset())
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -240,25 +240,25 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
for range 3 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Restore to snapshot state.
|
||||
if !c.Restore(snap, 6) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", c.Offset())
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
// Feed one more token. If idx was restored correctly, this should
|
||||
// produce a valid window of size 4 at offset 7.
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
|
||||
if int(c.Offsets()[0]) != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
|
||||
}
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
|
||||
Reference in New Issue
Block a user