mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

View File

@@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -52,11 +53,11 @@ func (c *fakeRewindableCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRewindableCache) Free() {
c.tokens = nil
@@ -172,11 +173,11 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeSlidingWindowCache) Free() {
c.tokens = nil
@@ -252,11 +253,11 @@ func (c *fakeRecurrentCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRecurrentCache) Free() {
c.tokens = nil
@@ -366,9 +367,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
for i, c := range e.caches {
assertTokens(t, label, c, expected)
// Verify all caches report the same offset.
if i > 0 && c.Offset() != e.caches[0].Offset() {
if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
label, i, c.Offset(), e.caches[0].Offset())
label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0]))
}
}
}
@@ -451,9 +452,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
if len(kvc.caches) < 2 {
return
}
expected := kvc.caches[0].Offset()
expected := int(kvc.caches[0].Offsets()[0])
for i := 1; i < len(kvc.caches); i++ {
if got := kvc.caches[i].Offset(); got != expected {
if got := int(kvc.caches[i].Offsets()[0]); got != expected {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
}
}