Files
ollama/x/mlxrunner/server.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

206 lines
5.1 KiB
Go

package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
if err := mlx.CheckInit(); err != nil {
return fmt.Errorf("MLX not available: %w", err)
}
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu")
} else {
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu")
}
var (
modelName string
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.StringVar(&modelName, "model", "", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
}
if err := runner.Load(modelName); err != nil {
return err
}
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(statusResponse{
Status: 0,
Progress: 100,
ContextLength: runner.contextLength,
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
request.Options.RepeatLastN,
request.Options.PresencePenalty,
)
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()
select {
case <-r.Context().Done():
return
case runner.Requests <- request:
}
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for {
select {
case <-r.Context().Done():
return
case response, ok := <-request.Responses:
if !ok {
return
}
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens := runner.Tokenizer.Encode(b.String(), true)
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}