mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 23:54:05 +02:00
Add periodic snapshots every 8k tokens and near the end of the prompt so that long prompts can be partially restored and thinking/generation can be retried without full reprocessing.
939 lines
29 KiB
Go
939 lines
29 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
// snapshotTracker records every fakeSnapshot created and every Close() call
|
|
// so tests can detect leaked (created but never closed) or double-closed snapshots.
|
|
type snapshotTracker struct {
|
|
all []*fakeSnapshot
|
|
}
|
|
|
|
func (tr *snapshotTracker) track(s *fakeSnapshot) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.tracker = tr
|
|
tr.all = append(tr.all, s)
|
|
}
|
|
|
|
// Fake caches that store actual token sequences so tests can verify the right
|
|
// data was restored, not just the right offset.
|
|
|
|
// fakeSnapshot stores a copy of the token sub-sequence it covers.
|
|
type fakeSnapshot struct {
|
|
tokens []int32
|
|
from, to int
|
|
byteSize int // configurable for eviction tests
|
|
|
|
tracker *snapshotTracker
|
|
closeCount int
|
|
}
|
|
|
|
func (s *fakeSnapshot) Size() int { return s.byteSize }
|
|
func (s *fakeSnapshot) Close() {
|
|
s.closeCount++
|
|
}
|
|
|
|
// fakeRewindableCache tracks the full token sequence and supports
|
|
// arbitrary rewind via Restore(nil, target).
|
|
type fakeRewindableCache struct {
|
|
tokens []int32
|
|
tracker *snapshotTracker
|
|
}
|
|
|
|
func (c *fakeRewindableCache) feed(tokens []int32) {
|
|
c.tokens = append(c.tokens, tokens...)
|
|
}
|
|
|
|
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
return nil, nil
|
|
}
|
|
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
|
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
|
|
|
func (c *fakeRewindableCache) Free() {
|
|
c.tokens = nil
|
|
}
|
|
|
|
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
|
if fromOffset >= len(c.tokens) {
|
|
return nil
|
|
}
|
|
from := fromOffset
|
|
if from < 0 {
|
|
from = 0
|
|
}
|
|
s := &fakeSnapshot{
|
|
tokens: slices.Clone(c.tokens[from:]),
|
|
from: from,
|
|
to: len(c.tokens),
|
|
}
|
|
c.tracker.track(s)
|
|
return s
|
|
}
|
|
|
|
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
|
if target < 0 {
|
|
return false
|
|
}
|
|
|
|
if snapshot == nil {
|
|
if target > len(c.tokens) {
|
|
return false
|
|
}
|
|
c.tokens = c.tokens[:target]
|
|
return true
|
|
}
|
|
s := snapshot.(*fakeSnapshot)
|
|
if target > s.to || len(c.tokens) < s.from {
|
|
return false
|
|
}
|
|
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
|
if target < len(c.tokens) {
|
|
c.tokens = c.tokens[:target]
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
|
if parent == nil || child == nil {
|
|
if parent != nil {
|
|
parent.Close()
|
|
}
|
|
if child != nil {
|
|
child.Close()
|
|
}
|
|
return nil
|
|
}
|
|
p := parent.(*fakeSnapshot)
|
|
ch := child.(*fakeSnapshot)
|
|
merged := make([]int32, len(p.tokens)+len(ch.tokens))
|
|
copy(merged, p.tokens)
|
|
copy(merged[len(p.tokens):], ch.tokens)
|
|
s := &fakeSnapshot{
|
|
tokens: merged,
|
|
from: p.from,
|
|
to: ch.to,
|
|
byteSize: p.byteSize + ch.byteSize,
|
|
}
|
|
c.tracker.track(s)
|
|
p.Close()
|
|
ch.Close()
|
|
return s
|
|
}
|
|
|
|
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
|
if snapshot == nil {
|
|
return nil, nil
|
|
}
|
|
s := snapshot.(*fakeSnapshot)
|
|
relAt := at - s.from
|
|
if relAt <= 0 {
|
|
return nil, snapshot
|
|
}
|
|
if relAt >= len(s.tokens) {
|
|
return snapshot, nil
|
|
}
|
|
p := &fakeSnapshot{
|
|
tokens: slices.Clone(s.tokens[:relAt]),
|
|
from: s.from,
|
|
to: at,
|
|
byteSize: s.byteSize,
|
|
}
|
|
ch := &fakeSnapshot{
|
|
tokens: slices.Clone(s.tokens[relAt:]),
|
|
from: at,
|
|
to: s.to,
|
|
byteSize: s.byteSize,
|
|
}
|
|
c.tracker.track(p)
|
|
c.tracker.track(ch)
|
|
s.Close()
|
|
return p, ch
|
|
}
|
|
|
|
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
|
|
// token sequence but only the trailing maxSize tokens are "live" in the window.
|
|
// Once the window fills, live rewind is impossible without a snapshot.
|
|
type fakeSlidingWindowCache struct {
|
|
tokens []int32
|
|
maxSize int
|
|
tracker *snapshotTracker
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
|
c.tokens = append(c.tokens, tokens...)
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
return nil, nil
|
|
}
|
|
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
|
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
|
|
|
func (c *fakeSlidingWindowCache) Free() {
|
|
c.tokens = nil
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
|
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
|
|
return nil
|
|
}
|
|
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
|
|
s := &fakeSnapshot{
|
|
tokens: slices.Clone(c.tokens),
|
|
from: 0,
|
|
to: len(c.tokens),
|
|
}
|
|
c.tracker.track(s)
|
|
return s
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
|
if target < 0 {
|
|
return false
|
|
}
|
|
|
|
if snapshot == nil {
|
|
if target >= len(c.tokens) {
|
|
return target == len(c.tokens)
|
|
}
|
|
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
|
if len(c.tokens) > c.maxSize {
|
|
return false
|
|
}
|
|
c.tokens = c.tokens[:target]
|
|
return true
|
|
}
|
|
s := snapshot.(*fakeSnapshot)
|
|
if target > s.to {
|
|
return false
|
|
}
|
|
// Reject if clamping would leave an incomplete window
|
|
// (matches RotatingKVCache behavior).
|
|
if target < s.to && s.to > c.maxSize {
|
|
return false
|
|
}
|
|
c.tokens = slices.Clone(s.tokens)
|
|
if target < len(c.tokens) {
|
|
c.tokens = c.tokens[:target]
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
|
// Child supersedes parent for sliding window (full window state).
|
|
if parent != nil {
|
|
parent.Close()
|
|
}
|
|
return child
|
|
}
|
|
|
|
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
|
// Can't split a ring buffer at an arbitrary point.
|
|
return nil, snapshot
|
|
}
|
|
|
|
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
|
|
// but cannot rewind without a snapshot.
|
|
type fakeRecurrentCache struct {
|
|
tokens []int32
|
|
tracker *snapshotTracker
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) feed(tokens []int32) {
|
|
c.tokens = append(c.tokens, tokens...)
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
return nil, nil
|
|
}
|
|
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
|
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
|
|
|
func (c *fakeRecurrentCache) Free() {
|
|
c.tokens = nil
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
|
// Recurrent state is cumulative; snapshot captures the full state.
|
|
if len(c.tokens) == 0 {
|
|
return nil
|
|
}
|
|
s := &fakeSnapshot{
|
|
tokens: slices.Clone(c.tokens),
|
|
from: 0,
|
|
to: len(c.tokens),
|
|
}
|
|
c.tracker.track(s)
|
|
return s
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
|
if snapshot == nil {
|
|
return target == len(c.tokens) // can only no-op
|
|
}
|
|
s := snapshot.(*fakeSnapshot)
|
|
if target != s.to {
|
|
return false // cumulative state requires exact match
|
|
}
|
|
c.tokens = slices.Clone(s.tokens)
|
|
return true
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
|
// Child supersedes parent for cumulative state.
|
|
if parent != nil {
|
|
parent.Close()
|
|
}
|
|
return child
|
|
}
|
|
|
|
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
|
return nil, snapshot // can't split cumulative state
|
|
}
|
|
|
|
type feedableCache interface {
|
|
cache.Cache
|
|
feed(tokens []int32)
|
|
}
|
|
|
|
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
|
type testEnv struct {
|
|
kvc *kvCache
|
|
caches []cache.Cache // typed references for assertions
|
|
tracker *snapshotTracker
|
|
rewindable bool // true when all caches support arbitrary Restore(nil, target)
|
|
}
|
|
|
|
// newTransformerEnv creates a test environment with a single rewindable cache
|
|
// (pure transformer model).
|
|
func newTransformerEnv() *testEnv {
|
|
tracker := &snapshotTracker{}
|
|
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
|
return &testEnv{
|
|
kvc: &kvCache{caches: caches},
|
|
caches: caches,
|
|
tracker: tracker,
|
|
rewindable: true,
|
|
}
|
|
}
|
|
|
|
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
|
// one sliding window cache (Mistral-style architecture). The sliding window
|
|
// maxSize is set small enough that test sequences fill it, making
|
|
// Restore(nil, target) fail — the same behavior as production models where
|
|
// the window fills after a few turns.
|
|
func newSlidingWindowEnv() *testEnv {
|
|
tr := &snapshotTracker{}
|
|
rc := &fakeRewindableCache{tracker: tr}
|
|
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
|
|
caches := []cache.Cache{rc, sw}
|
|
return &testEnv{
|
|
kvc: &kvCache{caches: caches},
|
|
caches: caches,
|
|
tracker: tr,
|
|
rewindable: false,
|
|
}
|
|
}
|
|
|
|
// newRecurrentEnv creates a test environment with one rewindable cache and one
|
|
// non-rewindable cache (Jamba-style architecture).
|
|
func newRecurrentEnv() *testEnv {
|
|
tr := &snapshotTracker{}
|
|
rc := &fakeRewindableCache{tracker: tr}
|
|
nrc := &fakeRecurrentCache{tracker: tr}
|
|
caches := []cache.Cache{rc, nrc}
|
|
return &testEnv{
|
|
kvc: &kvCache{caches: caches},
|
|
caches: caches,
|
|
tracker: tr,
|
|
rewindable: false,
|
|
}
|
|
}
|
|
|
|
// assertAllTokens checks that every cache in the environment contains exactly
|
|
// the expected token sequence.
|
|
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
|
|
t.Helper()
|
|
for i, c := range e.caches {
|
|
assertTokens(t, label, c, expected)
|
|
// Verify all caches report the same offset.
|
|
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
|
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
|
label, i, c.Offset(), e.caches[0].Offset())
|
|
}
|
|
}
|
|
}
|
|
|
|
// simulateRequest mirrors the production pipeline lifecycle:
|
|
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
|
|
|
type requestResult struct {
|
|
remaining []int32
|
|
pendingSnapshots int
|
|
}
|
|
|
|
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
|
// a user snapshot is requested at that offset during prefill.
|
|
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
|
t.Helper()
|
|
|
|
session := kvc.begin(nil, inputs)
|
|
for _, at := range userSnapshotAt {
|
|
if at > 0 {
|
|
session.requestSnapshot(at)
|
|
}
|
|
}
|
|
|
|
result := requestResult{
|
|
remaining: slices.Clone(session.remaining),
|
|
pendingSnapshots: len(session.pendingSnapshots),
|
|
}
|
|
|
|
assertCacheOffsetAlignment(t, kvc, "after begin")
|
|
|
|
baseOffset := kvc.minCacheOffset()
|
|
remaining := inputs[baseOffset:]
|
|
|
|
// Prefill: feed tokens, pausing at each pending snapshot.
|
|
for len(session.pendingSnapshots) > 0 {
|
|
sp := session.pendingSnapshots[0]
|
|
count := sp.offset - baseOffset
|
|
if count > len(remaining) {
|
|
break
|
|
}
|
|
if count > 0 {
|
|
feedAll(kvc.caches, remaining[:count])
|
|
remaining = remaining[count:]
|
|
baseOffset = sp.offset
|
|
}
|
|
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
|
session.snapshot()
|
|
}
|
|
|
|
// Feed rest of input tokens.
|
|
if len(remaining) > 0 {
|
|
feedAll(kvc.caches, remaining)
|
|
}
|
|
|
|
assertCacheOffsetAlignment(t, kvc, "after prefill")
|
|
|
|
// Generate tokens.
|
|
if len(generated) > 0 {
|
|
session.outputs = generated
|
|
feedAll(kvc.caches, generated)
|
|
}
|
|
|
|
assertCacheOffsetAlignment(t, kvc, "before close")
|
|
session.close()
|
|
return result
|
|
}
|
|
|
|
func feedAll(caches []cache.Cache, tokens []int32) {
|
|
for _, c := range caches {
|
|
if fc, ok := c.(feedableCache); ok {
|
|
fc.feed(tokens)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assertCacheOffsetAlignment verifies all caches report the same offset.
|
|
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
|
t.Helper()
|
|
if len(kvc.caches) < 2 {
|
|
return
|
|
}
|
|
expected := kvc.caches[0].Offset()
|
|
for i := 1; i < len(kvc.caches); i++ {
|
|
if got := kvc.caches[i].Offset(); got != expected {
|
|
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assertTokens checks that a feedable cache contains the expected token sequence.
|
|
// For sliding window caches, only the trailing maxSize tokens are checked.
|
|
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
|
|
t.Helper()
|
|
switch fc := c.(type) {
|
|
case *fakeRewindableCache:
|
|
if !slices.Equal(fc.tokens, expected) {
|
|
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
|
}
|
|
case *fakeSlidingWindowCache:
|
|
// Sliding window stores full history but only trailing maxSize are live.
|
|
// Verify the full token sequence matches (the window semantics are
|
|
// enforced by Snapshot/Restore, not by the token log).
|
|
if !slices.Equal(fc.tokens, expected) {
|
|
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
|
|
}
|
|
case *fakeRecurrentCache:
|
|
if !slices.Equal(fc.tokens, expected) {
|
|
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
|
}
|
|
default:
|
|
t.Fatalf("%s: unknown cache type %T", label, c)
|
|
}
|
|
}
|
|
|
|
// checkTrieInvariants walks the trie and checks structural invariants.
|
|
func checkTrieInvariants(t *testing.T, root *trieNode) {
|
|
t.Helper()
|
|
walkNodes(root, func(n *trieNode) bool {
|
|
if n.parent != nil {
|
|
if n.startOffset() != n.parent.endOffset {
|
|
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
|
|
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
|
|
}
|
|
}
|
|
if len(n.tokens) != n.endOffset-n.startOffset() {
|
|
t.Errorf("node [%d,%d): token count %d != offset span %d",
|
|
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
|
|
}
|
|
for _, c := range n.children {
|
|
if c.parent != n {
|
|
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
|
|
}
|
|
}
|
|
// No two siblings should start with the same token.
|
|
seen := make(map[int32]bool)
|
|
for _, c := range n.children {
|
|
if len(c.tokens) > 0 {
|
|
first := c.tokens[0]
|
|
if seen[first] {
|
|
t.Errorf("node [%d,%d): duplicate sibling first token %d",
|
|
n.startOffset(), n.endOffset, first)
|
|
}
|
|
seen[first] = true
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
|
|
// in the trie (closeCount == 0) or has been closed exactly once. It reports
|
|
// leaked snapshots (not in trie, never closed) and double-closes.
|
|
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
|
|
t.Helper()
|
|
if tracker == nil {
|
|
return
|
|
}
|
|
|
|
// Collect all live snapshots still referenced by trie nodes.
|
|
live := make(map[*fakeSnapshot]bool)
|
|
walkNodes(root, func(n *trieNode) bool {
|
|
for _, s := range n.snapshots {
|
|
if s != nil {
|
|
if fs, ok := s.(*fakeSnapshot); ok {
|
|
live[fs] = true
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
for i, s := range tracker.all {
|
|
if live[s] {
|
|
if s.closeCount != 0 {
|
|
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
|
|
i, s.from, s.to, s.closeCount)
|
|
}
|
|
} else {
|
|
if s.closeCount == 0 {
|
|
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
|
|
i, s.from, s.to)
|
|
} else if s.closeCount > 1 {
|
|
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
|
|
i, s.from, s.to, s.closeCount)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// forEachEnv runs fn as subtests for three realistic model configurations:
|
|
// pure transformer, transformer + sliding window (Mistral-style), and
|
|
// transformer + recurrent (Jamba-style). Leak checking runs automatically
|
|
// at the end of each subtest.
|
|
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
|
|
t.Helper()
|
|
run := func(t *testing.T, env *testEnv) {
|
|
t.Cleanup(func() {
|
|
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
|
|
})
|
|
fn(t, env)
|
|
}
|
|
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
|
|
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
|
|
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
|
|
}
|
|
|
|
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
|
|
// two conversations share a prefix and diverge, creating a branch point.
|
|
// A third conversation extends the first. Verifies trie structure, cache
|
|
// hit lengths, and that semantic caches contain the correct token sequences.
|
|
func TestBranchCreationAndReuse(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
|
|
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
|
|
if len(resA.remaining) != 8 {
|
|
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
|
|
}
|
|
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
|
|
|
// Verify trie was populated by close().
|
|
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
|
if mA != 10 {
|
|
t.Fatalf("A findable: expected 10 matched, got %d", mA)
|
|
}
|
|
|
|
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
|
// For rewindable caches, switchToPath rewinds to the match point
|
|
// so only the non-matching suffix needs evaluation. For non-rewindable
|
|
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
|
if env.rewindable {
|
|
if resB.pendingSnapshots != 0 {
|
|
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
|
}
|
|
if len(resB.remaining) != 3 {
|
|
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
|
}
|
|
} else {
|
|
if resB.pendingSnapshots != 1 {
|
|
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
|
}
|
|
if len(resB.remaining) != 8 {
|
|
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
|
}
|
|
}
|
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
|
|
|
// Both A and B should be findable in the trie.
|
|
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
|
if mA2 < 5 {
|
|
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
|
|
}
|
|
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
|
if mB < 5 {
|
|
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
|
|
}
|
|
|
|
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
|
|
// Should get a cache hit for the shared prefix.
|
|
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
|
|
if len(resC.remaining) >= 10 {
|
|
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
|
|
}
|
|
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
|
|
// same prompt is requested twice, the cache does not overclaim cached work.
|
|
// The last token must be re-evaluated to seed generation.
|
|
func TestExactMatchSeedBehavior(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: first time.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
|
|
|
// Request B: identical prompt. Holdback means matched=4, partial in
|
|
// the 5-token edge. For rewindable caches, switchToPath rewinds to
|
|
// offset 4, so only the held-back token needs re-evaluation. For
|
|
// non-rewindable caches, the rewind fails and freeAll fires.
|
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
|
if env.rewindable {
|
|
if len(resB.remaining) != 1 {
|
|
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
|
}
|
|
if resB.pendingSnapshots != 0 {
|
|
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
|
}
|
|
} else {
|
|
if len(resB.remaining) != 5 {
|
|
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
|
}
|
|
if resB.pendingSnapshots != 1 {
|
|
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
|
}
|
|
}
|
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestConversationResumption tests the most common pattern: user sends a message,
|
|
// gets a response, then sends a follow-up. The follow-up should reuse the cached
|
|
// prefix (system prompt + first turn + assistant response).
|
|
func TestConversationResumption(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Turn 1: system prompt + user message, assistant generates response.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
|
|
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
|
|
|
|
// Turn 2: full history + new user message. Should get a cache hit on
|
|
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
|
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
|
|
if len(resB.remaining) > 5 {
|
|
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
|
|
}
|
|
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
|
|
|
|
// Turn 3: even longer history.
|
|
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
|
|
if len(resC.remaining) > 5 {
|
|
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
|
|
}
|
|
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
|
|
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
|
|
// active path and shared prefix survive while memory stays bounded.
|
|
func TestEvictionPreservesActiveConversations(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
systemPrompt := []int32{1, 2, 3, 4, 5}
|
|
|
|
// Create 5 conversations with unique suffixes.
|
|
for i := range 5 {
|
|
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
|
|
inputs := append(slices.Clone(systemPrompt), suffix...)
|
|
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
|
|
}
|
|
|
|
// Inflate snapshot sizes to trigger eviction.
|
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
|
if !n.hasSnapshots() {
|
|
return true
|
|
}
|
|
snaps := make([]cache.Snapshot, len(n.snapshots))
|
|
for i, s := range n.snapshots {
|
|
if s != nil {
|
|
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
|
|
}
|
|
}
|
|
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
|
return true
|
|
})
|
|
|
|
// Run eviction.
|
|
kvc.enforceEvictionPolicy()
|
|
|
|
// Memory should be within limits.
|
|
if kvc.pagedOutBytes > maxPagedOutBytes {
|
|
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
|
|
}
|
|
|
|
// Active path should be untouched.
|
|
if len(kvc.activePath) < 2 {
|
|
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
|
|
}
|
|
|
|
// System prompt prefix should still be findable (multi-child
|
|
// branch points are protected from eviction entirely).
|
|
_, matched := findBestMatch(kvc.root, systemPrompt)
|
|
if matched < len(systemPrompt) {
|
|
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
|
|
}
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
|
|
// (snapshot(true)) resist structural changes that would destroy them:
|
|
// - A user node forces new tokens into a child instead of extending in-place
|
|
// - The snapshot remains restorable after other branches are added
|
|
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: user snapshot at offset 5, then generate.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
|
|
|
|
assertUserNodeExists(t, kvc, "after A")
|
|
|
|
// Request B: extends A's prefix. The user node at offset 5 should
|
|
// force tokens into a child rather than extending in-place.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
|
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
|
|
assertUserNodeExists(t, kvc, "after B")
|
|
|
|
// Request C: diverge from the user node.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
|
|
|
|
// Request D: switch back to A's branch — user snapshot still restorable.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
|
|
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
|
|
// a user-marked parent node is not auto-merged with its remaining single child.
|
|
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: user snapshot at offset 3, then continue to offset 5.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
|
|
|
|
// Request B: diverges at the user node, creating a second child.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
|
|
|
|
userNode := findUserNode(t, kvc)
|
|
if len(userNode.children) != 2 {
|
|
t.Fatalf("user node children = %d, want 2", len(userNode.children))
|
|
}
|
|
|
|
// Inflate snapshot sizes and evict. The non-active branch should be
|
|
// evicted, leaving the user node with one child.
|
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
|
if !n.hasSnapshots() {
|
|
return true
|
|
}
|
|
snaps := make([]cache.Snapshot, len(n.snapshots))
|
|
for i, s := range n.snapshots {
|
|
if s != nil {
|
|
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
|
|
}
|
|
}
|
|
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
|
return true
|
|
})
|
|
kvc.enforceEvictionPolicy()
|
|
|
|
// The user node should still exist (not auto-merged) even with one child.
|
|
assertUserNodeExists(t, kvc, "after eviction")
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
|
|
t.Helper()
|
|
var found *trieNode
|
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
|
if n.user {
|
|
found = n
|
|
}
|
|
return true
|
|
})
|
|
if found == nil {
|
|
t.Fatal("no user-marked node found")
|
|
}
|
|
return found
|
|
}
|
|
|
|
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
|
|
t.Helper()
|
|
var exists bool
|
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
|
if n.user {
|
|
exists = true
|
|
}
|
|
return true
|
|
})
|
|
if !exists {
|
|
t.Fatalf("%s: no user-marked node found", label)
|
|
}
|
|
}
|
|
|
|
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
|
|
// branch after working on a different one, verifying that the restored cache
|
|
// state contains the correct token sequence for both rewindable and
|
|
// non-rewindable caches.
|
|
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: [1,2,3,4,5] + generate [10,11]
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
|
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
|
|
|
|
// Request B: [1,2,3,6,7] — diverges at token 4
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
|
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
|
|
|
|
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
|
|
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|
|
|
|
// TestLRUOnlyUpdatesUsedNodes verifies that intermediate nodes on the active
|
|
// path whose snapshots were not actually restored don't get their lastUsed
|
|
// refreshed, allowing them to age out and collapse.
|
|
func TestLRUOnlyUpdatesUsedNodes(t *testing.T) {
|
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
|
kvc := env.kvc
|
|
|
|
// Request A: creates path [1,2,3,4,5] + generate [10,11]
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
|
|
|
// Request B: diverges at token 4, creating a branch point at offset 3
|
|
// with a split snapshot.
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20, 21})
|
|
|
|
// Set all lastUsed to a known old time.
|
|
oldTime := time.Now().Add(-1 * time.Hour)
|
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
|
n.lastUsed = oldTime
|
|
return true
|
|
})
|
|
|
|
// Request C: continue on B's branch. This will match B's path
|
|
// and extend it. The branch point's snapshot may be paged in
|
|
// for some cache types but not others.
|
|
beforeRequest := time.Now()
|
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7, 20, 21, 30}, nil)
|
|
|
|
// The path must have enough depth to exercise intermediate nodes.
|
|
if len(kvc.activePath) < 3 {
|
|
t.Fatalf("activePath too short to test intermediate nodes: got %d nodes", len(kvc.activePath))
|
|
}
|
|
|
|
// The frontier (deepest node on the active path) must be updated.
|
|
frontier := kvc.activePath[len(kvc.activePath)-1]
|
|
if frontier.lastUsed.Before(beforeRequest) {
|
|
t.Errorf("frontier lastUsed was not updated: got %v, want >= %v",
|
|
frontier.lastUsed, beforeRequest)
|
|
}
|
|
|
|
// Every non-frontier node on the active path (including root)
|
|
// should retain its old lastUsed — only the frontier gets refreshed.
|
|
for i, node := range kvc.activePath[:len(kvc.activePath)-1] {
|
|
if !node.lastUsed.Before(beforeRequest) {
|
|
t.Errorf("activePath[%d] (endOffset=%d) lastUsed was refreshed: got %v, want < %v",
|
|
i, node.endOffset, node.lastUsed, beforeRequest)
|
|
}
|
|
}
|
|
|
|
checkTrieInvariants(t, kvc.root)
|
|
})
|
|
}
|