mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 17:29:54 +02:00
When the entire prompt was already cached (e.g. repeated prompt), findRemaining returned an empty slice, causing FromValues to panic on an index-out-of-range accessing a zero-length byte slice. Fix by always keeping at least one token to re-evaluate so the pipeline can seed token generation. Also reject empty prompts early rather than panicking.
114 lines
2.8 KiB
Go
114 lines
2.8 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
type kvCache struct {
|
|
// For now we only support a single entry, so this is just one sequence
|
|
tokens []int32
|
|
caches []cache.Cache
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
if len(c.caches) == 0 {
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
} else {
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
}
|
|
|
|
remaining := c.findRemaining(inputs)
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
if offset := s.caches[0].Offset(); offset > 0 {
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, c := range s.caches {
|
|
k, v := c.State()
|
|
arrays = append(arrays, k, v)
|
|
}
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
|
}
|
|
}
|
|
|
|
// findRemaining finds the longest common prefix between tokens and the cached
|
|
// sequence, trims stale cache entries, and returns the remaining tokens.
|
|
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|
prefix := 0
|
|
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
|
prefix++
|
|
}
|
|
|
|
// Always keep at least one token to re-evaluate so the
|
|
// pipeline can seed token generation from it.
|
|
if prefix == len(tokens) && prefix > 0 {
|
|
prefix--
|
|
}
|
|
|
|
if prefix < len(c.tokens) {
|
|
trim := len(c.tokens) - prefix
|
|
for _, kv := range c.caches {
|
|
kv.Trim(trim)
|
|
}
|
|
c.tokens = c.tokens[:prefix]
|
|
}
|
|
|
|
if prefix == 0 {
|
|
slog.Info("Cache miss", "left", len(tokens))
|
|
} else {
|
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
}
|
|
return tokens[prefix:]
|
|
}
|
|
|
|
func (c *kvCache) log() {
|
|
if len(c.caches) == 0 {
|
|
return
|
|
}
|
|
var totalBytes int
|
|
for _, kv := range c.caches {
|
|
k, v := kv.State()
|
|
totalBytes += k.NumBytes() + v.NumBytes()
|
|
}
|
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
|
}
|