diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 1ef52c5df..faabcc51f 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -238,6 +238,13 @@ pageIn: } } + // Update last-used time on only the final used node. For recurrent + // caches we don't need the intermediate snapshots and for KV caches + // we can reslice the data out of merged edges. + if len(c.activePath) > 0 { + c.activePath[len(c.activePath)-1].lastUsed = time.Now() + } + if pageOutCount > 0 || pageInCount > 0 { slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount) } @@ -355,6 +362,7 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { } } node.setSnapshots(snaps, &c.pagedOutBytes) + node.lastUsed = time.Now() slog.Debug("created snapshot", "offset", cacheOffset) c.enforceEvictionPolicy() } @@ -412,10 +420,7 @@ func (s *cacheSession) close() { newTokens := stored[frontier.endOffset:offset] c.advancePath(frontier, newTokens, offset) } - now := time.Now() - for _, node := range c.activePath { - node.lastUsed = now - } + c.activePath[len(c.activePath)-1].lastUsed = time.Now() } } @@ -433,7 +438,7 @@ func (c *kvCache) enforceEvictionPolicy() { for c.pagedOutBytes > maxPagedOutBytes { var best *trieNode walkNodes(c.root, func(n *trieNode) bool { - if n == c.root || activeSet[n] || !n.hasSnapshots() { + if n == c.root || activeSet[n] || len(n.children) > 1 { return true } // Evict: oldest, then deepest, then largest. @@ -457,27 +462,16 @@ func (c *kvCache) enforceEvictionPolicy() { func (c *kvCache) evictNode(node *trieNode) { if len(node.children) == 0 { // Leaf: remove entirely. - parent := node.parent - kind := "evicting leaf" - if node.user { - kind = "evicting user snapshot" - } - slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) + slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) removeNode(node, &c.pagedOutBytes) - - // If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge. - if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root { - logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset)) - mergeWithChild(parent, c.caches, &c.pagedOutBytes) - } } else if len(node.children) == 1 { - // Interior snapshot node with one child: merge with child. - slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) + // Interior node with one child: merge with child. + before := c.pagedOutBytes + tokens := len(node.tokens) mergeWithChild(node, c.caches, &c.pagedOutBytes) + slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes))) } else { - // Multi-child branch point: drop snapshots but keep the node. - slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) - node.setSnapshots(nil, &c.pagedOutBytes) + panic("evictNode called on multi-child branch point") } } diff --git a/x/mlxrunner/cache_test.go b/x/mlxrunner/cache_test.go index 21e2f53ce..6c691c17b 100644 --- a/x/mlxrunner/cache_test.go +++ b/x/mlxrunner/cache_test.go @@ -3,6 +3,7 @@ package mlxrunner import ( "slices" "testing" + "time" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" @@ -761,8 +762,8 @@ func TestEvictionPreservesActiveConversations(t *testing.T) { t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath)) } - // System prompt prefix should still be findable (evicting a - // multi-child branch point only drops snapshots, not the node). + // System prompt prefix should still be findable (multi-child + // branch points are protected from eviction entirely). _, matched := findBestMatch(kvc.root, systemPrompt) if matched < len(systemPrompt) { t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt)) @@ -895,3 +896,55 @@ func TestBranchSwitchRestoresCorrectState(t *testing.T) { checkTrieInvariants(t, kvc.root) }) } + +// TestLRUOnlyUpdatesUsedNodes verifies that intermediate nodes on the active +// path whose snapshots were not actually restored don't get their lastUsed +// refreshed, allowing them to age out and collapse. +func TestLRUOnlyUpdatesUsedNodes(t *testing.T) { + forEachEnv(t, func(t *testing.T, env *testEnv) { + kvc := env.kvc + + // Request A: creates path [1,2,3,4,5] + generate [10,11] + simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) + + // Request B: diverges at token 4, creating a branch point at offset 3 + // with a split snapshot. + simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20, 21}) + + // Set all lastUsed to a known old time. + oldTime := time.Now().Add(-1 * time.Hour) + walkNodes(kvc.root, func(n *trieNode) bool { + n.lastUsed = oldTime + return true + }) + + // Request C: continue on B's branch. This will match B's path + // and extend it. The branch point's snapshot may be paged in + // for some cache types but not others. + beforeRequest := time.Now() + simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7, 20, 21, 30}, nil) + + // The path must have enough depth to exercise intermediate nodes. + if len(kvc.activePath) < 3 { + t.Fatalf("activePath too short to test intermediate nodes: got %d nodes", len(kvc.activePath)) + } + + // The frontier (deepest node on the active path) must be updated. + frontier := kvc.activePath[len(kvc.activePath)-1] + if frontier.lastUsed.Before(beforeRequest) { + t.Errorf("frontier lastUsed was not updated: got %v, want >= %v", + frontier.lastUsed, beforeRequest) + } + + // Every non-frontier node on the active path (including root) + // should retain its old lastUsed — only the frontier gets refreshed. + for i, node := range kvc.activePath[:len(kvc.activePath)-1] { + if !node.lastUsed.Before(beforeRequest) { + t.Errorf("activePath[%d] (endOffset=%d) lastUsed was refreshed: got %v, want < %v", + i, node.endOffset, node.lastUsed, beforeRequest) + } + } + + checkTrieInvariants(t, kvc.root) + }) +}