diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 39f5c1f5a..a513b5717 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV mlx.Pin(c.keys, c.values) } else { if c.idx < c.keys.Dim(2) { - c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) - c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) + if c.offset <= c.maxSize { + // Not yet wrapped: slots [c.idx, Dim) are grow padding + // or stale post-rewind data, not live window content. + c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) + c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) + } else { + // Wrapped: logical order is slots[idx..Dim) then slots[0..idx). + // Linearize so the trim + concat below operate on contiguous + // positions and preserve the last (maxSize - 1) old tokens. + tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice()) + tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice()) + headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()) + headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()) + c.keys.Set(tailK.Concatenate(2, headK)) + c.values.Set(tailV.Concatenate(2, headV)) + c.idx = c.keys.Dim(2) + } } // Trim to max_size to maintain sliding window diff --git a/x/mlxrunner/cache/rotating_multiturn_test.go b/x/mlxrunner/cache/rotating_multiturn_test.go new file mode 100644 index 000000000..9914e1e86 --- /dev/null +++ b/x/mlxrunner/cache/rotating_multiturn_test.go @@ -0,0 +1,338 @@ +package cache + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value +// tensors whose channel value is the token id, so stateIDs can recover +// which ids survived in the cache. +func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) { + k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2) + v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2) + return k, v +} + +func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) { + data := make([]float32, 0, 2*len(ids)) + for _, id := range ids { + data = append(data, id, id) + } + k := mlx.FromValues(data, 1, 1, len(ids), 2) + v := mlx.FromValues(data, 1, 1, len(ids), 2) + return k, v +} + +// stateIDs returns the ids currently in the cache in slot order (logical +// after a concat, physical/rotated after a single-token update). +func stateIDs(t *testing.T, c *RotatingKVCache) []float32 { + t.Helper() + state := c.State() + if state == nil { + return nil + } + mlx.Eval(state[0]) + flat := state[0].Floats() + n := state[0].Dim(2) + out := make([]float32, n) + for i := range n { + out[i] = flat[i*2] + } + return out +} + +func equalSlice(a, b []float32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func feedMulti(c *RotatingKVCache, startID float32, n int) float32 { + ids := make([]float32, n) + for i := range ids { + ids[i] = startID + float32(i) + } + k, v := multiTokenKV(ids) + c.Update(k, v) + return startID + float32(n) +} + +func feedSingle(c *RotatingKVCache, id float32) { + k, v := singleTokenKV(id) + c.Update(k, v) +} + +// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer +// has wrapped, a multi-token concat must keep the (maxSize-1) most recent +// pre-existing tokens in logical order so the first Q of the new batch +// has a full sliding window. +func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) { + skipIfNoMLX(t) + + const window = 4 + c := NewRotatingKVCache(window) + + nextID := feedMulti(c, 1, 3) + for range 6 { + feedSingle(c, nextID) + nextID++ + } + if c.Offset() != 9 { + t.Fatalf("setup: offset=%d want 9", c.Offset()) + } + if c.idx >= c.maxSize { + t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx) + } + + feedMulti(c, 10, 2) + got := stateIDs(t, c) + want := []float32{7, 8, 9, 10, 11} + if !equalSlice(got, want) { + t.Fatalf("post-concat window=%v want %v", got, want) + } + if c.Offset() != 11 { + t.Fatalf("offset=%d want 11", c.Offset()) + } +} + +// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer +// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing +// tokens plus the full new batch. This is the chunked-prefill contract +// x/mlxrunner/pipeline.go relies on. +func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) { + skipIfNoMLX(t) + + const window = 4 + c := NewRotatingKVCache(window) + + // Chunk 1 fills past maxSize, leaving Dim == maxSize aligned. + feedMulti(c, 1, 6) + // Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L + // so the first new Q has its full window in scope for this forward. + feedMulti(c, 7, 3) + got := stateIDs(t, c) + want := []float32{4, 5, 6, 7, 8, 9} + if !equalSlice(got, want) { + t.Fatalf("post-chunk-2 buffer=%v want %v", got, want) + } + + // The next decode trims oversize back to maxSize; order may be + // physical (rotated), so check as a set. + feedSingle(c, 10) + got = stateIDs(t, c) + if len(got) != window { + t.Fatalf("post-decode Dim=%d want %d", len(got), window) + } + seen := map[float32]bool{} + for _, v := range got { + seen[v] = true + } + for _, w := range []float32{7, 8, 9, 10} { + if !seen[w] { + t.Fatalf("post-decode window missing %v (got %v)", w, got) + } + } +} + +// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the +// underlying buffer by `step` slots via mlx.Zeros before writing, so +// after one decode on a short prefill c.idx < Dim even though the cache +// has not wrapped. Those trailing slots are zero padding and must not +// be pulled back into the live window on the next concat. +func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) { + skipIfNoMLX(t) + + const window = 512 + c := NewRotatingKVCache(window) + + feedMulti(c, 1, 3) + feedSingle(c, 4) + feedMulti(c, 5, 3) + + got := stateIDs(t, c) + want := []float32{1, 2, 3, 4, 5, 6, 7} + if !equalSlice(got, want) { + t.Fatalf("growing-buffer concat=%v want %v", got, want) + } +} + +// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls +// Restore(nil, target) between conversation turns to rewind the cache to +// the matched prefix. Restore moves c.offset/c.idx without trimming the +// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind +// tokens. A subsequent concat must drop those, not treat them as wrapped +// window content. +func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) { + skipIfNoMLX(t) + + const window = 8 + c := NewRotatingKVCache(window) + + // Grow the buffer to exactly maxSize without wrapping. + feedMulti(c, 1, 2) + for id := float32(3); id <= 8; id++ { + feedSingle(c, id) + } + if c.Offset() != window { + t.Fatalf("setup: offset=%d want %d", c.Offset(), window) + } + + if !c.Restore(nil, 2) { + t.Fatalf("live rewind to 2 failed") + } + if c.Offset() != 2 { + t.Fatalf("post-rewind offset=%d want 2", c.Offset()) + } + + feedMulti(c, 9, 3) + got := stateIDs(t, c) + want := []float32{1, 2, 9, 10, 11} + if !equalSlice(got, want) { + t.Fatalf("post-rewind concat=%v want %v", got, want) + } + if c.Offset() != 5 { + t.Fatalf("offset=%d want 5", c.Offset()) + } +} + +// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim +// formula drops to non-positive and all pre-existing tokens are kept. +func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) { + skipIfNoMLX(t) + + const window = 4 + c := NewRotatingKVCache(window) + + feedMulti(c, 1, 2) + feedMulti(c, 3, 2) + got := stateIDs(t, c) + want := []float32{1, 2, 3, 4} + if !equalSlice(got, want) { + t.Fatalf("growing buffer=%v want %v", got, want) + } +} + +// TestRotatingKVCacheRunnerChunkedPrefill mirrors the +// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through +// repeated L>1 Update() calls on a single cache. Scaled-down proxy for +// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048). +func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) { + skipIfNoMLX(t) + + const window = 4 + c := NewRotatingKVCache(window) + + feedMulti(c, 1, 8) + if c.Offset() != 8 { + t.Fatalf("chunk 1: offset=%d want 8", c.Offset()) + } + + feedMulti(c, 9, 8) + got := stateIDs(t, c) + want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if !equalSlice(got, want) { + t.Fatalf("chunk 2: buffer=%v want %v", got, want) + } + + feedMulti(c, 17, 4) + got = stateIDs(t, c) + want = []float32{14, 15, 16, 17, 18, 19, 20} + if !equalSlice(got, want) { + t.Fatalf("chunk 3: buffer=%v want %v", got, want) + } + + // Decode trims oversize back to maxSize; order may be physical. + feedSingle(c, 21) + got = stateIDs(t, c) + if len(got) != window { + t.Fatalf("post-decode Dim=%d want %d", len(got), window) + } + seen := map[float32]bool{} + for _, v := range got { + seen[v] = true + } + for _, w := range []float32{18, 19, 20, 21} { + if !seen[w] { + t.Fatalf("post-decode window missing %v (got %v)", w, got) + } + } +} + +// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode → +// prefill sequence and checks that each new prefill retains the last +// (maxSize-1) pre-existing tokens in logical order. +func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) { + skipIfNoMLX(t) + + const window = 4 + c := NewRotatingKVCache(window) + + nextID := feedMulti(c, 1, 2) + for range 5 { + feedSingle(c, nextID) + nextID++ + } + if c.Offset() != 7 { + t.Fatalf("turn 1: offset=%d want 7", c.Offset()) + } + + feedMulti(c, nextID, 3) + nextID += 3 + got := stateIDs(t, c) + want := []float32{5, 6, 7, 8, 9, 10} + if !equalSlice(got, want) { + t.Fatalf("turn 2 prefill buffer=%v want %v", got, want) + } + + for range 4 { + feedSingle(c, nextID) + nextID++ + } + if c.Offset() != 14 { + t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset()) + } + + feedMulti(c, nextID, 2) + got = stateIDs(t, c) + want = []float32{12, 13, 14, 15, 16} + if !equalSlice(got, want) { + t.Fatalf("turn 3 prefill buffer=%v want %v", got, want) + } +} + +// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical +// token count through any mix of Update() calls — Gemma 4 uses +// donorEntry.Offset - L for the consumer's RoPE offset. +func TestRotatingKVCacheOffsetTracking(t *testing.T) { + skipIfNoMLX(t) + + c := NewRotatingKVCache(4) + nextID := feedMulti(c, 1, 3) + if c.Offset() != 3 { + t.Fatalf("after prefill 3: offset=%d want 3", c.Offset()) + } + for i := range 5 { + feedSingle(c, nextID) + nextID++ + if c.Offset() != 3+i+1 { + t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1) + } + } + nextID = feedMulti(c, nextID, 2) + if c.Offset() != 10 { + t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset()) + } + // L > maxSize concat. + feedMulti(c, nextID, 7) + if c.Offset() != 17 { + t.Fatalf("after large prefill: offset=%d want 17", c.Offset()) + } +}