mirror of
https://github.com/ollama/ollama.git
synced 2026-04-25 02:06:11 +02:00
This change adds a new x/tokenizer package which includes: * New BPE and SentencePiece tokenizers * Removing the dependency on the imagegen tokenizers * Fixes to multibyte decoding in the pipeline * Various correctness and benchmark tests Not included in this PR is the WordPiece tokenizer for BERT models which will be added when we add embedding models. The imagegen tokenizers will also be removed in a follow-up PR.
130 lines
3.3 KiB
Go
130 lines
3.3 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
|
if r.Model == nil {
|
|
return errors.New("model not loaded")
|
|
}
|
|
|
|
enableCompile := true
|
|
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
enableCompile = modelCompile.EnableCompile()
|
|
}
|
|
if enableCompile {
|
|
mlx.EnableCompile()
|
|
} else {
|
|
mlx.DisableCompile()
|
|
}
|
|
|
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
|
|
|
caches, tokens := r.FindNearestCache(inputs)
|
|
if len(caches) == 0 {
|
|
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
|
caches = cacheFactory.NewCaches()
|
|
} else {
|
|
caches = make([]cache.Cache, r.Model.NumLayers())
|
|
for i := range caches {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
}
|
|
|
|
total, processed := len(tokens), 0
|
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
for total-processed > 1 {
|
|
n := min(2<<10, total-processed-1)
|
|
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
|
defer mlx.Free(temp)
|
|
mlx.Eval(func() []*mlx.Array {
|
|
s := make([]*mlx.Array, 2*len(caches))
|
|
for i, c := range caches {
|
|
s[2*i], s[2*i+1] = c.State()
|
|
}
|
|
return s
|
|
}()...)
|
|
processed += n
|
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
mlx.ClearCache()
|
|
}
|
|
|
|
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
|
logits := r.Model.Unembed(fwd)
|
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
|
|
|
logprobs := logits.Subtract(logits.Logsumexp(true))
|
|
return request.Sample(logprobs), logprobs
|
|
}
|
|
|
|
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
|
mlx.AsyncEval(sample, logprobs)
|
|
|
|
var b bytes.Buffer
|
|
|
|
now := time.Now()
|
|
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
|
outputs := make([]int32, 0, request.Options.MaxTokens)
|
|
for i := range request.Options.MaxTokens {
|
|
nextSample, nextLogprobs := step(sample)
|
|
mlx.AsyncEval(nextSample, nextLogprobs)
|
|
|
|
if i == 0 {
|
|
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
|
mlx.Eval(sample)
|
|
final.PromptTokensDuration = time.Since(now)
|
|
now = time.Now()
|
|
}
|
|
|
|
output := int32(sample.Int())
|
|
outputs = append(outputs, output)
|
|
|
|
if r.Tokenizer.IsEOS(output) {
|
|
final.Token = int(output)
|
|
final.DoneReason = 0
|
|
final.CompletionTokens = i
|
|
break
|
|
}
|
|
|
|
request.Responses <- Response{
|
|
Text: r.Decode(output, &b),
|
|
Token: int(output),
|
|
}
|
|
|
|
mlx.Free(sample, logprobs)
|
|
if i%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
|
|
sample, logprobs = nextSample, nextLogprobs
|
|
}
|
|
|
|
mlx.Free(sample, logprobs)
|
|
final.CompletionTokensDuration = time.Since(now)
|
|
request.Responses <- final
|
|
r.InsertCache(append(inputs, outputs...), caches)
|
|
return nil
|
|
}
|
|
|
|
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
|
token := r.Tokenizer.Decode([]int32{sample})
|
|
|
|
if _, err := b.WriteString(token); err != nil {
|
|
slog.Error("Failed to write token to buffer", "error", err)
|
|
return ""
|
|
}
|
|
|
|
return flushValidUTF8Prefix(b)
|
|
}
|