mlxrunner: multi-sequence KVCache, RotatingKVCache, and RecurrentCache

KVCache: AddSeq/RemoveSeq register sequences with contiguous buffer
regions separated by gaps. Update scatters new K/V via PutAlongAxis.
Rebuild reallocates on sequence add/remove. SnapshotSeq extracts one
sequence's K/V. Single-sequence mode (no AddSeq) unchanged.

RotatingKVCache: multi-sequence delegates to KVCache regions with
windowKVHistory limiting visibility to trailing maxSize tokens per
sequence. Single-sequence retains existing ring buffer.

RecurrentCache: per-sequence conv/delta state in seqStates map.
ConvState/DeltaState gather states for batch, SetConvState/SetDeltaState
scatter back. AddSeqWithSnapshot restores recurrent state from snapshots.
This commit is contained in:
Jesse Gross
2026-04-02 15:11:32 -07:00
parent 02fe50c90c
commit d8067801c3
12 changed files with 1229 additions and 491 deletions

View File

@@ -67,11 +67,17 @@ func (c *kvCache) ensureCaches(m base.Model) {
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
// Register the default sequence for single-sequence prefill.
for _, kv := range c.caches {
if kv != nil {
kv.SetSeqs([]int{0})
}
}
}
@@ -167,7 +173,7 @@ func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
if kv == nil {
continue
}
snaps[j] = kv.Snapshot(fromOffset)
snaps[j] = kv.Snapshot(0, fromOffset)
}
node.setSnapshots(snaps, &c.pagedOutBytes)
pageOutCount++
@@ -184,7 +190,7 @@ func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
if kv == nil {
continue
}
if !kv.Restore(nil, rewindTarget) {
if !kv.Restore(0, nil, rewindTarget) {
kv.Free()
}
}
@@ -205,10 +211,10 @@ pageIn:
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if int(kv.Offsets()[0]) >= nodeTarget {
if int(kv.Offsets(0)[0]) >= nodeTarget {
continue
}
if !kv.Restore(node.snapshots[j], nodeTarget) {
if !kv.Restore(0, node.snapshots[j], nodeTarget) {
// Restore failed — stop page-in and let alignment
// bring all caches to a consistent offset.
break pageIn
@@ -224,8 +230,8 @@ pageIn:
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && int(kv.Offsets()[0]) != minOff {
if !kv.Restore(nil, minOff) {
if kv != nil && int(kv.Offsets(0)[0]) != minOff {
if !kv.Restore(0, nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
break
@@ -390,10 +396,10 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if int(kv.Offsets()[0]) != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets()[0])))
if int(kv.Offsets(0)[0]) != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets(0)[0])))
}
snaps[i] = kv.Snapshot(node.startOffset())
snaps[i] = kv.Snapshot(0, node.startOffset())
}
}
node.setSnapshots(snaps, &c.pagedOutBytes)
@@ -418,7 +424,7 @@ func (c *kvCache) minCacheOffset() int {
if kv == nil {
continue
}
if off := int(kv.Offsets()[0]); !found || off < offset {
if off := int(kv.Offsets(0)[0]); !found || off < offset {
offset = off
found = true
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -13,259 +14,351 @@ func skipIfNoMLX(t *testing.T) {
}
}
var singleTokenBatch = &batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{1}}
func newKVCacheWithSeq() *KVCache {
c := NewKVCache()
c.SetSeqs([]int{0})
return c
}
func newRotatingKVCacheWithSeq(maxSize int) *RotatingKVCache {
c := NewRotatingKVCache(maxSize)
c.SetSeqs([]int{0})
return c
}
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c := newKVCacheWithSeq()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
// Snapshot [5, 10).
snap := c.Snapshot(5)
snap := c.Snapshot(0, 5)
// Free the cache completely — offset is now 0.
c.Free()
// Restore should fail because cache doesn't have data up to fromOffset=5.
if c.Restore(snap, 10) {
if c.Restore(0, snap, 10) {
t.Fatal("expected Restore to fail with no base data")
}
}
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
// is preserved through a snapshot→free→restore cycle.
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c := newKVCacheWithSeq()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
snap := c.Snapshot(0)
snap := c.Snapshot(0, 0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
// Free and restore to a fresh cache.
c2 := NewKVCache()
if !c2.Restore(snap, 10) {
c2 := newKVCacheWithSeq()
if !c2.Restore(0, snap, 10) {
t.Fatal("Restore failed")
}
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
if int(c2.Offsets(0)[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets(0)[0]))
}
// Verify State() returns arrays with correct sequence dimension.
state := c2.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// keys shape: [B, H, seqLen, Dk]
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
if state[0].Dim(2) < 10 {
t.Fatalf("keys seq dim = %d, want >= 10", state[0].Dim(2))
}
}
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
// that can be sequentially restored to rebuild the original cache state.
func TestKVCacheSplitPreservesData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c := newKVCacheWithSeq()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
snap := c.Snapshot(0)
snap := c.Snapshot(0, 0)
parent, child := c.Split(snap, 5)
if parent == nil || child == nil {
t.Fatal("Split returned nil")
}
// Restore parent → offset=5, seq dim=5.
c2 := NewKVCache()
if !c2.Restore(parent, 5) {
c2 := newKVCacheWithSeq()
if !c2.Restore(0, parent, 5) {
t.Fatal("Restore(parent) failed")
}
if int(c2.Offsets()[0]) != 5 {
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
}
state := c2.State()
if state[0].Dim(2) != 5 {
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
if int(c2.Offsets(0)[0]) != 5 {
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets(0)[0]))
}
// Restore child on top → offset=10, seq dim=10.
if !c2.Restore(child, 10) {
if !c2.Restore(0, child, 10) {
t.Fatal("Restore(child) failed")
}
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
}
state = c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
if int(c2.Offsets(0)[0]) != 10 {
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets(0)[0]))
}
}
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
// produces a snapshot equivalent to the original.
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c := newKVCacheWithSeq()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
snap := c.Snapshot(0)
snap := c.Snapshot(0, 0)
parent, child := c.Split(snap, 6)
merged := c.Merge(parent, child)
if merged == nil {
t.Fatal("Merge returned nil")
}
c2 := NewKVCache()
if !c2.Restore(merged, 10) {
c2 := newKVCacheWithSeq()
if !c2.Restore(0, merged, 10) {
t.Fatal("Restore(merged) failed")
}
if int(c2.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
}
state := c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
if int(c2.Offsets(0)[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c2.Offsets(0)[0]))
}
}
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
func TestRotatingKVCacheRewindOutsideWindow(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
c := newRotatingKVCacheWithSeq(4)
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
// Offset 3 is outside the window.
if c.Restore(nil, 3) {
if c.Restore(0, nil, 3) {
t.Fatal("Restore(nil, 3) should fail when outside window")
}
}
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
// from a snapshot, the rotating cache has the correct window of data.
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
func TestRotatingKVCacheWindowedHistory(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
c := newRotatingKVCacheWithSeq(4)
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
snap := c.Snapshot(0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
_, _, kv := c.Update(singleTokenBatch, k, v)
// Feed 5 more tokens.
for range 5 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
if len(kv.SeqLens) != 1 {
t.Fatalf("SeqLens length = %d, want 1", len(kv.SeqLens))
}
// Restore to offset 10.
if !c.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if int(c.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
if kv.SeqLens[0] != 4 {
t.Fatalf("SeqLens[0] = %d, want 4 (window size)", kv.SeqLens[0])
}
}
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
// snapshot correctly preserves the write position (idx), so subsequent
// single-token updates land in the right buffer slot.
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
c := newRotatingKVCacheWithSeq(8)
// Fill the window: 6 tokens into a size-4 window.
// After this, idx has wrapped and the buffer has rotated.
for range 6 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
}
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
}
snap := c.Snapshot(0)
// Mutate the cache further so live state diverges from snapshot.
for range 3 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
c.Update(singleTokenBatch, k, v)
}
if int(c.Offsets(0)[0]) != 3 {
t.Fatalf("offset = %d, want 3", int(c.Offsets(0)[0]))
}
// Restore to snapshot state.
if !c.Restore(snap, 6) {
t.Fatal("Restore failed")
// Rewind before wrap should succeed
if !c.Restore(0, nil, 1) {
t.Fatal("Restore(nil, 1) should succeed before wrap")
}
if int(c.Offsets()[0]) != 6 {
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
}
// Feed one more token. If idx was restored correctly, this should
// produce a valid window of size 4 at offset 7.
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(nil, k, v)
if int(c.Offsets()[0]) != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
if int(c.Offsets(0)[0]) != 1 {
t.Fatalf("offset after restore = %d, want 1", int(c.Offsets(0)[0]))
}
}
func TestKVCacheMultiSeqUpdate(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c.SetSeqs([]int{0, 1})
// Prefill: seq 0 gets 3 tokens, seq 1 gets 5 tokens
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{3, 5}}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 8, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 8, 8)
c.Update(b, k, v)
if int(c.Offsets(0)[0]) != 3 {
t.Fatalf("seq 0 offset = %d, want 3", int(c.Offsets(0)[0]))
}
if int(c.Offsets(1)[0]) != 5 {
t.Fatalf("seq 1 offset = %d, want 5", int(c.Offsets(1)[0]))
}
// Decode: each seq gets 1 token
b2 := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
k2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
v2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
c.Update(b2, k2, v2)
if int(c.Offsets(0)[0]) != 4 {
t.Fatalf("seq 0 offset after decode = %d, want 4", int(c.Offsets(0)[0]))
}
if int(c.Offsets(1)[0]) != 6 {
t.Fatalf("seq 1 offset after decode = %d, want 6", int(c.Offsets(1)[0]))
}
}
func TestKVCacheSetSeqsAndUpdate(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c.SetSeqs([]int{0, 1})
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{3, 3}}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 6, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 6, 8)
c.Update(b, k, v)
c.SetSeqs([]int{1})
// Update surviving sequence
b2 := &batch.ForwardBatch{SeqIDs: []int{1}, SeqLens: []int{1}}
k2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v2 := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(b2, k2, v2)
if int(c.Offsets(1)[0]) != 4 {
t.Fatalf("seq 1 offset = %d, want 4", int(c.Offsets(1)[0]))
}
}
func TestKVCacheRebuildWithOldLengths(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c.SetSeqs([]int{0})
// Fill to capacity boundary
for range 256 {
b := singleTokenBatch
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(b, k, v)
}
if int(c.Offsets(0)[0]) != 256 {
t.Fatalf("offset = %d, want 256", int(c.Offsets(0)[0]))
}
// Next token triggers rebuild (exceeds capacity)
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(singleTokenBatch, k, v)
if int(c.Offsets(0)[0]) != 257 {
t.Fatalf("offset after rebuild = %d, want 257", int(c.Offsets(0)[0]))
}
}
func TestRotatingKVCacheMultiSeqWindowedHistory(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
c.SetSeqs([]int{0, 1})
// Fill both sequences past the window
for range 6 {
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
_, _, kv := c.Update(b, k, v)
// After enough tokens, SeqLens should be clamped to window
if kv.SeqLens[0] > 4 || kv.SeqLens[1] > 4 {
t.Fatalf("SeqLens %v exceed window size 4", kv.SeqLens)
}
}
offsets := c.Offsets(0, 1)
if int(offsets[0]) != 6 || int(offsets[1]) != 6 {
t.Fatalf("offsets = %v, want [6 6]", offsets)
}
}
func TestKVCacheSetSeqsAfterMaterialized(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
c.SetSeqs([]int{0})
// Materialize with some tokens
for range 5 {
b := singleTokenBatch
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(b, k, v)
}
if int(c.Offsets(0)[0]) != 5 {
t.Fatalf("seq 0 offset = %d, want 5", int(c.Offsets(0)[0]))
}
// Add a new sequence after buffer already exists
c.SetSeqs([]int{0, 1})
// Update both sequences
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
c.Update(b, k, v)
if int(c.Offsets(0)[0]) != 6 {
t.Fatalf("seq 0 offset = %d, want 6", int(c.Offsets(0)[0]))
}
if int(c.Offsets(1)[0]) != 1 {
t.Fatalf("seq 1 offset = %d, want 1", int(c.Offsets(1)[0]))
}
}
func TestRotatingKVCacheSetSeqsAfterMaterialized(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
c.SetSeqs([]int{0})
for range 3 {
b := singleTokenBatch
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(b, k, v)
}
// Add after materialized
c.SetSeqs([]int{0, 1})
b := &batch.ForwardBatch{SeqIDs: []int{0, 1}, SeqLens: []int{1, 1}}
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 2, 8)
c.Update(b, k, v)
if int(c.Offsets(0)[0]) != 4 {
t.Fatalf("seq 0 offset = %d, want 4", int(c.Offsets(0)[0]))
}
if int(c.Offsets(1)[0]) != 1 {
t.Fatalf("seq 1 offset = %d, want 1", int(c.Offsets(1)[0]))
}
}

