mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 09:46:01 +02:00
mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory
Signature changes from Update(k, v) to Update(batch, k, v) returning (k, v, KVHistory). KVCache returns a real page table mapping positions to buffer slots. RecurrentCache returns empty KVHistory from Update. Replace Cache.Offset() with Offsets() returning per-sequence offsets. Add KVHistory type to mlx package.
This commit is contained in:
41
x/mlxrunner/cache/cache.go
vendored
41
x/mlxrunner/cache/cache.go
vendored
@@ -2,15 +2,16 @@ package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
Update(b *batch.ForwardBatch, keys, values *mlx.Array) (newKeys, newValues *mlx.Array, kv mlx.KVHistory)
|
||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||
State() []*mlx.Array
|
||||
Free()
|
||||
Offset() int
|
||||
Offsets() []int32
|
||||
|
||||
// Snapshot copies cache state from fromOffset to current offset into
|
||||
// pinned VRAM arrays. The active cache is unchanged.
|
||||
@@ -49,7 +50,7 @@ func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
func (c *KVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
@@ -77,8 +78,17 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
pt := make([]int32, c.offset)
|
||||
for i := range pt {
|
||||
pt[i] = int32(i)
|
||||
}
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(c.offset)}),
|
||||
SeqLens: []int{c.offset},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
@@ -143,7 +153,7 @@ func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
|
||||
// Rewind to snapshot start, then feed snapshot data through Update.
|
||||
c.offset = snap.fromOffset
|
||||
c.Update(snap.keys, snap.values)
|
||||
c.Update(nil, snap.keys, snap.values)
|
||||
|
||||
// Clamp to target if needed (target may be less than full snapshot).
|
||||
if target < c.offset {
|
||||
@@ -226,7 +236,7 @@ func (c *KVCache) Free() {
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Offsets() []int32 { return []int32{int32(c.offset)} }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
@@ -240,11 +250,24 @@ func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) Update(_ *batch.ForwardBatch, keys, values *mlx.Array) (*mlx.Array, *mlx.Array, mlx.KVHistory) {
|
||||
var k, v *mlx.Array
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
k, v = c.concat(keys, values)
|
||||
} else {
|
||||
k, v = c.update(keys, values)
|
||||
}
|
||||
|
||||
visibleLen := min(c.offset, c.maxSize)
|
||||
pt := make([]int32, visibleLen)
|
||||
for i := range visibleLen {
|
||||
pt[i] = int32(i)
|
||||
}
|
||||
|
||||
return k, v, mlx.KVHistory{
|
||||
PageTable: mlx.NewArrayInt32(pt, []int32{1, int32(visibleLen)}),
|
||||
SeqLens: []int{visibleLen},
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
|
||||
Reference in New Issue
Block a user