Files
ollama/x/mlxrunner/pipeline.go
Jesse Gross 987f74c8a5 mlxrunner: introduce ForwardBatch for model forward pass
Replace the raw *mlx.Array token input with a ForwardBatch struct
that carries InputIDs alongside sequence metadata (SeqIDs, SeqLens).
InputIDs remain [1, N] shaped — all model code is unchanged beyond
the signature.
2026-04-03 19:50:41 -07:00

239 lines
5.8 KiB
Go

package mlxrunner
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func prefillChunkSize() int {
return 2 << 10
}
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.MaxTokens.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
request.Tokens = tokens
return nil
}
func (r *Runner) TextGenerationPipeline(request Request) error {
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.dumpTree()
}
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := request.Tokens
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
prefillChunk := prefillChunkSize()
// Request periodic snapshots during prefill and near the end of the
// prompt so that long prompts can be partially restored and
// thinking/generation can be retried without full reprocessing.
const snapshotInterval = 8192
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
session.requestSnapshot(offset)
}
const preThinking = 4
if end := len(inputs) - preThinking; end > 0 {
session.requestSnapshot(end)
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = append(state, c.State()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0
for total-processed > 1 {
if err := ctx.Err(); err != nil {
return err
}
n := min(prefillChunk, total-processed-1)
// If there's a pending snapshot, split the batch so we can
// capture it at the exact offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
}
r.Model.Forward(&batch.ForwardBatch{
InputIDs: mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0),
SeqIDs: []int{0},
SeqLens: []int{n},
}, caches)
mlx.Sweep()
materializeCaches()
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot if we've reached a pending offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= snapOffset {
session.snapshot()
}
}
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(&batch.ForwardBatch{
InputIDs: token.ExpandDims(0),
SeqIDs: []int{0},
SeqLens: []int{1},
}, caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sampler.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := ctx.Err(); err != nil {
return err
}
request.Sampler.AppendToken(sample)
nextSample, nextLogprobs = step(sample)
if i == 0 {
mlx.Eval(sample)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
final.EvalCount = i
break
}
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
if i%256 == 0 {
mlx.ClearCache()
}
}
final.EvalDuration = time.Since(now)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
return flushValidUTF8Prefix(b)
}