diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 750d556b4..0d858d91b 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -9,59 +9,99 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" ) -// CacheEntry stores a single sequence -type CacheEntry struct { - Tokens []int32 - Caches []cache.Cache +type kvCache struct { + // For now we only support a single entry, so this is just one sequence + tokens []int32 + caches []cache.Cache } -// FindNearestCache finds the longest common prefix between tokens and the cached sequence -func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { - if r.cache == nil { - slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens +// cacheSession manages caches for a single pipeline run. +// Callers should append generated tokens to outputs and +// defer close to save the cache state. +type cacheSession struct { + cache *kvCache + inputs []int32 + outputs []int32 + + caches []cache.Cache + remaining []int32 +} + +// begin prepares caches for a new request. It finds the nearest +// matching cache or creates new caches if none match. +func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { + if len(c.caches) == 0 { + if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { + c.caches = cacheFactory.NewCaches() + } else { + c.caches = make([]cache.Cache, m.NumLayers()) + for i := range c.caches { + c.caches[i] = cache.NewKVCache() + } + } } - // Find longest common prefix + remaining := c.findRemaining(inputs) + + return &cacheSession{ + cache: c, + inputs: inputs, + caches: c.caches, + remaining: remaining, + } +} + +// close saves the token state if the forward pass ran. +func (s *cacheSession) close() { + if offset := s.caches[0].Offset(); offset > 0 { + // Ensure that if we have run the forward pass and set the metadata + // that we also actually have the data + arrays := make([]*mlx.Array, 0, 2*len(s.caches)) + for _, c := range s.caches { + k, v := c.State() + arrays = append(arrays, k, v) + } + mlx.AsyncEval(arrays...) + + s.cache.tokens = append(s.inputs, s.outputs...)[:offset] + } +} + +// findRemaining finds the longest common prefix between tokens and the cached +// sequence, trims stale cache entries, and returns the remaining tokens. +func (c *kvCache) findRemaining(tokens []int32) []int32 { prefix := 0 - for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] { + for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] { prefix++ } - switch { - case prefix == 0: - for _, c := range r.cache.Caches { - c.Free() + if prefix < len(c.tokens) { + trim := len(c.tokens) - prefix + for _, kv := range c.caches { + kv.Trim(trim) } - r.cache = nil + c.tokens = c.tokens[:prefix] + } + + if prefix == 0 { slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens - case prefix < len(r.cache.Tokens): - trim := len(r.cache.Tokens) - prefix - for _, c := range r.cache.Caches { - c.Trim(trim) - } - r.cache.Tokens = r.cache.Tokens[:prefix] + } else { + slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) } - - slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) - return r.cache.Caches, tokens[prefix:] + return tokens[prefix:] } -func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) { - r.cache = &CacheEntry{ - Tokens: tokens, - Caches: caches, +func (c *kvCache) log() { + if len(c.caches) == 0 { + return } -} - -func (c *CacheEntry) LogCache() { var totalBytes int - for _, kv := range c.Caches { + for _, kv := range c.caches { k, v := kv.State() totalBytes += k.NumBytes() + v.NumBytes() } - logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes))) + logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes))) } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index e16a6c9a6..901b25e89 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -10,7 +10,6 @@ import ( "time" "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } + var ( + sample, logprobs *mlx.Array + nextSample, nextLogprobs *mlx.Array + ) + + defer func() { + mlx.Unpin(sample, logprobs) + mlx.Unpin(nextSample, nextLogprobs) + mlx.Sweep() + mlx.ClearCache() + + if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { + mlx.LogArrays() + r.cache.log() + } + }() + enableCompile := true if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { enableCompile = modelCompile.EnableCompile() @@ -30,18 +46,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } inputs := r.Tokenizer.Encode(request.Prompt, true) + session := r.cache.begin(r.Model, inputs) + defer session.close() - caches, tokens := r.FindNearestCache(inputs) - if len(caches) == 0 { - if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok { - caches = cacheFactory.NewCaches() - } else { - caches = make([]cache.Cache, r.Model.NumLayers()) - for i := range caches { - caches[i] = cache.NewKVCache() - } - } - } + caches := session.caches + tokens := session.remaining total, processed := len(tokens), 0 slog.Info("Prompt processing progress", "processed", processed, "total", total) @@ -76,15 +85,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return sample, logprobs } - sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed)) + sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed)) var b bytes.Buffer now := time.Now() final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} - outputs := make([]int32, 0, request.Options.MaxTokens) for i := range request.Options.MaxTokens { - nextSample, nextLogprobs := step(sample) + nextSample, nextLogprobs = step(sample) if i == 0 { slog.Info("Prompt processing progress", "processed", total, "total", total) @@ -94,10 +102,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } output := int32(sample.Int()) - outputs = append(outputs, output) + session.outputs = append(session.outputs, output) if r.Tokenizer.IsEOS(output) { - mlx.Unpin(nextSample, nextLogprobs) final.Token = int(output) final.DoneReason = 0 final.CompletionTokens = i @@ -110,26 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } mlx.Unpin(sample, logprobs) + sample, logprobs = nextSample, nextLogprobs + nextSample, nextLogprobs = nil, nil + if i%256 == 0 { mlx.ClearCache() } - - sample, logprobs = nextSample, nextLogprobs } - mlx.Unpin(sample, logprobs) final.CompletionTokensDuration = time.Since(now) request.Responses <- final - r.InsertCache(append(inputs, outputs...), caches) - mlx.Sweep() - - if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { - mlx.LogArrays() - if r.cache != nil { - r.cache.LogCache() - } - } - return nil } diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index effaf0847..353b98d8d 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -61,7 +61,7 @@ type Runner struct { Model base.Model Tokenizer *tokenizer.Tokenizer Requests chan Request - cache *CacheEntry + cache kvCache } func (r *Runner) Load(modelName string) error {