mlxrunner: support partial match on pure transformer caches

Previously, a partial match within a node's edge would truncate the path
to the parent snapshot - effectively making all cache types behave as
recurrent caches. Caches with only transformer layers can rewind to
arbitrary boundary so this restores this capability to improve cache
hits
This commit is contained in:
Jesse Gross
2026-03-19 11:20:50 -07:00
parent b166b36cd2
commit 77491439c2
5 changed files with 140 additions and 99 deletions

View File

@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
}
// Check for partial match within a node's edge — truncate path
// to the parent boundary. snapshot() will split the node and
// create the branch point during prefill when caches are ready.
partialMatch := false
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath = matchPath[:len(matchPath)-1]
partialMatch = true
}
}
// Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath)
c.switchToPath(matchPath, matched)
// switchToPath aligns caches to a common offset
prefix := c.minCacheOffset()
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
var snapshotAt int
if partialMatch || (prefix == 0 && matched > 0) {
if prefix < matched {
snapshotAt = matched
}
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode) {
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
defer c.enforceEvictionPolicy()
// Find common ancestor index.
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary.
if leaf := len(c.activePath) - 1; leaf >= commonLen {
leaf := len(c.activePath) - 1
leafDiverges := leaf >= commonLen
leafNeedsRewind := matched < c.activePath[leaf].endOffset
if leafDiverges || leafNeedsRewind {
node := c.activePath[leaf]
if !node.hasAllSnapshots() {
fromOffset := node.startOffset()
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
}
}
// Rewind each cache to the ancestor offset or free it. Freed
// caches (e.g. RecurrentCache that can't rewind) will be restored
// from snapshots during page-in.
// Rewind each cache to the target offset or free it. When matched
// falls within the ancestor's range (same-path case), we rewind
// directly to the match point. Otherwise we rewind to the ancestor
// and let page-in bring us forward to matched.
rewindTarget := min(ancestorOffset, matched)
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.Restore(nil, ancestorOffset) {
if !kv.Restore(nil, rewindTarget) {
kv.Free()
}
}
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check.
pageIn:
for _, node := range newPath {
if len(node.snapshots) == 0 {
if !node.hasSnapshots() {
continue
}
nodeTarget := min(node.endOffset, matched)
for j, kv := range c.caches {
if kv == nil {
continue
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= node.endOffset {
if kv.Offset() >= nodeTarget {
continue
}
if !kv.Restore(node.snapshots[j], node.endOffset) {
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
c.freeAll()
c.activePath = []*trieNode{c.root}
return
if !kv.Restore(node.snapshots[j], nodeTarget) {
// Restore failed — stop page-in and let alignment
// bring all caches to a consistent offset.
break pageIn
}
}
if node.endOffset > ancestorOffset {
pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
}
}

View File

@@ -17,7 +17,8 @@ type Cache interface {
Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state.
// using the cache's own live state. Returns false if the target is
// unreachable (e.g. target > current offset, or negative).
Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
}
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
// Rewind using live state — just clamp offset.
target = max(0, min(target, c.offset))
if target > c.offset {
return false
}
c.offset = target
return true
}
snap := snapshot.(*kvSnapshot)
// Check that the cache has data up to the snapshot's starting point.
if c.offset < snap.fromOffset {
if target > snap.toOffset || c.offset < snap.fromOffset {
return false
}
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
}
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target >= c.offset {
return target == c.offset
}
// Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to —
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if c.offset > c.maxSize {
return false
}
target = max(0, min(target, c.offset))
c.offset = target
c.idx = target
return true
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*rotatingSnapshot)
if target > snap.toOffset {
return false
}
// Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize {
return false
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
// Clamp to target if needed.
if target < c.offset {
target = max(0, target)
c.offset = target
c.idx = target
}

View File

@@ -150,10 +150,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*recurrentSnapshot)
// Recurrent state encodes all tokens up to snap.offset. Restoring
// to a target before that would leave stale state from tokens
// [target, snap.offset) baked in. Only allow restoring forward.
if target < snap.offset {
// Recurrent snapshots encode cumulative state up to exactly
// snap.offset. Target must match — rewinding would leave stale
// state, and advancing isn't possible without feeding tokens.
if target != snap.offset {
return false
}

View File

@@ -6,39 +6,35 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
// allows restoring forward (target >= snapshot offset), never backward.
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
// only succeeds when target exactly matches the snapshot's offset. Recurrent
// state is cumulative, so it can't be rewound or fast-forwarded.
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8)
_ = c.ConvState(1, mlx.DTypeFloat16)
_ = c.DeltaState(1, mlx.DTypeFloat16)
c.Advance(10)
snap := c.Snapshot(0)
snap := c.Snapshot(0) // snap.offset == 10
c.Advance(5) // now at 15
c.Advance(5) // cache now at 15
// Restore backward should fail.
// target < snap.offset: fails (can't rewind past snapshot)
if c.Restore(snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
}
// Restore to exact snap offset should succeed.
// target > snap.offset: fails (can't advance without feeding tokens)
if c.Restore(snap, 15) {
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
}
// target == snap.offset: succeeds
if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed")
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
}
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
snap2 := c.Snapshot(0)
if !c.Restore(snap2, 15) {
t.Fatal("Restore(snap, 15) should succeed")
}
// Recurrent state is at snap.offset (10), not target (15).
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
}
}

