mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 12:54:12 +02:00
Enable multiple conversations to reuse cached computations when they share token prefixes (e.g. the same system prompt). A prefix trie tracks shared regions so switching between conversations only recomputes tokens that diverge. Inactive conversation state is paged from active GPU memory to other memory and restored on demand, with LRU eviction to keep memory usage bounded.
585 lines
17 KiB
Go
585 lines
17 KiB
Go
// cache.go manages a shared KV cache across conversations using a compressed
|
|
// prefix trie. Each trie node stores a token sequence (edge) and optional
|
|
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
|
|
//
|
|
// Key properties:
|
|
// - Only one path through the trie is "active" (backed by live MLX arrays)
|
|
// at a time. Switching paths pages out the frontier node and pages in the
|
|
// new path.
|
|
// - Snapshots are only captured at the frontier (end) of the active path.
|
|
// Intermediate node snapshots come from split prefill.
|
|
// - All cache layers must stay at the same token offset.
|
|
// - Sibling edges must not share a common token prefix (compressed trie
|
|
// invariant).
|
|
// - begin() always re-evaluates at least one token so the pipeline can seed
|
|
// generation, even on a full prefix match.
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"cmp"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
|
|
|
|
type kvCache struct {
|
|
root *trieNode // root of the prefix trie
|
|
activePath []*trieNode // current root→leaf path with live MLX arrays
|
|
caches []cache.Cache
|
|
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
|
|
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
|
// capture a snapshot during prefill. This enables future requests
|
|
// branching at this point to restore non-rewindable caches (e.g.
|
|
// RecurrentCache) instead of re-evaluating from scratch.
|
|
snapshotOffset int
|
|
}
|
|
|
|
func (c *kvCache) ensureCaches(m base.Model) {
|
|
if len(c.caches) != 0 {
|
|
return
|
|
}
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
return
|
|
}
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) ensureRoot() {
|
|
if c.root == nil {
|
|
c.root = &trieNode{
|
|
lastUsed: time.Now(),
|
|
}
|
|
c.activePath = []*trieNode{c.root}
|
|
}
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
c.ensureCaches(m)
|
|
c.ensureRoot()
|
|
|
|
matchPath, matched := findBestMatch(c.root, inputs)
|
|
originalMatched := matched
|
|
|
|
// Always keep at least one token to re-evaluate so the
|
|
// pipeline can seed token generation from it.
|
|
if matched == len(inputs) && matched > 0 {
|
|
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
|
}
|
|
|
|
// Check for partial match within a node's edge — truncate path
|
|
// to the parent boundary. snapshot() will split the node and
|
|
// create the branch point during prefill when caches are ready.
|
|
partialMatch := false
|
|
if len(matchPath) > 1 {
|
|
lastNode := matchPath[len(matchPath)-1]
|
|
matchedInEdge := matched - lastNode.startOffset()
|
|
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
|
matchPath = matchPath[:len(matchPath)-1]
|
|
partialMatch = true
|
|
}
|
|
}
|
|
|
|
// Switch to the matched path, paging in/out as needed.
|
|
c.switchToPath(matchPath)
|
|
|
|
// switchToPath aligns caches to a common offset
|
|
prefix := c.minCacheOffset()
|
|
remaining := inputs[prefix:]
|
|
|
|
// Schedule a snapshot at the branch point during prefill so future
|
|
// requests diverging here can restore instead of re-evaluating.
|
|
var snapshotAt int
|
|
if partialMatch || (prefix == 0 && matched > 0) {
|
|
snapshotAt = matched
|
|
}
|
|
|
|
args := []any{"total", len(inputs), "matched", originalMatched}
|
|
args = append(args, "cached", prefix, "left", len(remaining))
|
|
if snapshotAt > 0 {
|
|
args = append(args, "pending_snapshot", snapshotAt)
|
|
}
|
|
if prefix == 0 {
|
|
slog.Info("cache miss", args...)
|
|
} else {
|
|
slog.Info("cache hit", args...)
|
|
}
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
snapshotOffset: snapshotAt,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// switchToPath transitions from the current active path to a new path,
|
|
// paging out diverging segments and paging in the new path.
|
|
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|
defer c.enforceEvictionPolicy()
|
|
|
|
// Find common ancestor index.
|
|
commonLen := 0
|
|
for commonLen < len(c.activePath) && commonLen < len(newPath) {
|
|
if c.activePath[commonLen] != newPath[commonLen] {
|
|
break
|
|
}
|
|
commonLen++
|
|
}
|
|
|
|
ancestorOffset := 0
|
|
if commonLen > 0 {
|
|
ancestorOffset = c.activePath[commonLen-1].endOffset
|
|
}
|
|
|
|
var pageOutCount, pageInCount int
|
|
|
|
// Page out the leaf of the old path. Only the leaf's live cache
|
|
// state is correct — intermediate nodes already have snapshots
|
|
// captured during their creation (splitNode + prefill). Snapshotting
|
|
// non-leaf nodes here would produce wrong results for non-rewindable
|
|
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
|
// the intermediate boundary.
|
|
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
|
node := c.activePath[leaf]
|
|
if !node.hasAllSnapshots() {
|
|
fromOffset := node.startOffset()
|
|
snaps := make([]cache.Snapshot, len(c.caches))
|
|
for j, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
snaps[j] = kv.Snapshot(fromOffset)
|
|
}
|
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
|
pageOutCount++
|
|
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
|
|
}
|
|
}
|
|
|
|
// Rewind each cache to the ancestor offset or free it. Freed
|
|
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
|
// from snapshots during page-in.
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if !kv.Restore(nil, ancestorOffset) {
|
|
kv.Free()
|
|
}
|
|
}
|
|
|
|
// Page in — walk the full new path, restoring from snapshots.
|
|
// Freed caches naturally pick up the first available snapshot.
|
|
// Caches already past a node skip it via offset check.
|
|
for _, node := range newPath {
|
|
if len(node.snapshots) == 0 {
|
|
continue
|
|
}
|
|
for j, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
|
continue
|
|
}
|
|
if kv.Offset() >= node.endOffset {
|
|
continue
|
|
}
|
|
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
|
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
|
c.freeAll()
|
|
c.activePath = []*trieNode{c.root}
|
|
return
|
|
}
|
|
}
|
|
if node.endOffset > ancestorOffset {
|
|
pageInCount++
|
|
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
|
}
|
|
}
|
|
|
|
// Align all caches to the minimum offset.
|
|
c.activePath = newPath
|
|
minOff := c.minCacheOffset()
|
|
for _, kv := range c.caches {
|
|
if kv != nil && kv.Offset() != minOff {
|
|
if !kv.Restore(nil, minOff) {
|
|
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
|
c.freeAll()
|
|
break
|
|
}
|
|
}
|
|
}
|
|
for i := len(c.activePath) - 1; i >= 0; i-- {
|
|
if c.activePath[i].endOffset <= minOff {
|
|
c.activePath = c.activePath[:i+1]
|
|
break
|
|
}
|
|
}
|
|
|
|
if pageOutCount > 0 || pageInCount > 0 {
|
|
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
|
}
|
|
}
|
|
|
|
// snapshot creates a snapshot at the current cache position. During prefill,
|
|
// it is called at branch points (user=false) to create restore points for
|
|
// future diverging requests and with user=true to mark an explicit reusable
|
|
// restore point.
|
|
func (s *cacheSession) snapshot(user bool) {
|
|
c := s.cache
|
|
cacheOffset := c.minCacheOffset()
|
|
if cacheOffset <= 0 {
|
|
return
|
|
}
|
|
|
|
// Clear pending intermediate snapshot if we've reached or passed it.
|
|
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
|
s.snapshotOffset = 0
|
|
}
|
|
|
|
// The last node in activePath is the frontier where caches are advancing.
|
|
// cacheOffset is always >= its endOffset: begin() restores caches to this
|
|
// boundary and prefill advances monotonically forward.
|
|
frontier := c.activePath[len(c.activePath)-1]
|
|
|
|
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
|
|
if frontier.endOffset == cacheOffset {
|
|
if user {
|
|
frontier.user = true
|
|
}
|
|
if !frontier.hasAllSnapshots() {
|
|
s.attachSnapshots(frontier, cacheOffset)
|
|
}
|
|
return
|
|
}
|
|
|
|
if frontier.endOffset > cacheOffset {
|
|
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
|
|
return
|
|
}
|
|
|
|
// Advance the trie to cacheOffset — find or create a node there.
|
|
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
|
|
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
|
|
|
|
// Attach fresh snapshots from the live caches. Always use fresh
|
|
// snapshots even if the node already has some (e.g. from splitNode's
|
|
// Cache.Split which may be incomplete for non-splittable caches
|
|
// like RecurrentCache).
|
|
if user {
|
|
frontier.user = true
|
|
}
|
|
s.attachSnapshots(frontier, cacheOffset)
|
|
}
|
|
|
|
// advancePath advances the active path from the current frontier by matching
|
|
// tokens against existing trie children, splitting partial matches, and
|
|
// appending any remaining tokens as new nodes. Returns the new frontier.
|
|
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
|
|
// Check if existing children already cover some or all of tokens.
|
|
// tokens may span multiple trie nodes when extending a previous run's
|
|
// leaf and this snapshot now overlaps that same range.
|
|
matchPath, matched := findBestMatch(frontier, tokens)
|
|
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
|
|
remaining := tokens[matched:]
|
|
|
|
// Check for a partial match within the last node's edge — if so, split it.
|
|
if len(matchPath) > 1 {
|
|
lastNode := matchPath[len(matchPath)-1]
|
|
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
|
|
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
|
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
|
|
}
|
|
}
|
|
|
|
// Append traversed nodes (excluding frontier) to the active path.
|
|
c.activePath = append(c.activePath, matchPath[1:]...)
|
|
dest := matchPath[len(matchPath)-1]
|
|
|
|
if len(remaining) > 0 {
|
|
// Drop non-user snapshots so appendTokens can extend in-place
|
|
// rather than creating a new child node.
|
|
if len(dest.children) == 0 && !dest.user {
|
|
dest.setSnapshots(nil, &c.pagedOutBytes)
|
|
}
|
|
newDest := dest.appendTokens(c.root, remaining, endOffset)
|
|
if newDest != dest {
|
|
c.activePath = append(c.activePath, newDest)
|
|
}
|
|
dest = newDest
|
|
}
|
|
return dest
|
|
}
|
|
|
|
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
|
|
// The node must be on the active path (and thus protected from eviction;
|
|
// lastUsed is updated in close()). All non-nil caches must be at the same
|
|
// offset (cacheOffset); a mismatch indicates a bug in the caller.
|
|
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
|
c := s.cache
|
|
|
|
if c.activePath[len(c.activePath)-1] != node {
|
|
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
|
|
return
|
|
}
|
|
|
|
snaps := make([]cache.Snapshot, len(c.caches))
|
|
for i, kv := range c.caches {
|
|
if kv != nil {
|
|
if kv.Offset() != cacheOffset {
|
|
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
|
}
|
|
snaps[i] = kv.Snapshot(node.startOffset())
|
|
}
|
|
}
|
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
|
slog.Debug("created snapshot", "offset", cacheOffset)
|
|
c.enforceEvictionPolicy()
|
|
}
|
|
|
|
// freeAll releases all cache layers.
|
|
func (c *kvCache) freeAll() {
|
|
for _, kv := range c.caches {
|
|
if kv != nil {
|
|
kv.Free()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) minCacheOffset() int {
|
|
offset := 0
|
|
found := false
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if off := kv.Offset(); !found || off < offset {
|
|
offset = off
|
|
found = true
|
|
}
|
|
}
|
|
return offset
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
offset := s.cache.minCacheOffset()
|
|
if offset <= 0 {
|
|
return
|
|
}
|
|
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, kv := range s.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
arrays = append(arrays, kv.State()...)
|
|
}
|
|
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data.
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
// Advance the trie frontier with any newly generated tokens.
|
|
c := s.cache
|
|
if len(c.activePath) > 0 {
|
|
frontier := c.activePath[len(c.activePath)-1]
|
|
stored := append(s.inputs, s.outputs...)
|
|
|
|
if offset > frontier.endOffset {
|
|
newTokens := stored[frontier.endOffset:offset]
|
|
c.advancePath(frontier, newTokens, offset)
|
|
}
|
|
now := time.Now()
|
|
for _, node := range c.activePath {
|
|
node.lastUsed = now
|
|
}
|
|
}
|
|
}
|
|
|
|
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
|
|
func (c *kvCache) enforceEvictionPolicy() {
|
|
if c.pagedOutBytes <= maxPagedOutBytes {
|
|
return
|
|
}
|
|
|
|
activeSet := make(map[*trieNode]bool, len(c.activePath))
|
|
for _, n := range c.activePath {
|
|
activeSet[n] = true
|
|
}
|
|
|
|
for c.pagedOutBytes > maxPagedOutBytes {
|
|
var best *trieNode
|
|
walkNodes(c.root, func(n *trieNode) bool {
|
|
if n == c.root || activeSet[n] || !n.hasSnapshots() {
|
|
return true
|
|
}
|
|
// Evict: oldest, then deepest, then largest.
|
|
if best == nil || cmp.Or(
|
|
n.lastUsed.Compare(best.lastUsed),
|
|
cmp.Compare(best.endOffset, n.endOffset),
|
|
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
|
|
) < 0 {
|
|
best = n
|
|
}
|
|
return true
|
|
})
|
|
if best == nil {
|
|
break
|
|
}
|
|
c.evictNode(best)
|
|
}
|
|
}
|
|
|
|
// evictNode evicts a single node from the trie, freeing its snapshot memory.
|
|
func (c *kvCache) evictNode(node *trieNode) {
|
|
if len(node.children) == 0 {
|
|
// Leaf: remove entirely.
|
|
parent := node.parent
|
|
kind := "evicting leaf"
|
|
if node.user {
|
|
kind = "evicting user snapshot"
|
|
}
|
|
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
|
removeNode(node, &c.pagedOutBytes)
|
|
|
|
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
|
|
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
|
|
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
|
|
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
|
|
}
|
|
} else if len(node.children) == 1 {
|
|
// Interior snapshot node with one child: merge with child.
|
|
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
|
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
|
} else {
|
|
// Multi-child branch point: drop snapshots but keep the node.
|
|
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
|
node.setSnapshots(nil, &c.pagedOutBytes)
|
|
}
|
|
}
|
|
|
|
func (c *kvCache) dumpTree() {
|
|
// Summary stats
|
|
var cacheBytes int
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
for _, a := range kv.State() {
|
|
if a != nil {
|
|
cacheBytes += a.NumBytes()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build active path set for marking.
|
|
active := make(map[*trieNode]bool, len(c.activePath))
|
|
for _, n := range c.activePath {
|
|
active[n] = true
|
|
}
|
|
|
|
var nodeCount, snapshotCount int
|
|
var pagedBytes int64
|
|
var lines []string
|
|
var dump func(n *trieNode, prefix string, isLast bool)
|
|
dump = func(n *trieNode, prefix string, isLast bool) {
|
|
if n == nil {
|
|
return
|
|
}
|
|
nodeCount++
|
|
|
|
// Build connector
|
|
var connector string
|
|
if n.parent == nil {
|
|
connector = ""
|
|
} else if isLast {
|
|
connector = prefix + "`-- "
|
|
} else {
|
|
connector = prefix + "|-- "
|
|
}
|
|
|
|
// Node label
|
|
nodeBytes := n.snapshotBytes()
|
|
pagedBytes += nodeBytes
|
|
|
|
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
|
|
if nodeBytes > 0 {
|
|
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
|
}
|
|
var flags []string
|
|
if n.user {
|
|
flags = append(flags, "user")
|
|
}
|
|
if n.hasAllSnapshots() {
|
|
snapshotCount++
|
|
flags = append(flags, "snap")
|
|
}
|
|
if active[n] {
|
|
flags = append(flags, "active")
|
|
}
|
|
if len(flags) > 0 {
|
|
label += " (" + flags[0]
|
|
for _, f := range flags[1:] {
|
|
label += ", " + f
|
|
}
|
|
label += ")"
|
|
}
|
|
lines = append(lines, connector+label)
|
|
|
|
// Recurse children
|
|
childPrefix := prefix
|
|
if n.parent != nil {
|
|
if isLast {
|
|
childPrefix += " "
|
|
} else {
|
|
childPrefix += "| "
|
|
}
|
|
}
|
|
for i, child := range n.children {
|
|
dump(child, childPrefix, i == len(n.children)-1)
|
|
}
|
|
}
|
|
dump(c.root, "", true)
|
|
|
|
offset := c.minCacheOffset()
|
|
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
|
|
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
|
|
for i, l := range lines {
|
|
if i == 0 {
|
|
logutil.Trace("cache trie: " + l)
|
|
} else {
|
|
logutil.Trace(" " + l)
|
|
}
|
|
}
|
|
}
|