diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 49ddd04b6..750d556b4 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -3,94 +3,65 @@ package mlxrunner import ( + "fmt" "log/slog" + "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" ) +// CacheEntry stores a single sequence type CacheEntry struct { - Caches []cache.Cache - Count int - Entries map[int32]*CacheEntry + Tokens []int32 + Caches []cache.Cache } -func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { - current := &CacheEntry{Entries: s.CacheEntries} - index, cacheIndex := 0, -1 - for _, token := range tokens { - if _, ok := current.Entries[token]; !ok { - break - } - - current = current.Entries[token] - if len(current.Caches) > 0 { - cacheIndex = index - } - - index += 1 +// FindNearestCache finds the longest common prefix between tokens and the cached sequence +func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { + if r.cache == nil { + slog.Info("Cache miss", "left", len(tokens)) + return nil, tokens } - if cacheIndex == len(tokens)-1 { - slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens)) - return current.Caches, []int32{} - } else if cacheIndex > 1 { - slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:])) - return current.Caches, tokens[cacheIndex+1:] - } else if index > 0 && cacheIndex < 0 { - type stackItem struct { - entry *CacheEntry - tokens []int32 - } - - var best, item stackItem - stack := []stackItem{{entry: current, tokens: []int32{}}} - for len(stack) > 0 { - item, stack = stack[len(stack)-1], stack[:len(stack)-1] - if len(item.entry.Caches) > 0 { - if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) { - best = item - } - } else { - for token, entry := range item.entry.Entries { - stack = append(stack, stackItem{ - entry: entry, - tokens: append(item.tokens, token), - }) - } - } - } - - prefix := min(len(tokens)-1, index) - caches := make([]cache.Cache, len(best.entry.Caches)) - trim := len(best.tokens)+1 - for i := range caches { - caches[i] = best.entry.Caches[i].Clone() - caches[i].Trim(trim) - } - - slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim) - return caches, tokens[prefix:] + // Find longest common prefix + prefix := 0 + for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] { + prefix++ } - slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens + switch { + case prefix == 0: + for _, c := range r.cache.Caches { + c.Free() + } + r.cache = nil + slog.Info("Cache miss", "left", len(tokens)) + return nil, tokens + case prefix < len(r.cache.Tokens): + trim := len(r.cache.Tokens) - prefix + for _, c := range r.cache.Caches { + c.Trim(trim) + } + r.cache.Tokens = r.cache.Tokens[:prefix] + } + + slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) + return r.cache.Caches, tokens[prefix:] } -func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) { - current := &CacheEntry{Entries: s.CacheEntries} - for _, token := range tokens { - if _, ok := current.Entries[token]; !ok { - current.Entries[token] = &CacheEntry{ - Entries: make(map[int32]*CacheEntry), - } - } - - current = current.Entries[token] - } - - if len(current.Caches) > 0 { - current.Count += 1 - } else { - current.Caches = caches +func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) { + r.cache = &CacheEntry{ + Tokens: tokens, + Caches: caches, } } + +func (c *CacheEntry) LogCache() { + var totalBytes int + for _, kv := range c.Caches { + k, v := kv.State() + totalBytes += k.NumBytes() + v.NumBytes() + } + logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes))) +} diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 3196b9e2a..274bdffe1 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -13,6 +13,7 @@ type Cache interface { State() (keys, values *mlx.Array) Trim(int) int Clone() Cache + Free() Offset() int Len() int } @@ -84,6 +85,11 @@ func (c *KVCache) Clone() Cache { return clone } +func (c *KVCache) Free() { + mlx.Unpin(c.keys, c.values) + c.keys, c.values = nil, nil +} + func (c *KVCache) Offset() int { return c.offset } func (c *KVCache) Len() int { return c.offset } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 618d7ec9e..e16a6c9a6 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -125,6 +125,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { mlx.LogArrays() + if r.cache != nil { + r.cache.LogCache() + } } return nil diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 0b24fdb3d..effaf0847 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -58,10 +58,10 @@ type Response struct { } type Runner struct { - Model base.Model - Tokenizer *tokenizer.Tokenizer - Requests chan Request - CacheEntries map[int32]*CacheEntry + Model base.Model + Tokenizer *tokenizer.Tokenizer + Requests chan Request + cache *CacheEntry } func (r *Runner) Load(modelName string) error { diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index ef1e0dd1c..09b71f3c8 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -40,8 +40,7 @@ func Execute(args []string) error { flagSet.Parse(args) runner := Runner{ - Requests: make(chan Request), - CacheEntries: make(map[int32]*CacheEntry), + Requests: make(chan Request), } if err := runner.Load(modelName); err != nil {