diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index c1c53b668..49bf56fca 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -205,7 +205,7 @@ pageIn: if j >= len(node.snapshots) || node.snapshots[j] == nil { continue } - if kv.Offset() >= nodeTarget { + if int(kv.Offsets()[0]) >= nodeTarget { continue } if !kv.Restore(node.snapshots[j], nodeTarget) { @@ -224,7 +224,7 @@ pageIn: c.activePath = newPath minOff := c.minCacheOffset() for _, kv := range c.caches { - if kv != nil && kv.Offset() != minOff { + if kv != nil && int(kv.Offsets()[0]) != minOff { if !kv.Restore(nil, minOff) { slog.Warn("failed to restore cache, freeing all caches", "offset", minOff) c.freeAll() @@ -390,8 +390,8 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { snaps := make([]cache.Snapshot, len(c.caches)) for i, kv := range c.caches { if kv != nil { - if kv.Offset() != cacheOffset { - panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset())) + if int(kv.Offsets()[0]) != cacheOffset { + panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets()[0]))) } snaps[i] = kv.Snapshot(node.startOffset()) } @@ -418,7 +418,7 @@ func (c *kvCache) minCacheOffset() int { if kv == nil { continue } - if off := kv.Offset(); !found || off < offset { + if off := int(kv.Offsets()[0]); !found || off < offset { offset = off found = true } diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 39f5c1f5a..d0e72f8f0 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -2,15 +2,16 @@ package cache import ( "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/mlx" ) type Cache interface { - Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) + Update(b *batch.ForwardBatch, keys, values *mlx.Array) (newKeys, newValues *mlx.Array, kv mlx.KVHistory) // State returns the cache-owned state roots that should be kept/evaluated. State() []*mlx.Array Free() - Offset() int + Offsets() []int32 // Snapshot copies cache state from fromOffset to current offset into // pinned VRAM arrays. The active cache is unchanged. @@ -49,7 +50,7 @@ func NewKVCache() *KVCache { return &KVCache{step: 256} } -func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { +func (c *KVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3) prev := c.offset @@ -77,8 +78,17 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice())) c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice())) + pt := make([]int32, c.offset) + for i := range pt { + pt[i] = int32(i) + } + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), - c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + mlx.KVHistory{ + PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(c.offset)}), + SeqLens: []int{c.offset}, + } } func (c *KVCache) State() []*mlx.Array { @@ -143,7 +153,7 @@ func (c *KVCache) Restore(snapshot Snapshot, target int) bool { // Rewind to snapshot start, then feed snapshot data through Update. c.offset = snap.fromOffset - c.Update(snap.keys, snap.values) + c.Update(nil, snap.keys, snap.values) // Clamp to target if needed (target may be less than full snapshot). if target < c.offset { @@ -226,7 +236,7 @@ func (c *KVCache) Free() { c.offset = 0 } -func (c *KVCache) Offset() int { return c.offset } +func (c *KVCache) Offsets() []int32 { return []int32{int32(c.offset)} } // RotatingKVCache implements sliding window attention with bounded memory type RotatingKVCache struct { @@ -240,11 +250,24 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache { return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()} } -func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { +func (c *RotatingKVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { + var k, v *mlx.Array if keys.Dim(2) > 1 { - return c.concat(keys, values) + k, v = c.concat(keys, values) + } else { + k, v = c.update(keys, values) + } + + visibleLen := min(c.offset, c.maxSize) + pt := make([]int32, visibleLen) + for i := range visibleLen { + pt[i] = int32(i) + } + + return k, v, mlx.KVHistory{ + PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(visibleLen)}), + SeqLens: []int{visibleLen}, } - return c.update(keys, values) } func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) { diff --git a/x/mlxrunner/cache/cache_test.go b/x/mlxrunner/cache/cache_test.go index 86c26004a..1c81e828c 100644 --- a/x/mlxrunner/cache/cache_test.go +++ b/x/mlxrunner/cache/cache_test.go @@ -20,7 +20,7 @@ func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } // Snapshot [5, 10). @@ -44,7 +44,7 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } snap := c.Snapshot(0) @@ -57,8 +57,8 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) { if !c2.Restore(snap, 10) { t.Fatal("Restore failed") } - if c2.Offset() != 10 { - t.Fatalf("offset = %d, want 10", c2.Offset()) + if int(c2.Offsets()[0]) != 10 { + t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0])) } // Verify State() returns arrays with correct sequence dimension. @@ -84,7 +84,7 @@ func TestKVCacheSplitPreservesData(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } snap := c.Snapshot(0) @@ -98,8 +98,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) { if !c2.Restore(parent, 5) { t.Fatal("Restore(parent) failed") } - if c2.Offset() != 5 { - t.Fatalf("offset after parent = %d, want 5", c2.Offset()) + if int(c2.Offsets()[0]) != 5 { + t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0])) } state := c2.State() if state[0].Dim(2) != 5 { @@ -110,8 +110,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) { if !c2.Restore(child, 10) { t.Fatal("Restore(child) failed") } - if c2.Offset() != 10 { - t.Fatalf("offset after child = %d, want 10", c2.Offset()) + if int(c2.Offsets()[0]) != 10 { + t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0])) } state = c2.State() if state[0].Dim(2) != 10 { @@ -128,7 +128,7 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } snap := c.Snapshot(0) @@ -142,8 +142,8 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) { if !c2.Restore(merged, 10) { t.Fatal("Restore(merged) failed") } - if c2.Offset() != 10 { - t.Fatalf("offset = %d, want 10", c2.Offset()) + if int(c2.Offsets()[0]) != 10 { + t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0])) } state := c2.State() @@ -163,7 +163,7 @@ func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } // Offset 3 is outside the window. @@ -182,7 +182,7 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) { for range 10 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } snap := c.Snapshot(0) @@ -194,15 +194,15 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) { for range 5 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } // Restore to offset 10. if !c.Restore(snap, 10) { t.Fatal("Restore failed") } - if c.Offset() != 10 { - t.Fatalf("offset = %d, want 10", c.Offset()) + if int(c.Offsets()[0]) != 10 { + t.Fatalf("offset = %d, want 10", int(c.Offsets()[0])) } state := c.State() @@ -228,10 +228,10 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) { for range 6 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } - if c.Offset() != 6 { - t.Fatalf("offset = %d, want 6", c.Offset()) + if int(c.Offsets()[0]) != 6 { + t.Fatalf("offset = %d, want 6", int(c.Offsets()[0])) } snap := c.Snapshot(0) @@ -240,25 +240,25 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) { for range 3 { k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) } // Restore to snapshot state. if !c.Restore(snap, 6) { t.Fatal("Restore failed") } - if c.Offset() != 6 { - t.Fatalf("offset after restore = %d, want 6", c.Offset()) + if int(c.Offsets()[0]) != 6 { + t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0])) } // Feed one more token. If idx was restored correctly, this should // produce a valid window of size 4 at offset 7. k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8) - c.Update(k, v) + c.Update(nil, k, v) - if c.Offset() != 7 { - t.Fatalf("offset after post-restore update = %d, want 7", c.Offset()) + if int(c.Offsets()[0]) != 7 { + t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0])) } state := c.State() if len(state) != 2 { diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 0f0016539..357bf9959 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -1,6 +1,9 @@ package cache -import "github.com/ollama/ollama/x/mlxrunner/mlx" +import ( + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) // RecurrentCache stores state for linear-recurrent layers. // @@ -87,8 +90,8 @@ func (c *RecurrentCache) Advance(n int) { c.offset += n } -func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { - return keys, values +func (c *RecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { + return keys, values, mlx.KVHistory{} } func (c *RecurrentCache) State() []*mlx.Array { @@ -162,4 +165,4 @@ func (c *RecurrentCache) Free() { c.offset = 0 } -func (c *RecurrentCache) Offset() int { return c.offset } +func (c *RecurrentCache) Offsets() []int32 { return []int32{int32(c.offset)} } diff --git a/x/mlxrunner/cache/recurrent_test.go b/x/mlxrunner/cache/recurrent_test.go index ef8b7f7a3..9c5428b2b 100644 --- a/x/mlxrunner/cache/recurrent_test.go +++ b/x/mlxrunner/cache/recurrent_test.go @@ -34,7 +34,7 @@ func TestRecurrentCacheRestoreExactOffset(t *testing.T) { if !c.Restore(snap, 10) { t.Fatal("Restore(snap, 10) should succeed — target == snap.offset") } - if c.Offset() != 10 { - t.Fatalf("offset = %d, want 10", c.Offset()) + if int(c.Offsets()[0]) != 10 { + t.Fatalf("offset = %d, want 10", int(c.Offsets()[0])) } } diff --git a/x/mlxrunner/cache_test.go b/x/mlxrunner/cache_test.go index fba0d4fbd..68d069f19 100644 --- a/x/mlxrunner/cache_test.go +++ b/x/mlxrunner/cache_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -52,11 +53,11 @@ func (c *fakeRewindableCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } -func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { - return nil, nil +func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { + return nil, nil, mlx.KVHistory{} } func (c *fakeRewindableCache) State() []*mlx.Array { return nil } -func (c *fakeRewindableCache) Offset() int { return len(c.tokens) } +func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} } func (c *fakeRewindableCache) Free() { c.tokens = nil @@ -172,11 +173,11 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } -func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { - return nil, nil +func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { + return nil, nil, mlx.KVHistory{} } func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil } -func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) } +func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} } func (c *fakeSlidingWindowCache) Free() { c.tokens = nil @@ -252,11 +253,11 @@ func (c *fakeRecurrentCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } -func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { - return nil, nil +func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) { + return nil, nil, mlx.KVHistory{} } func (c *fakeRecurrentCache) State() []*mlx.Array { return nil } -func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) } +func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} } func (c *fakeRecurrentCache) Free() { c.tokens = nil @@ -366,9 +367,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) for i, c := range e.caches { assertTokens(t, label, c, expected) // Verify all caches report the same offset. - if i > 0 && c.Offset() != e.caches[0].Offset() { + if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) { t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", - label, i, c.Offset(), e.caches[0].Offset()) + label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0])) } } } @@ -451,9 +452,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) { if len(kvc.caches) < 2 { return } - expected := kvc.caches[0].Offset() + expected := int(kvc.caches[0].Offsets()[0]) for i := 1; i < len(kvc.caches); i++ { - if got := kvc.caches[i].Offset(); got != expected { + if got := int(kvc.caches[i].Offsets()[0]); got != expected { t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected) } } diff --git a/x/mlxrunner/mlx/sdpa.go b/x/mlxrunner/mlx/sdpa.go new file mode 100644 index 000000000..aa922ac6a --- /dev/null +++ b/x/mlxrunner/mlx/sdpa.go @@ -0,0 +1,13 @@ +package mlx + +// KVHistory carries sequence metadata alongside K/V buffers for SDPA. +// Page table and seq lens travel together — SDPA always needs both. +type KVHistory struct { + // PageTable maps (seqIdx, position) → slot index in the K/V buffer. + // Shape: [numSeqs, maxSeqLen], int32. Unused entries are 0. + PageTable *Array + + // SeqLens is the history length per sequence (number of valid + // entries in each row of PageTable). + SeqLens []int +} diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 0da2355fc..f297810ba 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -494,13 +494,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b offset := 0 if c != nil { - offset = c.Offset() + offset = int(c.Offsets()[0]) } q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) if c != nil { - k, v = c.Update(k, v) + k, v, _ = c.Update(nil, k, v) } // MLX SDPA supports grouped-query attention directly (Q heads can be a diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 8732bc418..9d9c7e3da 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -112,7 +112,7 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con offset := 0 if c != nil { - offset = c.Offset() + offset = int(c.Offsets()[0]) } qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) @@ -124,7 +124,7 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con cachedL := L if c != nil { placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0}) - keys, _ = c.Update(keys, placeholderValues) + keys, _, _ = c.Update(nil, keys, placeholderValues) cachedL = int32(keys.Dim(2)) } diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index 22a65d452..c1d53cb44 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -298,13 +298,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config offset := 0 if c != nil { - offset = c.Offset() + offset = int(c.Offsets()[0]) } q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) if c != nil { - k, v = c.Update(k, v) + k, v, _ = c.Update(nil, k, v) } // MLX SDPA supports grouped-query attention directly (Q heads can be a diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 60a5ca501..022eef4ce 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -317,13 +317,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config offset := 0 if c != nil { - offset = c.Offset() + offset = int(c.Offsets()[0]) } q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) if c != nil { - k, v = c.Update(k, v) + k, v, _ = c.Update(nil, k, v) } // MLX SDPA supports grouped-query attention directly (Q heads can be a diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index 30740588f..9b08fc029 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1148,13 +1148,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co offset := 0 if c != nil { - offset = c.Offset() + offset = int(c.Offsets()[0]) } q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) if c != nil { - k, v = c.Update(k, v) + k, v, _ = c.Update(nil, k, v) } out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)