mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlxrunner: multi-sequence KVCache, RotatingKVCache, and RecurrentCache
KVCache: AddSeq/RemoveSeq register sequences with contiguous buffer regions separated by gaps. Update scatters new K/V via PutAlongAxis. Rebuild reallocates on sequence add/remove. SnapshotSeq extracts one sequence's K/V. Single-sequence mode (no AddSeq) unchanged. RotatingKVCache: multi-sequence delegates to KVCache regions with windowKVHistory limiting visibility to trailing maxSize tokens per sequence. Single-sequence retains existing ring buffer. RecurrentCache: per-sequence conv/delta state in seqStates map. ConvState/DeltaState gather states for batch, SetConvState/SetDeltaState scatter back. AddSeqWithSnapshot restores recurrent state from snapshots.
This commit is contained in:
@@ -67,11 +67,17 @@ func (c *kvCache) ensureCaches(m base.Model) {
|
||||
}
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
return
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
// Register the default sequence for single-sequence prefill.
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil {
|
||||
kv.SetSeqs([]int{0})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +173,7 @@ func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
snaps[j] = kv.Snapshot(fromOffset)
|
||||
snaps[j] = kv.Snapshot(0, fromOffset)
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
pageOutCount++
|
||||
@@ -184,7 +190,7 @@ func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(nil, rewindTarget) {
|
||||
if !kv.Restore(0, nil, rewindTarget) {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
@@ -205,10 +211,10 @@ pageIn:
|
||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||
continue
|
||||
}
|
||||
if int(kv.Offsets()[0]) >= nodeTarget {
|
||||
if int(kv.Offsets(0)[0]) >= nodeTarget {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||
if !kv.Restore(0, node.snapshots[j], nodeTarget) {
|
||||
// Restore failed — stop page-in and let alignment
|
||||
// bring all caches to a consistent offset.
|
||||
break pageIn
|
||||
@@ -224,8 +230,8 @@ pageIn:
|
||||
c.activePath = newPath
|
||||
minOff := c.minCacheOffset()
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil && int(kv.Offsets()[0]) != minOff {
|
||||
if !kv.Restore(nil, minOff) {
|
||||
if kv != nil && int(kv.Offsets(0)[0]) != minOff {
|
||||
if !kv.Restore(0, nil, minOff) {
|
||||
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
||||
c.freeAll()
|
||||
break
|
||||
@@ -390,10 +396,10 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for i, kv := range c.caches {
|
||||
if kv != nil {
|
||||
if int(kv.Offsets()[0]) != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets()[0])))
|
||||
if int(kv.Offsets(0)[0]) != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets(0)[0])))
|
||||
}
|
||||
snaps[i] = kv.Snapshot(node.startOffset())
|
||||
snaps[i] = kv.Snapshot(0, node.startOffset())
|
||||
}
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
@@ -418,7 +424,7 @@ func (c *kvCache) minCacheOffset() int {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := int(kv.Offsets()[0]); !found || off < offset {
|
||||
if off := int(kv.Offsets(0)[0]); !found || off < offset {
|
||||
offset = off
|
||||
found = true
|
||||
}
|
||||
|
||||
912
x/mlxrunner/cache/cache.go
vendored
912
x/mlxrunner/cache/cache.go
vendored
File diff suppressed because it is too large
Load Diff
379
x/mlxrunner/cache/cache_test.go
vendored
379
x/mlxrunner/cache/cache_test.go
vendored
@@ -3,6 +3,7 @@ package cache
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
@@ -13,259 +14,351 @@ func skipIfNoMLX(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var singleTokenBatch = &batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{1}}
|
||||
|
||||
func newKVCacheWithSeq() *KVCache {
|
||||
c := NewKVCache()
|
||||
c.SetSeqs([]int{0})
|
||||
return c
|
||||
}
|
||||
|
||||
func newRotatingKVCacheWithSeq(maxSize int) *RotatingKVCache {
|
||||
c := NewRotatingKVCache(maxSize)
|
||||
c.SetSeqs([]int{0})
|
||||
return c
|
||||
}
|
||||
|
||||
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c := newKVCacheWithSeq()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
// Snapshot [5, 10).
|
||||
snap := c.Snapshot(5)
|
||||
snap := c.Snapshot(0, 5)
|
||||
|
||||
// Free the cache completely — offset is now 0.
|
||||
c.Free()
|
||||
|
||||
// Restore should fail because cache doesn't have data up to fromOffset=5.
|
||||
if c.Restore(snap, 10) {
|
||||
if c.Restore(0, snap, 10) {
|
||||
t.Fatal("expected Restore to fail with no base data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
|
||||
// is preserved through a snapshot→free→restore cycle.
|
||||
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c := newKVCacheWithSeq()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
snap := c.Snapshot(0, 0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
|
||||
// Free and restore to a fresh cache.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(snap, 10) {
|
||||
c2 := newKVCacheWithSeq()
|
||||
if !c2.Restore(0, snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
if int(c2.Offsets(0)[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets(0)[0]))
|
||||
}
|
||||
|
||||
// Verify State() returns arrays with correct sequence dimension.
|
||||
state := c2.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// keys shape: [B, H, seqLen, Dk]
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
if state[0].Dim(2) < 10 {
|
||||
t.Fatalf("keys seq dim = %d, want >= 10", state[0].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
|
||||
// that can be sequentially restored to rebuild the original cache state.
|
||||
func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c := newKVCacheWithSeq()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
snap := c.Snapshot(0, 0)
|
||||
parent, child := c.Split(snap, 5)
|
||||
if parent == nil || child == nil {
|
||||
t.Fatal("Split returned nil")
|
||||
}
|
||||
|
||||
// Restore parent → offset=5, seq dim=5.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(parent, 5) {
|
||||
c2 := newKVCacheWithSeq()
|
||||
if !c2.Restore(0, parent, 5) {
|
||||
t.Fatal("Restore(parent) failed")
|
||||
}
|
||||
if int(c2.Offsets()[0]) != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
|
||||
}
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 5 {
|
||||
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
|
||||
if int(c2.Offsets(0)[0]) != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets(0)[0]))
|
||||
}
|
||||
|
||||
// Restore child on top → offset=10, seq dim=10.
|
||||
if !c2.Restore(child, 10) {
|
||||
if !c2.Restore(0, child, 10) {
|
||||
t.Fatal("Restore(child) failed")
|
||||
}
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
state = c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
|
||||
if int(c2.Offsets(0)[0]) != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
|
||||
// produces a snapshot equivalent to the original.
|
||||
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c := newKVCacheWithSeq()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
snap := c.Snapshot(0, 0)
|
||||
parent, child := c.Split(snap, 6)
|
||||
merged := c.Merge(parent, child)
|
||||
if merged == nil {
|
||||
t.Fatal("Merge returned nil")
|
||||
}
|
||||
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(merged, 10) {
|
||||
c2 := newKVCacheWithSeq()
|
||||
if !c2.Restore(0, merged, 10) {
|
||||
t.Fatal("Restore(merged) failed")
|
||||
}
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
if int(c2.Offsets(0)[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||
func TestRotatingKVCacheRewindOutsideWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
c := newRotatingKVCacheWithSeq(4)
|
||||
|
||||
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
// Offset 3 is outside the window.
|
||||
if c.Restore(nil, 3) {
|
||||
if c.Restore(0, nil, 3) {
|
||||
t.Fatal("Restore(nil, 3) should fail when outside window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
|
||||
// from a snapshot, the rotating cache has the correct window of data.
|
||||
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
func TestRotatingKVCacheWindowedHistory(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
c := newRotatingKVCacheWithSeq(4)
|
||||
|
||||
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
_, _, kv := c.Update(singleTokenBatch, k, v)
|
||||
|
||||
// Feed 5 more tokens.
|
||||
for range 5 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
if len(kv.SeqLens) != 1 {
|
||||
t.Fatalf("SeqLens length = %d, want 1", len(kv.SeqLens))
|
||||
}
|
||||
|
||||
// Restore to offset 10.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if int(c.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
if kv.SeqLens[0] != 4 {
|
||||
t.Fatalf("SeqLens[0] = %d, want 4 (window size)", kv.SeqLens[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
|
||||
// snapshot correctly preserves the write position (idx), so subsequent
|
||||
// single-token updates land in the right buffer slot.
|
||||
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
c := newRotatingKVCacheWithSeq(8)
|
||||
|
||||
// Fill the window: 6 tokens into a size-4 window.
|
||||
// After this, idx has wrapped and the buffer has rotated.
|
||||
for range 6 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
|
||||
// Mutate the cache further so live state diverges from snapshot.
|
||||
for range 3 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
}
|
||||
if int(c.Offsets(0)[0]) != 3 {
|
||||
t.Fatalf("offset = %d, want 3", int(c.Offsets(0)[0]))
|
||||
}
|
||||
|
||||
// Restore to snapshot state.
|
||||
if !c.Restore(snap, 6) {
|
||||
t.Fatal("Restore failed")
|
||||
// Rewind before wrap should succeed
|
||||
if !c.Restore(0, nil, 1) {
|
||||
t.Fatal("Restore(nil, 1) should succeed before wrap")
|
||||
}
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
// Feed one more token. If idx was restored correctly, this should
|
||||
// produce a valid window of size 4 at offset 7.
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(nil, k, v)
|
||||
|
||||
if int(c.Offsets()[0]) != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
|
||||
}
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
if int(c.Offsets(0)[0]) != 1 {
|
||||
t.Fatalf("offset after restore = %d, want 1", int(c.Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheMultiSeqUpdate(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c.SetSeqs([]int{0, 1})
|
||||
|
||||
// Prefill: seq 0 gets 3 tokens, seq 1 gets 5 tokens
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{3, 5}}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 8, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 8, 8)
|
||||
c.Update(b, k, v)
|
||||
|
||||
if int(c.Offsets(0)[0]) != 3 {
|
||||
t.Fatalf("seq 0 offset = %d, want 3", int(c.Offsets(0)[0]))
|
||||
}
|
||||
if int(c.Offsets(1)[0]) != 5 {
|
||||
t.Fatalf("seq 1 offset = %d, want 5", int(c.Offsets(1)[0]))
|
||||
}
|
||||
|
||||
// Decode: each seq gets 1 token
|
||||
b2 := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
|
||||
k2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
v2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
c.Update(b2, k2, v2)
|
||||
|
||||
if int(c.Offsets(0)[0]) != 4 {
|
||||
t.Fatalf("seq 0 offset after decode = %d, want 4", int(c.Offsets(0)[0]))
|
||||
}
|
||||
if int(c.Offsets(1)[0]) != 6 {
|
||||
t.Fatalf("seq 1 offset after decode = %d, want 6", int(c.Offsets(1)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheSetSeqsAndUpdate(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c.SetSeqs([]int{0, 1})
|
||||
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{3, 3}}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 6, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 6, 8)
|
||||
c.Update(b, k, v)
|
||||
|
||||
c.SetSeqs([]int{1})
|
||||
|
||||
// Update surviving sequence
|
||||
b2 := &batch.ForwardBatch{SeqIDs: []int{1}, SeqLens: []int{1}}
|
||||
k2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(b2, k2, v2)
|
||||
|
||||
if int(c.Offsets(1)[0]) != 4 {
|
||||
t.Fatalf("seq 1 offset = %d, want 4", int(c.Offsets(1)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheRebuildWithOldLengths(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c.SetSeqs([]int{0})
|
||||
|
||||
// Fill to capacity boundary
|
||||
for range 256 {
|
||||
b := singleTokenBatch
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(b, k, v)
|
||||
}
|
||||
|
||||
if int(c.Offsets(0)[0]) != 256 {
|
||||
t.Fatalf("offset = %d, want 256", int(c.Offsets(0)[0]))
|
||||
}
|
||||
|
||||
// Next token triggers rebuild (exceeds capacity)
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(singleTokenBatch, k, v)
|
||||
|
||||
if int(c.Offsets(0)[0]) != 257 {
|
||||
t.Fatalf("offset after rebuild = %d, want 257", int(c.Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCacheMultiSeqWindowedHistory(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
c.SetSeqs([]int{0, 1})
|
||||
|
||||
// Fill both sequences past the window
|
||||
for range 6 {
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
_, _, kv := c.Update(b, k, v)
|
||||
|
||||
// After enough tokens, SeqLens should be clamped to window
|
||||
if kv.SeqLens[0] > 4 || kv.SeqLens[1] > 4 {
|
||||
t.Fatalf("SeqLens %v exceed window size 4", kv.SeqLens)
|
||||
}
|
||||
}
|
||||
|
||||
offsets := c.Offsets(0, 1)
|
||||
if int(offsets[0]) != 6 || int(offsets[1]) != 6 {
|
||||
t.Fatalf("offsets = %v, want [6 6]", offsets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheSetSeqsAfterMaterialized(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
c.SetSeqs([]int{0})
|
||||
|
||||
// Materialize with some tokens
|
||||
for range 5 {
|
||||
b := singleTokenBatch
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(b, k, v)
|
||||
}
|
||||
if int(c.Offsets(0)[0]) != 5 {
|
||||
t.Fatalf("seq 0 offset = %d, want 5", int(c.Offsets(0)[0]))
|
||||
}
|
||||
|
||||
// Add a new sequence after buffer already exists
|
||||
c.SetSeqs([]int{0, 1})
|
||||
|
||||
// Update both sequences
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
c.Update(b, k, v)
|
||||
|
||||
if int(c.Offsets(0)[0]) != 6 {
|
||||
t.Fatalf("seq 0 offset = %d, want 6", int(c.Offsets(0)[0]))
|
||||
}
|
||||
if int(c.Offsets(1)[0]) != 1 {
|
||||
t.Fatalf("seq 1 offset = %d, want 1", int(c.Offsets(1)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCacheSetSeqsAfterMaterialized(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
c.SetSeqs([]int{0})
|
||||
|
||||
for range 3 {
|
||||
b := singleTokenBatch
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(b, k, v)
|
||||
}
|
||||
|
||||
// Add after materialized
|
||||
c.SetSeqs([]int{0, 1})
|
||||
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
|
||||
c.Update(b, k, v)
|
||||
|
||||
if int(c.Offsets(0)[0]) != 4 {
|
||||
t.Fatalf("seq 0 offset = %d, want 4", int(c.Offsets(0)[0]))
|
||||
}
|
||||
if int(c.Offsets(1)[0]) != 1 {
|
||||
t.Fatalf("seq 1 offset = %d, want 1", int(c.Offsets(1)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
334
x/mlxrunner/cache/recurrent.go
vendored
334
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,18 +1,25 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
// RecurrentCache stores state for linear-recurrent layers using pool tensors.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
// convState: [poolSize, convTail, convDim]
|
||||
// deltaState: [poolSize, numVHeads, headVDim, headKDim]
|
||||
//
|
||||
// Row i in the pool belongs to seqOrder[i]. seqOffsets[i] tracks
|
||||
// how many tokens sequence i has processed.
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
seqOffsets []int
|
||||
seqOrder []int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
@@ -21,22 +28,6 @@ type RecurrentCache struct {
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setState(old, v *mlx.Array, contiguous bool) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
v = mlx.Contiguous(v, false)
|
||||
}
|
||||
v = v.Clone()
|
||||
|
||||
mlx.Pin(v)
|
||||
mlx.Unpin(old)
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
@@ -47,53 +38,151 @@ func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
func (c *RecurrentCache) setState(old, v *mlx.Array) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
v = v.Clone()
|
||||
mlx.Pin(v)
|
||||
mlx.Unpin(old)
|
||||
return v
|
||||
}
|
||||
|
||||
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
// seqIndex returns the pool row index for a seqID, or -1 if not found.
|
||||
func (c *RecurrentCache) seqIndex(seqID int) int {
|
||||
for i, id := range c.seqOrder {
|
||||
if id == seqID {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// ensure grows pool tensors if needed to match poolSize, preserving existing rows.
|
||||
func (c *RecurrentCache) ensure(poolSize int, dtype mlx.DType) {
|
||||
if poolSize <= 0 {
|
||||
return
|
||||
}
|
||||
if c.convState != nil && c.convState.Valid() && c.convState.DType() == dtype && c.convState.Dim(0) == poolSize {
|
||||
return
|
||||
}
|
||||
|
||||
if needConv {
|
||||
c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false)
|
||||
grow := poolSize
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
if c.convState.DType() != dtype {
|
||||
// Dtype changed — replace entire pool
|
||||
c.convState = c.setState(c.convState, mlx.Zeros(dtype, poolSize, c.convTail, c.convDim))
|
||||
c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, poolSize, c.numVHeads, c.headVDim, c.headKDim))
|
||||
return
|
||||
}
|
||||
grow = poolSize - c.convState.Dim(0)
|
||||
}
|
||||
if needDelta {
|
||||
c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim), false)
|
||||
|
||||
if grow <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newConvRows := mlx.Zeros(dtype, grow, c.convTail, c.convDim)
|
||||
newDeltaRows := mlx.Zeros(dtype, grow, c.numVHeads, c.headVDim, c.headKDim)
|
||||
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
c.convState = c.setState(c.convState, c.convState.Concatenate(0, newConvRows))
|
||||
c.deltaState = c.setState(c.deltaState, c.deltaState.Concatenate(0, newDeltaRows))
|
||||
} else {
|
||||
c.convState = c.setState(c.convState, newConvRows)
|
||||
c.deltaState = c.setState(c.deltaState, newDeltaRows)
|
||||
}
|
||||
}
|
||||
|
||||
// batchExtent returns the smallest [start, end) range of pool rows covering all batch sequences.
|
||||
func (c *RecurrentCache) batchExtent(b *batch.ForwardBatch) (int, int) {
|
||||
minIdx := -1
|
||||
maxIdx := 0
|
||||
for _, seqID := range b.SeqIDs {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx < 0 {
|
||||
panic(fmt.Sprintf("RecurrentCache.batchExtent: sequence %d not found in cache", seqID))
|
||||
}
|
||||
if minIdx < 0 || idx < minIdx {
|
||||
minIdx = idx
|
||||
}
|
||||
if idx+1 > maxIdx {
|
||||
maxIdx = idx + 1
|
||||
}
|
||||
}
|
||||
if minIdx < 0 {
|
||||
minIdx = 0
|
||||
}
|
||||
return minIdx, maxIdx
|
||||
}
|
||||
|
||||
// stateHistory builds KVHistory mapping batch positions to pool rows,
|
||||
// remapped relative to sliceStart.
|
||||
func (c *RecurrentCache) stateHistory(b *batch.ForwardBatch, sliceStart int) mlx.KVHistory {
|
||||
n := len(b.SeqIDs)
|
||||
indices := make([]int32, n)
|
||||
seqLens := make([]int, n)
|
||||
for i, seqID := range b.SeqIDs {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx < 0 {
|
||||
panic(fmt.Sprintf("RecurrentCache.stateHistory: sequence %d not found in cache", seqID))
|
||||
}
|
||||
indices[i] = int32(idx - sliceStart)
|
||||
seqLens[i] = c.seqOffsets[idx]
|
||||
}
|
||||
return mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32(indices, []int32{int32(n), 1}),
|
||||
SeqLens: seqLens,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(b *batch.ForwardBatch, dtype mlx.DType) (*mlx.Array, mlx.KVHistory) {
|
||||
c.ensure(1, dtype)
|
||||
return c.convState, mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32([]int32{0}, []int32{1, 1}),
|
||||
SeqLens: []int{c.offset},
|
||||
}
|
||||
c.ensure(len(c.seqOrder), dtype)
|
||||
sliceStart, sliceEnd := c.batchExtent(b)
|
||||
return c.convState.Slice(mlx.Slice(sliceStart, sliceEnd), mlx.Slice(), mlx.Slice()),
|
||||
c.stateHistory(b, sliceStart)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(b *batch.ForwardBatch, v *mlx.Array) {
|
||||
c.convState = c.setState(c.convState, v, true)
|
||||
n := int32(len(b.SeqIDs))
|
||||
indices := c.batchIndices(b)
|
||||
// Reshape to [N, 1, 1] for broadcasting with [poolSize, convTail, convDim]
|
||||
indices = mlx.Reshape(indices, n, 1, 1)
|
||||
c.convState.Set(c.convState.PutAlongAxis(indices, v, 0))
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(b *batch.ForwardBatch, dtype mlx.DType) (*mlx.Array, mlx.KVHistory) {
|
||||
c.ensure(1, dtype)
|
||||
return c.deltaState, mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32([]int32{0}, []int32{1, 1}),
|
||||
SeqLens: []int{c.offset},
|
||||
}
|
||||
c.ensure(len(c.seqOrder), dtype)
|
||||
sliceStart, sliceEnd := c.batchExtent(b)
|
||||
return c.deltaState.Slice(mlx.Slice(sliceStart, sliceEnd), mlx.Slice(), mlx.Slice(), mlx.Slice()),
|
||||
c.stateHistory(b, sliceStart)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(b *batch.ForwardBatch, v *mlx.Array) {
|
||||
c.deltaState = c.setState(c.deltaState, v, false)
|
||||
n := int32(len(b.SeqIDs))
|
||||
indices := c.batchIndices(b)
|
||||
// Reshape to [N, 1, 1, 1] for broadcasting with [poolSize, numVHeads, headVDim, headKDim]
|
||||
indices = mlx.Reshape(indices, n, 1, 1, 1)
|
||||
c.deltaState.Set(c.deltaState.PutAlongAxis(indices, v, 0))
|
||||
}
|
||||
|
||||
// batchIndices returns an int32 tensor mapping each batch position to its
|
||||
// pool row index, for use with PutAlongAxis scatter.
|
||||
func (c *RecurrentCache) batchIndices(b *batch.ForwardBatch) *mlx.Array {
|
||||
idx := make([]int32, len(b.SeqIDs))
|
||||
for i, seqID := range b.SeqIDs {
|
||||
idx[i] = int32(c.seqIndex(seqID))
|
||||
}
|
||||
return mlx.NewArrayInt32(idx, []int32{int32(len(idx))})
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(b *batch.ForwardBatch) {
|
||||
c.offset += b.TotalLen()
|
||||
for i, seqID := range b.SeqIDs {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx >= 0 {
|
||||
c.seqOffsets[idx] += b.SeqLens[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
@@ -104,55 +193,79 @@ func (c *RecurrentCache) State() []*mlx.Array {
|
||||
return []*mlx.Array{c.convState, c.deltaState}
|
||||
}
|
||||
|
||||
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
|
||||
// does not depend on any parent state.
|
||||
// recurrentSnapshot holds paged-out recurrent state for one sequence.
|
||||
type recurrentSnapshot struct {
|
||||
convState, deltaState *mlx.Array
|
||||
offset int
|
||||
}
|
||||
|
||||
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
|
||||
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
|
||||
func (s *recurrentSnapshot) Size() int {
|
||||
n := 0
|
||||
if s.convState != nil {
|
||||
n += s.convState.NumBytes()
|
||||
}
|
||||
if s.deltaState != nil {
|
||||
n += s.deltaState.NumBytes()
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
|
||||
// Recurrent state is not position-sliceable — always snapshot the full state.
|
||||
if c.convState == nil && c.deltaState == nil {
|
||||
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
|
||||
|
||||
func (c *RecurrentCache) Snapshot(seqID int, fromOffset int) Snapshot {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
snap := &recurrentSnapshot{offset: c.offset}
|
||||
snap.convState = c.convState.Clone()
|
||||
snap.deltaState = c.deltaState.Clone()
|
||||
mlx.Pin(snap.convState, snap.deltaState)
|
||||
|
||||
snap := &recurrentSnapshot{offset: c.seqOffsets[idx]}
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
row := c.convState.Slice(mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice())
|
||||
snap.convState = mlx.Contiguous(row, false)
|
||||
mlx.Pin(snap.convState)
|
||||
}
|
||||
if c.deltaState != nil && c.deltaState.Valid() {
|
||||
row := c.deltaState.Slice(mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice(), mlx.Slice())
|
||||
snap.deltaState = mlx.Contiguous(row, false)
|
||||
mlx.Pin(snap.deltaState)
|
||||
}
|
||||
mlx.AsyncEval(snap.convState, snap.deltaState)
|
||||
return snap
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
||||
func (c *RecurrentCache) Restore(seqID int, snapshot Snapshot, target int) bool {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
// Recurrent state is cumulative and can't rewind. Only succeed
|
||||
// if we're already at the target (no-op).
|
||||
return target == c.offset
|
||||
return target == c.seqOffsets[idx]
|
||||
}
|
||||
|
||||
snap := snapshot.(*recurrentSnapshot)
|
||||
|
||||
// Recurrent snapshots encode cumulative state up to exactly
|
||||
// snap.offset. Target must match — rewinding would leave stale
|
||||
// state, and advancing isn't possible without feeding tokens.
|
||||
if target != snap.offset {
|
||||
return false
|
||||
}
|
||||
|
||||
c.convState = c.setState(c.convState, snap.convState, false)
|
||||
c.deltaState = c.setState(c.deltaState, snap.deltaState, false)
|
||||
c.offset = snap.offset
|
||||
|
||||
if snap.convState != nil {
|
||||
if c.convState == nil {
|
||||
c.ensure(len(c.seqOrder), snap.convState.DType())
|
||||
}
|
||||
c.convState.Set(c.convState.SliceUpdate(snap.convState,
|
||||
mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice()))
|
||||
}
|
||||
if snap.deltaState != nil {
|
||||
if c.deltaState == nil {
|
||||
c.ensure(len(c.seqOrder), snap.deltaState.DType())
|
||||
}
|
||||
c.deltaState.Set(c.deltaState.SliceUpdate(snap.deltaState,
|
||||
mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice(), mlx.Slice()))
|
||||
}
|
||||
c.seqOffsets[idx] = snap.offset
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
|
||||
// Recurrent snapshots are self-contained — child supersedes parent.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
@@ -160,15 +273,86 @@ func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
// Recurrent state is cumulative and not position-sliceable.
|
||||
// Cannot recover intermediate state at the split point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
mlx.Unpin(c.convState, c.deltaState)
|
||||
c.convState, c.deltaState = nil, nil
|
||||
c.offset = 0
|
||||
// Preserve sequence registration — callers may Restore after Free.
|
||||
for i := range c.seqOffsets {
|
||||
c.seqOffsets[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offsets() []int32 { return []int32{int32(c.offset)} }
|
||||
func (c *RecurrentCache) Offsets(seqIDs ...int) []int32 {
|
||||
offsets := make([]int32, len(seqIDs))
|
||||
for i, seqID := range seqIDs {
|
||||
idx := c.seqIndex(seqID)
|
||||
if idx >= 0 {
|
||||
offsets[i] = int32(c.seqOffsets[idx])
|
||||
}
|
||||
}
|
||||
return offsets
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetSeqs(seqIDs []int) {
|
||||
wanted := make(map[int]bool, len(seqIDs))
|
||||
for _, id := range seqIDs {
|
||||
wanted[id] = true
|
||||
}
|
||||
|
||||
changed := len(seqIDs) != len(c.seqOrder)
|
||||
if !changed {
|
||||
for _, id := range c.seqOrder {
|
||||
if !wanted[id] {
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
// Build new order: preserve existing order for survivors, then append new
|
||||
newOrder := make([]int, 0, len(seqIDs))
|
||||
newOffsets := make([]int, 0, len(seqIDs))
|
||||
var survivingRows []int32
|
||||
for i, id := range c.seqOrder {
|
||||
if wanted[id] {
|
||||
survivingRows = append(survivingRows, int32(i))
|
||||
newOrder = append(newOrder, id)
|
||||
newOffsets = append(newOffsets, c.seqOffsets[i])
|
||||
}
|
||||
}
|
||||
added := make(map[int]bool, len(newOrder))
|
||||
for _, id := range newOrder {
|
||||
added[id] = true
|
||||
}
|
||||
for _, id := range seqIDs {
|
||||
if !added[id] {
|
||||
newOrder = append(newOrder, id)
|
||||
newOffsets = append(newOffsets, 0)
|
||||
}
|
||||
}
|
||||
|
||||
c.seqOrder = newOrder
|
||||
c.seqOffsets = newOffsets
|
||||
|
||||
// Rebuild pool tensor if it exists
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
dtype := c.convState.DType()
|
||||
if len(survivingRows) == 0 {
|
||||
mlx.Unpin(c.convState, c.deltaState)
|
||||
c.convState, c.deltaState = nil, nil
|
||||
} else if len(survivingRows) != c.convState.Dim(0) {
|
||||
takeIdx := mlx.NewArrayInt32(survivingRows, []int32{int32(len(survivingRows))})
|
||||
c.convState = c.setState(c.convState, c.convState.TakeAxis(takeIdx, 0))
|
||||
c.deltaState = c.setState(c.deltaState, c.deltaState.TakeAxis(takeIdx, 0))
|
||||
}
|
||||
if len(c.seqOrder) > 0 {
|
||||
c.ensure(len(c.seqOrder), dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
14
x/mlxrunner/cache/recurrent_test.go
vendored
14
x/mlxrunner/cache/recurrent_test.go
vendored
@@ -13,30 +13,32 @@ import (
|
||||
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||
c.SetSeqs([]int{0})
|
||||
|
||||
b := &batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{1}}
|
||||
_, _ = c.ConvState(b, mlx.DTypeFloat16)
|
||||
_, _ = c.DeltaState(b, mlx.DTypeFloat16)
|
||||
c.Advance(&batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{10}})
|
||||
|
||||
snap := c.Snapshot(0) // snap.offset == 10
|
||||
snap := c.Snapshot(0, 0) // snap.offset == 10
|
||||
|
||||
c.Advance(&batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{5}}) // cache now at 15
|
||||
|
||||
// target < snap.offset: fails (can't rewind past snapshot)
|
||||
if c.Restore(snap, 5) {
|
||||
if c.Restore(0, snap, 5) {
|
||||
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// target > snap.offset: fails (can't advance without feeding tokens)
|
||||
if c.Restore(snap, 15) {
|
||||
if c.Restore(0, snap, 15) {
|
||||
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// target == snap.offset: succeeds
|
||||
if !c.Restore(snap, 10) {
|
||||
if !c.Restore(0, snap, 10) {
|
||||
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||
}
|
||||
if int(c.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
||||
if int(c.Offsets(0)[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,14 +56,15 @@ func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||
func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRewindableCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
func (c *fakeRewindableCache) SetSeqs(seqIDs []int) {}
|
||||
|
||||
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
func (c *fakeRewindableCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
|
||||
if fromOffset >= len(c.tokens) {
|
||||
return nil
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
func (c *fakeRewindableCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
@@ -176,14 +177,15 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||
func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeSlidingWindowCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) SetSeqs(seqIDs []int) {}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
func (c *fakeSlidingWindowCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
|
||||
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
@@ -197,7 +199,7 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
func (c *fakeSlidingWindowCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
@@ -256,14 +258,15 @@ func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||
func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRecurrentCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
func (c *fakeRecurrentCache) SetSeqs(seqIDs []int) {}
|
||||
|
||||
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
func (c *fakeRecurrentCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
|
||||
// Recurrent state is cumulative; snapshot captures the full state.
|
||||
if len(c.tokens) == 0 {
|
||||
return nil
|
||||
@@ -277,7 +280,7 @@ func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
func (c *fakeRecurrentCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
return target == len(c.tokens) // can only no-op
|
||||
}
|
||||
@@ -367,9 +370,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
|
||||
for i, c := range e.caches {
|
||||
assertTokens(t, label, c, expected)
|
||||
// Verify all caches report the same offset.
|
||||
if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) {
|
||||
if i > 0 && int(c.Offsets(0)[0]) != int(e.caches[0].Offsets(0)[0]) {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||
label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0]))
|
||||
label, i, int(c.Offsets(0)[0]), int(e.caches[0].Offsets(0)[0]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,9 +455,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||
if len(kvc.caches) < 2 {
|
||||
return
|
||||
}
|
||||
expected := int(kvc.caches[0].Offsets()[0])
|
||||
expected := int(kvc.caches[0].Offsets(0)[0])
|
||||
for i := 1; i < len(kvc.caches); i++ {
|
||||
if got := int(kvc.caches[i].Offsets()[0]); got != expected {
|
||||
if got := int(kvc.caches[i].Offsets(0)[0]); got != expected {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestSplitNodeWithSnapshots(t *testing.T) {
|
||||
child := root.children[0]
|
||||
|
||||
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
|
||||
child.snapshots = []cache.Snapshot{rc.Snapshot(0, 0)}
|
||||
child.user = true
|
||||
|
||||
caches := []cache.Cache{rc}
|
||||
|
||||
@@ -492,7 +492,7 @@ func (a *Attention) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache,
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
|
||||
positions := batch.SequentialPositions(b, c.Offsets())
|
||||
positions := batch.SequentialPositions(b, c.Offsets(b.SeqIDs...))
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
|
||||
|
||||
@@ -707,7 +707,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
|
||||
@@ -242,7 +242,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
|
||||
@@ -259,7 +259,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
|
||||
@@ -1321,7 +1321,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
|
||||
Reference in New Issue
Block a user