mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
284 lines
8.1 KiB
Go
284 lines
8.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"image"
|
|
_ "image/jpeg"
|
|
_ "image/png"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime/pprof"
|
|
|
|
"github.com/ollama/ollama/x/imagegen"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// stringSlice is a flag type that accumulates multiple values
|
|
type stringSlice []string
|
|
|
|
func (s *stringSlice) String() string {
|
|
return fmt.Sprintf("%v", *s)
|
|
}
|
|
|
|
func (s *stringSlice) Set(value string) error {
|
|
*s = append(*s, value)
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
modelPath := flag.String("model", "", "Model directory")
|
|
prompt := flag.String("prompt", "Hello", "Prompt")
|
|
|
|
// Text generation params
|
|
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
|
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
|
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
|
topK := flag.Int("top-k", 40, "Top-k sampling")
|
|
imagePath := flag.String("image", "", "Image path for multimodal models")
|
|
|
|
// Image generation params
|
|
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
|
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
|
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
|
seed := flag.Int64("seed", 42, "Random seed")
|
|
out := flag.String("output", "output.png", "Output path")
|
|
|
|
// Utility flags
|
|
listTensors := flag.Bool("list", false, "List tensors only")
|
|
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
|
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
|
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
|
|
|
// Legacy mode flags
|
|
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
|
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
|
var inputImages stringSlice
|
|
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
|
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
|
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
|
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
|
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
|
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
|
|
|
flag.Parse()
|
|
|
|
if *modelPath == "" {
|
|
flag.Usage()
|
|
return
|
|
}
|
|
|
|
// Check if MLX initialized successfully
|
|
if !mlx.IsMLXAvailable() {
|
|
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
|
}
|
|
|
|
// CPU profiling
|
|
if *cpuProfile != "" {
|
|
f, err := os.Create(*cpuProfile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
if err := pprof.StartCPUProfile(f); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer pprof.StopCPUProfile()
|
|
}
|
|
|
|
var err error
|
|
|
|
// Handle legacy mode flags that aren't unified yet
|
|
switch {
|
|
case *zimageFlag:
|
|
m := &zimage.Model{}
|
|
if loadErr := m.Load(*modelPath); loadErr != nil {
|
|
log.Fatal(loadErr)
|
|
}
|
|
var img *mlx.Array
|
|
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
|
Prompt: *prompt,
|
|
NegativePrompt: *negativePrompt,
|
|
CFGScale: float32(*cfgScale),
|
|
Width: int32(*width),
|
|
Height: int32(*height),
|
|
Steps: *steps,
|
|
Seed: *seed,
|
|
CapturePath: *gpuCapture,
|
|
TeaCache: *teaCache,
|
|
TeaCacheThreshold: float32(*teaCacheThreshold),
|
|
FusedQKV: *fusedQKV,
|
|
})
|
|
if err == nil {
|
|
err = saveImageArray(img, *out)
|
|
}
|
|
case *flux2Flag:
|
|
m := &flux2.Model{}
|
|
if loadErr := m.Load(*modelPath); loadErr != nil {
|
|
log.Fatal(loadErr)
|
|
}
|
|
// Load input images with EXIF orientation correction
|
|
var loadedImages []image.Image
|
|
for _, path := range inputImages {
|
|
img, loadErr := loadImageWithEXIF(path)
|
|
if loadErr != nil {
|
|
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
|
}
|
|
loadedImages = append(loadedImages, img)
|
|
}
|
|
// When input images provided and user didn't override dimensions, use 0 to match input
|
|
fluxWidth := int32(*width)
|
|
fluxHeight := int32(*height)
|
|
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
|
// Both unset, will auto-detect from input
|
|
} else if len(loadedImages) > 0 && *width == 0 {
|
|
fluxWidth = 0 // Compute from height + aspect ratio
|
|
} else if len(loadedImages) > 0 && *height == 0 {
|
|
fluxHeight = 0 // Compute from width + aspect ratio
|
|
}
|
|
var img *mlx.Array
|
|
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
|
Prompt: *prompt,
|
|
Width: fluxWidth,
|
|
Height: fluxHeight,
|
|
Steps: *steps,
|
|
GuidanceScale: float32(*cfgScale),
|
|
Seed: *seed,
|
|
CapturePath: *gpuCapture,
|
|
InputImages: loadedImages,
|
|
})
|
|
if err == nil {
|
|
err = saveImageArray(img, *out)
|
|
}
|
|
case *listTensors:
|
|
err = listModelTensors(*modelPath)
|
|
default:
|
|
// llm path
|
|
m, err := load(*modelPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Load image if provided and model supports it.
|
|
var image *mlx.Array
|
|
if *imagePath != "" {
|
|
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
|
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
|
if err != nil {
|
|
log.Fatal("load image:", err)
|
|
}
|
|
} else {
|
|
log.Fatal("model does not support image input")
|
|
}
|
|
}
|
|
|
|
err = generate(context.Background(), m, input{
|
|
Prompt: *prompt,
|
|
Image: image,
|
|
MaxTokens: *maxTokens,
|
|
Temperature: float32(*temperature),
|
|
TopP: float32(*topP),
|
|
TopK: *topK,
|
|
WiredLimitGB: *wiredLimitGB,
|
|
}, func(out output) {
|
|
if out.Text != "" {
|
|
fmt.Print(out.Text)
|
|
}
|
|
if out.Done {
|
|
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
|
}
|
|
})
|
|
}
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func listModelTensors(modelPath string) error {
|
|
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, name := range weights.ListTensors() {
|
|
info, _ := weights.GetTensorInfo(name)
|
|
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadModel builds and evaluates a model using the common load pattern.
|
|
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
|
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
|
func loadModel[T Model](build func() T, cleanup func()) T {
|
|
m := build()
|
|
weights := mlx.Collect(m)
|
|
cleanup()
|
|
mlx.Eval(weights...)
|
|
return m
|
|
}
|
|
|
|
func load(modelPath string) (Model, error) {
|
|
kind, err := detectModelKind(modelPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("detect model kind: %w", err)
|
|
}
|
|
|
|
switch kind {
|
|
default:
|
|
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
|
}
|
|
}
|
|
|
|
func detectModelKind(modelPath string) (string, error) {
|
|
indexPath := filepath.Join(modelPath, "model_index.json")
|
|
if _, err := os.Stat(indexPath); err == nil {
|
|
data, err := os.ReadFile(indexPath)
|
|
if err != nil {
|
|
return "zimage", nil
|
|
}
|
|
var index struct {
|
|
ClassName string `json:"_class_name"`
|
|
}
|
|
if err := json.Unmarshal(data, &index); err == nil {
|
|
switch index.ClassName {
|
|
case "FluxPipeline", "ZImagePipeline":
|
|
return "zimage", nil
|
|
case "Flux2KleinPipeline":
|
|
return "flux2", nil
|
|
}
|
|
}
|
|
return "zimage", nil
|
|
}
|
|
|
|
configPath := filepath.Join(modelPath, "config.json")
|
|
data, err := os.ReadFile(configPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
|
}
|
|
|
|
var cfg struct {
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
return "", fmt.Errorf("parse config.json: %w", err)
|
|
}
|
|
|
|
return cfg.ModelType, nil
|
|
}
|
|
|
|
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
|
func loadImageWithEXIF(path string) (image.Image, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read file: %w", err)
|
|
}
|
|
return imagegen.DecodeImage(data)
|
|
}
|