mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
152 lines
3.8 KiB
Go
152 lines
3.8 KiB
Go
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
|
package imagegen
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// Execute is the entry point for the unified MLX runner subprocess.
|
|
func Execute(args []string) error {
|
|
// Set up logging with appropriate level from environment
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
|
|
|
|
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
|
|
modelName := fs.String("model", "", "path to model")
|
|
port := fs.Int("port", 0, "port to listen on")
|
|
|
|
if err := fs.Parse(args); err != nil {
|
|
return err
|
|
}
|
|
|
|
if *modelName == "" {
|
|
return fmt.Errorf("--model is required")
|
|
}
|
|
if *port == 0 {
|
|
return fmt.Errorf("--port is required")
|
|
}
|
|
|
|
// Detect model type from capabilities
|
|
mode := detectModelMode(*modelName)
|
|
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
|
|
|
if mode != ModeImageGen {
|
|
return fmt.Errorf("imagegen runner only supports image generation models")
|
|
}
|
|
|
|
// Initialize MLX only for image generation mode.
|
|
if err := mlx.InitMLX(); err != nil {
|
|
slog.Error("unable to initialize MLX", "error", err)
|
|
return err
|
|
}
|
|
slog.Info("MLX library initialized")
|
|
|
|
// Create and start server
|
|
server, err := newServer(*modelName, *port)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create server: %w", err)
|
|
}
|
|
|
|
// Set up HTTP handlers
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/health", server.healthHandler)
|
|
mux.HandleFunc("/completion", server.completionHandler)
|
|
|
|
httpServer := &http.Server{
|
|
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
|
Handler: mux,
|
|
}
|
|
|
|
// Handle shutdown
|
|
done := make(chan struct{})
|
|
go func() {
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sigCh
|
|
slog.Info("shutting down mlx runner")
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
httpServer.Shutdown(ctx)
|
|
close(done)
|
|
}()
|
|
|
|
slog.Info("mlx runner listening", "addr", httpServer.Addr)
|
|
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
|
return err
|
|
}
|
|
|
|
<-done
|
|
return nil
|
|
}
|
|
|
|
// detectModelMode determines whether a model is an LLM or image generation model.
|
|
func detectModelMode(modelName string) ModelMode {
|
|
// Check for image generation model by looking at model_index.json
|
|
modelType := DetectModelType(modelName)
|
|
if modelType != "" {
|
|
// Known image generation model types
|
|
switch modelType {
|
|
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
|
|
return ModeImageGen
|
|
}
|
|
}
|
|
|
|
// Default to LLM mode for safetensors models without known image gen types
|
|
return ModeLLM
|
|
}
|
|
|
|
// server holds the model and handles HTTP requests.
|
|
type server struct {
|
|
modelName string
|
|
port int
|
|
|
|
// Image generation model.
|
|
imageModel ImageModel
|
|
}
|
|
|
|
// newServer creates a new server instance for image generation models.
|
|
func newServer(modelName string, port int) (*server, error) {
|
|
s := &server{
|
|
modelName: modelName,
|
|
port: port,
|
|
}
|
|
|
|
if err := s.loadImageModel(); err != nil {
|
|
return nil, fmt.Errorf("failed to load image model: %w", err)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
|
resp := HealthResponse{Status: "ok"}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req Request
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
s.handleImageCompletion(w, r, req)
|
|
}
|