mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 09:46:01 +02:00
Signature changes from Update(k, v) to Update(batch, k, v) returning (k, v, KVHistory). KVCache returns a real page table mapping positions to buffer slots. RecurrentCache returns empty KVHistory from Update. Replace Cache.Offset() with Offsets() returning per-sequence offsets. Add KVHistory type to mlx package.
272 lines
6.9 KiB
Go
272 lines
6.9 KiB
Go
package cache
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func skipIfNoMLX(t *testing.T) {
|
|
t.Helper()
|
|
if err := mlx.CheckInit(); err != nil {
|
|
t.Skipf("MLX not available: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewKVCache()
|
|
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
// Snapshot [5, 10).
|
|
snap := c.Snapshot(5)
|
|
|
|
// Free the cache completely — offset is now 0.
|
|
c.Free()
|
|
|
|
// Restore should fail because cache doesn't have data up to fromOffset=5.
|
|
if c.Restore(snap, 10) {
|
|
t.Fatal("expected Restore to fail with no base data")
|
|
}
|
|
}
|
|
|
|
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
|
|
// is preserved through a snapshot→free→restore cycle.
|
|
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewKVCache()
|
|
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
snap := c.Snapshot(0)
|
|
if snap == nil {
|
|
t.Fatal("Snapshot returned nil")
|
|
}
|
|
|
|
// Free and restore to a fresh cache.
|
|
c2 := NewKVCache()
|
|
if !c2.Restore(snap, 10) {
|
|
t.Fatal("Restore failed")
|
|
}
|
|
if int(c2.Offsets()[0]) != 10 {
|
|
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
|
}
|
|
|
|
// Verify State() returns arrays with correct sequence dimension.
|
|
state := c2.State()
|
|
if len(state) != 2 {
|
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
|
}
|
|
// keys shape: [B, H, seqLen, Dk]
|
|
if state[0].Dim(2) != 10 {
|
|
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
|
}
|
|
if state[1].Dim(2) != 10 {
|
|
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
|
}
|
|
}
|
|
|
|
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
|
|
// that can be sequentially restored to rebuild the original cache state.
|
|
func TestKVCacheSplitPreservesData(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewKVCache()
|
|
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
snap := c.Snapshot(0)
|
|
parent, child := c.Split(snap, 5)
|
|
if parent == nil || child == nil {
|
|
t.Fatal("Split returned nil")
|
|
}
|
|
|
|
// Restore parent → offset=5, seq dim=5.
|
|
c2 := NewKVCache()
|
|
if !c2.Restore(parent, 5) {
|
|
t.Fatal("Restore(parent) failed")
|
|
}
|
|
if int(c2.Offsets()[0]) != 5 {
|
|
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
|
|
}
|
|
state := c2.State()
|
|
if state[0].Dim(2) != 5 {
|
|
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
|
|
}
|
|
|
|
// Restore child on top → offset=10, seq dim=10.
|
|
if !c2.Restore(child, 10) {
|
|
t.Fatal("Restore(child) failed")
|
|
}
|
|
if int(c2.Offsets()[0]) != 10 {
|
|
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
|
|
}
|
|
state = c2.State()
|
|
if state[0].Dim(2) != 10 {
|
|
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
|
|
}
|
|
}
|
|
|
|
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
|
|
// produces a snapshot equivalent to the original.
|
|
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewKVCache()
|
|
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
snap := c.Snapshot(0)
|
|
parent, child := c.Split(snap, 6)
|
|
merged := c.Merge(parent, child)
|
|
if merged == nil {
|
|
t.Fatal("Merge returned nil")
|
|
}
|
|
|
|
c2 := NewKVCache()
|
|
if !c2.Restore(merged, 10) {
|
|
t.Fatal("Restore(merged) failed")
|
|
}
|
|
if int(c2.Offsets()[0]) != 10 {
|
|
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
|
}
|
|
|
|
state := c2.State()
|
|
if state[0].Dim(2) != 10 {
|
|
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
|
}
|
|
if state[1].Dim(2) != 10 {
|
|
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
|
}
|
|
}
|
|
|
|
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewRotatingKVCache(4)
|
|
|
|
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
// Offset 3 is outside the window.
|
|
if c.Restore(nil, 3) {
|
|
t.Fatal("Restore(nil, 3) should fail when outside window")
|
|
}
|
|
}
|
|
|
|
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
|
|
// from a snapshot, the rotating cache has the correct window of data.
|
|
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewRotatingKVCache(4)
|
|
|
|
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
|
|
for range 10 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
snap := c.Snapshot(0)
|
|
if snap == nil {
|
|
t.Fatal("Snapshot returned nil")
|
|
}
|
|
|
|
// Feed 5 more tokens.
|
|
for range 5 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
// Restore to offset 10.
|
|
if !c.Restore(snap, 10) {
|
|
t.Fatal("Restore failed")
|
|
}
|
|
if int(c.Offsets()[0]) != 10 {
|
|
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
|
}
|
|
|
|
state := c.State()
|
|
if len(state) != 2 {
|
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
|
}
|
|
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
|
|
seqDim := state[0].Dim(2)
|
|
if seqDim != 4 {
|
|
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
|
}
|
|
}
|
|
|
|
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
|
|
// snapshot correctly preserves the write position (idx), so subsequent
|
|
// single-token updates land in the right buffer slot.
|
|
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
c := NewRotatingKVCache(4)
|
|
|
|
// Fill the window: 6 tokens into a size-4 window.
|
|
// After this, idx has wrapped and the buffer has rotated.
|
|
for range 6 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
if int(c.Offsets()[0]) != 6 {
|
|
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
|
|
}
|
|
|
|
snap := c.Snapshot(0)
|
|
|
|
// Mutate the cache further so live state diverges from snapshot.
|
|
for range 3 {
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
}
|
|
|
|
// Restore to snapshot state.
|
|
if !c.Restore(snap, 6) {
|
|
t.Fatal("Restore failed")
|
|
}
|
|
if int(c.Offsets()[0]) != 6 {
|
|
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
|
|
}
|
|
|
|
// Feed one more token. If idx was restored correctly, this should
|
|
// produce a valid window of size 4 at offset 7.
|
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
|
c.Update(nil, k, v)
|
|
|
|
if int(c.Offsets()[0]) != 7 {
|
|
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
|
|
}
|
|
state := c.State()
|
|
if len(state) != 2 {
|
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
|
}
|
|
seqDim := state[0].Dim(2)
|
|
if seqDim != 4 {
|
|
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
|
}
|
|
}
|