mlxrunner: Cache.Update takes ForwardBatch and returns KVHistory

Signature changes from Update(k, v) to Update(batch, k, v) returning
(k, v, KVHistory). KVCache returns a real page table mapping positions
to buffer slots. RecurrentCache returns empty KVHistory from Update.

Replace Cache.Offset() with Offsets() returning per-sequence offsets.
Add KVHistory type to mlx package.
This commit is contained in:
Jesse Gross
2026-04-02 12:05:35 -07:00
parent 987f74c8a5
commit b7b2aa5d4e
12 changed files with 109 additions and 69 deletions

View File

@@ -317,13 +317,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
offset := 0
if c != nil {
offset = c.Offset()
offset = int(c.Offsets()[0])
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
k, v, _ = c.Update(nil, k, v)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a