mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

View File

@@ -2,15 +2,16 @@ package cache
import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
Update(b *batch.ForwardBatch, keys, values *mlx.Array) (newKeys, newValues *mlx.Array, kv mlx.KVHistory)
// State returns the cache-owned state roots that should be kept/evaluated.
State() []*mlx.Array
Free()
Offset() int
Offsets() []int32
// Snapshot copies cache state from fromOffset to current offset into
// pinned VRAM arrays. The active cache is unchanged.
@@ -49,7 +50,7 @@ func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
func (c *KVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
@@ -77,8 +78,17 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
pt := make([]int32, c.offset)
for i := range pt {
pt[i] = int32(i)
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
mlx.KVHistory{
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(c.offset)}),
SeqLens: []int{c.offset},
}
}
func (c *KVCache) State() []*mlx.Array {
@@ -143,7 +153,7 @@ func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
// Rewind to snapshot start, then feed snapshot data through Update.
c.offset = snap.fromOffset
c.Update(snap.keys, snap.values)
c.Update(nil, snap.keys, snap.values)
// Clamp to target if needed (target may be less than full snapshot).
if target < c.offset {
@@ -226,7 +236,7 @@ func (c *KVCache) Free() {
c.offset = 0
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Offsets() []int32 { return []int32{int32(c.offset)} }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
@@ -240,11 +250,24 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
func (c *RotatingKVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
var k, v *mlx.Array
if keys.Dim(2) > 1 {
return c.concat(keys, values)
k, v = c.concat(keys, values)
} else {
k, v = c.update(keys, values)
}
visibleLen := min(c.offset, c.maxSize)
pt := make([]int32, visibleLen)
for i := range visibleLen {
pt[i] = int32(i)
}
return k, v, mlx.KVHistory{
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(visibleLen)}),
SeqLens: []int{visibleLen},
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {

View File

@@ -20,7 +20,7 @@ func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Snapshot [5, 10).
@@ -44,7 +44,7 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -57,8 +57,8 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
if !c2.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
}
// Verify State() returns arrays with correct sequence dimension.
@@ -84,7 +84,7 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -98,8 +98,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
if !c2.Restore(parent, 5) {
t.Fatal("Restore(parent) failed")
}
if c2.Offset() != 5 {
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
if int(c2.Offsets()[0]) != 5 {
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
}
state := c2.State()
if state[0].Dim(2) != 5 {
@@ -110,8 +110,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
if !c2.Restore(child, 10) {
t.Fatal("Restore(child) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset after child = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
}
state = c2.State()
if state[0].Dim(2) != 10 {
@@ -128,7 +128,7 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -142,8 +142,8 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
if !c2.Restore(merged, 10) {
t.Fatal("Restore(merged) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
}
state := c2.State()
@@ -163,7 +163,7 @@ func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Offset 3 is outside the window.
@@ -182,7 +182,7 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
snap := c.Snapshot(0)
@@ -194,15 +194,15 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
for range 5 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Restore to offset 10.
if !c.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
if int(c.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
}
state := c.State()
@@ -228,10 +228,10 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
for range 6 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
if c.Offset() != 6 {
t.Fatalf("offset = %d, want 6", c.Offset())
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
}
snap := c.Snapshot(0)
@@ -240,25 +240,25 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
for range 3 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
}
// Restore to snapshot state.
if !c.Restore(snap, 6) {
t.Fatal("Restore failed")
}
if c.Offset() != 6 {
t.Fatalf("offset after restore = %d, want 6", c.Offset())
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
}
// Feed one more token. If idx was restored correctly, this should
// produce a valid window of size 4 at offset 7.
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
c.Update(nil, k, v)
if c.Offset() != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
if int(c.Offsets()[0]) != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
}
state := c.State()
if len(state) != 2 {

View File

@@ -1,6 +1,9 @@
package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
import (
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RecurrentCache stores state for linear-recurrent layers.
//
@@ -87,8 +90,8 @@ func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
func (c *RecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return keys, values, mlx.KVHistory{}
}
func (c *RecurrentCache) State() []*mlx.Array {
@@ -162,4 +165,4 @@ func (c *RecurrentCache) Free() {
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Offsets() []int32 { return []int32{int32(c.offset)} }

View File

@@ -34,7 +34,7 @@ func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
if int(c.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
}
}