Files
ollama/x/mlxrunner/pipeline.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

192 lines
4.3 KiB
Go

package mlxrunner
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.log()
}
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
prefillChunk := prefillChunkSize()
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = appendCacheState(state, c)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0
for total-processed > 1 {
if err := ctx.Err(); err != nil {
return err
}
n := min(prefillChunk, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
materializeCaches()
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sampler.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := ctx.Err(); err != nil {
return err
}
request.Sampler.AppendToken(sample)
nextSample, nextLogprobs = step(sample)
if i == 0 {
mlx.Eval(sample)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
final.EvalCount = i
break
}
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
if i%256 == 0 {
mlx.ClearCache()
}
}
final.EvalDuration = time.Since(now)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
return flushValidUTF8Prefix(b)
}