mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: fix RotatingKVCache.concat() dropping context on mid-rotation (#15591)
After the rotating buffer has wrapped (c.offset > c.maxSize) a subsequent L>1 Update() went through a slice-to-[0, c.idx) path that discarded all slots in [c.idx, Dim), losing the older-but-still-in-window tokens the first Q of the new batch needs for its sliding-window attention. Linearize the circular buffer to logical order in that wrapped case so the existing trim + concat preserves the last (maxSize - 1) old tokens. When the buffer has not yet wrapped (c.offset <= c.maxSize), slots [c.idx, Dim) are grow padding or stale post-rewind data, so keep dropping them.
This commit is contained in:
19
x/mlxrunner/cache/cache.go
vendored
19
x/mlxrunner/cache/cache.go
vendored
@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
|||||||
mlx.Pin(c.keys, c.values)
|
mlx.Pin(c.keys, c.values)
|
||||||
} else {
|
} else {
|
||||||
if c.idx < c.keys.Dim(2) {
|
if c.idx < c.keys.Dim(2) {
|
||||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
if c.offset <= c.maxSize {
|
||||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||||
|
// or stale post-rewind data, not live window content.
|
||||||
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
} else {
|
||||||
|
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||||
|
// Linearize so the trim + concat below operate on contiguous
|
||||||
|
// positions and preserve the last (maxSize - 1) old tokens.
|
||||||
|
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||||
|
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||||
|
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
c.keys.Set(tailK.Concatenate(2, headK))
|
||||||
|
c.values.Set(tailV.Concatenate(2, headV))
|
||||||
|
c.idx = c.keys.Dim(2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trim to max_size to maintain sliding window
|
// Trim to max_size to maintain sliding window
|
||||||
|
|||||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||||
|
// tensors whose channel value is the token id, so stateIDs can recover
|
||||||
|
// which ids survived in the cache.
|
||||||
|
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
data := make([]float32, 0, 2*len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
data = append(data, id, id)
|
||||||
|
}
|
||||||
|
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||||
|
// after a concat, physical/rotated after a single-token update).
|
||||||
|
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
state := c.State()
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mlx.Eval(state[0])
|
||||||
|
flat := state[0].Floats()
|
||||||
|
n := state[0].Dim(2)
|
||||||
|
out := make([]float32, n)
|
||||||
|
for i := range n {
|
||||||
|
out[i] = flat[i*2]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalSlice(a, b []float32) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||||
|
ids := make([]float32, n)
|
||||||
|
for i := range ids {
|
||||||
|
ids[i] = startID + float32(i)
|
||||||
|
}
|
||||||
|
k, v := multiTokenKV(ids)
|
||||||
|
c.Update(k, v)
|
||||||
|
return startID + float32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedSingle(c *RotatingKVCache, id float32) {
|
||||||
|
k, v := singleTokenKV(id)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||||
|
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||||
|
// pre-existing tokens in logical order so the first Q of the new batch
|
||||||
|
// has a full sliding window.
|
||||||
|
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
for range 6 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 9 {
|
||||||
|
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||||
|
}
|
||||||
|
if c.idx >= c.maxSize {
|
||||||
|
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 10, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{7, 8, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 11 {
|
||||||
|
t.Fatalf("offset=%d want 11", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||||
|
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||||
|
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||||
|
// x/mlxrunner/pipeline.go relies on.
|
||||||
|
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||||
|
feedMulti(c, 1, 6)
|
||||||
|
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||||
|
// so the first new Q has its full window in scope for this forward.
|
||||||
|
feedMulti(c, 7, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{4, 5, 6, 7, 8, 9}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next decode trims oversize back to maxSize; order may be
|
||||||
|
// physical (rotated), so check as a set.
|
||||||
|
feedSingle(c, 10)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{7, 8, 9, 10} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||||
|
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||||
|
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||||
|
// has not wrapped. Those trailing slots are zero padding and must not
|
||||||
|
// be pulled back into the live window on the next concat.
|
||||||
|
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 512
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 3)
|
||||||
|
feedSingle(c, 4)
|
||||||
|
feedMulti(c, 5, 3)
|
||||||
|
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||||
|
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||||
|
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||||
|
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||||
|
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||||
|
// window content.
|
||||||
|
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 8
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Grow the buffer to exactly maxSize without wrapping.
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
for id := float32(3); id <= 8; id++ {
|
||||||
|
feedSingle(c, id)
|
||||||
|
}
|
||||||
|
if c.Offset() != window {
|
||||||
|
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Restore(nil, 2) {
|
||||||
|
t.Fatalf("live rewind to 2 failed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 2 {
|
||||||
|
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 5 {
|
||||||
|
t.Fatalf("offset=%d want 5", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||||
|
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||||
|
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
feedMulti(c, 3, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||||
|
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||||
|
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||||
|
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||||
|
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 8)
|
||||||
|
if c.Offset() != 8 {
|
||||||
|
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 8)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 17, 4)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode trims oversize back to maxSize; order may be physical.
|
||||||
|
feedSingle(c, 21)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{18, 19, 20, 21} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||||
|
// prefill sequence and checks that each new prefill retains the last
|
||||||
|
// (maxSize-1) pre-existing tokens in logical order.
|
||||||
|
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 2)
|
||||||
|
for range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 7 {
|
||||||
|
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 3)
|
||||||
|
nextID += 3
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{5, 6, 7, 8, 9, 10}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range 4 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 14 {
|
||||||
|
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 2)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||||
|
// token count through any mix of Update() calls — Gemma 4 uses
|
||||||
|
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||||
|
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
if c.Offset() != 3 {
|
||||||
|
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||||
|
}
|
||||||
|
for i := range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
if c.Offset() != 3+i+1 {
|
||||||
|
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextID = feedMulti(c, nextID, 2)
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||||
|
}
|
||||||
|
// L > maxSize concat.
|
||||||
|
feedMulti(c, nextID, 7)
|
||||||
|
if c.Offset() != 17 {
|
||||||
|
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user