Files
ollama/x/mlxrunner/cache_trie_test.go
Jesse Gross 96e36c0d90 mlxrunner: share KV cache across conversations with common prefixes
Enable multiple conversations to reuse cached computations when they
share token prefixes (e.g. the same system prompt). A prefix trie
tracks shared regions so switching between conversations only
recomputes tokens that diverge. Inactive conversation state is paged
from active GPU memory to other memory and restored on demand, with LRU
eviction to keep memory usage bounded.
2026-03-18 16:06:33 -07:00

456 lines
12 KiB
Go

package mlxrunner
import (
"slices"
"testing"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
func newTestTrie(tokens []int32) *trieNode {
root := &trieNode{lastUsed: time.Now()}
if len(tokens) > 0 {
child := &trieNode{
tokens: slices.Clone(tokens),
endOffset: len(tokens),
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child}
}
return root
}
func TestFindBestMatchMultipleBranches(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
branch1 := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
branch2 := &trieNode{
tokens: []int32{4, 5, 6},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{branch1, branch2}
// Match branch 1.
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch1 {
t.Fatal("expected to match branch1")
}
// Match branch 2.
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch2 {
t.Fatal("expected to match branch2")
}
// Match neither.
_, matched = findBestMatch(root, []int32{7, 8, 9})
if matched != 0 {
t.Fatalf("expected 0 matched, got %d", matched)
}
}
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{shared}
longer := &trieNode{
tokens: []int32{10, 11, 12, 13, 14},
endOffset: 8,
parent: shared,
lastUsed: time.Now(),
}
shorter := &trieNode{
tokens: []int32{10, 11, 12},
endOffset: 6,
parent: shared,
lastUsed: time.Now(),
}
// Put longer first so naive first-match would pick it.
shared.children = []*trieNode{longer, shorter}
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
path, matched := findBestMatch(root, input)
if matched != 6 {
t.Fatalf("expected 6 matched, got %d", matched)
}
if len(path) != 3 {
t.Fatalf("expected 3 nodes in path, got %d", len(path))
}
if path[2] != shorter {
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
}
}
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
child1 := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
}
child2 := &trieNode{
tokens: []int32{1, 2, 9},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child2, child1}
input := []int32{1, 2, 3, 7, 8}
path, matched := findBestMatch(root, input)
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if path[1] != child1 {
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
}
}
func TestSplitNodeWithSnapshots(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
child := root.children[0]
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
child.user = true
caches := []cache.Cache{rc}
newParent := splitNode(child, 3, caches, nil)
if !newParent.hasSnapshots() {
t.Fatal("newParent should have snapshots after split")
}
if newParent.user {
t.Fatal("newParent should not be a user snapshot after splitNode")
}
if !child.hasSnapshots() {
t.Fatal("child should have snapshots after split")
}
if !child.user {
t.Fatal("child should remain a user snapshot")
}
}
func TestFindSplitAppendSequence(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
lastNode := path[len(path)-1]
matchedInEdge := matched - lastNode.startOffset()
split := splitNode(lastNode, matchedInEdge, nil, nil)
split.appendTokens(root, []int32{6, 7}, 5)
if len(root.children) != 1 {
t.Fatalf("root should have 1 child, got %d", len(root.children))
}
shared := root.children[0]
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
}
if len(shared.children) != 2 {
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
}
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if m1 != 5 {
t.Fatalf("original branch: expected 5 matched, got %d", m1)
}
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if m2 != 5 {
t.Fatalf("new branch: expected 5 matched, got %d", m2)
}
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
if m3 != 3 {
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
}
}
func TestRepeatedBranching(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matchedB != 3 {
t.Fatalf("B: expected 3 matched, got %d", matchedB)
}
nodeA := root.children[0]
split1 := splitNode(nodeA, 3, nil, nil)
split1.appendTokens(root, []int32{6, 7}, 5)
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
if matchedC != 2 {
t.Fatalf("C: expected 2 matched, got %d", matchedC)
}
split2 := splitNode(split1, 2, nil, nil)
split2.appendTokens(root, []int32{8, 9}, 4)
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if mA != 5 {
t.Fatalf("A: expected 5 matched, got %d", mA)
}
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if mB != 5 {
t.Fatalf("B: expected 5 matched, got %d", mB)
}
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
if mC != 4 {
t.Fatalf("C: expected 4 matched, got %d", mC)
}
checkTrieInvariants(t, root)
}
func TestMergeWithChild(t *testing.T) {
t.Run("Basic", func(t *testing.T) {
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
now := time.Now()
root := &trieNode{lastUsed: now}
a := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
}
b := &trieNode{
tokens: []int32{4, 5},
endOffset: 5,
parent: a,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
}
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
root.children = []*trieNode{a}
a.children = []*trieNode{b}
b.children = []*trieNode{c, d}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
mergeWithChild(a, []cache.Cache{mc}, nil)
// Tokens concatenated.
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
}
if a.endOffset != 5 {
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
}
// Grandchildren reparented.
if len(a.children) != 2 {
t.Fatalf("merged children count = %d, want 2", len(a.children))
}
if c.parent != a || d.parent != a {
t.Fatal("grandchildren should be reparented to merged node")
}
// B detached.
if b.parent != nil || b.children != nil || b.snapshots != nil {
t.Fatal("child B should be fully detached after merge")
}
// Merged snapshot should cover [0,5).
if !a.hasSnapshots() {
t.Fatal("merged node should have snapshots")
}
ms := a.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
})
t.Run("UserFlag", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
parent := &trieNode{
tokens: []int32{1, 2}, endOffset: 2, parent: root,
lastUsed: time.Now(), user: false,
}
child := &trieNode{
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
lastUsed: time.Now(), user: true,
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.user {
t.Fatal("merged node should inherit user=true from child")
}
})
t.Run("LastUsed", func(t *testing.T) {
now := time.Now()
root := &trieNode{lastUsed: now}
parent := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root,
lastUsed: now.Add(-1 * time.Hour),
}
child := &trieNode{
tokens: []int32{2}, endOffset: 2, parent: parent,
lastUsed: now.Add(1 * time.Hour),
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
t.Fatal("merged node should pick the more recent lastUsed")
}
})
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on node with 2 children")
}
}()
root := &trieNode{lastUsed: time.Now()}
node := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
children: []*trieNode{
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
},
}
root.children = []*trieNode{node}
mergeWithChild(node, nil, nil)
})
}
func TestSplitMergeRoundTrip(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
leaf := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
}
root.children = []*trieNode{leaf}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
caches := []cache.Cache{mc}
// Split at 3: [1,2,3] -> [4,5]
newParent := splitNode(leaf, 3, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
}
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
}
checkTrieInvariants(t, root)
// Merge back: should restore [1,2,3,4,5]
mergeWithChild(newParent, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
}
if newParent.endOffset != 5 {
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
}
if len(newParent.children) != 0 {
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
}
// Merged snapshot should cover [0,5).
if !newParent.hasSnapshots() {
t.Fatal("after merge: should have snapshots")
}
ms := newParent.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
}
func TestRemoveNode(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
}
leafA := &trieNode{
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
leafB := &trieNode{
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
root.children = []*trieNode{shared}
shared.children = []*trieNode{leafA, leafB}
removeNode(leafA, nil)
if len(shared.children) != 1 {
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
}
if shared.children[0] != leafB {
t.Fatal("remaining child should be leafB")
}
if leafA.parent != nil {
t.Fatal("removed node parent should be nil")
}
if leafA.snapshots != nil {
t.Fatal("removed node snapshots should be nil")
}
checkTrieInvariants(t, root)
})
t.Run("PanicOnRoot", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing root")
}
}()
removeNode(&trieNode{}, nil)
})
t.Run("PanicOnNonLeaf", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing non-leaf")
}
}()
parent := &trieNode{parent: &trieNode{}}
parent.children = []*trieNode{{}}
removeNode(parent, nil)
})
}