View File

@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
}
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
// Rewind live state.
if target < 0 {
target = 0
}
if target > len(c.tokens) {
target = len(c.tokens)
return false
}
c.tokens = c.tokens[:target]
return true
}
s := snapshot.(*fakeSnapshot)
if len(c.tokens) < s.from {
return false // don't have base data up to snapshot start
if target > s.to || len(c.tokens) < s.from {
return false
}
c.tokens = append(c.tokens[:s.from], s.tokens...)
if target < len(c.tokens) {
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
}
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target == len(c.tokens) {
return true
if target >= len(c.tokens) {
return target == len(c.tokens)
}
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
if len(c.tokens) > c.maxSize {
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
return true
}
s := snapshot.(*fakeSnapshot)
if target > s.to {
return false
}
// Reject if clamping would leave an incomplete window
// (matches RotatingKVCache behavior).
if target < s.to && s.to > c.maxSize {
return false
}
c.tokens = slices.Clone(s.tokens)
if target < len(c.tokens) {
c.tokens = c.tokens[:target]
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
return target == len(c.tokens) // can only no-op
}
s := snapshot.(*fakeSnapshot)
if target < s.to {
return false // can't go backward
if target != s.to {
return false // cumulative state requires exact match
}
c.tokens = slices.Clone(s.tokens)
return true
@@ -294,9 +306,10 @@ type feedableCache interface {
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
type testEnv struct {
kvc *kvCache
caches []cache.Cache // typed references for assertions
tracker *snapshotTracker
kvc *kvCache
caches []cache.Cache // typed references for assertions
tracker *snapshotTracker
rewindable bool // true when all caches support arbitrary Restore(nil, target)
}
// newTransformerEnv creates a test environment with a single rewindable cache
@@ -305,23 +318,28 @@ func newTransformerEnv() *testEnv {
tracker := &snapshotTracker{}
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tracker,
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tracker,
rewindable: true,
}
}
// newSlidingWindowEnv creates a test environment with one rewindable cache and
// one sliding window cache (Mistral-style architecture).
// one sliding window cache (Mistral-style architecture). The sliding window
// maxSize is set small enough that test sequences fill it, making
// Restore(nil, target) fail — the same behavior as production models where
// the window fills after a few turns.
func newSlidingWindowEnv() *testEnv {
tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
caches := []cache.Cache{rc, sw}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
rewindable: false,
}
}
@@ -333,9 +351,10 @@ func newRecurrentEnv() *testEnv {
nrc := &fakeRecurrentCache{tracker: tr}
caches := []cache.Cache{rc, nrc}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
rewindable: false,
}
}
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
}
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
// Partial match in A's edge triggers snapshotOffset.
// For rewindable caches, switchToPath rewinds to the match point
// so only the non-matching suffix needs evaluation. For non-rewindable
// caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
}
// Cache was rewound to 0 (partial match truncates path to root),
// so all tokens were re-evaluated.
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
if env.rewindable {
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
}
if len(resB.remaining) != 3 {
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
}
} else {
if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
}
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
}
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
@@ -635,14 +663,24 @@ func TestExactMatchSeedBehavior(t *testing.T) {
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: identical prompt. Holdback means matched=4, partial in
// the 5-token edge, so path truncates to root and all tokens are
// re-evaluated. snapshotOffset should be set at the holdback point.
// the 5-token edge. For rewindable caches, switchToPath rewinds to
// offset 4, so only the held-back token needs re-evaluation. For
// non-rewindable caches, the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
}
if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
if env.rewindable {
if len(resB.remaining) != 1 {
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
}
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
}
} else {
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
}
if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
}
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})