// cache.go manages a shared KV cache across conversations using a compressed // prefix trie. Each trie node stores a token sequence (edge) and optional // per-layer snapshots that can be paged in/out of the live MLX cache arrays. // // Key properties: // - Only one path through the trie is "active" (backed by live MLX arrays) // at a time. Switching paths pages out the frontier node and pages in the // new path. // - Snapshots are only captured at the frontier (end) of the active path. // Intermediate node snapshots come from split prefill. // - All cache layers must stay at the same token offset. // - Sibling edges must not share a common token prefix (compressed trie // invariant). // - begin() always re-evaluates at least one token so the pipeline can seed // generation, even on a full prefix match. package mlxrunner import ( "cmp" "fmt" "log/slog" "slices" "time" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model/base" ) const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory type kvCache struct { root *trieNode // root of the prefix trie activePath []*trieNode // current root→leaf path with live MLX arrays caches []cache.Cache pagedOutBytes int64 // total bytes in paged-out snapshots across the trie } // pendingSnapshot is a snapshot scheduled to be taken during prefill. type pendingSnapshot struct { offset int user bool } // cacheSession manages caches for a single pipeline run. // Callers should append generated tokens to outputs and // defer close to save the cache state. type cacheSession struct { cache *kvCache inputs []int32 outputs []int32 caches []cache.Cache remaining []int32 // pendingSnapshots lists offsets where snapshots should be captured // during prefill, sorted by offset. Entries are consumed as the // cache advances past them. pendingSnapshots []pendingSnapshot } func (c *kvCache) ensureCaches(m base.Model) { if len(c.caches) != 0 { return } if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { c.caches = cacheFactory.NewCaches() return } c.caches = make([]cache.Cache, m.NumLayers()) for i := range c.caches { c.caches[i] = cache.NewKVCache() } } func (c *kvCache) ensureRoot() { if c.root == nil { c.root = &trieNode{ lastUsed: time.Now(), } c.activePath = []*trieNode{c.root} } } // begin prepares caches for a new request. It finds the nearest // matching cache or creates new caches if none match. func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { c.ensureCaches(m) c.ensureRoot() matchPath, matched := findBestMatch(c.root, inputs) originalMatched := matched // Always keep at least one token to re-evaluate so the // pipeline can seed token generation from it. if matched == len(inputs) && matched > 0 { matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) } // Switch to the matched path, paging in/out as needed. c.switchToPath(matchPath, matched) // switchToPath aligns caches to a common offset prefix := c.minCacheOffset() remaining := inputs[prefix:] session := &cacheSession{ cache: c, inputs: inputs, caches: c.caches, remaining: remaining, } // Schedule a snapshot at the branch point during prefill so future // requests diverging here can restore instead of re-evaluating. if prefix < matched { session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false}) } msg := "cache hit" if prefix == 0 { msg = "cache miss" } slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining)) return session } // switchToPath transitions from the current active path to a new path, // paging out diverging segments and paging in the new path. func (c *kvCache) switchToPath(newPath []*trieNode, matched int) { defer c.enforceEvictionPolicy() // Find common ancestor index. commonLen := 0 for commonLen < len(c.activePath) && commonLen < len(newPath) { if c.activePath[commonLen] != newPath[commonLen] { break } commonLen++ } ancestorOffset := 0 if commonLen > 0 { ancestorOffset = c.activePath[commonLen-1].endOffset } var pageOutCount, pageInCount int // Page out the leaf of the old path. Only the leaf's live cache // state is correct — intermediate nodes already have snapshots // captured during their creation (splitNode + prefill). Snapshotting // non-leaf nodes here would produce wrong results for non-rewindable // caches (e.g. RecurrentCache) whose state reflects the leaf, not // the intermediate boundary. leaf := len(c.activePath) - 1 leafDiverges := leaf >= commonLen leafNeedsRewind := matched < c.activePath[leaf].endOffset if leafDiverges || leafNeedsRewind { node := c.activePath[leaf] if !node.hasAllSnapshots() { fromOffset := node.startOffset() snaps := make([]cache.Snapshot, len(c.caches)) for j, kv := range c.caches { if kv == nil { continue } snaps[j] = kv.Snapshot(fromOffset) } node.setSnapshots(snaps, &c.pagedOutBytes) pageOutCount++ logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset)) } } // Rewind each cache to the target offset or free it. When matched // falls within the ancestor's range (same-path case), we rewind // directly to the match point. Otherwise we rewind to the ancestor // and let page-in bring us forward to matched. rewindTarget := min(ancestorOffset, matched) for _, kv := range c.caches { if kv == nil { continue } if !kv.Restore(nil, rewindTarget) { kv.Free() } } // Page in — walk the full new path, restoring from snapshots. // Freed caches naturally pick up the first available snapshot. // Caches already past a node skip it via offset check. pageIn: for _, node := range newPath { if !node.hasSnapshots() { continue } nodeTarget := min(node.endOffset, matched) for j, kv := range c.caches { if kv == nil { continue } if j >= len(node.snapshots) || node.snapshots[j] == nil { continue } if kv.Offset() >= nodeTarget { continue } if !kv.Restore(node.snapshots[j], nodeTarget) { // Restore failed — stop page-in and let alignment // bring all caches to a consistent offset. break pageIn } } if node.endOffset > ancestorOffset { pageInCount++ logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget)) } } // Align all caches to the minimum offset. c.activePath = newPath minOff := c.minCacheOffset() for _, kv := range c.caches { if kv != nil && kv.Offset() != minOff { if !kv.Restore(nil, minOff) { slog.Warn("failed to restore cache, freeing all caches", "offset", minOff) c.freeAll() break } } } for i := len(c.activePath) - 1; i >= 0; i-- { if c.activePath[i].endOffset <= minOff { c.activePath = c.activePath[:i+1] break } } // Update last-used time on only the final used node. For recurrent // caches we don't need the intermediate snapshots and for KV caches // we can reslice the data out of merged edges. if len(c.activePath) > 0 { c.activePath[len(c.activePath)-1].lastUsed = time.Now() } if pageOutCount > 0 || pageInCount > 0 { slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount) } } // requestSnapshot schedules a user snapshot at the given absolute token // offset. The snapshot will be captured during prefill when the cache // reaches this offset. func (s *cacheSession) requestSnapshot(offset int) { baseOffset := len(s.inputs) - len(s.remaining) if offset <= baseOffset || offset > len(s.inputs) { return } // Deduplicate: if this offset already exists, upgrade to user. for i := range s.pendingSnapshots { if s.pendingSnapshots[i].offset == offset { s.pendingSnapshots[i].user = true return } } s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true}) slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int { return a.offset - b.offset }) } // nextPendingSnapshot returns the offset of the next pending snapshot, // or 0 if there are none. func (s *cacheSession) nextPendingSnapshot() int { if len(s.pendingSnapshots) == 0 { return 0 } return s.pendingSnapshots[0].offset } // snapshot creates a snapshot at the current cache position. It determines // whether this is a user snapshot by consuming pending entries whose offset // has been reached. func (s *cacheSession) snapshot() { c := s.cache cacheOffset := c.minCacheOffset() if cacheOffset <= 0 { return } // Consume pending snapshots up to the current offset and derive // the user flag from them. user := false for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset { if s.pendingSnapshots[0].user { user = true } s.pendingSnapshots = s.pendingSnapshots[1:] } // The last node in activePath is the frontier where caches are advancing. // cacheOffset is always >= its endOffset: begin() restores caches to this // boundary and prefill advances monotonically forward. frontier := c.activePath[len(c.activePath)-1] // If the frontier already ends at cacheOffset, just ensure it has snapshots. if frontier.endOffset == cacheOffset { if user { frontier.user = true } if !frontier.hasAllSnapshots() { s.attachSnapshots(frontier, cacheOffset) } return } if frontier.endOffset > cacheOffset { slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset) return } // Advance the trie to cacheOffset — find or create a node there. edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset] frontier = c.advancePath(frontier, edgeTokens, cacheOffset) // Attach fresh snapshots from the live caches. Always use fresh // snapshots even if the node already has some (e.g. from splitNode's // Cache.Split which may be incomplete for non-splittable caches // like RecurrentCache). if user { frontier.user = true } s.attachSnapshots(frontier, cacheOffset) } // advancePath advances the active path from the current frontier by matching // tokens against existing trie children, splitting partial matches, and // appending any remaining tokens as new nodes. Returns the new frontier. func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode { // Check if existing children already cover some or all of tokens. // tokens may span multiple trie nodes when extending a previous run's // leaf and this snapshot now overlaps that same range. matchPath, matched := findBestMatch(frontier, tokens) // matchPath[0] is frontier itself; the rest are newly traversed nodes. remaining := tokens[matched:] // Check for a partial match within the last node's edge — if so, split it. if len(matchPath) > 1 { lastNode := matchPath[len(matchPath)-1] matchedInEdge := frontier.endOffset + matched - lastNode.startOffset() if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes) } } // Append traversed nodes (excluding frontier) to the active path. c.activePath = append(c.activePath, matchPath[1:]...) dest := matchPath[len(matchPath)-1] if len(remaining) > 0 { // Drop non-user snapshots so appendTokens can extend in-place // rather than creating a new child node. if len(dest.children) == 0 && !dest.user { dest.setSnapshots(nil, &c.pagedOutBytes) } newDest := dest.appendTokens(c.root, remaining, endOffset) if newDest != dest { c.activePath = append(c.activePath, newDest) } dest = newDest } return dest } // attachSnapshots attaches cache snapshots to a trie node at the given offset. // The node must be on the active path (and thus protected from eviction; // lastUsed is updated in close()). All non-nil caches must be at the same // offset (cacheOffset); a mismatch indicates a bug in the caller. func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { c := s.cache if c.activePath[len(c.activePath)-1] != node { slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset) return } snaps := make([]cache.Snapshot, len(c.caches)) for i, kv := range c.caches { if kv != nil { if kv.Offset() != cacheOffset { panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset())) } snaps[i] = kv.Snapshot(node.startOffset()) } } node.setSnapshots(snaps, &c.pagedOutBytes) node.lastUsed = time.Now() slog.Debug("created snapshot", "offset", cacheOffset) c.enforceEvictionPolicy() } // freeAll releases all cache layers. func (c *kvCache) freeAll() { for _, kv := range c.caches { if kv != nil { kv.Free() } } } func (c *kvCache) minCacheOffset() int { offset := 0 found := false for _, kv := range c.caches { if kv == nil { continue } if off := kv.Offset(); !found || off < offset { offset = off found = true } } return offset } // close saves the token state if the forward pass ran. func (s *cacheSession) close() { offset := s.cache.minCacheOffset() if offset <= 0 { return } arrays := make([]*mlx.Array, 0, 2*len(s.caches)) for _, kv := range s.caches { if kv == nil { continue } arrays = append(arrays, kv.State()...) } // Ensure that if we have run the forward pass and set the metadata // that we also actually have the data. mlx.AsyncEval(arrays...) // Advance the trie frontier with any newly generated tokens. c := s.cache if len(c.activePath) > 0 { frontier := c.activePath[len(c.activePath)-1] stored := append(s.inputs, s.outputs...) if offset > frontier.endOffset { newTokens := stored[frontier.endOffset:offset] c.advancePath(frontier, newTokens, offset) } c.activePath[len(c.activePath)-1].lastUsed = time.Now() } } // enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits. func (c *kvCache) enforceEvictionPolicy() { if c.pagedOutBytes <= maxPagedOutBytes { return } activeSet := make(map[*trieNode]bool, len(c.activePath)) for _, n := range c.activePath { activeSet[n] = true } for c.pagedOutBytes > maxPagedOutBytes { var best *trieNode walkNodes(c.root, func(n *trieNode) bool { if n == c.root || activeSet[n] || len(n.children) > 1 { return true } // Evict: oldest, then deepest, then largest. if best == nil || cmp.Or( n.lastUsed.Compare(best.lastUsed), cmp.Compare(best.endOffset, n.endOffset), cmp.Compare(best.snapshotBytes(), n.snapshotBytes()), ) < 0 { best = n } return true }) if best == nil { break } c.evictNode(best) } } // evictNode evicts a single node from the trie, freeing its snapshot memory. func (c *kvCache) evictNode(node *trieNode) { if len(node.children) == 0 { // Leaf: remove entirely. slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes()))) removeNode(node, &c.pagedOutBytes) } else if len(node.children) == 1 { // Interior node with one child: merge with child. before := c.pagedOutBytes tokens := len(node.tokens) mergeWithChild(node, c.caches, &c.pagedOutBytes) slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes))) } else { panic("evictNode called on multi-child branch point") } } func (c *kvCache) dumpTree() { // Summary stats var cacheBytes int for _, kv := range c.caches { if kv == nil { continue } for _, a := range kv.State() { if a != nil { cacheBytes += a.NumBytes() } } } // Build active path set for marking. active := make(map[*trieNode]bool, len(c.activePath)) for _, n := range c.activePath { active[n] = true } var nodeCount, snapshotCount int var pagedBytes int64 var lines []string var dump func(n *trieNode, prefix string, isLast bool) dump = func(n *trieNode, prefix string, isLast bool) { if n == nil { return } nodeCount++ // Build connector var connector string if n.parent == nil { connector = "" } else if isLast { connector = prefix + "`-- " } else { connector = prefix + "|-- " } // Node label nodeBytes := n.snapshotBytes() pagedBytes += nodeBytes label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens)) if nodeBytes > 0 { label += " " + mlx.PrettyBytes(int(nodeBytes)).String() } if !n.lastUsed.IsZero() { label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond)) } var flags []string if n.user { flags = append(flags, "user") } if n.hasAllSnapshots() { snapshotCount++ flags = append(flags, "snap") } if active[n] { flags = append(flags, "active") } if len(flags) > 0 { label += " (" + flags[0] for _, f := range flags[1:] { label += ", " + f } label += ")" } lines = append(lines, connector+label) // Recurse children childPrefix := prefix if n.parent != nil { if isLast { childPrefix += " " } else { childPrefix += "| " } } for i, child := range n.children { dump(child, childPrefix, i == len(n.children)-1) } } dump(c.root, "", true) offset := c.minCacheOffset() logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d", offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount)) for i, l := range lines { if i == 0 { logutil.Trace("cache trie: " + l) } else { logutil.Trace(" " + l) } } }