mirror of
https://github.com/ollama/ollama.git
synced 2026-04-25 18:25:42 +02:00
mlxrunner: Simplify KV cache to single-entry prefix matching
The KV cache previously used a tree structure which could store multiple divergent sequences, which is good for cache reuse. However, this is typically used in conjunction with paged attention so each node in the tree can store just a chunk of the KV cache and they can be stitched together later. We don't currently do this, so the cache was storing copies of the full cache for each past sequence. This redundancy plus the lack of resource limits, caused significant memory use as a conversation grew. Instead, this changes to store a single entry for the cache, which can be prefix matched. Although it is less ideal for multiple users, it largely matches Ollama's current behavior. It can be improved as additional pieces are fleshed out.
This commit is contained in:
@@ -3,94 +3,65 @@
|
|||||||
package mlxrunner
|
package mlxrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CacheEntry stores a single sequence
|
||||||
type CacheEntry struct {
|
type CacheEntry struct {
|
||||||
Caches []cache.Cache
|
Tokens []int32
|
||||||
Count int
|
Caches []cache.Cache
|
||||||
Entries map[int32]*CacheEntry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||||
current := &CacheEntry{Entries: s.CacheEntries}
|
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||||
index, cacheIndex := 0, -1
|
if r.cache == nil {
|
||||||
for _, token := range tokens {
|
slog.Info("Cache miss", "left", len(tokens))
|
||||||
if _, ok := current.Entries[token]; !ok {
|
return nil, tokens
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
current = current.Entries[token]
|
|
||||||
if len(current.Caches) > 0 {
|
|
||||||
cacheIndex = index
|
|
||||||
}
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cacheIndex == len(tokens)-1 {
|
// Find longest common prefix
|
||||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
prefix := 0
|
||||||
return current.Caches, []int32{}
|
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||||
} else if cacheIndex > 1 {
|
prefix++
|
||||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
|
||||||
return current.Caches, tokens[cacheIndex+1:]
|
|
||||||
} else if index > 0 && cacheIndex < 0 {
|
|
||||||
type stackItem struct {
|
|
||||||
entry *CacheEntry
|
|
||||||
tokens []int32
|
|
||||||
}
|
|
||||||
|
|
||||||
var best, item stackItem
|
|
||||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
|
||||||
for len(stack) > 0 {
|
|
||||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
|
||||||
if len(item.entry.Caches) > 0 {
|
|
||||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
|
||||||
best = item
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for token, entry := range item.entry.Entries {
|
|
||||||
stack = append(stack, stackItem{
|
|
||||||
entry: entry,
|
|
||||||
tokens: append(item.tokens, token),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := min(len(tokens)-1, index)
|
|
||||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
|
||||||
trim := len(best.tokens)+1
|
|
||||||
for i := range caches {
|
|
||||||
caches[i] = best.entry.Caches[i].Clone()
|
|
||||||
caches[i].Trim(trim)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
|
||||||
return caches, tokens[prefix:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Cache miss", "left", len(tokens))
|
switch {
|
||||||
return nil, tokens
|
case prefix == 0:
|
||||||
|
for _, c := range r.cache.Caches {
|
||||||
|
c.Free()
|
||||||
|
}
|
||||||
|
r.cache = nil
|
||||||
|
slog.Info("Cache miss", "left", len(tokens))
|
||||||
|
return nil, tokens
|
||||||
|
case prefix < len(r.cache.Tokens):
|
||||||
|
trim := len(r.cache.Tokens) - prefix
|
||||||
|
for _, c := range r.cache.Caches {
|
||||||
|
c.Trim(trim)
|
||||||
|
}
|
||||||
|
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||||
|
return r.cache.Caches, tokens[prefix:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||||
current := &CacheEntry{Entries: s.CacheEntries}
|
r.cache = &CacheEntry{
|
||||||
for _, token := range tokens {
|
Tokens: tokens,
|
||||||
if _, ok := current.Entries[token]; !ok {
|
Caches: caches,
|
||||||
current.Entries[token] = &CacheEntry{
|
|
||||||
Entries: make(map[int32]*CacheEntry),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
current = current.Entries[token]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(current.Caches) > 0 {
|
|
||||||
current.Count += 1
|
|
||||||
} else {
|
|
||||||
current.Caches = caches
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CacheEntry) LogCache() {
|
||||||
|
var totalBytes int
|
||||||
|
for _, kv := range c.Caches {
|
||||||
|
k, v := kv.State()
|
||||||
|
totalBytes += k.NumBytes() + v.NumBytes()
|
||||||
|
}
|
||||||
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||||
|
}
|
||||||
|
|||||||
6
x/mlxrunner/cache/cache.go
vendored
6
x/mlxrunner/cache/cache.go
vendored
@@ -13,6 +13,7 @@ type Cache interface {
|
|||||||
State() (keys, values *mlx.Array)
|
State() (keys, values *mlx.Array)
|
||||||
Trim(int) int
|
Trim(int) int
|
||||||
Clone() Cache
|
Clone() Cache
|
||||||
|
Free()
|
||||||
Offset() int
|
Offset() int
|
||||||
Len() int
|
Len() int
|
||||||
}
|
}
|
||||||
@@ -84,6 +85,11 @@ func (c *KVCache) Clone() Cache {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) Free() {
|
||||||
|
mlx.Unpin(c.keys, c.values)
|
||||||
|
c.keys, c.values = nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *KVCache) Offset() int { return c.offset }
|
func (c *KVCache) Offset() int { return c.offset }
|
||||||
func (c *KVCache) Len() int { return c.offset }
|
func (c *KVCache) Len() int { return c.offset }
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||||
mlx.LogArrays()
|
mlx.LogArrays()
|
||||||
|
if r.cache != nil {
|
||||||
|
r.cache.LogCache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -58,10 +58,10 @@ type Response struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
CacheEntries map[int32]*CacheEntry
|
cache *CacheEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ func Execute(args []string) error {
|
|||||||
flagSet.Parse(args)
|
flagSet.Parse(args)
|
||||||
|
|
||||||
runner := Runner{
|
runner := Runner{
|
||||||
Requests: make(chan Request),
|
Requests: make(chan Request),
|
||||||
CacheEntries: make(map[int32]*CacheEntry),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := runner.Load(modelName); err != nil {
|
if err := runner.Load(modelName); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user