Files
ollama/x/mlxrunner/pipeline.go
Patrick Devine 97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00

130 lines
3.3 KiB
Go

//go:build mlx
package mlxrunner
import (
"bytes"
"errors"
"log/slog"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
outputs = append(outputs, output)
if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}
mlx.Free(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
return flushValidUTF8Prefix(b)
}