mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
mlx: fix RotatingKVCache.concat() dropping context on mid-rotation (#15591)
After the rotating buffer has wrapped (c.offset > c.maxSize) a subsequent L>1 Update() went through a slice-to-[0, c.idx) path that discarded all slots in [c.idx, Dim), losing the older-but-still-in-window tokens the first Q of the new batch needs for its sliding-window attention. Linearize the circular buffer to logical order in that wrapped case so the existing trim + concat preserves the last (maxSize - 1) old tokens. When the buffer has not yet wrapped (c.offset <= c.maxSize), slots [c.idx, Dim) are grow padding or stale post-rewind data, so keep dropping them.
This commit is contained in:
19
x/mlxrunner/cache/cache.go
vendored
19
x/mlxrunner/cache/cache.go
vendored
@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
||||
mlx.Pin(c.keys, c.values)
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
if c.offset <= c.maxSize {
|
||||
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||
// or stale post-rewind data, not live window content.
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
} else {
|
||||
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||
// Linearize so the trim + concat below operate on contiguous
|
||||
// positions and preserve the last (maxSize - 1) old tokens.
|
||||
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
c.keys.Set(tailK.Concatenate(2, headK))
|
||||
c.values.Set(tailV.Concatenate(2, headV))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
|
||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||
// tensors whose channel value is the token id, so stateIDs can recover
|
||||
// which ids survived in the cache.
|
||||
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||
data := make([]float32, 0, 2*len(ids))
|
||||
for _, id := range ids {
|
||||
data = append(data, id, id)
|
||||
}
|
||||
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||
// after a concat, physical/rotated after a single-token update).
|
||||
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||
t.Helper()
|
||||
state := c.State()
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
mlx.Eval(state[0])
|
||||
flat := state[0].Floats()
|
||||
n := state[0].Dim(2)
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
out[i] = flat[i*2]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalSlice(a, b []float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||
ids := make([]float32, n)
|
||||
for i := range ids {
|
||||
ids[i] = startID + float32(i)
|
||||
}
|
||||
k, v := multiTokenKV(ids)
|
||||
c.Update(k, v)
|
||||
return startID + float32(n)
|
||||
}
|
||||
|
||||
func feedSingle(c *RotatingKVCache, id float32) {
|
||||
k, v := singleTokenKV(id)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||
// pre-existing tokens in logical order so the first Q of the new batch
|
||||
// has a full sliding window.
|
||||
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
for range 6 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 9 {
|
||||
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||
}
|
||||
if c.idx >= c.maxSize {
|
||||
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||
}
|
||||
|
||||
feedMulti(c, 10, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{7, 8, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 11 {
|
||||
t.Fatalf("offset=%d want 11", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||
// x/mlxrunner/pipeline.go relies on.
|
||||
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||
feedMulti(c, 1, 6)
|
||||
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||
// so the first new Q has its full window in scope for this forward.
|
||||
feedMulti(c, 7, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{4, 5, 6, 7, 8, 9}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// The next decode trims oversize back to maxSize; order may be
|
||||
// physical (rotated), so check as a set.
|
||||
feedSingle(c, 10)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{7, 8, 9, 10} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||
// has not wrapped. Those trailing slots are zero padding and must not
|
||||
// be pulled back into the live window on the next concat.
|
||||
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 512
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 3)
|
||||
feedSingle(c, 4)
|
||||
feedMulti(c, 5, 3)
|
||||
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||
// window content.
|
||||
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 8
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Grow the buffer to exactly maxSize without wrapping.
|
||||
feedMulti(c, 1, 2)
|
||||
for id := float32(3); id <= 8; id++ {
|
||||
feedSingle(c, id)
|
||||
}
|
||||
if c.Offset() != window {
|
||||
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||
}
|
||||
|
||||
if !c.Restore(nil, 2) {
|
||||
t.Fatalf("live rewind to 2 failed")
|
||||
}
|
||||
if c.Offset() != 2 {
|
||||
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 5 {
|
||||
t.Fatalf("offset=%d want 5", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 2)
|
||||
feedMulti(c, 3, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 8)
|
||||
if c.Offset() != 8 {
|
||||
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 8)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
feedMulti(c, 17, 4)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// Decode trims oversize back to maxSize; order may be physical.
|
||||
feedSingle(c, 21)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{18, 19, 20, 21} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||
// prefill sequence and checks that each new prefill retains the last
|
||||
// (maxSize-1) pre-existing tokens in logical order.
|
||||
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 2)
|
||||
for range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 3)
|
||||
nextID += 3
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{5, 6, 7, 8, 9, 10}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
for range 4 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 14 {
|
||||
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 2)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||
// token count through any mix of Update() calls — Gemma 4 uses
|
||||
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
c := NewRotatingKVCache(4)
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
if c.Offset() != 3 {
|
||||
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||
}
|
||||
for i := range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
if c.Offset() != 3+i+1 {
|
||||
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||
}
|
||||
}
|
||||
nextID = feedMulti(c, nextID, 2)
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||
}
|
||||
// L > maxSize concat.
|
||||
feedMulti(c, nextID, 7)
|
||||
if c.Offset() != 17 {
|
||||
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user