mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 16:25:42 +02:00
Particularly in error cases, it can be difficult to ensure that all pinned memory is unpinned, MLX buffers are released and cache state is consistent. This encapsulates those pieces and sets up proper deferrals so that this happens automatically on exit.
108 lines
2.7 KiB
Go
108 lines
2.7 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
type kvCache struct {
|
|
// For now we only support a single entry, so this is just one sequence
|
|
tokens []int32
|
|
caches []cache.Cache
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
if len(c.caches) == 0 {
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
} else {
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
}
|
|
|
|
remaining := c.findRemaining(inputs)
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
if offset := s.caches[0].Offset(); offset > 0 {
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, c := range s.caches {
|
|
k, v := c.State()
|
|
arrays = append(arrays, k, v)
|
|
}
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
|
}
|
|
}
|
|
|
|
// findRemaining finds the longest common prefix between tokens and the cached
|
|
// sequence, trims stale cache entries, and returns the remaining tokens.
|
|
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|
prefix := 0
|
|
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
|
prefix++
|
|
}
|
|
|
|
if prefix < len(c.tokens) {
|
|
trim := len(c.tokens) - prefix
|
|
for _, kv := range c.caches {
|
|
kv.Trim(trim)
|
|
}
|
|
c.tokens = c.tokens[:prefix]
|
|
}
|
|
|
|
if prefix == 0 {
|
|
slog.Info("Cache miss", "left", len(tokens))
|
|
} else {
|
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
}
|
|
return tokens[prefix:]
|
|
}
|
|
|
|
func (c *kvCache) log() {
|
|
if len(c.caches) == 0 {
|
|
return
|
|
}
|
|
var totalBytes int
|
|
for _, kv := range c.caches {
|
|
k, v := kv.State()
|
|
totalBytes += k.NumBytes() + v.NumBytes()
|
|
}
|
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
|
}
|