mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

13
x/mlxrunner/mlx/sdpa.go Normal file
View File

@@ -0,0 +1,13 @@
package mlx
// KVHistory carries sequence metadata alongside K/V buffers for SDPA.
// Page table and seq lens travel together — SDPA always needs both.
type KVHistory struct {
// PageTable maps (seqIdx, position) → slot index in the K/V buffer.
// Shape: [numSeqs, maxSeqLen], int32. Unused entries are 0.
PageTable *Array
// SeqLens is the history length per sequence (number of valid
// entries in each row of PageTable).
SeqLens []int
}