mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 09:15:44 +02:00
This change adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner. It also: * introduces recurrent cache support and related MLX ops * updates pipeline/runner integration and adds tests * properly quantizes stacked expert tensors * a Gated Delta Metal kernel for fast SSM inference * adds new MLX calls for Conv1d, DepthwideConv1d, Contiguous, Exp, Log, SoftmaxAxis
204 lines
4.3 KiB
Go
204 lines
4.3 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
type kvCache struct {
|
|
// For now we only support a single entry, so this is just one sequence
|
|
tokens []int32
|
|
caches []cache.Cache
|
|
}
|
|
|
|
// cacheSession manages caches for a single pipeline run.
|
|
// Callers should append generated tokens to outputs and
|
|
// defer close to save the cache state.
|
|
type cacheSession struct {
|
|
cache *kvCache
|
|
inputs []int32
|
|
outputs []int32
|
|
|
|
caches []cache.Cache
|
|
remaining []int32
|
|
}
|
|
|
|
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
|
|
if c == nil {
|
|
return dst
|
|
}
|
|
|
|
keys, values := c.State()
|
|
if keys != nil && keys.Valid() {
|
|
dst = append(dst, keys)
|
|
}
|
|
if values != nil && values.Valid() {
|
|
dst = append(dst, values)
|
|
}
|
|
|
|
return dst
|
|
}
|
|
|
|
func (c *kvCache) free() {
|
|
for i, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
kv.Free()
|
|
c.caches[i] = nil
|
|
}
|
|
c.caches = nil
|
|
c.tokens = nil
|
|
}
|
|
|
|
func (c *kvCache) cachesCanTrim() bool {
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if !kv.CanTrim() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *kvCache) trimToPrefix(prefix int) {
|
|
for _, kv := range c.caches {
|
|
if kv == nil || !kv.CanTrim() {
|
|
continue
|
|
}
|
|
if trim := kv.Offset() - prefix; trim > 0 {
|
|
kv.Trim(trim)
|
|
}
|
|
}
|
|
if prefix < len(c.tokens) {
|
|
c.tokens = c.tokens[:prefix]
|
|
}
|
|
}
|
|
|
|
// begin prepares caches for a new request. It finds the nearest
|
|
// matching cache or creates new caches if none match.
|
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
ensureCaches := func() {
|
|
if len(c.caches) != 0 {
|
|
return
|
|
}
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
c.caches = cacheFactory.NewCaches()
|
|
return
|
|
}
|
|
c.caches = make([]cache.Cache, m.NumLayers())
|
|
for i := range c.caches {
|
|
c.caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
ensureCaches()
|
|
|
|
remaining := c.findRemaining(inputs)
|
|
ensureCaches()
|
|
|
|
return &cacheSession{
|
|
cache: c,
|
|
inputs: inputs,
|
|
caches: c.caches,
|
|
remaining: remaining,
|
|
}
|
|
}
|
|
|
|
// close saves the token state if the forward pass ran.
|
|
func (s *cacheSession) close() {
|
|
if len(s.caches) == 0 {
|
|
return
|
|
}
|
|
|
|
offset := -1
|
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
for _, kv := range s.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
// Mixed cache types (e.g. recurrent + KV) can transiently report different
|
|
// offsets, so use the minimum as the safe reusable token prefix.
|
|
if off := kv.Offset(); offset < 0 || off < offset {
|
|
offset = off
|
|
}
|
|
arrays = appendCacheState(arrays, kv)
|
|
}
|
|
if offset <= 0 {
|
|
return
|
|
}
|
|
|
|
// Ensure that if we have run the forward pass and set the metadata
|
|
// that we also actually have the data.
|
|
mlx.AsyncEval(arrays...)
|
|
|
|
stored := append(s.inputs, s.outputs...)
|
|
if offset > len(stored) {
|
|
offset = len(stored)
|
|
}
|
|
s.cache.tokens = stored[:offset]
|
|
}
|
|
|
|
// findRemaining finds the longest common prefix between tokens and the cached
|
|
// sequence, trims stale cache entries, and returns the remaining tokens.
|
|
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|
prefix := 0
|
|
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
|
prefix++
|
|
}
|
|
|
|
// Always keep at least one token to re-evaluate so the
|
|
// pipeline can seed token generation from it.
|
|
if prefix == len(tokens) && prefix > 0 {
|
|
prefix--
|
|
}
|
|
|
|
if prefix < len(c.tokens) {
|
|
if c.cachesCanTrim() {
|
|
c.trimToPrefix(prefix)
|
|
} else {
|
|
c.free()
|
|
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
|
return tokens
|
|
}
|
|
}
|
|
|
|
if prefix == 0 {
|
|
slog.Info("Cache miss", "left", len(tokens))
|
|
} else {
|
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
}
|
|
return tokens[prefix:]
|
|
}
|
|
|
|
func (c *kvCache) log() {
|
|
if len(c.caches) == 0 {
|
|
return
|
|
}
|
|
offset := -1
|
|
var totalBytes int
|
|
for _, kv := range c.caches {
|
|
if kv == nil {
|
|
continue
|
|
}
|
|
if off := kv.Offset(); offset < 0 || off < offset {
|
|
offset = off
|
|
}
|
|
for _, a := range appendCacheState(nil, kv) {
|
|
totalBytes += a.NumBytes()
|
|
}
|
|
}
|
|
if offset < 0 {
|
|
return
|
|
}
|
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
|
}
|