View File

@@ -1,18 +1,25 @@
package cache
import (
"fmt"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RecurrentCache stores state for linear-recurrent layers.
// RecurrentCache stores state for linear-recurrent layers using pool tensors.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
// convState: [poolSize, convTail, convDim]
// deltaState: [poolSize, numVHeads, headVDim, headKDim]
//
// Row i in the pool belongs to seqOrder[i]. seqOffsets[i] tracks
// how many tokens sequence i has processed.
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
seqOffsets []int
seqOrder []int
convTail int
convDim int
@@ -21,22 +28,6 @@ type RecurrentCache struct {
headKDim int
}
func (c *RecurrentCache) setState(old, v *mlx.Array, contiguous bool) *mlx.Array {
if v == nil || !v.Valid() {
return old
}
if contiguous {
v = mlx.Contiguous(v, false)
}
v = v.Clone()
mlx.Pin(v)
mlx.Unpin(old)
return v
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
@@ -47,53 +38,151 @@ func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
func (c *RecurrentCache) setState(old, v *mlx.Array) *mlx.Array {
if v == nil || !v.Valid() {
return old
}
v = v.Clone()
mlx.Pin(v)
mlx.Unpin(old)
return v
}
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
// seqIndex returns the pool row index for a seqID, or -1 if not found.
func (c *RecurrentCache) seqIndex(seqID int) int {
for i, id := range c.seqOrder {
if id == seqID {
return i
}
}
return -1
}
// ensure grows pool tensors if needed to match poolSize, preserving existing rows.
func (c *RecurrentCache) ensure(poolSize int, dtype mlx.DType) {
if poolSize <= 0 {
return
}
if c.convState != nil && c.convState.Valid() && c.convState.DType() == dtype && c.convState.Dim(0) == poolSize {
return
}
if needConv {
c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false)
grow := poolSize
if c.convState != nil && c.convState.Valid() {
if c.convState.DType() != dtype {
// Dtype changed — replace entire pool
c.convState = c.setState(c.convState, mlx.Zeros(dtype, poolSize, c.convTail, c.convDim))
c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, poolSize, c.numVHeads, c.headVDim, c.headKDim))
return
}
grow = poolSize - c.convState.Dim(0)
}
if needDelta {
c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim), false)
if grow <= 0 {
return
}
newConvRows := mlx.Zeros(dtype, grow, c.convTail, c.convDim)
newDeltaRows := mlx.Zeros(dtype, grow, c.numVHeads, c.headVDim, c.headKDim)
if c.convState != nil && c.convState.Valid() {
c.convState = c.setState(c.convState, c.convState.Concatenate(0, newConvRows))
c.deltaState = c.setState(c.deltaState, c.deltaState.Concatenate(0, newDeltaRows))
} else {
c.convState = c.setState(c.convState, newConvRows)
c.deltaState = c.setState(c.deltaState, newDeltaRows)
}
}
// batchExtent returns the smallest [start, end) range of pool rows covering all batch sequences.
func (c *RecurrentCache) batchExtent(b *batch.ForwardBatch) (int, int) {
minIdx := -1
maxIdx := 0
for _, seqID := range b.SeqIDs {
idx := c.seqIndex(seqID)
if idx < 0 {
panic(fmt.Sprintf("RecurrentCache.batchExtent: sequence %d not found in cache", seqID))
}
if minIdx < 0 || idx < minIdx {
minIdx = idx
}
if idx+1 > maxIdx {
maxIdx = idx + 1
}
}
if minIdx < 0 {
minIdx = 0
}
return minIdx, maxIdx
}
// stateHistory builds KVHistory mapping batch positions to pool rows,
// remapped relative to sliceStart.
func (c *RecurrentCache) stateHistory(b *batch.ForwardBatch, sliceStart int) mlx.KVHistory {
n := len(b.SeqIDs)
indices := make([]int32, n)
seqLens := make([]int, n)
for i, seqID := range b.SeqIDs {
idx := c.seqIndex(seqID)
if idx < 0 {
panic(fmt.Sprintf("RecurrentCache.stateHistory: sequence %d not found in cache", seqID))
}
indices[i] = int32(idx - sliceStart)
seqLens[i] = c.seqOffsets[idx]
}
return mlx.KVHistory{
PageTable: mlx.NewArrayInt32(indices, []int32{int32(n), 1}),
SeqLens: seqLens,
}
}
func (c *RecurrentCache) ConvState(b *batch.ForwardBatch, dtype mlx.DType) (*mlx.Array, mlx.KVHistory) {
c.ensure(1, dtype)
return c.convState, mlx.KVHistory{
PageTable: mlx.NewArrayInt32([]int32{0}, []int32{1, 1}),
SeqLens: []int{c.offset},
}
c.ensure(len(c.seqOrder), dtype)
sliceStart, sliceEnd := c.batchExtent(b)
return c.convState.Slice(mlx.Slice(sliceStart, sliceEnd), mlx.Slice(), mlx.Slice()),
c.stateHistory(b, sliceStart)
}
func (c *RecurrentCache) SetConvState(b *batch.ForwardBatch, v *mlx.Array) {
c.convState = c.setState(c.convState, v, true)
n := int32(len(b.SeqIDs))
indices := c.batchIndices(b)
// Reshape to [N, 1, 1] for broadcasting with [poolSize, convTail, convDim]
indices = mlx.Reshape(indices, n, 1, 1)
c.convState.Set(c.convState.PutAlongAxis(indices, v, 0))
}
func (c *RecurrentCache) DeltaState(b *batch.ForwardBatch, dtype mlx.DType) (*mlx.Array, mlx.KVHistory) {
c.ensure(1, dtype)
return c.deltaState, mlx.KVHistory{
PageTable: mlx.NewArrayInt32([]int32{0}, []int32{1, 1}),
SeqLens: []int{c.offset},
}
c.ensure(len(c.seqOrder), dtype)
sliceStart, sliceEnd := c.batchExtent(b)
return c.deltaState.Slice(mlx.Slice(sliceStart, sliceEnd), mlx.Slice(), mlx.Slice(), mlx.Slice()),
c.stateHistory(b, sliceStart)
}
func (c *RecurrentCache) SetDeltaState(b *batch.ForwardBatch, v *mlx.Array) {
c.deltaState = c.setState(c.deltaState, v, false)
n := int32(len(b.SeqIDs))
indices := c.batchIndices(b)
// Reshape to [N, 1, 1, 1] for broadcasting with [poolSize, numVHeads, headVDim, headKDim]
indices = mlx.Reshape(indices, n, 1, 1, 1)
c.deltaState.Set(c.deltaState.PutAlongAxis(indices, v, 0))
}
// batchIndices returns an int32 tensor mapping each batch position to its
// pool row index, for use with PutAlongAxis scatter.
func (c *RecurrentCache) batchIndices(b *batch.ForwardBatch) *mlx.Array {
idx := make([]int32, len(b.SeqIDs))
for i, seqID := range b.SeqIDs {
idx[i] = int32(c.seqIndex(seqID))
}
return mlx.NewArrayInt32(idx, []int32{int32(len(idx))})
}
func (c *RecurrentCache) Advance(b *batch.ForwardBatch) {
c.offset += b.TotalLen()
for i, seqID := range b.SeqIDs {
idx := c.seqIndex(seqID)
if idx >= 0 {
c.seqOffsets[idx] += b.SeqLens[i]
}
}
}
func (c *RecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
@@ -104,55 +193,79 @@ func (c *RecurrentCache) State() []*mlx.Array {
return []*mlx.Array{c.convState, c.deltaState}
}
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
// does not depend on any parent state.
// recurrentSnapshot holds paged-out recurrent state for one sequence.
type recurrentSnapshot struct {
convState, deltaState *mlx.Array
offset int
}
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
func (s *recurrentSnapshot) Size() int {
n := 0
if s.convState != nil {
n += s.convState.NumBytes()
}
if s.deltaState != nil {
n += s.deltaState.NumBytes()
}
return n
}
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
// Recurrent state is not position-sliceable — always snapshot the full state.
if c.convState == nil && c.deltaState == nil {
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
func (c *RecurrentCache) Snapshot(seqID int, fromOffset int) Snapshot {
idx := c.seqIndex(seqID)
if idx < 0 {
return nil
}
snap := &recurrentSnapshot{offset: c.offset}
snap.convState = c.convState.Clone()
snap.deltaState = c.deltaState.Clone()
mlx.Pin(snap.convState, snap.deltaState)
snap := &recurrentSnapshot{offset: c.seqOffsets[idx]}
if c.convState != nil && c.convState.Valid() {
row := c.convState.Slice(mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice())
snap.convState = mlx.Contiguous(row, false)
mlx.Pin(snap.convState)
}
if c.deltaState != nil && c.deltaState.Valid() {
row := c.deltaState.Slice(mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice(), mlx.Slice())
snap.deltaState = mlx.Contiguous(row, false)
mlx.Pin(snap.deltaState)
}
mlx.AsyncEval(snap.convState, snap.deltaState)
return snap
}
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
func (c *RecurrentCache) Restore(seqID int, snapshot Snapshot, target int) bool {
idx := c.seqIndex(seqID)
if idx < 0 {
return false
}
if snapshot == nil {
// Recurrent state is cumulative and can't rewind. Only succeed
// if we're already at the target (no-op).
return target == c.offset
return target == c.seqOffsets[idx]
}
snap := snapshot.(*recurrentSnapshot)
// Recurrent snapshots encode cumulative state up to exactly
// snap.offset. Target must match — rewinding would leave stale
// state, and advancing isn't possible without feeding tokens.
if target != snap.offset {
return false
}
c.convState = c.setState(c.convState, snap.convState, false)
c.deltaState = c.setState(c.deltaState, snap.deltaState, false)
c.offset = snap.offset
if snap.convState != nil {
if c.convState == nil {
c.ensure(len(c.seqOrder), snap.convState.DType())
}
c.convState.Set(c.convState.SliceUpdate(snap.convState,
mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice()))
}
if snap.deltaState != nil {
if c.deltaState == nil {
c.ensure(len(c.seqOrder), snap.deltaState.DType())
}
c.deltaState.Set(c.deltaState.SliceUpdate(snap.deltaState,
mlx.Slice(idx, idx+1), mlx.Slice(), mlx.Slice(), mlx.Slice()))
}
c.seqOffsets[idx] = snap.offset
return true
}
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
// Recurrent snapshots are self-contained — child supersedes parent.
if parent != nil {
parent.Close()
}
@@ -160,15 +273,86 @@ func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
}
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Recurrent state is cumulative and not position-sliceable.
// Cannot recover intermediate state at the split point.
return nil, snapshot
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
// Preserve sequence registration — callers may Restore after Free.
for i := range c.seqOffsets {
c.seqOffsets[i] = 0
}
}
func (c *RecurrentCache) Offsets() []int32 { return []int32{int32(c.offset)} }
func (c *RecurrentCache) Offsets(seqIDs ...int) []int32 {
offsets := make([]int32, len(seqIDs))
for i, seqID := range seqIDs {
idx := c.seqIndex(seqID)
if idx >= 0 {
offsets[i] = int32(c.seqOffsets[idx])
}
}
return offsets
}
func (c *RecurrentCache) SetSeqs(seqIDs []int) {
wanted := make(map[int]bool, len(seqIDs))
for _, id := range seqIDs {
wanted[id] = true
}
changed := len(seqIDs) != len(c.seqOrder)
if !changed {
for _, id := range c.seqOrder {
if !wanted[id] {
changed = true
break
}
}
}
if !changed {
return
}
// Build new order: preserve existing order for survivors, then append new
newOrder := make([]int, 0, len(seqIDs))
newOffsets := make([]int, 0, len(seqIDs))
var survivingRows []int32
for i, id := range c.seqOrder {
if wanted[id] {
survivingRows = append(survivingRows, int32(i))
newOrder = append(newOrder, id)
newOffsets = append(newOffsets, c.seqOffsets[i])
}
}
added := make(map[int]bool, len(newOrder))
for _, id := range newOrder {
added[id] = true
}
for _, id := range seqIDs {
if !added[id] {
newOrder = append(newOrder, id)
newOffsets = append(newOffsets, 0)
}
}
c.seqOrder = newOrder
c.seqOffsets = newOffsets
// Rebuild pool tensor if it exists
if c.convState != nil && c.convState.Valid() {
dtype := c.convState.DType()
if len(survivingRows) == 0 {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
} else if len(survivingRows) != c.convState.Dim(0) {
takeIdx := mlx.NewArrayInt32(survivingRows, []int32{int32(len(survivingRows))})
c.convState = c.setState(c.convState, c.convState.TakeAxis(takeIdx, 0))
c.deltaState = c.setState(c.deltaState, c.deltaState.TakeAxis(takeIdx, 0))
}
if len(c.seqOrder) > 0 {
c.ensure(len(c.seqOrder), dtype)
}
}
}

