diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index faabcc51f..c1c53b668 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -20,6 +20,7 @@ import ( "cmp" "fmt" "log/slog" + "slices" "time" "github.com/ollama/ollama/logutil" @@ -37,6 +38,12 @@ type kvCache struct { pagedOutBytes int64 // total bytes in paged-out snapshots across the trie } +// pendingSnapshot is a snapshot scheduled to be taken during prefill. +type pendingSnapshot struct { + offset int + user bool +} + // cacheSession manages caches for a single pipeline run. // Callers should append generated tokens to outputs and // defer close to save the cache state. @@ -48,11 +55,10 @@ type cacheSession struct { caches []cache.Cache remaining []int32 - // snapshotOffset, if > 0, is a trie node boundary where we need to - // capture a snapshot during prefill. This enables future requests - // branching at this point to restore non-rewindable caches (e.g. - // RecurrentCache) instead of re-evaluating from scratch. - snapshotOffset int + // pendingSnapshots lists offsets where snapshots should be captured + // during prefill, sorted by offset. Entries are consumed as the + // cache advances past them. + pendingSnapshots []pendingSnapshot } func (c *kvCache) ensureCaches(m base.Model) { @@ -100,31 +106,26 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { prefix := c.minCacheOffset() remaining := inputs[prefix:] + session := &cacheSession{ + cache: c, + inputs: inputs, + caches: c.caches, + remaining: remaining, + } + // Schedule a snapshot at the branch point during prefill so future // requests diverging here can restore instead of re-evaluating. - var snapshotAt int if prefix < matched { - snapshotAt = matched + session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false}) } - args := []any{"total", len(inputs), "matched", originalMatched} - args = append(args, "cached", prefix, "left", len(remaining)) - if snapshotAt > 0 { - args = append(args, "pending_snapshot", snapshotAt) - } + msg := "cache hit" if prefix == 0 { - slog.Info("cache miss", args...) - } else { - slog.Info("cache hit", args...) + msg = "cache miss" } + slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining)) - return &cacheSession{ - cache: c, - inputs: inputs, - snapshotOffset: snapshotAt, - caches: c.caches, - remaining: remaining, - } + return session } // switchToPath transitions from the current active path to a new path, @@ -250,20 +251,54 @@ pageIn: } } -// snapshot creates a snapshot at the current cache position. During prefill, -// it is called at branch points (user=false) to create restore points for -// future diverging requests and with user=true to mark an explicit reusable -// restore point. -func (s *cacheSession) snapshot(user bool) { +// requestSnapshot schedules a user snapshot at the given absolute token +// offset. The snapshot will be captured during prefill when the cache +// reaches this offset. +func (s *cacheSession) requestSnapshot(offset int) { + baseOffset := len(s.inputs) - len(s.remaining) + if offset <= baseOffset || offset > len(s.inputs) { + return + } + // Deduplicate: if this offset already exists, upgrade to user. + for i := range s.pendingSnapshots { + if s.pendingSnapshots[i].offset == offset { + s.pendingSnapshots[i].user = true + return + } + } + s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true}) + slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int { + return a.offset - b.offset + }) +} + +// nextPendingSnapshot returns the offset of the next pending snapshot, +// or 0 if there are none. +func (s *cacheSession) nextPendingSnapshot() int { + if len(s.pendingSnapshots) == 0 { + return 0 + } + return s.pendingSnapshots[0].offset +} + +// snapshot creates a snapshot at the current cache position. It determines +// whether this is a user snapshot by consuming pending entries whose offset +// has been reached. +func (s *cacheSession) snapshot() { c := s.cache cacheOffset := c.minCacheOffset() if cacheOffset <= 0 { return } - // Clear pending intermediate snapshot if we've reached or passed it. - if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset { - s.snapshotOffset = 0 + // Consume pending snapshots up to the current offset and derive + // the user flag from them. + user := false + for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset { + if s.pendingSnapshots[0].user { + user = true + } + s.pendingSnapshots = s.pendingSnapshots[1:] } // The last node in activePath is the frontier where caches are advancing. diff --git a/x/mlxrunner/cache_test.go b/x/mlxrunner/cache_test.go index 6c691c17b..fba0d4fbd 100644 --- a/x/mlxrunner/cache_test.go +++ b/x/mlxrunner/cache_test.go @@ -377,24 +377,25 @@ func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) // begin -> prefill with snapshot(false) at branch points -> generate -> close type requestResult struct { - remaining []int32 - snapshotOffset int + remaining []int32 + pendingSnapshots int } // simulateRequest runs a request through the harness. If userSnapshotAt > 0, -// a user snapshot (snapshot(true)) is created at that offset during prefill. +// a user snapshot is requested at that offset during prefill. func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult { t.Helper() - userSnapAt := 0 - if len(userSnapshotAt) > 0 { - userSnapAt = userSnapshotAt[0] + session := kvc.begin(nil, inputs) + for _, at := range userSnapshotAt { + if at > 0 { + session.requestSnapshot(at) + } } - session := kvc.begin(nil, inputs) result := requestResult{ - remaining: slices.Clone(session.remaining), - snapshotOffset: session.snapshotOffset, + remaining: slices.Clone(session.remaining), + pendingSnapshots: len(session.pendingSnapshots), } assertCacheOffsetAlignment(t, kvc, "after begin") @@ -402,22 +403,9 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user baseOffset := kvc.minCacheOffset() remaining := inputs[baseOffset:] - // Collect snapshot points (offset -> user flag) in ascending order. - type snapPoint struct { - offset int - user bool - } - var points []snapPoint - if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset { - points = append(points, snapPoint{session.snapshotOffset, false}) - } - if userSnapAt > 0 && userSnapAt > baseOffset { - points = append(points, snapPoint{userSnapAt, true}) - } - slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset }) - - // Prefill: feed tokens, pausing at each snapshot point. - for _, sp := range points { + // Prefill: feed tokens, pausing at each pending snapshot. + for len(session.pendingSnapshots) > 0 { + sp := session.pendingSnapshots[0] count := sp.offset - baseOffset if count > len(remaining) { break @@ -428,7 +416,7 @@ func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, user baseOffset = sp.offset } assertCacheOffsetAlignment(t, kvc, "at snapshot point") - session.snapshot(sp.user) + session.snapshot() } // Feed rest of input tokens. @@ -615,15 +603,15 @@ func TestBranchCreationAndReuse(t *testing.T) { // caches (RecurrentCache), the rewind fails and freeAll fires. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) if env.rewindable { - if resB.snapshotOffset != 0 { - t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) + if resB.pendingSnapshots != 0 { + t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots) } if len(resB.remaining) != 3 { t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining)) } } else { - if resB.snapshotOffset != 5 { - t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) + if resB.pendingSnapshots != 1 { + t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots) } if len(resB.remaining) != 8 { t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining)) @@ -672,15 +660,15 @@ func TestExactMatchSeedBehavior(t *testing.T) { if len(resB.remaining) != 1 { t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining)) } - if resB.snapshotOffset != 0 { - t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) + if resB.pendingSnapshots != 0 { + t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots) } } else { if len(resB.remaining) != 5 { t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining)) } - if resB.snapshotOffset != 4 { - t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) + if resB.pendingSnapshots != 1 { + t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots) } } env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index ea7e12a30..d98d25ccd 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -79,10 +79,24 @@ func (r *Runner) TextGenerationPipeline(request Request) error { session := r.cache.begin(r.Model, inputs) defer session.close() + caches := session.caches tokens := session.remaining prefillChunk := prefillChunkSize() + // Request periodic snapshots during prefill and near the end of the + // prompt so that long prompts can be partially restored and + // thinking/generation can be retried without full reprocessing. + const snapshotInterval = 8192 + for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval { + session.requestSnapshot(offset) + } + + const preThinking = 4 + if end := len(inputs) - preThinking; end > 0 { + session.requestSnapshot(end) + } + materializeCaches := func() { state := make([]*mlx.Array, 0, 2*len(caches)) for _, c := range caches { @@ -103,12 +117,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { n := min(prefillChunk, total-processed-1) - // If there's a pending intermediate snapshot, split the batch - // so we can capture it at the exact offset. The cache offset - // after this batch will be: baseOffset + processed + n. - if session.snapshotOffset > 0 { + // If there's a pending snapshot, split the batch so we can + // capture it at the exact offset. + if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { baseOffset := len(session.inputs) - len(tokens) - tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed) + tokensUntilSnapshot := snapOffset - (baseOffset + processed) if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n { n = tokensUntilSnapshot } @@ -120,11 +133,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { processed += n slog.Info("Prompt processing progress", "processed", processed, "total", total) - // Create snapshot at branch point for future diverging requests. - if session.snapshotOffset > 0 { + // Create snapshot if we've reached a pending offset. + if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { baseOffset := len(session.inputs) - len(tokens) - if baseOffset+processed >= session.snapshotOffset { - session.snapshot(false) + if baseOffset+processed >= snapOffset { + session.snapshot() } }