Files
ollama/x/mlxrunner/cache.go
2026-03-19 16:35:08 -07:00

602 lines
17 KiB
Go

// cache.go manages a shared KV cache across conversations using a compressed
// prefix trie. Each trie node stores a token sequence (edge) and optional
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
//
// Key properties:
// - Only one path through the trie is "active" (backed by live MLX arrays)
// at a time. Switching paths pages out the frontier node and pages in the
// new path.
// - Snapshots are only captured at the frontier (end) of the active path.
// Intermediate node snapshots come from split prefill.
// - All cache layers must stay at the same token offset.
// - Sibling edges must not share a common token prefix (compressed trie
// invariant).
// - begin() always re-evaluates at least one token so the pipeline can seed
// generation, even on a full prefix match.
package mlxrunner
import (
"cmp"
"fmt"
"log/slog"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
type kvCache struct {
root *trieNode // root of the prefix trie
activePath []*trieNode // current root→leaf path with live MLX arrays
caches []cache.Cache
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
}
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
// snapshotOffset, if > 0, is a trie node boundary where we need to
// capture a snapshot during prefill. This enables future requests
// branching at this point to restore non-rewindable caches (e.g.
// RecurrentCache) instead of re-evaluating from scratch.
snapshotOffset int
}
func (c *kvCache) ensureCaches(m base.Model) {
if len(c.caches) != 0 {
return
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
}
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
func (c *kvCache) ensureRoot() {
if c.root == nil {
c.root = &trieNode{
lastUsed: time.Now(),
}
c.activePath = []*trieNode{c.root}
}
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
c.ensureCaches(m)
c.ensureRoot()
matchPath, matched := findBestMatch(c.root, inputs)
originalMatched := matched
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if matched == len(inputs) && matched > 0 {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
}
// Check for partial match within a node's edge — truncate path
// to the parent boundary. snapshot() will split the node and
// create the branch point during prefill when caches are ready.
partialMatch := false
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath = matchPath[:len(matchPath)-1]
partialMatch = true
}
}
// Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath)
// switchToPath aligns caches to a common offset
prefix := c.minCacheOffset()
remaining := inputs[prefix:]
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
var snapshotAt int
if partialMatch || (prefix == 0 && matched > 0) {
snapshotAt = matched
}
args := []any{"total", len(inputs), "matched", originalMatched}
args = append(args, "cached", prefix, "left", len(remaining))
if snapshotAt > 0 {
args = append(args, "pending_snapshot", snapshotAt)
}
if prefix == 0 {
slog.Info("cache miss", args...)
} else {
slog.Info("cache hit", args...)
}
return &cacheSession{
cache: c,
inputs: inputs,
snapshotOffset: snapshotAt,
caches: c.caches,
remaining: remaining,
}
}
// switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode) {
defer c.enforceEvictionPolicy()
// Find common ancestor index.
commonLen := 0
for commonLen < len(c.activePath) && commonLen < len(newPath) {
if c.activePath[commonLen] != newPath[commonLen] {
break
}
commonLen++
}
ancestorOffset := 0
if commonLen > 0 {
ancestorOffset = c.activePath[commonLen-1].endOffset
}
var pageOutCount, pageInCount int
// Page out the leaf of the old path. Only the leaf's live cache
// state is correct — intermediate nodes already have snapshots
// captured during their creation (splitNode + prefill). Snapshotting
// non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary.
if leaf := len(c.activePath) - 1; leaf >= commonLen {
node := c.activePath[leaf]
if !node.hasAllSnapshots() {
fromOffset := node.startOffset()
snaps := make([]cache.Snapshot, len(c.caches))
for j, kv := range c.caches {
if kv == nil {
continue
}
snaps[j] = kv.Snapshot(fromOffset)
}
node.setSnapshots(snaps, &c.pagedOutBytes)
pageOutCount++
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
}
}
// Rewind each cache to the ancestor offset or free it. Freed
// caches (e.g. RecurrentCache that can't rewind) will be restored
// from snapshots during page-in.
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.Restore(nil, ancestorOffset) {
kv.Free()
}
}
// Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check.
for _, node := range newPath {
if len(node.snapshots) == 0 {
continue
}
for j, kv := range c.caches {
if kv == nil {
continue
}
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= node.endOffset {
continue
}
if !kv.Restore(node.snapshots[j], node.endOffset) {
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
c.freeAll()
c.activePath = []*trieNode{c.root}
return
}
}
if node.endOffset > ancestorOffset {
pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
}
}
// Align all caches to the minimum offset.
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && kv.Offset() != minOff {
if !kv.Restore(nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
break
}
}
}
for i := len(c.activePath) - 1; i >= 0; i-- {
if c.activePath[i].endOffset <= minOff {
c.activePath = c.activePath[:i+1]
break
}
}
if pageOutCount > 0 || pageInCount > 0 {
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
}
}
// snapshot creates a snapshot at the current cache position. During prefill,
// it is called at branch points (user=false) to create restore points for
// future diverging requests and with user=true to mark an explicit reusable
// restore point.
func (s *cacheSession) snapshot(user bool) {
c := s.cache
cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 {
return
}
// Clear pending intermediate snapshot if we've reached or passed it.
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
s.snapshotOffset = 0
}
// The last node in activePath is the frontier where caches are advancing.
// cacheOffset is always >= its endOffset: begin() restores caches to this
// boundary and prefill advances monotonically forward.
frontier := c.activePath[len(c.activePath)-1]
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
if frontier.endOffset == cacheOffset {
if user {
frontier.user = true
}
if !frontier.hasAllSnapshots() {
s.attachSnapshots(frontier, cacheOffset)
}
return
}
if frontier.endOffset > cacheOffset {
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
return
}
// Advance the trie to cacheOffset — find or create a node there.
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
// Attach fresh snapshots from the live caches. Always use fresh
// snapshots even if the node already has some (e.g. from splitNode's
// Cache.Split which may be incomplete for non-splittable caches
// like RecurrentCache).
if user {
frontier.user = true
}
s.attachSnapshots(frontier, cacheOffset)
}
// advancePath advances the active path from the current frontier by matching
// tokens against existing trie children, splitting partial matches, and
// appending any remaining tokens as new nodes. Returns the new frontier.
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
// Check if existing children already cover some or all of tokens.
// tokens may span multiple trie nodes when extending a previous run's
// leaf and this snapshot now overlaps that same range.
matchPath, matched := findBestMatch(frontier, tokens)
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
remaining := tokens[matched:]
// Check for a partial match within the last node's edge — if so, split it.
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
}
}
// Append traversed nodes (excluding frontier) to the active path.
c.activePath = append(c.activePath, matchPath[1:]...)
dest := matchPath[len(matchPath)-1]
if len(remaining) > 0 {
// Drop non-user snapshots so appendTokens can extend in-place
// rather than creating a new child node.
if len(dest.children) == 0 && !dest.user {
dest.setSnapshots(nil, &c.pagedOutBytes)
}
newDest := dest.appendTokens(c.root, remaining, endOffset)
if newDest != dest {
c.activePath = append(c.activePath, newDest)
}
dest = newDest
}
return dest
}
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
// The node must be on the active path (and thus protected from eviction;
// lastUsed is updated in close()). All non-nil caches must be at the same
// offset (cacheOffset); a mismatch indicates a bug in the caller.
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
c := s.cache
if c.activePath[len(c.activePath)-1] != node {
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
return
}
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if kv.Offset() != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
}
snaps[i] = kv.Snapshot(node.startOffset())
}
}
node.setSnapshots(snaps, &c.pagedOutBytes)
slog.Debug("created snapshot", "offset", cacheOffset)
c.enforceEvictionPolicy()
}
// clear releases live caches and drops the trie so future requests cannot
// reuse prompt state keyed only by token IDs.
func (c *kvCache) clear() {
c.freeAll()
walkNodes(c.root, func(n *trieNode) bool {
for _, s := range n.snapshots {
if s != nil {
s.Close()
}
}
n.snapshots = nil
return true
})
c.root = nil
c.activePath = nil
}
// freeAll releases all cache layers.
func (c *kvCache) freeAll() {
for _, kv := range c.caches {
if kv != nil {
kv.Free()
}
}
}
func (c *kvCache) minCacheOffset() int {
offset := 0
found := false
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); !found || off < offset {
offset = off
found = true
}
}
return offset
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
offset := s.cache.minCacheOffset()
if offset <= 0 {
return
}
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
}
arrays = append(arrays, kv.State()...)
}
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...)
// Advance the trie frontier with any newly generated tokens.
c := s.cache
if len(c.activePath) > 0 {
frontier := c.activePath[len(c.activePath)-1]
stored := append(s.inputs, s.outputs...)
if offset > frontier.endOffset {
newTokens := stored[frontier.endOffset:offset]
c.advancePath(frontier, newTokens, offset)
}
now := time.Now()
for _, node := range c.activePath {
node.lastUsed = now
}
}
}
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
func (c *kvCache) enforceEvictionPolicy() {
if c.pagedOutBytes <= maxPagedOutBytes {
return
}
activeSet := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
activeSet[n] = true
}
for c.pagedOutBytes > maxPagedOutBytes {
var best *trieNode
walkNodes(c.root, func(n *trieNode) bool {
if n == c.root || activeSet[n] || !n.hasSnapshots() {
return true
}
// Evict: oldest, then deepest, then largest.
if best == nil || cmp.Or(
n.lastUsed.Compare(best.lastUsed),
cmp.Compare(best.endOffset, n.endOffset),
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
) < 0 {
best = n
}
return true
})
if best == nil {
break
}
c.evictNode(best)
}
}
// evictNode evicts a single node from the trie, freeing its snapshot memory.
func (c *kvCache) evictNode(node *trieNode) {
if len(node.children) == 0 {
// Leaf: remove entirely.
parent := node.parent
kind := "evicting leaf"
if node.user {
kind = "evicting user snapshot"
}
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
removeNode(node, &c.pagedOutBytes)
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
}
} else if len(node.children) == 1 {
// Interior snapshot node with one child: merge with child.
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
mergeWithChild(node, c.caches, &c.pagedOutBytes)
} else {
// Multi-child branch point: drop snapshots but keep the node.
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
node.setSnapshots(nil, &c.pagedOutBytes)
}
}
func (c *kvCache) dumpTree() {
// Summary stats
var cacheBytes int
for _, kv := range c.caches {
if kv == nil {
continue
}
for _, a := range kv.State() {
if a != nil {
cacheBytes += a.NumBytes()
}
}
}
// Build active path set for marking.
active := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
active[n] = true
}
var nodeCount, snapshotCount int
var pagedBytes int64
var lines []string
var dump func(n *trieNode, prefix string, isLast bool)
dump = func(n *trieNode, prefix string, isLast bool) {
if n == nil {
return
}
nodeCount++
// Build connector
var connector string
if n.parent == nil {
connector = ""
} else if isLast {
connector = prefix + "`-- "
} else {
connector = prefix + "|-- "
}
// Node label
nodeBytes := n.snapshotBytes()
pagedBytes += nodeBytes
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
}
var flags []string
if n.user {
flags = append(flags, "user")
}
if n.hasAllSnapshots() {
snapshotCount++
flags = append(flags, "snap")
}
if active[n] {
flags = append(flags, "active")
}
if len(flags) > 0 {
label += " (" + flags[0]
for _, f := range flags[1:] {
label += ", " + f
}
label += ")"
}
lines = append(lines, connector+label)
// Recurse children
childPrefix := prefix
if n.parent != nil {
if isLast {
childPrefix += " "
} else {
childPrefix += "| "
}
}
for i, child := range n.children {
dump(child, childPrefix, i == len(n.children)-1)
}
}
dump(c.root, "", true)
offset := c.minCacheOffset()
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
for i, l := range lines {
if i == 0 {
logutil.Trace("cache trie: " + l)
} else {
logutil.Trace(" " + l)
}
}
}