mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 08:13:29 +02:00
Enable multiple conversations to reuse cached computations when they share token prefixes (e.g. the same system prompt). A prefix trie tracks shared regions so switching between conversations only recomputes tokens that diverge. Inactive conversation state is paged from active GPU memory to other memory and restored on demand, with LRU eviction to keep memory usage bounded.
297 lines
8.2 KiB
Go
297 lines
8.2 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
)
|
|
|
|
// trieNode represents a node in the compressed prefix trie for KV cache branching.
|
|
// Each node stores a compressed edge (multiple tokens) and optional paged-out
|
|
// snapshot data per cache layer.
|
|
type trieNode struct {
|
|
tokens []int32 // compressed edge — multiple tokens per node
|
|
endOffset int // cumulative tokens from root to end of this node
|
|
parent *trieNode
|
|
children []*trieNode
|
|
lastUsed time.Time // for LRU eviction
|
|
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
|
|
user bool // true = explicit restore point (resist auto-merge)
|
|
}
|
|
|
|
// startOffset returns the cumulative token offset at the start of this node's edge.
|
|
func (n *trieNode) startOffset() int {
|
|
return n.endOffset - len(n.tokens)
|
|
}
|
|
|
|
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
|
|
func (n *trieNode) snapshotBytes() int64 {
|
|
var total int64
|
|
for _, s := range n.snapshots {
|
|
if s != nil {
|
|
total += int64(s.Size())
|
|
}
|
|
}
|
|
return total
|
|
}
|
|
|
|
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
|
|
// If counter is non-nil, the net byte delta is applied to it.
|
|
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
|
|
old := n.swapSnapshots(snaps, counter)
|
|
for _, s := range old {
|
|
if s != nil {
|
|
s.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
// swapSnapshots is like setSnapshots but returns the previous snapshots
|
|
// without closing them. Use this when the old snapshots will be consumed
|
|
// (e.g. by Split/Merge).
|
|
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
|
|
old := n.snapshots
|
|
if counter != nil {
|
|
*counter -= n.snapshotBytes()
|
|
}
|
|
n.snapshots = snaps
|
|
if counter != nil {
|
|
*counter += n.snapshotBytes()
|
|
}
|
|
return old
|
|
}
|
|
|
|
// hasSnapshots returns true if any layer has snapshot data.
|
|
func (n *trieNode) hasSnapshots() bool {
|
|
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
|
|
}
|
|
|
|
// hasAllSnapshots returns true if every layer has snapshot data.
|
|
func (n *trieNode) hasAllSnapshots() bool {
|
|
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
|
|
}
|
|
|
|
// findBestMatch walks the trie matching input tokens, returning the path of
|
|
// nodes traversed and the total number of tokens matched.
|
|
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
|
|
if root == nil {
|
|
return nil, 0
|
|
}
|
|
|
|
path = []*trieNode{root}
|
|
pos := 0
|
|
|
|
node := root
|
|
for pos < len(tokens) {
|
|
// When multiple children share the same first token (e.g. after
|
|
// a split), prefer the child whose full edge matches over one
|
|
// that only partially matches. This is just being defensive - it
|
|
// shouldn't actually happen.
|
|
var best *trieNode
|
|
bestMatched := 0
|
|
bestFull := false
|
|
for _, child := range node.children {
|
|
edge := child.tokens
|
|
if len(edge) == 0 {
|
|
continue
|
|
}
|
|
if edge[0] != tokens[pos] {
|
|
continue
|
|
}
|
|
// Count matching tokens in this child's edge.
|
|
j := 0
|
|
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
|
|
j++
|
|
}
|
|
full := j == len(edge)
|
|
// Prefer full edge matches; among same type, prefer longer.
|
|
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
|
|
best = child
|
|
bestMatched = j
|
|
bestFull = full
|
|
}
|
|
}
|
|
if best == nil {
|
|
break
|
|
}
|
|
|
|
pos += bestMatched
|
|
path = append(path, best)
|
|
|
|
if !bestFull {
|
|
// Partial match within this edge
|
|
break
|
|
}
|
|
node = best
|
|
}
|
|
|
|
return path, pos
|
|
}
|
|
|
|
// appendTokens either creates a new child node or extends the leaf in place,
|
|
// returning the node that now holds the tokens.
|
|
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
|
|
if n == root || len(n.children) > 0 || n.hasSnapshots() {
|
|
child := &trieNode{
|
|
tokens: make([]int32, len(tokens)),
|
|
endOffset: endOffset,
|
|
parent: n,
|
|
lastUsed: n.lastUsed,
|
|
}
|
|
copy(child.tokens, tokens)
|
|
n.children = append(n.children, child)
|
|
return child
|
|
}
|
|
n.tokens = append(n.tokens, tokens...)
|
|
n.endOffset = endOffset
|
|
return n
|
|
}
|
|
|
|
// removeNode removes a leaf node from the trie.
|
|
func removeNode(node *trieNode, counter *int64) {
|
|
if node.parent == nil {
|
|
panic("removeNode called on root")
|
|
}
|
|
if len(node.children) != 0 {
|
|
panic("removeNode called on non-leaf node")
|
|
}
|
|
p := node.parent
|
|
for i, child := range p.children {
|
|
if child == node {
|
|
p.children = append(p.children[:i], p.children[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
node.parent = nil
|
|
node.setSnapshots(nil, counter)
|
|
}
|
|
|
|
// splitNode splits a node at the given token offset within its edge,
|
|
// creating a new parent node. Returns the new parent.
|
|
// `at` is relative to the node's edge (0-based index into node.tokens).
|
|
// If caches are provided, snapshots are split between parent and child
|
|
// using Cache.Split; otherwise snapshots are invalidated.
|
|
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
|
|
if at <= 0 || at >= len(node.tokens) {
|
|
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
|
|
}
|
|
|
|
// Create new parent with the prefix of the edge.
|
|
newParent := &trieNode{
|
|
tokens: make([]int32, at),
|
|
endOffset: node.startOffset() + at,
|
|
parent: node.parent,
|
|
children: []*trieNode{node},
|
|
lastUsed: node.lastUsed,
|
|
}
|
|
copy(newParent.tokens, node.tokens[:at])
|
|
|
|
// Update the original node to have only the suffix.
|
|
node.tokens = node.tokens[at:]
|
|
// endOffset stays the same for the original node.
|
|
|
|
// Split snapshots between parent and child using Cache.Split.
|
|
// Split consumes the old snapshots, so we remove them first (adjusting
|
|
// the counter), then assign the split halves (adjusting it back).
|
|
if node.hasSnapshots() {
|
|
oldSnaps := node.swapSnapshots(nil, counter)
|
|
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
|
|
childSnaps := make([]cache.Snapshot, len(oldSnaps))
|
|
for i, snap := range oldSnaps {
|
|
if snap != nil {
|
|
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
|
|
}
|
|
}
|
|
newParent.setSnapshots(parentSnaps, counter)
|
|
node.setSnapshots(childSnaps, counter)
|
|
}
|
|
|
|
// Reparent: replace node with newParent in the old parent's children.
|
|
if node.parent != nil {
|
|
for i, child := range node.parent.children {
|
|
if child == node {
|
|
node.parent.children[i] = newParent
|
|
break
|
|
}
|
|
}
|
|
}
|
|
node.parent = newParent
|
|
|
|
return newParent
|
|
}
|
|
|
|
// mergeWithChild merges a node with its single child: concatenates tokens,
|
|
// merges snapshot data via Cache.Merge, and removes the child.
|
|
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
|
|
if len(node.children) != 1 {
|
|
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
|
|
}
|
|
|
|
child := node.children[0]
|
|
|
|
// Concatenate tokens.
|
|
node.tokens = append(node.tokens, child.tokens...)
|
|
node.endOffset = child.endOffset
|
|
|
|
// Merge snapshots per layer. Merge consumes the old snapshots, so we
|
|
// remove them first (adjusting the counter), then assign the merged
|
|
// result (adjusting it back).
|
|
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
|
|
nodeSnaps := node.swapSnapshots(nil, counter)
|
|
childSnaps := child.swapSnapshots(nil, counter)
|
|
merged := make([]cache.Snapshot, len(caches))
|
|
for i := range caches {
|
|
var ps, cs cache.Snapshot
|
|
if nodeSnaps != nil {
|
|
ps = nodeSnaps[i]
|
|
}
|
|
if childSnaps != nil {
|
|
cs = childSnaps[i]
|
|
}
|
|
|
|
merged[i] = caches[i].Merge(ps, cs)
|
|
}
|
|
node.setSnapshots(merged, counter)
|
|
}
|
|
|
|
// Adopt grandchildren.
|
|
node.children = child.children
|
|
for _, gc := range node.children {
|
|
gc.parent = node
|
|
}
|
|
|
|
// Inherit user flag from child if child was a user-created snapshot node.
|
|
node.user = child.user
|
|
|
|
// Update lastUsed to the more recent of the two.
|
|
if child.lastUsed.After(node.lastUsed) {
|
|
node.lastUsed = child.lastUsed
|
|
}
|
|
|
|
child.parent = nil
|
|
child.children = nil
|
|
}
|
|
|
|
// walkNodes calls fn for every node in the trie (depth-first).
|
|
// If fn returns false, the walk stops.
|
|
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
|
|
if root == nil {
|
|
return
|
|
}
|
|
var walk func(*trieNode) bool
|
|
walk = func(n *trieNode) bool {
|
|
if !fn(n) {
|
|
return false
|
|
}
|
|
for _, child := range n.children {
|
|
if !walk(child) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
walk(root)
|
|
}
|