// cache.go manages a shared KV cache across conversations using a compressed // prefix trie. Each trie node stores a token sequence (edge) and optional // per-layer snapshots that can be paged in/out of the live MLX cache arrays. // // Key properties: // - Only one path through the trie is "active" (backed by live MLX arrays) // at a time. Switching paths pages out the frontier node and pages in the // new path. // - Snapshots are only captured at the frontier (end) of the active path. // Intermediate node snapshots come from split prefill. // - All cache layers must stay at the same token offset. // - Sibling edges must not share a common token prefix (compressed trie // invariant). // - begin() always re-evaluates at least one token so the pipeline can seed // generation, even on a full prefix match. package mlxrunner import ( "cmp" "fmt" "log/slog" "time" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model/base" ) const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory type kvCache struct { root *trieNode // root of the prefix trie activePath []*trieNode // current root→leaf path with live MLX arrays caches []cache.Cache pagedOutBytes int64 // total bytes in paged-out snapshots across the trie } // cacheSession manages caches for a single pipeline run. // Callers should append generated tokens to outputs and // defer close to save the cache state. type cacheSession struct { cache *kvCache inputs []int32 outputs []int32 caches []cache.Cache remaining []int32 // snapshotOffset, if > 0, is a trie node boundary where we need to // capture a snapshot during prefill. This enables future requests // branching at this point to restore non-rewindable caches (e.g. // RecurrentCache) instead of re-evaluating from scratch. snapshotOffset int } func (c *kvCache) ensureCaches(m base.Model) { if len(c.caches) != 0 { return } if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { c.caches = cacheFactory.NewCaches() return } c.caches = make([]cache.Cache, m.NumLayers()) for i := range c.caches { c.caches[i] = cache.NewKVCache() } } func (c *kvCache) ensureRoot() { if c.root == nil { c.root = &trieNode{ lastUsed: time.Now(), } c.activePath = []*trieNode{c.root} } } // begin prepares caches for a new request. It finds the nearest // matching cache or creates new caches if none match. func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { c.ensureCaches(m) c.ensureRoot() matchPath, matched := findBestMatch(c.root, inputs) originalMatched := matched // Always keep at least one token to re-evaluate so the // pipeline can seed token generation from it. if matched == len(inputs) && matched > 0 { matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) } // Check for partial match within a node's edge — truncate path // to the parent boundary. snapshot() will split the node and // create the branch point during prefill when caches are ready. partialMatch := false if len(matchPath) > 1 { lastNode := matchPath[len(matchPath)-1] matchedInEdge := matched - lastNode.startOffset() if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { matchPath = matchPath[:len(matchPath)-1] partialMatch = true } } // Switch to the matched path, paging in/out as needed. c.switchToPath(matchPath) // switchToPath aligns caches to a common offset prefix := c.minCacheOffset() remaining := inputs[prefix:] // Schedule a snapshot at the branch point during prefill so future // requests diverging here can restore instead of re-evaluating. var snapshotAt int if partialMatch || (prefix == 0 && matched > 0) { snapshotAt = matched } args := []any{"total", len(inputs), "matched", originalMatched} args = append(args, "cached", prefix, "left", len(remaining)) if snapshotAt > 0 { args = append(args, "pending_snapshot", snapshotAt) } if prefix == 0 { slog.Info("cache miss", args...) } else { slog.Info("cache hit", args...) } return &cacheSession{ cache: c, inputs: inputs, snapshotOffset: snapshotAt, caches: c.caches, remaining: remaining, } } // switchToPath transitions from the current active path to a new path, // paging out diverging segments and paging in the new path. func (c *kvCache) switchToPath(newPath []*trieNode) { defer c.enforceEvictionPolicy() // Find common ancestor index. commonLen := 0 for commonLen < len(c.activePath) && commonLen < len(newPath) { if c.activePath[commonLen] != newPath[commonLen] { break } commonLen++ } ancestorOffset := 0 if commonLen > 0 { ancestorOffset = c.activePath[commonLen-1].endOffset } var pageOutCount, pageInCount int // Page out the leaf of the old path. Only the leaf's live cache // state is correct — intermediate nodes already have snapshots // captured during their creation (splitNode + prefill). Snapshotting // non-leaf nodes here would produce wrong results for non-rewindable // caches (e.g. RecurrentCache) whose state reflects the leaf, not // the intermediate boundary. if leaf := len(c.activePath) - 1; leaf >= commonLen { node := c.activePath[leaf] if !node.hasAllSnapshots() { fromOffset := node.startOffset() snaps := make([]cache.Snapshot, len(c.caches)) for j, kv := range c.caches { if kv == nil { continue } snaps[j] = kv.Snapshot(fromOffset) } node.setSnapshots(snaps, &c.pagedOutBytes) pageOutCount++ logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset)) } } // Rewind each cache to the ancestor offset or free it. Freed // caches (e.g. RecurrentCache that can't rewind) will be restored // from snapshots during page-in. for _, kv := range c.caches { if kv == nil { continue } if !kv.Restore(nil, ancestorOffset) { kv.Free() } } // Page in — walk the full new path, restoring from snapshots. // Freed caches naturally pick up the first available snapshot. // Caches already past a node skip it via offset check. for _, node := range newPath { if len(node.snapshots) == 0 { continue } for j, kv := range c.caches { if kv == nil { continue } if j >= len(node.snapshots) || node.snapshots[j] == nil { continue } if kv.Offset() >= node.endOffset { continue } if !kv.Restore(node.snapshots[j], node.endOffset) { slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset()) c.freeAll() c.activePath = []*trieNode{c.root} return } } if node.endOffset > ancestorOffset { pageInCount++ logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset)) } } // Align all caches to the minimum offset. c.activePath = newPath minOff := c.minCacheOffset() for _, kv := range c.caches { if kv != nil && kv.Offset() != minOff { if !kv.Restore(nil, minOff) { slog.Warn("failed to restore cache, freeing all caches", "offset", minOff) c.freeAll() break } } } for i := len(c.activePath) - 1; i >= 0; i-- { if c.activePath[i].endOffset <= minOff { c.activePath = c.activePath[:i+1] break } } if pageOutCount > 0 || pageInCount > 0 { slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount) } } // snapshot creates a snapshot at the current cache position. During prefill, // it is called at branch points (user=false) to create restore points for // future diverging requests and with user=true to mark an explicit reusable // restore point. func (s *cacheSession) snapshot(user bool) { c := s.cache cacheOffset := c.minCacheOffset() if cacheOffset <= 0 { return } // Clear pending intermediate snapshot if we've reached or passed it. if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset { s.snapshotOffset = 0 } // The last node in activePath is the frontier where caches are advancing. // cacheOffset is always >= its endOffset: begin() restores caches to this // boundary and prefill advances monotonically forward. frontier := c.activePath[len(c.activePath)-1] // If the frontier already ends at cacheOffset, just ensure it has snapshots. if frontier.endOffset == cacheOffset { if user { frontier.user = true } if !frontier.hasAllSnapshots() { s.attachSnapshots(frontier, cacheOffset) } return } if frontier.endOffset > cacheOffset { slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset) return } // Advance the trie to cacheOffset — find or create a node there. edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset] frontier = c.advancePath(frontier, edgeTokens, cacheOffset) // Attach fresh snapshots from the live caches. Always use fresh // snapshots even if the node already has some (e.g. from splitNode's // Cache.Split which may be incomplete for non-splittable caches // like RecurrentCache). if user { frontier.user = true } s.attachSnapshots(frontier, cacheOffset) } // advancePath advances the active path from the current frontier by matching // tokens against existing trie children, splitting partial matches, and // appending any remaining tokens as new nodes. Returns the new frontier. func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode { // Check if existing children already cover some or all of tokens. // tokens may span multiple trie nodes when extending a previous run's // leaf and this snapshot now overlaps that same range. matchPath, matched := findBestMatch(frontier, tokens) // matchPath[0] is frontier itself; the rest are newly traversed nodes. remaining := tokens[matched:] // Check for a partial match within the last node's edge — if so, split it. if len(matchPath) > 1 { lastNode := matchPath[len(matchPath)-1] matchedInEdge := frontier.endOffset + matched - lastNode.startOffset() if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes) } } // Append traversed nodes (excluding frontier) to the active path. c.activePath = append(c.activePath, matchPath[1:]...) dest := matchPath[len(matchPath)-1] if len(remaining) > 0 { // Drop non-user snapshots so appendTokens can extend in-place // rather than creating a new child node. if len(dest.children) == 0 && !dest.user { dest.setSnapshots(nil, &c.pagedOutBytes) } newDest := dest.appendTokens(c.root, remaining, endOffset) if newDest != dest { c.activePath = append(c.activePath, newDest) } dest = newDest } return dest } // attachSnapshots attaches cache snapshots to a trie node at the given offset. // The node must be on the active path (and thus protected from eviction; // lastUsed is updated in close()). All non-nil caches must be at the same // offset (cacheOffset); a mismatch indicates a bug in the caller. func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { c := s.cache if c.activePath[len(c.activePath)-1] != node { slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset) return } snaps := make([]cache.Snapshot, len(c.caches)) for i, kv := range c.caches { if kv != nil { if kv.Offset() != cacheOffset { panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset())) } snaps[i] = kv.Snapshot(node.startOffset()) } } node.setSnapshots(snaps, &c.pagedOutBytes) slog.Debug("created snapshot", "offset", cacheOffset) c.enforceEvictionPolicy() } // freeAll releases all cache layers. func (c *kvCache) freeAll() { for _, kv := range c.caches { if kv != nil { kv.Free() } } } func (c *kvCache) minCacheOffset() int { offset := 0 found := false for _, kv := range c.caches { if kv == nil { continue } if off := kv.Offset(); !found || off < offset { offset = off found = true } } return offset } // close saves the token state if the forward pass ran. func (s *cacheSession) close() { offset := s.cache.minCacheOffset() if offset <= 0 { return } arrays := make([]*mlx.Array, 0, 2*len(s.caches)) for _, kv := range s.caches { if kv == nil { continue } arrays = append(arrays, kv.State()...) } // Ensure that if we have run the forward pass and set the metadata // that we also actually have the data. mlx.AsyncEval(arrays...) // Advance the trie frontier with any newly generated tokens. c := s.cache if len(c.activePath) > 0 { frontier := c.activePath[len(c.activePath)-1] stored := append(s.inputs, s.outputs...) if offset > frontier.endOffset { newTokens := stored[frontier.endOffset:offset] c.advancePath(frontier, newTokens, offset) } now := time.Now() for _, node := range c.activePath { node.lastUsed = now } } } // enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits. func (c *kvCache) enforceEvictionPolicy() { if c.pagedOutBytes <= maxPagedOutBytes { return } activeSet := make(map[*trieNode]bool, len(c.activePath)) for _, n := range c.activePath { activeSet[n] = true } for c.pagedOutBytes > maxPagedOutBytes { var best *trieNode walkNodes(c.root, func(n *trieNode) bool { if n == c.root || activeSet[n] || !n.hasSnapshots() { return true } // Evict: oldest, then deepest, then largest. if best == nil || cmp.Or( n.lastUsed.Compare(best.lastUsed), cmp.Compare(best.endOffset, n.endOffset), cmp.Compare(best.snapshotBytes(), n.snapshotBytes()), ) < 0 { best = n } return true }) if best == nil { break } c.evictNode(best) } } // evictNode evicts a single node from the trie, freeing its snapshot memory. func (c *kvCache) evictNode(node *trieNode) { if len(node.children) == 0 { // Leaf: remove entirely. parent := node.parent kind := "evicting leaf" if node.user { kind = "evicting user snapshot" } slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) removeNode(node, &c.pagedOutBytes) // If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge. if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root { logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset)) mergeWithChild(parent, c.caches, &c.pagedOutBytes) } } else if len(node.children) == 1 { // Interior snapshot node with one child: merge with child. slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) mergeWithChild(node, c.caches, &c.pagedOutBytes) } else { // Multi-child branch point: drop snapshots but keep the node. slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) node.setSnapshots(nil, &c.pagedOutBytes) } } func (c *kvCache) dumpTree() { // Summary stats var cacheBytes int for _, kv := range c.caches { if kv == nil { continue } for _, a := range kv.State() { if a != nil { cacheBytes += a.NumBytes() } } } // Build active path set for marking. active := make(map[*trieNode]bool, len(c.activePath)) for _, n := range c.activePath { active[n] = true } var nodeCount, snapshotCount int var pagedBytes int64 var lines []string var dump func(n *trieNode, prefix string, isLast bool) dump = func(n *trieNode, prefix string, isLast bool) { if n == nil { return } nodeCount++ // Build connector var connector string if n.parent == nil { connector = "" } else if isLast { connector = prefix + "`-- " } else { connector = prefix + "|-- " } // Node label nodeBytes := n.snapshotBytes() pagedBytes += nodeBytes label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens)) if nodeBytes > 0 { label += " " + mlx.PrettyBytes(int(nodeBytes)).String() } var flags []string if n.user { flags = append(flags, "user") } if n.hasAllSnapshots() { snapshotCount++ flags = append(flags, "snap") } if active[n] { flags = append(flags, "active") } if len(flags) > 0 { label += " (" + flags[0] for _, f := range flags[1:] { label += ", " + f } label += ")" } lines = append(lines, connector+label) // Recurse children childPrefix := prefix if n.parent != nil { if isLast { childPrefix += " " } else { childPrefix += "| " } } for i, child := range n.children { dump(child, childPrefix, i == len(n.children)-1) } } dump(c.root, "", true) offset := c.minCacheOffset() logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d", offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount)) for i, l := range lines { if i == 0 { logutil.Trace("cache trie: " + l) } else { logutil.Trace(" " + l) } } }