mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 18:55:53 +02:00
mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory
Signature changes from Update(k, v) to Update(batch, k, v) returning (k, v, KVHistory). KVCache returns a real page table mapping positions to buffer slots. RecurrentCache returns empty KVHistory from Update. Replace Cache.Offset() with Offsets() returning per-sequence offsets. Add KVHistory type to mlx package.
This commit is contained in:
@@ -205,7 +205,7 @@ pageIn:
|
||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||
continue
|
||||
}
|
||||
if kv.Offset() >= nodeTarget {
|
||||
if int(kv.Offsets()[0]) >= nodeTarget {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||
@@ -224,7 +224,7 @@ pageIn:
|
||||
c.activePath = newPath
|
||||
minOff := c.minCacheOffset()
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil && kv.Offset() != minOff {
|
||||
if kv != nil && int(kv.Offsets()[0]) != minOff {
|
||||
if !kv.Restore(nil, minOff) {
|
||||
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
||||
c.freeAll()
|
||||
@@ -390,8 +390,8 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for i, kv := range c.caches {
|
||||
if kv != nil {
|
||||
if kv.Offset() != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
||||
if int(kv.Offsets()[0]) != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, int(kv.Offsets()[0])))
|
||||
}
|
||||
snaps[i] = kv.Snapshot(node.startOffset())
|
||||
}
|
||||
@@ -418,7 +418,7 @@ func (c *kvCache) minCacheOffset() int {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); !found || off < offset {
|
||||
if off := int(kv.Offsets()[0]); !found || off < offset {
|
||||
offset = off
|
||||
found = true
|
||||
}
|
||||
|
||||
41
x/mlxrunner/cache/cache.go
vendored
41
x/mlxrunner/cache/cache.go
vendored
@@ -2,15 +2,16 @@ package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
Update(b *batch.ForwardBatch, keys, values *mlx.Array) (newKeys, newValues *mlx.Array, kv mlx.KVHistory)
|
||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||
State() []*mlx.Array
|
||||
Free()
|
||||
Offset() int
|
||||
Offsets() []int32
|
||||
|
||||
// Snapshot copies cache state from fromOffset to current offset into
|
||||
// pinned VRAM arrays. The active cache is unchanged.
|
||||
@@ -49,7 +50,7 @@ func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
func (c *KVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
@@ -77,8 +78,17 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
pt := make([]int32, c.offset)
|
||||
for i := range pt {
|
||||
pt[i] = int32(i)
|
||||
}
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(c.offset)}),
|
||||
SeqLens: []int{c.offset},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
@@ -143,7 +153,7 @@ func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
|
||||
// Rewind to snapshot start, then feed snapshot data through Update.
|
||||
c.offset = snap.fromOffset
|
||||
c.Update(snap.keys, snap.values)
|
||||
c.Update(nil, snap.keys, snap.values)
|
||||
|
||||
// Clamp to target if needed (target may be less than full snapshot).
|
||||
if target < c.offset {
|
||||
@@ -226,7 +236,7 @@ func (c *KVCache) Free() {
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Offsets() []int32 { return []int32{int32(c.offset)} }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
@@ -240,11 +250,24 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
var k, v *mlx.Array
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
k, v = c.concat(keys, values)
|
||||
} else {
|
||||
k, v = c.update(keys, values)
|
||||
}
|
||||
|
||||
visibleLen := min(c.offset, c.maxSize)
|
||||
pt := make([]int32, visibleLen)
|
||||
for i := range visibleLen {
|
||||
pt[i] = int32(i)
|
||||
}
|
||||
|
||||
return k, v, mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(visibleLen)}),
|
||||
SeqLens: []int{visibleLen},
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
|
||||
52
x/mlxrunner/cache/cache_test.go
vendored
52
x/mlxrunner/cache/cache_test.go
vendored
@@ -20,7 +20,7 @@ func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Snapshot [5, 10).
|
||||
@@ -44,7 +44,7 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -57,8 +57,8 @@ func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
if !c2.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
|
||||
// Verify State() returns arrays with correct sequence dimension.
|
||||
@@ -84,7 +84,7 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -98,8 +98,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
if !c2.Restore(parent, 5) {
|
||||
t.Fatal("Restore(parent) failed")
|
||||
}
|
||||
if c2.Offset() != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", int(c2.Offsets()[0]))
|
||||
}
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 5 {
|
||||
@@ -110,8 +110,8 @@ func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
if !c2.Restore(child, 10) {
|
||||
t.Fatal("Restore(child) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
state = c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
@@ -128,7 +128,7 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -142,8 +142,8 @@ func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
if !c2.Restore(merged, 10) {
|
||||
t.Fatal("Restore(merged) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
if int(c2.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c2.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c2.State()
|
||||
@@ -163,7 +163,7 @@ func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Offset 3 is outside the window.
|
||||
@@ -182,7 +182,7 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -194,15 +194,15 @@ func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
for range 5 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Restore to offset 10.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
if int(c.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
@@ -228,10 +228,10 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
for range 6 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset = %d, want 6", c.Offset())
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
@@ -240,25 +240,25 @@ func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
for range 3 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// Restore to snapshot state.
|
||||
if !c.Restore(snap, 6) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", c.Offset())
|
||||
if int(c.Offsets()[0]) != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", int(c.Offsets()[0]))
|
||||
}
|
||||
|
||||
// Feed one more token. If idx was restored correctly, this should
|
||||
// produce a valid window of size 4 at offset 7.
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
c.Update(nil, k, v)
|
||||
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
|
||||
if int(c.Offsets()[0]) != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", int(c.Offsets()[0]))
|
||||
}
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
|
||||
11
x/mlxrunner/cache/recurrent.go
vendored
11
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,6 +1,9 @@
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
@@ -87,8 +90,8 @@ func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
func (c *RecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return keys, values, mlx.KVHistory{}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() []*mlx.Array {
|
||||
@@ -162,4 +165,4 @@ func (c *RecurrentCache) Free() {
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Offsets() []int32 { return []int32{int32(c.offset)} }
|
||||
|
||||
4
x/mlxrunner/cache/recurrent_test.go
vendored
4
x/mlxrunner/cache/recurrent_test.go
vendored
@@ -34,7 +34,7 @@ func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
if int(c.Offsets()[0]) != 10 {
|
||||
t.Fatalf("offset = %d, want 10", int(c.Offsets()[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -52,11 +53,11 @@ func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeRewindableCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeRewindableCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRewindableCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -172,11 +173,11 @@ func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeSlidingWindowCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeSlidingWindowCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeSlidingWindowCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -252,11 +253,11 @@ func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
func (c *fakeRecurrentCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
return nil, nil, mlx.KVHistory{}
|
||||
}
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
||||
func (c *fakeRecurrentCache) Offsets() []int32 { return []int32{int32(len(c.tokens))} }
|
||||
|
||||
func (c *fakeRecurrentCache) Free() {
|
||||
c.tokens = nil
|
||||
@@ -366,9 +367,9 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32)
|
||||
for i, c := range e.caches {
|
||||
assertTokens(t, label, c, expected)
|
||||
// Verify all caches report the same offset.
|
||||
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
||||
if i > 0 && int(c.Offsets()[0]) != int(e.caches[0].Offsets()[0]) {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||
label, i, c.Offset(), e.caches[0].Offset())
|
||||
label, i, int(c.Offsets()[0]), int(e.caches[0].Offsets()[0]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -451,9 +452,9 @@ func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||
if len(kvc.caches) < 2 {
|
||||
return
|
||||
}
|
||||
expected := kvc.caches[0].Offset()
|
||||
expected := int(kvc.caches[0].Offsets()[0])
|
||||
for i := 1; i < len(kvc.caches); i++ {
|
||||
if got := kvc.caches[i].Offset(); got != expected {
|
||||
if got := int(kvc.caches[i].Offsets()[0]); got != expected {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
13
x/mlxrunner/mlx/sdpa.go
Normal file
13
x/mlxrunner/mlx/sdpa.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package mlx
|
||||
|
||||
// KVHistory carries sequence metadata alongside K/V buffers for SDPA.
|
||||
// Page table and seq lens travel together — SDPA always needs both.
|
||||
type KVHistory struct {
|
||||
// PageTable maps (seqIdx, position) → slot index in the K/V buffer.
|
||||
// Shape: [numSeqs, maxSeqLen], int32. Unused entries are 0.
|
||||
PageTable *Array
|
||||
|
||||
// SeqLens is the history length per sequence (number of valid
|
||||
// entries in each row of PageTable).
|
||||
SeqLens []int
|
||||
}
|
||||
@@ -494,13 +494,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
|
||||
@@ -112,7 +112,7 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
@@ -124,7 +124,7 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
|
||||
cachedL := L
|
||||
if c != nil {
|
||||
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
|
||||
keys, _ = c.Update(keys, placeholderValues)
|
||||
keys, _, _ = c.Update(nil, keys, placeholderValues)
|
||||
cachedL = int32(keys.Dim(2))
|
||||
}
|
||||
|
||||
|
||||
@@ -298,13 +298,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
|
||||
@@ -317,13 +317,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
|
||||
@@ -1148,13 +1148,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
|
||||
Reference in New Issue
Block a user