View File

@@ -13,30 +13,32 @@ import (
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8)
c.SetSeqs([]int{0})
b := &batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{1}}
_, _ = c.ConvState(b, mlx.DTypeFloat16)
_, _ = c.DeltaState(b, mlx.DTypeFloat16)
c.Advance(&batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{10}})
snap := c.Snapshot(0) // snap.offset == 10
snap := c.Snapshot(0, 0) // snap.offset == 10
c.Advance(&batch.ForwardBatch{SeqIDs: []int{0}, SeqLens: []int{5}}) // cache now at 15
// target < snap.offset: fails (can't rewind past snapshot)
if c.Restore(snap, 5) {
if c.Restore(0, snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
}
// target > snap.offset: fails (can't advance without feeding tokens)
if c.Restore(snap, 15) {
if c.Restore(0, snap, 15) {
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
}
// target == snap.offset: succeeds
if !c.Restore(snap, 10) {
if !c.Restore(0, snap, 10) {
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
}
if int(c.Offsets()[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
if int(c.Offsets(0)[0]) != 10 {
t.Fatalf("offset = %d, want 10", int(c.Offsets(0)[0]))
}
}

View File

@@ -56,14 +56,15 @@ func (c *fakeRewindableCache) feed(tokens []int32) {
func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
func (c *fakeRewindableCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRewindableCache) Free() {
c.tokens = nil
}
func (c *fakeRewindableCache) SetSeqs(seqIDs []int) {}
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
func (c *fakeRewindableCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
if fromOffset >= len(c.tokens) {
return nil
}
@@ -80,7 +81,7 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
return s
}
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
func (c *fakeRewindableCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
@@ -176,14 +177,15 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) {
func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
func (c *fakeSlidingWindowCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeSlidingWindowCache) Free() {
c.tokens = nil
}
func (c *fakeSlidingWindowCache) SetSeqs(seqIDs []int) {}
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
func (c *fakeSlidingWindowCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
return nil
}
@@ -197,7 +199,7 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
return s
}
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
func (c *fakeSlidingWindowCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
@@ -256,14 +258,15 @@ func (c *fakeRecurrentCache) feed(tokens []int32) {
func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
return nil, nil, mlx.KVHistory{}
}
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
func (c *fakeRecurrentCache) Offsets(_ ...int) []int32 { return []int32{int32(len(c.tokens))} }
func (c *fakeRecurrentCache) Free() {
c.tokens = nil
}
func (c *fakeRecurrentCache) SetSeqs(seqIDs []int) {}
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
func (c *fakeRecurrentCache) Snapshot(seqID int, fromOffset int) cache.Snapshot {
// Recurrent state is cumulative; snapshot captures the full state.
if len(c.tokens) == 0 {
return nil
@@ -277,7 +280,7 @@ func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
return s
}
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
func (c *fakeRecurrentCache) Restore(seqID int, snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
return target == len(c.tokens) // can only no-op
}
@@ -367,9 +370,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
for i, c := range e.caches {
assertTokens(t, label, c, expected)
// Verify all caches report the same offset.
if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) {
if i > 0 && int(c.Offsets(0)[0]) != int(e.caches[0].Offsets(0)[0]) {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0]))
label, i, int(c.Offsets(0)[0]), int(e.caches[0].Offsets(0)[0]))
}
}
}
@@ -452,9 +455,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
if len(kvc.caches) < 2 {
return
}
expected := int(kvc.caches[0].Offsets()[0])
expected := int(kvc.caches[0].Offsets(0)[0])
for i := 1; i < len(kvc.caches); i++ {
if got := int(kvc.caches[i].Offsets()[0]); got != expected {
if got := int(kvc.caches[i].Offsets(0)[0]); got != expected {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
}
}

View File

@@ -137,7 +137,7 @@ func TestSplitNodeWithSnapshots(t *testing.T) {
child := root.children[0]
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
child.snapshots = []cache.Snapshot{rc.Snapshot(0, 0)}
child.user = true
caches := []cache.Cache{rc}

View File

@@ -492,7 +492,7 @@ func (a *Attention) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache,
ropeTheta = cfg.RopeLocalBaseFreq
}
positions := batch.SequentialPositions(b, c.Offsets())
positions := batch.SequentialPositions(b, c.Offsets(b.SeqIDs...))
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)

View File

@@ -707,7 +707,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
for i, layer := range m.Layers {
var c cache.Cache

View File

@@ -242,7 +242,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
for i, layer := range m.Layers {
var c cache.Cache

View File

@@ -259,7 +259,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
for i, layer := range m.Layers {
var c cache.Cache

View File

@@ -1321,7 +1321,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
positions := batch.SequentialPositions(b, caches[0].Offsets(b.SeqIDs...))
for i, layer := range m.Layers {
var c cache.Cache