From 4e57d2094e127fcb32ed40d68289da7e41a83264 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 24 Feb 2026 14:19:12 -0800 Subject: [PATCH] mlxrunner: Simplify pipeline memory and cache management Particularly in error cases, it can be difficult to ensure that all pinned memory is unpinned, MLX buffers are released and cache state is consistent. This encapsulates those pieces and sets up proper deferrals so that this happens automatically on exit. --- x/mlxrunner/cache.go | 110 +++++++++++++++++++++++++++------------- x/mlxrunner/pipeline.go | 57 ++++++++++----------- x/mlxrunner/runner.go | 2 +- 3 files changed, 103 insertions(+), 66 deletions(-) diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 750d556b4..0d858d91b 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -9,59 +9,99 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" ) -// CacheEntry stores a single sequence -type CacheEntry struct { - Tokens []int32 - Caches []cache.Cache +type kvCache struct { + // For now we only support a single entry, so this is just one sequence + tokens []int32 + caches []cache.Cache } -// FindNearestCache finds the longest common prefix between tokens and the cached sequence -func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { - if r.cache == nil { - slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens +// cacheSession manages caches for a single pipeline run. +// Callers should append generated tokens to outputs and +// defer close to save the cache state. +type cacheSession struct { + cache *kvCache + inputs []int32 + outputs []int32 + + caches []cache.Cache + remaining []int32 +} + +// begin prepares caches for a new request. It finds the nearest +// matching cache or creates new caches if none match. +func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { + if len(c.caches) == 0 { + if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { + c.caches = cacheFactory.NewCaches() + } else { + c.caches = make([]cache.Cache, m.NumLayers()) + for i := range c.caches { + c.caches[i] = cache.NewKVCache() + } + } } - // Find longest common prefix + remaining := c.findRemaining(inputs) + + return &cacheSession{ + cache: c, + inputs: inputs, + caches: c.caches, + remaining: remaining, + } +} + +// close saves the token state if the forward pass ran. +func (s *cacheSession) close() { + if offset := s.caches[0].Offset(); offset > 0 { + // Ensure that if we have run the forward pass and set the metadata + // that we also actually have the data + arrays := make([]*mlx.Array, 0, 2*len(s.caches)) + for _, c := range s.caches { + k, v := c.State() + arrays = append(arrays, k, v) + } + mlx.AsyncEval(arrays...) + + s.cache.tokens = append(s.inputs, s.outputs...)[:offset] + } +} + +// findRemaining finds the longest common prefix between tokens and the cached +// sequence, trims stale cache entries, and returns the remaining tokens. +func (c *kvCache) findRemaining(tokens []int32) []int32 { prefix := 0 - for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] { + for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] { prefix++ } - switch { - case prefix == 0: - for _, c := range r.cache.Caches { - c.Free() + if prefix < len(c.tokens) { + trim := len(c.tokens) - prefix + for _, kv := range c.caches { + kv.Trim(trim) } - r.cache = nil + c.tokens = c.tokens[:prefix] + } + + if prefix == 0 { slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens - case prefix < len(r.cache.Tokens): - trim := len(r.cache.Tokens) - prefix - for _, c := range r.cache.Caches { - c.Trim(trim) - } - r.cache.Tokens = r.cache.Tokens[:prefix] + } else { + slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) } - - slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) - return r.cache.Caches, tokens[prefix:] + return tokens[prefix:] } -func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) { - r.cache = &CacheEntry{ - Tokens: tokens, - Caches: caches, +func (c *kvCache) log() { + if len(c.caches) == 0 { + return } -} - -func (c *CacheEntry) LogCache() { var totalBytes int - for _, kv := range c.Caches { + for _, kv := range c.caches { k, v := kv.State() totalBytes += k.NumBytes() + v.NumBytes() } - logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes))) + logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes))) } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index e16a6c9a6..901b25e89 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -10,7 +10,6 @@ import ( "time" "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } + var ( + sample, logprobs *mlx.Array + nextSample, nextLogprobs *mlx.Array + ) + + defer func() { + mlx.Unpin(sample, logprobs) + mlx.Unpin(nextSample, nextLogprobs) + mlx.Sweep() + mlx.ClearCache() + + if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { + mlx.LogArrays() + r.cache.log() + } + }() + enableCompile := true if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { enableCompile = modelCompile.EnableCompile() @@ -30,18 +46,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } inputs := r.Tokenizer.Encode(request.Prompt, true) + session := r.cache.begin(r.Model, inputs) + defer session.close() - caches, tokens := r.FindNearestCache(inputs) - if len(caches) == 0 { - if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok { - caches = cacheFactory.NewCaches() - } else { - caches = make([]cache.Cache, r.Model.NumLayers()) - for i := range caches { - caches[i] = cache.NewKVCache() - } - } - } + caches := session.caches + tokens := session.remaining total, processed := len(tokens), 0 slog.Info("Prompt processing progress", "processed", processed, "total", total) @@ -76,15 +85,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return sample, logprobs } - sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed)) + sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed)) var b bytes.Buffer now := time.Now() final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} - outputs := make([]int32, 0, request.Options.MaxTokens) for i := range request.Options.MaxTokens { - nextSample, nextLogprobs := step(sample) + nextSample, nextLogprobs = step(sample) if i == 0 { slog.Info("Prompt processing progress", "processed", total, "total", total) @@ -94,10 +102,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } output := int32(sample.Int()) - outputs = append(outputs, output) + session.outputs = append(session.outputs, output) if r.Tokenizer.IsEOS(output) { - mlx.Unpin(nextSample, nextLogprobs) final.Token = int(output) final.DoneReason = 0 final.CompletionTokens = i @@ -110,26 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } mlx.Unpin(sample, logprobs) + sample, logprobs = nextSample, nextLogprobs + nextSample, nextLogprobs = nil, nil + if i%256 == 0 { mlx.ClearCache() } - - sample, logprobs = nextSample, nextLogprobs } - mlx.Unpin(sample, logprobs) final.CompletionTokensDuration = time.Since(now) request.Responses <- final - r.InsertCache(append(inputs, outputs...), caches) - mlx.Sweep() - - if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { - mlx.LogArrays() - if r.cache != nil { - r.cache.LogCache() - } - } - return nil } diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index effaf0847..353b98d8d 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -61,7 +61,7 @@ type Runner struct { Model base.Model Tokenizer *tokenizer.Tokenizer Requests chan Request - cache *CacheEntry + cache kvCache } func (r *Runner) Load(modelName string) error {