mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory
Signature changes from Update(k, v) to Update(batch, k, v) returning (k, v, KVHistory). KVCache returns a real page table mapping positions to buffer slots. RecurrentCache returns empty KVHistory from Update. Replace Cache.Offset() with Offsets() returning per-sequence offsets. Add KVHistory type to mlx package.
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -52,11 +53,11 @@ func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRewindableCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -172,11 +173,11 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeSlidingWindowCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -252,11 +253,11 @@ func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRecurrentCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -366,9 +367,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
|
||||
for i, c := range e.caches {
|
||||
assertTokens(t, label, c, expected)
|
||||
// Verify all caches report the same offset.
|
||||
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
||||
if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||
label, i, c.Offset(), e.caches[0].Offset())
|
||||
label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -451,9 +452,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||
if len(kvc.caches) < 2 {
|
||||
return
|
||||
}
|
||||
expected := kvc.caches[0].Offset()
|
||||
expected := int(kvc.caches[0].Offsets()[0])
|
||||
for i := 1; i < len(kvc.caches); i++ {
|
||||
if got := kvc.caches[i].Offset(); got != expected {
|
||||
if got := int(kvc.caches[i].Offsets()[0]); got != expected {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user