package mlxrunner import ( "slices" "testing" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" ) // snapshotTracker records every fakeSnapshot created and every Close() call // so tests can detect leaked (created but never closed) or double-closed snapshots. type snapshotTracker struct { all []*fakeSnapshot } func (tr *snapshotTracker) track(s *fakeSnapshot) { if s == nil { return } s.tracker = tr tr.all = append(tr.all, s) } // Fake caches that store actual token sequences so tests can verify the right // data was restored, not just the right offset. // fakeSnapshot stores a copy of the token sub-sequence it covers. type fakeSnapshot struct { tokens []int32 from, to int byteSize int // configurable for eviction tests tracker *snapshotTracker closeCount int } func (s *fakeSnapshot) Size() int { return s.byteSize } func (s *fakeSnapshot) Close() { s.closeCount++ } // fakeRewindableCache tracks the full token sequence and supports // arbitrary rewind via Restore(nil, target). type fakeRewindableCache struct { tokens []int32 tracker *snapshotTracker } func (c *fakeRewindableCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return nil, nil } func (c *fakeRewindableCache) State() []*mlx.Array { return nil } func (c *fakeRewindableCache) Offset() int { return len(c.tokens) } func (c *fakeRewindableCache) Free() { c.tokens = nil } func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot { if fromOffset >= len(c.tokens) { return nil } from := fromOffset if from < 0 { from = 0 } s := &fakeSnapshot{ tokens: slices.Clone(c.tokens[from:]), from: from, to: len(c.tokens), } c.tracker.track(s) return s } func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool { if snapshot == nil { // Rewind live state. if target < 0 { target = 0 } if target > len(c.tokens) { target = len(c.tokens) } c.tokens = c.tokens[:target] return true } s := snapshot.(*fakeSnapshot) if len(c.tokens) < s.from { return false // don't have base data up to snapshot start } c.tokens = append(c.tokens[:s.from], s.tokens...) if target < len(c.tokens) { c.tokens = c.tokens[:target] } return true } func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot { if parent == nil || child == nil { if parent != nil { parent.Close() } if child != nil { child.Close() } return nil } p := parent.(*fakeSnapshot) ch := child.(*fakeSnapshot) merged := make([]int32, len(p.tokens)+len(ch.tokens)) copy(merged, p.tokens) copy(merged[len(p.tokens):], ch.tokens) s := &fakeSnapshot{ tokens: merged, from: p.from, to: ch.to, byteSize: p.byteSize + ch.byteSize, } c.tracker.track(s) p.Close() ch.Close() return s } func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { if snapshot == nil { return nil, nil } s := snapshot.(*fakeSnapshot) relAt := at - s.from if relAt <= 0 { return nil, snapshot } if relAt >= len(s.tokens) { return snapshot, nil } p := &fakeSnapshot{ tokens: slices.Clone(s.tokens[:relAt]), from: s.from, to: at, byteSize: s.byteSize, } ch := &fakeSnapshot{ tokens: slices.Clone(s.tokens[relAt:]), from: at, to: s.to, byteSize: s.byteSize, } c.tracker.track(p) c.tracker.track(ch) s.Close() return p, ch } // fakeSlidingWindowCache models RotatingKVCache semantics: stores the full // token sequence but only the trailing maxSize tokens are "live" in the window. // Once the window fills, live rewind is impossible without a snapshot. type fakeSlidingWindowCache struct { tokens []int32 maxSize int tracker *snapshotTracker } func (c *fakeSlidingWindowCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return nil, nil } func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil } func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) } func (c *fakeSlidingWindowCache) Free() { c.tokens = nil } func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot { if len(c.tokens) == 0 || len(c.tokens) <= fromOffset { return nil } // Snapshot captures the full window state (like RotatingKVCache.Snapshot). s := &fakeSnapshot{ tokens: slices.Clone(c.tokens), from: 0, to: len(c.tokens), } c.tracker.track(s) return s } func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool { if snapshot == nil { if target == len(c.tokens) { return true } // Live rewind only works when buffer hasn't filled (offset <= maxSize). if len(c.tokens) > c.maxSize { return false } c.tokens = c.tokens[:target] return true } s := snapshot.(*fakeSnapshot) c.tokens = slices.Clone(s.tokens) if target < len(c.tokens) { c.tokens = c.tokens[:target] } return true } func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot { // Child supersedes parent for sliding window (full window state). if parent != nil { parent.Close() } return child } func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { // Can't split a ring buffer at an arbitrary point. return nil, snapshot } // fakeRecurrentCache models RecurrentCache semantics: stores tokens // but cannot rewind without a snapshot. type fakeRecurrentCache struct { tokens []int32 tracker *snapshotTracker } func (c *fakeRecurrentCache) feed(tokens []int32) { c.tokens = append(c.tokens, tokens...) } func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { return nil, nil } func (c *fakeRecurrentCache) State() []*mlx.Array { return nil } func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) } func (c *fakeRecurrentCache) Free() { c.tokens = nil } func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot { // Recurrent state is cumulative; snapshot captures the full state. if len(c.tokens) == 0 { return nil } s := &fakeSnapshot{ tokens: slices.Clone(c.tokens), from: 0, to: len(c.tokens), } c.tracker.track(s) return s } func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool { if snapshot == nil { return target == len(c.tokens) // can only no-op } s := snapshot.(*fakeSnapshot) if target < s.to { return false // can't go backward } c.tokens = slices.Clone(s.tokens) return true } func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot { // Child supersedes parent for cumulative state. if parent != nil { parent.Close() } return child } func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) { return nil, snapshot // can't split cumulative state } type feedableCache interface { cache.Cache feed(tokens []int32) } // testEnv encapsulates a kvCache and its fake caches for a test scenario. type testEnv struct { kvc *kvCache caches []cache.Cache // typed references for assertions tracker *snapshotTracker } // newTransformerEnv creates a test environment with a single rewindable cache // (pure transformer model). func newTransformerEnv() *testEnv { tracker := &snapshotTracker{} caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}} return &testEnv{ kvc: &kvCache{caches: caches}, caches: caches, tracker: tracker, } } // newSlidingWindowEnv creates a test environment with one rewindable cache and // one sliding window cache (Mistral-style architecture). func newSlidingWindowEnv() *testEnv { tr := &snapshotTracker{} rc := &fakeRewindableCache{tracker: tr} sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr} caches := []cache.Cache{rc, sw} return &testEnv{ kvc: &kvCache{caches: caches}, caches: caches, tracker: tr, } } // newRecurrentEnv creates a test environment with one rewindable cache and one // non-rewindable cache (Jamba-style architecture). func newRecurrentEnv() *testEnv { tr := &snapshotTracker{} rc := &fakeRewindableCache{tracker: tr} nrc := &fakeRecurrentCache{tracker: tr} caches := []cache.Cache{rc, nrc} return &testEnv{ kvc: &kvCache{caches: caches}, caches: caches, tracker: tr, } } // assertAllTokens checks that every cache in the environment contains exactly // the expected token sequence. func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) { t.Helper() for i, c := range e.caches { assertTokens(t, label, c, expected) // Verify all caches report the same offset. if i > 0 && c.Offset() != e.caches[0].Offset() { t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, c.Offset(), e.caches[0].Offset()) } } } // simulateRequest mirrors the production pipeline lifecycle: // begin -> prefill with snapshot(false) at branch points -> generate -> close type requestResult struct { remaining []int32 snapshotOffset int } // simulateRequest runs a request through the harness. If userSnapshotAt > 0, // a user snapshot (snapshot(true)) is created at that offset during prefill. func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult { t.Helper() userSnapAt := 0 if len(userSnapshotAt) > 0 { userSnapAt = userSnapshotAt[0] } session := kvc.begin(nil, inputs) result := requestResult{ remaining: slices.Clone(session.remaining), snapshotOffset: session.snapshotOffset, } assertCacheOffsetAlignment(t, kvc, "after begin") baseOffset := kvc.minCacheOffset() remaining := inputs[baseOffset:] // Collect snapshot points (offset -> user flag) in ascending order. type snapPoint struct { offset int user bool } var points []snapPoint if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset { points = append(points, snapPoint{session.snapshotOffset, false}) } if userSnapAt > 0 && userSnapAt > baseOffset { points = append(points, snapPoint{userSnapAt, true}) } slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset }) // Prefill: feed tokens, pausing at each snapshot point. for _, sp := range points { count := sp.offset - baseOffset if count > len(remaining) { break } if count > 0 { feedAll(kvc.caches, remaining[:count]) remaining = remaining[count:] baseOffset = sp.offset } assertCacheOffsetAlignment(t, kvc, "at snapshot point") session.snapshot(sp.user) } // Feed rest of input tokens. if len(remaining) > 0 { feedAll(kvc.caches, remaining) } assertCacheOffsetAlignment(t, kvc, "after prefill") // Generate tokens. if len(generated) > 0 { session.outputs = generated feedAll(kvc.caches, generated) } assertCacheOffsetAlignment(t, kvc, "before close") session.close() return result } func feedAll(caches []cache.Cache, tokens []int32) { for _, c := range caches { if fc, ok := c.(feedableCache); ok { fc.feed(tokens) } } } // assertCacheOffsetAlignment verifies all caches report the same offset. func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) { t.Helper() if len(kvc.caches) < 2 { return } expected := kvc.caches[0].Offset() for i := 1; i < len(kvc.caches); i++ { if got := kvc.caches[i].Offset(); got != expected { t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected) } } } // assertTokens checks that a feedable cache contains the expected token sequence. // For sliding window caches, only the trailing maxSize tokens are checked. func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) { t.Helper() switch fc := c.(type) { case *fakeRewindableCache: if !slices.Equal(fc.tokens, expected) { t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected) } case *fakeSlidingWindowCache: // Sliding window stores full history but only trailing maxSize are live. // Verify the full token sequence matches (the window semantics are // enforced by Snapshot/Restore, not by the token log). if !slices.Equal(fc.tokens, expected) { t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected) } case *fakeRecurrentCache: if !slices.Equal(fc.tokens, expected) { t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected) } default: t.Fatalf("%s: unknown cache type %T", label, c) } } // checkTrieInvariants walks the trie and checks structural invariants. func checkTrieInvariants(t *testing.T, root *trieNode) { t.Helper() walkNodes(root, func(n *trieNode) bool { if n.parent != nil { if n.startOffset() != n.parent.endOffset { t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d", n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset) } } if len(n.tokens) != n.endOffset-n.startOffset() { t.Errorf("node [%d,%d): token count %d != offset span %d", n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset()) } for _, c := range n.children { if c.parent != n { t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset) } } // No two siblings should start with the same token. seen := make(map[int32]bool) for _, c := range n.children { if len(c.tokens) > 0 { first := c.tokens[0] if seen[first] { t.Errorf("node [%d,%d): duplicate sibling first token %d", n.startOffset(), n.endOffset, first) } seen[first] = true } } return true }) } // checkSnapshotLeaks verifies that every tracked snapshot is either still live // in the trie (closeCount == 0) or has been closed exactly once. It reports // leaked snapshots (not in trie, never closed) and double-closes. func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) { t.Helper() if tracker == nil { return } // Collect all live snapshots still referenced by trie nodes. live := make(map[*fakeSnapshot]bool) walkNodes(root, func(n *trieNode) bool { for _, s := range n.snapshots { if s != nil { if fs, ok := s.(*fakeSnapshot); ok { live[fs] = true } } } return true }) for i, s := range tracker.all { if live[s] { if s.closeCount != 0 { t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)", i, s.from, s.to, s.closeCount) } } else { if s.closeCount == 0 { t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie", i, s.from, s.to) } else if s.closeCount > 1 { t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times", i, s.from, s.to, s.closeCount) } } } } // forEachEnv runs fn as subtests for three realistic model configurations: // pure transformer, transformer + sliding window (Mistral-style), and // transformer + recurrent (Jamba-style). Leak checking runs automatically // at the end of each subtest. func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) { t.Helper() run := func(t *testing.T, env *testEnv) { t.Cleanup(func() { checkSnapshotLeaks(t, env.tracker, env.kvc.root) }) fn(t, env) } t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) }) t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) }) t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) }) } // TestBranchCreationAndReuse exercises the core multi-conversation lifecycle: // two conversations share a prefix and diverge, creating a branch point. // A third conversation extends the first. Verifies trie structure, cache // hit lengths, and that semantic caches contain the correct token sequences. func TestBranchCreationAndReuse(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss. resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21}) if len(resA.remaining) != 8 { t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining)) } env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) // Verify trie was populated by close(). _, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) if mA != 10 { t.Fatalf("A findable: expected 10 matched, got %d", mA) } // Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A. // Partial match in A's edge triggers snapshotOffset. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) if resB.snapshotOffset != 5 { t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) } // Cache was rewound to 0 (partial match truncates path to root), // so all tokens were re-evaluated. if len(resB.remaining) != 8 { t.Fatalf("B: remaining = %d, want 8", len(resB.remaining)) } env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) // Both A and B should be findable in the trie. _, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21}) if mA2 < 5 { t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2) } _, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) if mB < 5 { t.Fatalf("B findable: expected >= 5 matched, got %d", mB) } // Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix. // Should get a cache hit for the shared prefix. resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil) if len(resC.remaining) >= 10 { t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining)) } env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}) checkTrieInvariants(t, kvc.root) }) } // TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact // same prompt is requested twice, the cache does not overclaim cached work. // The last token must be re-evaluated to seed generation. func TestExactMatchSeedBehavior(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Request A: first time. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) // Request B: identical prompt. Holdback means matched=4, partial in // the 5-token edge, so path truncates to root and all tokens are // re-evaluated. snapshotOffset should be set at the holdback point. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21}) if len(resB.remaining) != 5 { t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining)) } if resB.snapshotOffset != 4 { t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) } env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) checkTrieInvariants(t, kvc.root) }) } // TestConversationResumption tests the most common pattern: user sends a message, // gets a response, then sends a follow-up. The follow-up should reuse the cached // prefix (system prompt + first turn + assistant response). func TestConversationResumption(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Turn 1: system prompt + user message, assistant generates response. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12}) env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12}) // Turn 2: full history + new user message. Should get a cache hit on // the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21]. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30}) if len(resB.remaining) > 5 { t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining)) } env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30}) // Turn 3: even longer history. resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil) if len(resC.remaining) > 5 { t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining)) } env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}) checkTrieInvariants(t, kvc.root) }) } // TestEvictionPreservesActiveConversations creates multiple conversations sharing // a system prompt, triggers eviction via large snapshot sizes, and verifies the // active path and shared prefix survive while memory stays bounded. func TestEvictionPreservesActiveConversations(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc systemPrompt := []int32{1, 2, 3, 4, 5} // Create 5 conversations with unique suffixes. for i := range 5 { suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)} inputs := append(slices.Clone(systemPrompt), suffix...) simulateRequest(t, kvc, inputs, []int32{int32(200 + i)}) } // Inflate snapshot sizes to trigger eviction. walkNodes(kvc.root, func(n *trieNode) bool { if !n.hasSnapshots() { return true } snaps := make([]cache.Snapshot, len(n.snapshots)) for i, s := range n.snapshots { if s != nil { snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot } } n.setSnapshots(snaps, &kvc.pagedOutBytes) return true }) // Run eviction. kvc.enforceEvictionPolicy() // Memory should be within limits. if kvc.pagedOutBytes > maxPagedOutBytes { t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes) } // Active path should be untouched. if len(kvc.activePath) < 2 { t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath)) } // System prompt prefix should still be findable (evicting a // multi-child branch point only drops snapshots, not the node). _, matched := findBestMatch(kvc.root, systemPrompt) if matched < len(systemPrompt) { t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt)) } checkTrieInvariants(t, kvc.root) }) } // TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots // (snapshot(true)) resist structural changes that would destroy them: // - A user node forces new tokens into a child instead of extending in-place // - The snapshot remains restorable after other branches are added func TestUserSnapshotPreservesRestorePoint(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Request A: user snapshot at offset 5, then generate. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5) assertUserNodeExists(t, kvc, "after A") // Request B: extends A's prefix. The user node at offset 5 should // force tokens into a child rather than extending in-place. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}) assertUserNodeExists(t, kvc, "after B") // Request C: diverge from the user node. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40}) // Request D: switch back to A's branch — user snapshot still restorable. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil) env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}) checkTrieInvariants(t, kvc.root) }) } // TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted, // a user-marked parent node is not auto-merged with its remaining single child. func TestUserSnapshotResistsAutoMerge(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Request A: user snapshot at offset 3, then continue to offset 5. simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3) // Request B: diverges at the user node, creating a second child. simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20}) userNode := findUserNode(t, kvc) if len(userNode.children) != 2 { t.Fatalf("user node children = %d, want 2", len(userNode.children)) } // Inflate snapshot sizes and evict. The non-active branch should be // evicted, leaving the user node with one child. walkNodes(kvc.root, func(n *trieNode) bool { if !n.hasSnapshots() { return true } snaps := make([]cache.Snapshot, len(n.snapshots)) for i, s := range n.snapshots { if s != nil { snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024} } } n.setSnapshots(snaps, &kvc.pagedOutBytes) return true }) kvc.enforceEvictionPolicy() // The user node should still exist (not auto-merged) even with one child. assertUserNodeExists(t, kvc, "after eviction") checkTrieInvariants(t, kvc.root) }) } func findUserNode(t *testing.T, kvc *kvCache) *trieNode { t.Helper() var found *trieNode walkNodes(kvc.root, func(n *trieNode) bool { if n.user { found = n } return true }) if found == nil { t.Fatal("no user-marked node found") } return found } func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) { t.Helper() var exists bool walkNodes(kvc.root, func(n *trieNode) bool { if n.user { exists = true } return true }) if !exists { t.Fatalf("%s: no user-marked node found", label) } } // TestBranchSwitchRestoresCorrectState exercises switching back to an older // branch after working on a different one, verifying that the restored cache // state contains the correct token sequence for both rewindable and // non-rewindable caches. func TestBranchSwitchRestoresCorrectState(t *testing.T) { forEachEnv(t, func(t *testing.T, env *testEnv) { kvc := env.kvc // Request A: [1,2,3,4,5] + generate [10,11] simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11}) // Request B: [1,2,3,6,7] — diverges at token 4 simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13}) // Request C: switch back to A's branch [1,2,3,4,5,10,11,20] simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil) env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20}) checkTrieInvariants(t, kvc.root) }) }