From 30fdd229a434cfae409cc07456684315bf95a561 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 7 Apr 2026 08:12:57 -0700 Subject: [PATCH] create: Clean up experimental paths, fix create from existing safetensor model (#14679) * create: Clean up experimental paths This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions. * create: preserve config and layer names when creating from safetensors models When creating a model FROM an existing safetensors model, ModelFormat, Capabilities, and layer Name fields were lost. ModelFormat stayed empty because it's only set from GGML layers (which safetensors models lack), and layer names weren't copied in parseFromModel. This caused derived models to fail loading ("config.json not found in manifest"). * review comments --- .gitignore | 1 + cmd/cmd.go | 16 +- integration/create_imagegen_test.go | 107 ++++++ integration/create_test.go | 350 +++++++++++++++++++ server/create.go | 5 +- server/model.go | 5 +- server/routes_create_test.go | 269 +++++++++++++++ x/create/client/create.go | 2 +- x/create/client/create_test.go | 156 +++++++++ x/create/create.go | 2 +- x/create/create_test.go | 2 +- x/create/imagegen.go | 2 +- x/create/qwen35.go | 2 +- x/{imagegen => }/safetensors/extractor.go | 1 - x/safetensors/extractor_test.go | 394 ++++++++++++++++++++++ 15 files changed, 1292 insertions(+), 22 deletions(-) create mode 100644 integration/create_imagegen_test.go create mode 100644 integration/create_test.go rename x/{imagegen => }/safetensors/extractor.go (99%) create mode 100644 x/safetensors/extractor_test.go diff --git a/.gitignore b/.gitignore index eabf94c28..4e0b8974f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ __debug_bin* llama/build llama/vendor /ollama +integration/testdata/models/ diff --git a/cmd/cmd.go b/cmd/cmd.go index a9bf3d4dc..250b63f40 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -54,7 +54,6 @@ import ( "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" - "github.com/ollama/ollama/x/create" xcreateclient "github.com/ollama/ollama/x/create/client" "github.com/ollama/ollama/x/imagegen" ) @@ -164,11 +163,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } // Check for --experimental flag for safetensors model creation + // This gates both safetensors LLM and imagegen model creation experimental, _ := cmd.Flags().GetBool("experimental") if experimental { if !isLocalhost() { return errors.New("remote safetensor model creation not yet supported") } + // Get Modelfile content - either from -f flag or default to "FROM ." var reader io.Reader filename, err := getModelfileName(cmd) @@ -211,23 +212,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error { }, p) } + // Standard Modelfile + API path var reader io.Reader filename, err := getModelfileName(cmd) if os.IsNotExist(err) { if filename == "" { - // No Modelfile found - check if current directory is an image gen model - if create.IsTensorModelDir(".") { - if !isLocalhost() { - return errors.New("remote safetensor model creation not yet supported") - } - quantize, _ := cmd.Flags().GetString("quantize") - return xcreateclient.CreateModel(xcreateclient.CreateOptions{ - ModelName: modelName, - ModelDir: ".", - Quantize: quantize, - }, p) - } reader = strings.NewReader("FROM .\n") } else { return errModelfileNotFound diff --git a/integration/create_imagegen_test.go b/integration/create_imagegen_test.go new file mode 100644 index 000000000..53903e7a6 --- /dev/null +++ b/integration/create_imagegen_test.go @@ -0,0 +1,107 @@ +//go:build integration && imagegen + +package integration + +import ( + "context" + "encoding/base64" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestCreateImageGen(t *testing.T) { + skipIfRemote(t) + skipUnderMinVRAM(t, 13) + + // Allow overriding the model directory via env var for local testing, + // since the model is ~33GB and may already be downloaded elsewhere. + modelDir := os.Getenv("OLLAMA_TEST_IMAGEGEN_MODEL_DIR") + if modelDir == "" { + modelDir = filepath.Join(testdataModelsDir, "Z-Image-Turbo") + downloadHFModel(t, "Tongyi-MAI/Z-Image-Turbo", modelDir) + } else { + t.Logf("Using existing imagegen model at %s", modelDir) + } + + // Verify it looks like a valid imagegen model directory + if _, err := os.Stat(filepath.Join(modelDir, "model_index.json")); err != nil { + t.Fatalf("model_index.json not found in %s — not a valid imagegen model directory", modelDir) + } + + ensureMLXLibraryPath(t) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-z-image-turbo-create" + + absModelDir, err := filepath.Abs(modelDir) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + // Create a Modelfile pointing to the diffusers model directory + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte("FROM "+absModelDir+"\n"), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + t.Logf("Creating imagegen model from %s", absModelDir) + runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile) + + // Verify model exists via show + showReq := &api.ShowRequest{Name: modelName} + showResp, err := client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + t.Logf("Created model details: %+v", showResp.Details) + + // Generate an image to verify the model isn't corrupted + t.Log("Generating test image...") + imageBase64, err := generateImage(ctx, client, modelName, "A red circle on a white background") + if err != nil { + if strings.Contains(err.Error(), "image generation not available") { + t.Skip("Target system does not support image generation") + } else if strings.Contains(err.Error(), "insufficient memory for image generation") { + t.Skip("insufficient memory for image generation") + } else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") { + t.Skip("unsupported architecture") + } + t.Fatalf("Image generation failed: %v", err) + } + + // Verify we got valid image data + imageData, err := base64.StdEncoding.DecodeString(imageBase64) + if err != nil { + t.Fatalf("Failed to decode base64 image: %v", err) + } + + t.Logf("Generated image: %d bytes", len(imageData)) + + if len(imageData) < 1000 { + t.Fatalf("Generated image suspiciously small (%d bytes), likely corrupted", len(imageData)) + } + + // Check for PNG or JPEG magic bytes + isPNG := len(imageData) >= 4 && imageData[0] == 0x89 && imageData[1] == 'P' && imageData[2] == 'N' && imageData[3] == 'G' + isJPEG := len(imageData) >= 2 && imageData[0] == 0xFF && imageData[1] == 0xD8 + if !isPNG && !isJPEG { + t.Fatalf("Generated image is neither PNG nor JPEG (first bytes: %x)", imageData[:min(8, len(imageData))]) + } + t.Logf("Image format validated (PNG=%v, JPEG=%v)", isPNG, isJPEG) + + // Cleanup: delete the model + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} diff --git a/integration/create_test.go b/integration/create_test.go new file mode 100644 index 000000000..60a1269b9 --- /dev/null +++ b/integration/create_test.go @@ -0,0 +1,350 @@ +//go:build integration + +package integration + +import ( + "context" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +const testdataModelsDir = "testdata/models" + +// skipIfRemote skips the test if OLLAMA_HOST points to a non-local server. +// Safetensors/imagegen creation requires localhost since it reads model files +// from disk and uses the --experimental CLI path. +func skipIfRemote(t *testing.T) { + t.Helper() + host := os.Getenv("OLLAMA_HOST") + if host == "" { + return // default is localhost + } + // Strip scheme if present + _, hostport, ok := strings.Cut(host, "://") + if !ok { + hostport = host + } + h, _, err := net.SplitHostPort(hostport) + if err != nil { + h = hostport + } + if h == "" || h == "localhost" { + return + } + ip := net.ParseIP(h) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) { + return + } + t.Skipf("safetensors/imagegen creation requires a local server (OLLAMA_HOST=%s)", host) +} + +// findHFCLI returns the path to the HuggingFace CLI, or "" if not found. +func findHFCLI() string { + for _, name := range []string{"huggingface-cli", "hf"} { + if p, err := exec.LookPath(name); err == nil { + return p + } + } + return "" +} + +// downloadHFModel idempotently downloads a HuggingFace model to destDir. +// Skips the test if CLI is missing and model isn't already present. +func downloadHFModel(t *testing.T, repo, destDir string, extraArgs ...string) { + t.Helper() + + // Check if model already exists + if _, err := os.Stat(destDir); err == nil { + entries, err := os.ReadDir(destDir) + if err == nil && len(entries) > 0 { + t.Logf("Model %s already present at %s", repo, destDir) + return + } + } + + cli := findHFCLI() + if cli == "" { + t.Skipf("HuggingFace CLI not found and model %s not present at %s", repo, destDir) + } + + t.Logf("Downloading %s to %s", repo, destDir) + os.MkdirAll(destDir, 0o755) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + args := []string{"download", repo, "--local-dir", destDir} + args = append(args, extraArgs...) + cmd := exec.CommandContext(ctx, cli, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("Failed to download %s: %v", repo, err) + } +} + +// ollamaBin returns the path to the ollama binary to use for tests. +// Prefers OLLAMA_BIN env, then falls back to the built binary at ../ollama +// (same binary the integration test server uses). +func ollamaBin() string { + if bin := os.Getenv("OLLAMA_BIN"); bin != "" { + return bin + } + if abs, err := filepath.Abs("../ollama"); err == nil { + if _, err := os.Stat(abs); err == nil { + return abs + } + } + return "ollama" +} + +// ensureMLXLibraryPath sets OLLAMA_LIBRARY_PATH so the MLX dynamic library +// is discoverable. Integration tests run from integration/ dir, so the +// default CWD-based search won't find the library at the repo root. +func ensureMLXLibraryPath(t *testing.T) { + t.Helper() + if libPath, err := filepath.Abs("../build/lib/ollama"); err == nil { + if _, err := os.Stat(libPath); err == nil { + if existing := os.Getenv("OLLAMA_LIBRARY_PATH"); existing != "" { + t.Setenv("OLLAMA_LIBRARY_PATH", existing+string(filepath.ListSeparator)+libPath) + } else { + t.Setenv("OLLAMA_LIBRARY_PATH", libPath) + } + } + } +} + +// runOllamaCreate runs "ollama create" as a subprocess. Skips the test if +// the error indicates the server is remote. +func runOllamaCreate(ctx context.Context, t *testing.T, args ...string) { + t.Helper() + createCmd := exec.CommandContext(ctx, ollamaBin(), append([]string{"create"}, args...)...) + var createStderr strings.Builder + createCmd.Stdout = os.Stdout + createCmd.Stderr = io.MultiWriter(os.Stderr, &createStderr) + if err := createCmd.Run(); err != nil { + if strings.Contains(createStderr.String(), "remote") { + t.Skip("safetensors creation requires a local server") + } + t.Fatalf("ollama create failed: %v", err) + } +} + +func TestCreateSafetensorsLLM(t *testing.T) { + skipIfRemote(t) + + modelDir := filepath.Join(testdataModelsDir, "TinyLlama-1.1B") + downloadHFModel(t, "TinyLlama/TinyLlama-1.1B-Chat-v1.0", modelDir) + + ensureMLXLibraryPath(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-tinyllama-safetensors" + + absModelDir, err := filepath.Abs(modelDir) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + // Create a Modelfile pointing to the model directory. + // Include a chat template since the safetensors importer doesn't extract + // chat_template from tokenizer_config.json yet. + modelfileContent := "FROM " + absModelDir + "\n" + + "TEMPLATE \"{{ if .System }}<|system|>\n{{ .System }}\n{{ end }}" + + "{{ if .Prompt }}<|user|>\n{{ .Prompt }}\n{{ end }}" + + "<|assistant|>\n{{ .Response }}\n\"\n" + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte(modelfileContent), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile) + + // Verify model exists via show + showReq := &api.ShowRequest{Name: modelName} + showResp, err := client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + t.Logf("Created model details: %+v", showResp.Details) + + // Use the chat API for proper template application. + chatReq := &api.ChatRequest{ + Model: modelName, + Messages: []api.Message{ + {Role: "user", Content: "Write a short sentence about the weather."}, + }, + Options: map[string]interface{}{ + "num_predict": 20, + "temperature": 0.0, + }, + } + + var output strings.Builder + err = client.Chat(ctx, chatReq, func(resp api.ChatResponse) error { + output.WriteString(resp.Message.Content) + return nil + }) + if err != nil { + t.Fatalf("Chat failed: %v", err) + } + + text := output.String() + t.Logf("Generated output: %q", text) + assertCoherentOutput(t, text) + + // Cleanup: delete the model + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} + +func TestCreateGGUF(t *testing.T) { + modelDir := filepath.Join(testdataModelsDir, "Llama-3.2-1B-GGUF") + downloadHFModel(t, "bartowski/Llama-3.2-1B-Instruct-GGUF", modelDir, + "--include", "Llama-3.2-1B-Instruct-IQ3_M.gguf") + + // Find the GGUF file + entries, err := os.ReadDir(modelDir) + if err != nil { + t.Fatalf("Failed to read model dir: %v", err) + } + + var ggufPath string + for _, e := range entries { + if filepath.Ext(e.Name()) == ".gguf" { + ggufPath = filepath.Join(modelDir, e.Name()) + break + } + } + if ggufPath == "" { + t.Skip("No GGUF file found in model directory") + } + + absGGUF, err := filepath.Abs(ggufPath) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-llama32-gguf" + + // Create a Modelfile and use the CLI + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte("FROM "+absGGUF+"\n"), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + createCmd := exec.CommandContext(ctx, ollamaBin(), "create", modelName, "-f", tmpModelfile) + createCmd.Stdout = os.Stdout + createCmd.Stderr = os.Stderr + if err := createCmd.Run(); err != nil { + t.Fatalf("ollama create failed: %v", err) + } + + // Verify model exists + showReq := &api.ShowRequest{Name: modelName} + _, err = client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + + // Generate and verify output is coherent + genReq := &api.GenerateRequest{ + Model: modelName, + Prompt: "Write a short sentence about the weather.", + Options: map[string]interface{}{ + "num_predict": 20, + "temperature": 0.0, + }, + } + + var output strings.Builder + err = client.Generate(ctx, genReq, func(resp api.GenerateResponse) error { + output.WriteString(resp.Response) + return nil + }) + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + + text := output.String() + t.Logf("Generated output: %q", text) + assertCoherentOutput(t, text) + + // Cleanup + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} + +// assertCoherentOutput checks that model output looks like real language, not +// garbled binary or repeated garbage. This catches corrupted model creation +// where inference "works" but produces nonsense. +func assertCoherentOutput(t *testing.T, text string) { + t.Helper() + + if len(text) == 0 { + t.Fatal("model produced empty output") + } + + // Check minimum length — 20 tokens should produce at least a few words + if len(text) < 5 { + t.Fatalf("model output suspiciously short (%d bytes): %q", len(text), text) + } + + // Check for mostly-printable ASCII/Unicode — garbled models often emit + // high ratios of control characters or replacement characters + unprintable := 0 + for _, r := range text { + if r < 0x20 && r != '\n' && r != '\r' && r != '\t' { + unprintable++ + } + if r == '\ufffd' { // Unicode replacement character + unprintable++ + } + } + ratio := float64(unprintable) / float64(len([]rune(text))) + if ratio > 0.3 { + t.Fatalf("model output is %.0f%% unprintable characters (likely garbled): %q", ratio*100, text) + } + + // Check it contains at least one space — real language has word boundaries + if !strings.Contains(text, " ") { + t.Fatalf("model output contains no spaces (likely garbled): %q", text) + } + + // Check for excessive repetition — a broken model might repeat one token + words := strings.Fields(text) + if len(words) >= 4 { + counts := map[string]int{} + for _, w := range words { + counts[strings.ToLower(w)]++ + } + for w, c := range counts { + if c > len(words)*3/4 { + t.Fatalf("model output is excessively repetitive (%q appears %d/%d times): %q", w, c, len(words), text) + } + } + } +} diff --git a/server/create.go b/server/create.go index 9c48b5186..01fbe5738 100644 --- a/server/create.go +++ b/server/create.go @@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) { ch <- gin.H{"error": err.Error()} } - if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) { + if err == nil && !remote { mf, mErr := manifest.ParseNamedManifest(fromName) if mErr == nil && mf.Config.Digest != "" { configPath, pErr := manifest.BlobsPath(mf.Config.Digest) @@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) { if config.Requires == "" { config.Requires = baseConfig.Requires } + if config.ModelFormat == "" { + config.ModelFormat = baseConfig.ModelFormat + } if len(config.Capabilities) == 0 { config.Capabilities = baseConfig.Capabilities } diff --git a/server/model.go b/server/model.go index 57190ffe0..06bde52b6 100644 --- a/server/model.go +++ b/server/model.go @@ -41,11 +41,12 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - for _, layer := range m.Layers { - layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) + for _, srcLayer := range m.Layers { + layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest()) if err != nil { return nil, err } + layer.Name = srcLayer.Name switch layer.MediaType { case "application/vnd.ollama.image.model", diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 401f98d9d..75bdac73b 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -1024,3 +1024,272 @@ func TestDetectModelTypeFromFiles(t *testing.T) { } }) } + +// createTestBlob creates a blob in the blobs directory and returns its digest. +func createTestBlob(t *testing.T, data []byte) string { + t.Helper() + digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + blobPath, err := manifest.BlobsPath(digest) + if err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Dir(blobPath), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(blobPath, data, 0o644); err != nil { + t.Fatal(err) + } + return digest +} + +// createSafetensorsTestModel creates a minimal safetensors model manifest for testing. +func createSafetensorsTestModel(t *testing.T, modelName string, config model.ConfigV2, extraLayers []manifest.Layer) { + t.Helper() + + // Create a fake tensor blob + tensorData := []byte("fake-tensor-data-for-testing") + tensorDigest := createTestBlob(t, tensorData) + + layers := []manifest.Layer{ + { + MediaType: manifest.MediaTypeImageTensor, + Digest: tensorDigest, + Size: int64(len(tensorData)), + Name: "model.embed_tokens.weight", + }, + } + layers = append(layers, extraLayers...) + + configLayer, err := createConfigLayer(layers, config) + if err != nil { + t.Fatalf("failed to create config layer: %v", err) + } + + name := model.ParseName(modelName) + if err := manifest.WriteManifest(name, *configLayer, layers); err != nil { + t.Fatalf("failed to write manifest: %v", err) + } +} + +func TestCreateFromSafetensorsModel_PreservesConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create a source safetensors model with specific config fields + createSafetensorsTestModel(t, "source-model", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + Requires: "0.14.0", + Renderer: "gemma3", + Parser: "gemma3", + }, nil) + + // Create a derived model FROM the source + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-model", + From: "source-model", + System: "You are a pirate.", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Read the derived model's config + derivedName := model.ParseName("derived-model") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + configBlobPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("failed to get config blob path: %v", err) + } + + configBlob, err := os.ReadFile(configBlobPath) + if err != nil { + t.Fatalf("failed to read config blob: %v", err) + } + + var cfg model.ConfigV2 + if err := json.Unmarshal(configBlob, &cfg); err != nil { + t.Fatalf("failed to unmarshal config: %v", err) + } + + // Verify safetensors-specific config fields are preserved + if cfg.ModelFormat != "safetensors" { + t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors") + } + + if !slices.Contains(cfg.Capabilities, "completion") { + t.Errorf("Capabilities = %v, want to contain %q", cfg.Capabilities, "completion") + } + + if cfg.Requires != "0.14.0" { + t.Errorf("Requires = %q, want %q", cfg.Requires, "0.14.0") + } + + if cfg.Renderer != "gemma3" { + t.Errorf("Renderer = %q, want %q", cfg.Renderer, "gemma3") + } + + if cfg.Parser != "gemma3" { + t.Errorf("Parser = %q, want %q", cfg.Parser, "gemma3") + } + + // Verify system prompt was added + var hasSystem bool + for _, l := range mf.Layers { + if l.MediaType == "application/vnd.ollama.image.system" { + hasSystem = true + break + } + } + if !hasSystem { + t.Error("expected system prompt layer in derived model") + } + + // Verify tensor layers were copied with names preserved + var tensorNames []string + for _, l := range mf.Layers { + if l.MediaType == manifest.MediaTypeImageTensor { + tensorNames = append(tensorNames, l.Name) + } + } + if len(tensorNames) == 0 { + t.Error("expected tensor layers in derived model") + } + for _, name := range tensorNames { + if name == "" { + t.Error("tensor layer has empty name — names must be preserved from source") + } + } +} + +func TestCreateFromSafetensorsModel_OverrideSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create source with a system prompt + createSafetensorsTestModel(t, "source-with-system", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + }, nil) + + // First create with system prompt + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "source-with-system", + From: "source-with-system", + System: "Original system prompt", + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Now create a derived model with a different system prompt + w = createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-new-system", + From: "source-with-system", + System: "New system prompt", + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify ModelFormat is preserved even after override + derivedName := model.ParseName("derived-new-system") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + configBlobPath, _ := manifest.BlobsPath(mf.Config.Digest) + configBlob, _ := os.ReadFile(configBlobPath) + + var cfg model.ConfigV2 + json.Unmarshal(configBlob, &cfg) + + if cfg.ModelFormat != "safetensors" { + t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors") + } +} + +func TestCreateFromSafetensorsModel_PreservesLayerNames(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create JSON config blobs to include as layers + configJSON := []byte(`{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`) + configDigest := createTestBlob(t, configJSON) + tokenizerJSON := []byte(`{"version": "1.0"}`) + tokenizerDigest := createTestBlob(t, tokenizerJSON) + + extraLayers := []manifest.Layer{ + { + MediaType: "application/vnd.ollama.image.json", + Digest: configDigest, + Size: int64(len(configJSON)), + Name: "config.json", + }, + { + MediaType: "application/vnd.ollama.image.json", + Digest: tokenizerDigest, + Size: int64(len(tokenizerJSON)), + Name: "tokenizer.json", + }, + } + + createSafetensorsTestModel(t, "source-named-layers", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + }, extraLayers) + + // Create derived model + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-named-layers", + From: "source-named-layers", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + derivedName := model.ParseName("derived-named-layers") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + // Check tensor layer names are preserved + for _, l := range mf.Layers { + if l.MediaType == manifest.MediaTypeImageTensor && l.Name == "" { + t.Error("tensor layer has empty name — names must be preserved from source") + } + } + + // Check JSON layer names are preserved + jsonNames := make(map[string]bool) + for _, l := range mf.Layers { + if l.MediaType == "application/vnd.ollama.image.json" && l.Name != "" { + jsonNames[l.Name] = true + } + } + + if !jsonNames["config.json"] { + t.Error("config.json layer name not preserved in derived model") + } + if !jsonNames["tokenizer.json"] { + t.Error("tokenizer.json layer name not preserved in derived model") + } +} diff --git a/x/create/client/create.go b/x/create/client/create.go index d380ae718..74abb865e 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -22,7 +22,7 @@ import ( "github.com/ollama/ollama/progress" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/create" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // MinOllamaVersion is the minimum Ollama version required for safetensors models. diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index 539a5360b..286ea2208 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -429,3 +429,159 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) { t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8") } } + +func TestSupportsThinking(t *testing.T) { + tests := []struct { + name string + configJSON string + want bool + }{ + { + name: "qwen3 architecture", + configJSON: `{"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"}`, + want: true, + }, + { + name: "deepseek architecture", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: true, + }, + { + name: "glm4moe architecture", + configJSON: `{"architectures": ["GLM4MoeForCausalLM"]}`, + want: true, + }, + { + name: "llama architecture (no thinking)", + configJSON: `{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`, + want: false, + }, + { + name: "gemma architecture (no thinking)", + configJSON: `{"architectures": ["Gemma3ForCausalLM"], "model_type": "gemma3"}`, + want: false, + }, + { + name: "model_type only", + configJSON: `{"model_type": "deepseek"}`, + want: true, + }, + { + name: "empty config", + configJSON: `{}`, + want: false, + }, + { + name: "invalid json", + configJSON: `not json`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := supportsThinking(dir); got != tt.want { + t.Errorf("supportsThinking() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSupportsThinking_NoConfig(t *testing.T) { + if supportsThinking(t.TempDir()) { + t.Error("supportsThinking should return false for missing config.json") + } +} + +func TestGetParserName(t *testing.T) { + tests := []struct { + name string + configJSON string + want string + }{ + { + name: "qwen3 model", + configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`, + want: "qwen3", + }, + { + name: "deepseek model", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: "deepseek3", + }, + { + name: "glm4 model", + configJSON: `{"architectures": ["GLM4ForCausalLM"]}`, + want: "glm-4.7", + }, + { + name: "llama model (no parser)", + configJSON: `{"architectures": ["LlamaForCausalLM"]}`, + want: "", + }, + { + name: "qwen3 via model_type", + configJSON: `{"model_type": "qwen3"}`, + want: "qwen3", + }, + { + name: "no config", + configJSON: `{}`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := getParserName(dir); got != tt.want { + t.Errorf("getParserName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetRendererName(t *testing.T) { + tests := []struct { + name string + configJSON string + want string + }{ + { + name: "qwen3 model", + configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`, + want: "qwen3-coder", + }, + { + name: "deepseek model", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: "deepseek3", + }, + { + name: "glm4 model", + configJSON: `{"architectures": ["GLM4ForCausalLM"]}`, + want: "glm-4.7", + }, + { + name: "llama model (no renderer)", + configJSON: `{"architectures": ["LlamaForCausalLM"]}`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := getRendererName(dir); got != tt.want { + t.Errorf("getRendererName() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/x/create/create.go b/x/create/create.go index 5e234904f..da544e6a4 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -13,7 +13,7 @@ import ( "strings" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // ModelConfig represents the config blob stored with a model. diff --git a/x/create/create_test.go b/x/create/create_test.go index 6fd419cef..f5a68ba1a 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -12,7 +12,7 @@ import ( "testing" "github.com/d4l3k/go-bfloat16" - st "github.com/ollama/ollama/x/imagegen/safetensors" + st "github.com/ollama/ollama/x/safetensors" ) func TestIsTensorModelDir(t *testing.T) { diff --git a/x/create/imagegen.go b/x/create/imagegen.go index 6dbbcbfcc..09c832137 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -9,7 +9,7 @@ import ( "path/filepath" "strings" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // CreateImageGenModel imports an image generation model from a directory. diff --git a/x/create/qwen35.go b/x/create/qwen35.go index 0431fc452..a8ac98848 100644 --- a/x/create/qwen35.go +++ b/x/create/qwen35.go @@ -7,7 +7,7 @@ import ( "path/filepath" "strings" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) type qwen35ImportTransform struct { diff --git a/x/imagegen/safetensors/extractor.go b/x/safetensors/extractor.go similarity index 99% rename from x/imagegen/safetensors/extractor.go rename to x/safetensors/extractor.go index 549222eab..f4f7e5d87 100644 --- a/x/imagegen/safetensors/extractor.go +++ b/x/safetensors/extractor.go @@ -11,7 +11,6 @@ import ( ) // tensorInfo holds tensor metadata from safetensors headers. -// This avoids depending on safetensors.go which requires the mlx tag. type tensorInfo struct { Dtype string `json:"dtype"` Shape []int32 `json:"shape"` diff --git a/x/safetensors/extractor_test.go b/x/safetensors/extractor_test.go new file mode 100644 index 000000000..0b9e1efe6 --- /dev/null +++ b/x/safetensors/extractor_test.go @@ -0,0 +1,394 @@ +package safetensors + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "os" + "path/filepath" + "slices" + "testing" +) + +// createTestSafetensors creates a minimal valid safetensors file with the given tensors. +func createTestSafetensors(t *testing.T, path string, tensors map[string]struct { + dtype string + shape []int32 + data []byte +}, +) { + t.Helper() + + header := make(map[string]tensorInfo) + var offset int + var allData []byte + + // Sort names for deterministic file layout + names := make([]string, 0, len(tensors)) + for name := range tensors { + names = append(names, name) + } + slices.Sort(names) + + for _, name := range names { + info := tensors[name] + header[name] = tensorInfo{ + Dtype: info.dtype, + Shape: info.shape, + DataOffsets: [2]int{offset, offset + len(info.data)}, + } + allData = append(allData, info.data...) + offset += len(info.data) + } + + headerJSON, err := json.Marshal(header) + if err != nil { + t.Fatalf("failed to marshal header: %v", err) + } + + // Pad to 8-byte alignment + padding := (8 - len(headerJSON)%8) % 8 + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + f, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil { + t.Fatalf("failed to write header size: %v", err) + } + if _, err := f.Write(headerJSON); err != nil { + t.Fatalf("failed to write header: %v", err) + } + if _, err := f.Write(allData); err != nil { + t.Fatalf("failed to write data: %v", err) + } +} + +func TestOpenForExtraction(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + // 4 float32 values = 16 bytes + data := make([]byte, 16) + binary.LittleEndian.PutUint32(data[0:4], 0x3f800000) // 1.0 + binary.LittleEndian.PutUint32(data[4:8], 0x40000000) // 2.0 + binary.LittleEndian.PutUint32(data[8:12], 0x40400000) // 3.0 + binary.LittleEndian.PutUint32(data[12:16], 0x40800000) // 4.0 + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "test_tensor": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + if ext.TensorCount() != 1 { + t.Errorf("TensorCount() = %d, want 1", ext.TensorCount()) + } + + names := ext.ListTensors() + if len(names) != 1 || names[0] != "test_tensor" { + t.Errorf("ListTensors() = %v, want [test_tensor]", names) + } +} + +func TestGetTensor(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + data := make([]byte, 16) + for i := range 4 { + binary.LittleEndian.PutUint32(data[i*4:], uint32(i+1)) + } + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "weight": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + td, err := ext.GetTensor("weight") + if err != nil { + t.Fatalf("GetTensor failed: %v", err) + } + + if td.Name != "weight" { + t.Errorf("Name = %q, want %q", td.Name, "weight") + } + if td.Dtype != "F32" { + t.Errorf("Dtype = %q, want %q", td.Dtype, "F32") + } + if td.Size != 16 { + t.Errorf("Size = %d, want 16", td.Size) + } + if len(td.Shape) != 2 || td.Shape[0] != 2 || td.Shape[1] != 2 { + t.Errorf("Shape = %v, want [2 2]", td.Shape) + } + + // Read the raw data + rawData, err := io.ReadAll(td.Reader()) + if err != nil { + t.Fatalf("Reader() read failed: %v", err) + } + if len(rawData) != 16 { + t.Errorf("raw data length = %d, want 16", len(rawData)) + } +} + +func TestGetTensor_NotFound(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "exists": {dtype: "F32", shape: []int32{1}, data: make([]byte, 4)}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + _, err = ext.GetTensor("missing") + if err == nil { + t.Error("expected error for missing tensor, got nil") + } +} + +func TestSafetensorsReaderRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + data := make([]byte, 16) + for i := range 4 { + binary.LittleEndian.PutUint32(data[i*4:], uint32(0x3f800000+i)) + } + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "tensor_a": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + td, err := ext.GetTensor("tensor_a") + if err != nil { + t.Fatalf("GetTensor failed: %v", err) + } + + // Wrap as safetensors and extract back + stReader := td.SafetensorsReader() + stData, err := io.ReadAll(stReader) + if err != nil { + t.Fatalf("SafetensorsReader read failed: %v", err) + } + + // Verify size + if int64(len(stData)) != td.SafetensorsSize() { + t.Errorf("SafetensorsSize() = %d, actual = %d", td.SafetensorsSize(), len(stData)) + } + + // Extract raw data back + raw, err := ExtractRawFromSafetensors(bytes.NewReader(stData)) + if err != nil { + t.Fatalf("ExtractRawFromSafetensors failed: %v", err) + } + + if !bytes.Equal(raw, data) { + t.Errorf("round-trip data mismatch: got %v, want %v", raw, data) + } +} + +func TestNewTensorDataFromBytes(t *testing.T) { + data := []byte{1, 2, 3, 4} + td := NewTensorDataFromBytes("test", "U8", []int32{4}, data) + + if td.Name != "test" { + t.Errorf("Name = %q, want %q", td.Name, "test") + } + if td.Size != 4 { + t.Errorf("Size = %d, want 4", td.Size) + } + + rawData, err := io.ReadAll(td.Reader()) + if err != nil { + t.Fatalf("Reader() failed: %v", err) + } + if !bytes.Equal(rawData, data) { + t.Errorf("data mismatch: got %v, want %v", rawData, data) + } +} + +func TestBuildPackedSafetensorsReader(t *testing.T) { + data1 := []byte{1, 2, 3, 4} + data2 := []byte{5, 6, 7, 8, 9, 10, 11, 12} + + td1 := NewTensorDataFromBytes("a", "U8", []int32{4}, data1) + td2 := NewTensorDataFromBytes("b", "U8", []int32{8}, data2) + + packed := BuildPackedSafetensorsReader([]*TensorData{td1, td2}) + packedBytes, err := io.ReadAll(packed) + if err != nil { + t.Fatalf("BuildPackedSafetensorsReader read failed: %v", err) + } + + // Verify it's a valid safetensors file by parsing the header + var headerSize uint64 + if err := binary.Read(bytes.NewReader(packedBytes), binary.LittleEndian, &headerSize); err != nil { + t.Fatalf("failed to read header size: %v", err) + } + + headerJSON := packedBytes[8 : 8+headerSize] + var header map[string]tensorInfo + if err := json.Unmarshal(headerJSON, &header); err != nil { + t.Fatalf("failed to parse header: %v", err) + } + + if len(header) != 2 { + t.Errorf("header has %d entries, want 2", len(header)) + } + + infoA, ok := header["a"] + if !ok { + t.Fatal("tensor 'a' not found in header") + } + if infoA.Dtype != "U8" { + t.Errorf("tensor 'a' dtype = %q, want %q", infoA.Dtype, "U8") + } + + infoB, ok := header["b"] + if !ok { + t.Fatal("tensor 'b' not found in header") + } + + // Verify data region contains both tensors + dataStart := 8 + int(headerSize) + dataRegion := packedBytes[dataStart:] + if infoA.DataOffsets[0] == 0 { + // a comes first + if !bytes.Equal(dataRegion[:4], data1) { + t.Error("tensor 'a' data mismatch") + } + if !bytes.Equal(dataRegion[infoB.DataOffsets[0]:infoB.DataOffsets[1]], data2) { + t.Error("tensor 'b' data mismatch") + } + } else { + // b comes first + if !bytes.Equal(dataRegion[:8], data2) { + t.Error("tensor 'b' data mismatch") + } + } +} + +func TestExtractAll(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "alpha": {dtype: "F32", shape: []int32{2}, data: make([]byte, 8)}, + "beta": {dtype: "F16", shape: []int32{4}, data: make([]byte, 8)}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + tensors, err := ext.ExtractAll() + if err != nil { + t.Fatalf("ExtractAll failed: %v", err) + } + + if len(tensors) != 2 { + t.Errorf("ExtractAll returned %d tensors, want 2", len(tensors)) + } + + // Verify sorted order + if tensors[0].Name != "alpha" || tensors[1].Name != "beta" { + t.Errorf("tensors not in sorted order: %s, %s", tensors[0].Name, tensors[1].Name) + } +} + +func TestExtractRawFromSafetensors_InvalidInput(t *testing.T) { + // Empty reader + _, err := ExtractRawFromSafetensors(bytes.NewReader(nil)) + if err == nil { + t.Error("expected error for empty reader") + } + + // Truncated header size + _, err = ExtractRawFromSafetensors(bytes.NewReader([]byte{1, 2, 3})) + if err == nil { + t.Error("expected error for truncated header size") + } +} + +func TestOpenForExtraction_MetadataIgnored(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + // Manually create a safetensors file with __metadata__ + header := map[string]any{ + "__metadata__": map[string]string{"format": "pt"}, + "weight": tensorInfo{ + Dtype: "F32", + Shape: []int32{2}, + DataOffsets: [2]int{0, 8}, + }, + } + headerJSON, _ := json.Marshal(header) + padding := (8 - len(headerJSON)%8) % 8 + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + f, _ := os.Create(path) + binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))) + f.Write(headerJSON) + f.Write(make([]byte, 8)) + f.Close() + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + // __metadata__ should be stripped + if ext.TensorCount() != 1 { + t.Errorf("TensorCount() = %d, want 1 (metadata should be stripped)", ext.TensorCount()) + } +}