mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
253 lines
6.2 KiB
Go
253 lines
6.2 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
)
|
|
|
|
func prefillChunkSize() int {
|
|
return 2 << 10
|
|
}
|
|
|
|
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
|
|
if len(request.Images) > 0 {
|
|
// The shared trie cache keys only on token IDs today, so multimodal
|
|
// prompts must not reuse snapshots across distinct image inputs.
|
|
r.cache.clear()
|
|
}
|
|
|
|
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
|
|
images := make([]base.ImageInput, len(request.Images))
|
|
for i := range request.Images {
|
|
images[i] = base.ImageInput{
|
|
ID: request.Images[i].ID,
|
|
Data: request.Images[i].Data,
|
|
}
|
|
}
|
|
|
|
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if out == nil {
|
|
return nil, nil, errors.New("empty multimodal tokenization result")
|
|
}
|
|
return out.Tokens, out.State, nil
|
|
}
|
|
|
|
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
|
|
}
|
|
|
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
|
if r.Model == nil {
|
|
return errors.New("model not loaded")
|
|
}
|
|
|
|
enableCompile := true
|
|
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
enableCompile = modelCompile.EnableCompile()
|
|
}
|
|
if enableCompile {
|
|
mlx.EnableCompile()
|
|
} else {
|
|
mlx.DisableCompile()
|
|
}
|
|
mlx.ResetPeakMemory()
|
|
ctx := request.Ctx
|
|
var (
|
|
sample, logprobs *mlx.Array
|
|
nextSample, nextLogprobs *mlx.Array
|
|
)
|
|
|
|
defer func() {
|
|
if request.Sampler != nil {
|
|
request.Sampler.Free()
|
|
}
|
|
mlx.Unpin(sample, logprobs)
|
|
mlx.Unpin(nextSample, nextLogprobs)
|
|
mlx.Sweep()
|
|
mlx.ClearCache()
|
|
|
|
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
|
mlx.LogArrays()
|
|
r.cache.dumpTree()
|
|
}
|
|
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
|
}()
|
|
|
|
inputs, promptState, err := r.tokenizeRequest(request)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(inputs) == 0 {
|
|
return errors.New("empty prompt")
|
|
}
|
|
|
|
if len(inputs) >= r.contextLength {
|
|
return api.StatusError{
|
|
StatusCode: http.StatusBadRequest,
|
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
|
}
|
|
}
|
|
|
|
// Cap generation to stay within the model's context length
|
|
maxGenerate := r.contextLength - len(inputs)
|
|
if request.Options.MaxTokens <= 0 {
|
|
request.Options.MaxTokens = maxGenerate
|
|
} else {
|
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
|
}
|
|
|
|
request.Sampler.ResetHistory(inputs)
|
|
|
|
session := r.cache.begin(r.Model, inputs)
|
|
defer session.close()
|
|
caches := session.caches
|
|
tokens := session.remaining
|
|
prefillChunk := prefillChunkSize()
|
|
|
|
modelForward := func(tokens *mlx.Array) *mlx.Array {
|
|
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
|
|
return withState.ForwardWithState(tokens, caches, promptState)
|
|
}
|
|
return r.Model.Forward(tokens, caches)
|
|
}
|
|
|
|
materializeCaches := func() {
|
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
|
for _, c := range caches {
|
|
state = append(state, c.State()...)
|
|
}
|
|
if len(state) == 0 {
|
|
return
|
|
}
|
|
mlx.Eval(state...)
|
|
}
|
|
|
|
now := time.Now()
|
|
total, processed := len(tokens), 0
|
|
for total-processed > 1 {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
n := min(prefillChunk, total-processed-1)
|
|
|
|
// If there's a pending intermediate snapshot, split the batch
|
|
// so we can capture it at the exact offset. The cache offset
|
|
// after this batch will be: baseOffset + processed + n.
|
|
if session.snapshotOffset > 0 {
|
|
baseOffset := len(session.inputs) - len(tokens)
|
|
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
|
|
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
|
n = tokensUntilSnapshot
|
|
}
|
|
}
|
|
|
|
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
|
|
mlx.Sweep()
|
|
materializeCaches()
|
|
processed += n
|
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
|
|
// Create snapshot at branch point for future diverging requests.
|
|
if session.snapshotOffset > 0 {
|
|
baseOffset := len(session.inputs) - len(tokens)
|
|
if baseOffset+processed >= session.snapshotOffset {
|
|
session.snapshot(false)
|
|
}
|
|
}
|
|
|
|
mlx.ClearCache()
|
|
}
|
|
|
|
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
fwd := modelForward(token.ExpandDims(0))
|
|
logits := r.Model.Unembed(fwd)
|
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
|
|
|
logprobs := logits.Subtract(logits.Logsumexp(true))
|
|
sample := request.Sampler.Sample(logprobs)
|
|
|
|
mlx.Pin(sample, logprobs)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(sample, logprobs)
|
|
|
|
return sample, logprobs
|
|
}
|
|
|
|
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
|
|
|
var b bytes.Buffer
|
|
|
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
|
for i := range request.Options.MaxTokens {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
request.Sampler.AppendToken(sample)
|
|
nextSample, nextLogprobs = step(sample)
|
|
|
|
if i == 0 {
|
|
mlx.Eval(sample)
|
|
final.PromptEvalDuration = time.Since(now)
|
|
now = time.Now()
|
|
}
|
|
|
|
output := int32(sample.Int())
|
|
session.outputs = append(session.outputs, output)
|
|
|
|
if r.Tokenizer.IsEOS(output) {
|
|
final.DoneReason = 0
|
|
final.EvalCount = i
|
|
break
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- CompletionResponse{
|
|
Content: r.Decode(output, &b),
|
|
}:
|
|
}
|
|
|
|
mlx.Unpin(sample, logprobs)
|
|
sample, logprobs = nextSample, nextLogprobs
|
|
nextSample, nextLogprobs = nil, nil
|
|
|
|
if i%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
final.EvalDuration = time.Since(now)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- final:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
|
token := r.Tokenizer.Decode([]int32{sample})
|
|
|
|
if _, err := b.WriteString(token); err != nil {
|
|
slog.Error("Failed to write token to buffer", "error", err)
|
|
return ""
|
|
}
|
|
|
|
return flushValidUTF8Prefix(b)
|
|
}
|