diff --git a/.gitignore b/.gitignore index eabf94c28..4e0b8974f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ __debug_bin* llama/build llama/vendor /ollama +integration/testdata/models/ diff --git a/cmd/cmd.go b/cmd/cmd.go index a9bf3d4dc..250b63f40 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -54,7 +54,6 @@ import ( "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" - "github.com/ollama/ollama/x/create" xcreateclient "github.com/ollama/ollama/x/create/client" "github.com/ollama/ollama/x/imagegen" ) @@ -164,11 +163,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } // Check for --experimental flag for safetensors model creation + // This gates both safetensors LLM and imagegen model creation experimental, _ := cmd.Flags().GetBool("experimental") if experimental { if !isLocalhost() { return errors.New("remote safetensor model creation not yet supported") } + // Get Modelfile content - either from -f flag or default to "FROM ." var reader io.Reader filename, err := getModelfileName(cmd) @@ -211,23 +212,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error { }, p) } + // Standard Modelfile + API path var reader io.Reader filename, err := getModelfileName(cmd) if os.IsNotExist(err) { if filename == "" { - // No Modelfile found - check if current directory is an image gen model - if create.IsTensorModelDir(".") { - if !isLocalhost() { - return errors.New("remote safetensor model creation not yet supported") - } - quantize, _ := cmd.Flags().GetString("quantize") - return xcreateclient.CreateModel(xcreateclient.CreateOptions{ - ModelName: modelName, - ModelDir: ".", - Quantize: quantize, - }, p) - } reader = strings.NewReader("FROM .\n") } else { return errModelfileNotFound diff --git a/integration/create_imagegen_test.go b/integration/create_imagegen_test.go new file mode 100644 index 000000000..53903e7a6 --- /dev/null +++ b/integration/create_imagegen_test.go @@ -0,0 +1,107 @@ +//go:build integration && imagegen + +package integration + +import ( + "context" + "encoding/base64" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestCreateImageGen(t *testing.T) { + skipIfRemote(t) + skipUnderMinVRAM(t, 13) + + // Allow overriding the model directory via env var for local testing, + // since the model is ~33GB and may already be downloaded elsewhere. + modelDir := os.Getenv("OLLAMA_TEST_IMAGEGEN_MODEL_DIR") + if modelDir == "" { + modelDir = filepath.Join(testdataModelsDir, "Z-Image-Turbo") + downloadHFModel(t, "Tongyi-MAI/Z-Image-Turbo", modelDir) + } else { + t.Logf("Using existing imagegen model at %s", modelDir) + } + + // Verify it looks like a valid imagegen model directory + if _, err := os.Stat(filepath.Join(modelDir, "model_index.json")); err != nil { + t.Fatalf("model_index.json not found in %s — not a valid imagegen model directory", modelDir) + } + + ensureMLXLibraryPath(t) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-z-image-turbo-create" + + absModelDir, err := filepath.Abs(modelDir) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + // Create a Modelfile pointing to the diffusers model directory + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte("FROM "+absModelDir+"\n"), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + t.Logf("Creating imagegen model from %s", absModelDir) + runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile) + + // Verify model exists via show + showReq := &api.ShowRequest{Name: modelName} + showResp, err := client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + t.Logf("Created model details: %+v", showResp.Details) + + // Generate an image to verify the model isn't corrupted + t.Log("Generating test image...") + imageBase64, err := generateImage(ctx, client, modelName, "A red circle on a white background") + if err != nil { + if strings.Contains(err.Error(), "image generation not available") { + t.Skip("Target system does not support image generation") + } else if strings.Contains(err.Error(), "insufficient memory for image generation") { + t.Skip("insufficient memory for image generation") + } else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") { + t.Skip("unsupported architecture") + } + t.Fatalf("Image generation failed: %v", err) + } + + // Verify we got valid image data + imageData, err := base64.StdEncoding.DecodeString(imageBase64) + if err != nil { + t.Fatalf("Failed to decode base64 image: %v", err) + } + + t.Logf("Generated image: %d bytes", len(imageData)) + + if len(imageData) < 1000 { + t.Fatalf("Generated image suspiciously small (%d bytes), likely corrupted", len(imageData)) + } + + // Check for PNG or JPEG magic bytes + isPNG := len(imageData) >= 4 && imageData[0] == 0x89 && imageData[1] == 'P' && imageData[2] == 'N' && imageData[3] == 'G' + isJPEG := len(imageData) >= 2 && imageData[0] == 0xFF && imageData[1] == 0xD8 + if !isPNG && !isJPEG { + t.Fatalf("Generated image is neither PNG nor JPEG (first bytes: %x)", imageData[:min(8, len(imageData))]) + } + t.Logf("Image format validated (PNG=%v, JPEG=%v)", isPNG, isJPEG) + + // Cleanup: delete the model + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} diff --git a/integration/create_test.go b/integration/create_test.go new file mode 100644 index 000000000..60a1269b9 --- /dev/null +++ b/integration/create_test.go @@ -0,0 +1,350 @@ +//go:build integration + +package integration + +import ( + "context" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +const testdataModelsDir = "testdata/models" + +// skipIfRemote skips the test if OLLAMA_HOST points to a non-local server. +// Safetensors/imagegen creation requires localhost since it reads model files +// from disk and uses the --experimental CLI path. +func skipIfRemote(t *testing.T) { + t.Helper() + host := os.Getenv("OLLAMA_HOST") + if host == "" { + return // default is localhost + } + // Strip scheme if present + _, hostport, ok := strings.Cut(host, "://") + if !ok { + hostport = host + } + h, _, err := net.SplitHostPort(hostport) + if err != nil { + h = hostport + } + if h == "" || h == "localhost" { + return + } + ip := net.ParseIP(h) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) { + return + } + t.Skipf("safetensors/imagegen creation requires a local server (OLLAMA_HOST=%s)", host) +} + +// findHFCLI returns the path to the HuggingFace CLI, or "" if not found. +func findHFCLI() string { + for _, name := range []string{"huggingface-cli", "hf"} { + if p, err := exec.LookPath(name); err == nil { + return p + } + } + return "" +} + +// downloadHFModel idempotently downloads a HuggingFace model to destDir. +// Skips the test if CLI is missing and model isn't already present. +func downloadHFModel(t *testing.T, repo, destDir string, extraArgs ...string) { + t.Helper() + + // Check if model already exists + if _, err := os.Stat(destDir); err == nil { + entries, err := os.ReadDir(destDir) + if err == nil && len(entries) > 0 { + t.Logf("Model %s already present at %s", repo, destDir) + return + } + } + + cli := findHFCLI() + if cli == "" { + t.Skipf("HuggingFace CLI not found and model %s not present at %s", repo, destDir) + } + + t.Logf("Downloading %s to %s", repo, destDir) + os.MkdirAll(destDir, 0o755) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + args := []string{"download", repo, "--local-dir", destDir} + args = append(args, extraArgs...) + cmd := exec.CommandContext(ctx, cli, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("Failed to download %s: %v", repo, err) + } +} + +// ollamaBin returns the path to the ollama binary to use for tests. +// Prefers OLLAMA_BIN env, then falls back to the built binary at ../ollama +// (same binary the integration test server uses). +func ollamaBin() string { + if bin := os.Getenv("OLLAMA_BIN"); bin != "" { + return bin + } + if abs, err := filepath.Abs("../ollama"); err == nil { + if _, err := os.Stat(abs); err == nil { + return abs + } + } + return "ollama" +} + +// ensureMLXLibraryPath sets OLLAMA_LIBRARY_PATH so the MLX dynamic library +// is discoverable. Integration tests run from integration/ dir, so the +// default CWD-based search won't find the library at the repo root. +func ensureMLXLibraryPath(t *testing.T) { + t.Helper() + if libPath, err := filepath.Abs("../build/lib/ollama"); err == nil { + if _, err := os.Stat(libPath); err == nil { + if existing := os.Getenv("OLLAMA_LIBRARY_PATH"); existing != "" { + t.Setenv("OLLAMA_LIBRARY_PATH", existing+string(filepath.ListSeparator)+libPath) + } else { + t.Setenv("OLLAMA_LIBRARY_PATH", libPath) + } + } + } +} + +// runOllamaCreate runs "ollama create" as a subprocess. Skips the test if +// the error indicates the server is remote. +func runOllamaCreate(ctx context.Context, t *testing.T, args ...string) { + t.Helper() + createCmd := exec.CommandContext(ctx, ollamaBin(), append([]string{"create"}, args...)...) + var createStderr strings.Builder + createCmd.Stdout = os.Stdout + createCmd.Stderr = io.MultiWriter(os.Stderr, &createStderr) + if err := createCmd.Run(); err != nil { + if strings.Contains(createStderr.String(), "remote") { + t.Skip("safetensors creation requires a local server") + } + t.Fatalf("ollama create failed: %v", err) + } +} + +func TestCreateSafetensorsLLM(t *testing.T) { + skipIfRemote(t) + + modelDir := filepath.Join(testdataModelsDir, "TinyLlama-1.1B") + downloadHFModel(t, "TinyLlama/TinyLlama-1.1B-Chat-v1.0", modelDir) + + ensureMLXLibraryPath(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-tinyllama-safetensors" + + absModelDir, err := filepath.Abs(modelDir) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + // Create a Modelfile pointing to the model directory. + // Include a chat template since the safetensors importer doesn't extract + // chat_template from tokenizer_config.json yet. + modelfileContent := "FROM " + absModelDir + "\n" + + "TEMPLATE \"{{ if .System }}<|system|>\n{{ .System }}\n{{ end }}" + + "{{ if .Prompt }}<|user|>\n{{ .Prompt }}\n{{ end }}" + + "<|assistant|>\n{{ .Response }}\n\"\n" + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte(modelfileContent), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile) + + // Verify model exists via show + showReq := &api.ShowRequest{Name: modelName} + showResp, err := client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + t.Logf("Created model details: %+v", showResp.Details) + + // Use the chat API for proper template application. + chatReq := &api.ChatRequest{ + Model: modelName, + Messages: []api.Message{ + {Role: "user", Content: "Write a short sentence about the weather."}, + }, + Options: map[string]interface{}{ + "num_predict": 20, + "temperature": 0.0, + }, + } + + var output strings.Builder + err = client.Chat(ctx, chatReq, func(resp api.ChatResponse) error { + output.WriteString(resp.Message.Content) + return nil + }) + if err != nil { + t.Fatalf("Chat failed: %v", err) + } + + text := output.String() + t.Logf("Generated output: %q", text) + assertCoherentOutput(t, text) + + // Cleanup: delete the model + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} + +func TestCreateGGUF(t *testing.T) { + modelDir := filepath.Join(testdataModelsDir, "Llama-3.2-1B-GGUF") + downloadHFModel(t, "bartowski/Llama-3.2-1B-Instruct-GGUF", modelDir, + "--include", "Llama-3.2-1B-Instruct-IQ3_M.gguf") + + // Find the GGUF file + entries, err := os.ReadDir(modelDir) + if err != nil { + t.Fatalf("Failed to read model dir: %v", err) + } + + var ggufPath string + for _, e := range entries { + if filepath.Ext(e.Name()) == ".gguf" { + ggufPath = filepath.Join(modelDir, e.Name()) + break + } + } + if ggufPath == "" { + t.Skip("No GGUF file found in model directory") + } + + absGGUF, err := filepath.Abs(ggufPath) + if err != nil { + t.Fatalf("Failed to get absolute path: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "test-llama32-gguf" + + // Create a Modelfile and use the CLI + tmpModelfile := filepath.Join(t.TempDir(), "Modelfile") + if err := os.WriteFile(tmpModelfile, []byte("FROM "+absGGUF+"\n"), 0o644); err != nil { + t.Fatalf("Failed to write Modelfile: %v", err) + } + + createCmd := exec.CommandContext(ctx, ollamaBin(), "create", modelName, "-f", tmpModelfile) + createCmd.Stdout = os.Stdout + createCmd.Stderr = os.Stderr + if err := createCmd.Run(); err != nil { + t.Fatalf("ollama create failed: %v", err) + } + + // Verify model exists + showReq := &api.ShowRequest{Name: modelName} + _, err = client.Show(ctx, showReq) + if err != nil { + t.Fatalf("Model show failed after create: %v", err) + } + + // Generate and verify output is coherent + genReq := &api.GenerateRequest{ + Model: modelName, + Prompt: "Write a short sentence about the weather.", + Options: map[string]interface{}{ + "num_predict": 20, + "temperature": 0.0, + }, + } + + var output strings.Builder + err = client.Generate(ctx, genReq, func(resp api.GenerateResponse) error { + output.WriteString(resp.Response) + return nil + }) + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + + text := output.String() + t.Logf("Generated output: %q", text) + assertCoherentOutput(t, text) + + // Cleanup + deleteReq := &api.DeleteRequest{Model: modelName} + if err := client.Delete(ctx, deleteReq); err != nil { + t.Logf("Warning: failed to delete test model: %v", err) + } +} + +// assertCoherentOutput checks that model output looks like real language, not +// garbled binary or repeated garbage. This catches corrupted model creation +// where inference "works" but produces nonsense. +func assertCoherentOutput(t *testing.T, text string) { + t.Helper() + + if len(text) == 0 { + t.Fatal("model produced empty output") + } + + // Check minimum length — 20 tokens should produce at least a few words + if len(text) < 5 { + t.Fatalf("model output suspiciously short (%d bytes): %q", len(text), text) + } + + // Check for mostly-printable ASCII/Unicode — garbled models often emit + // high ratios of control characters or replacement characters + unprintable := 0 + for _, r := range text { + if r < 0x20 && r != '\n' && r != '\r' && r != '\t' { + unprintable++ + } + if r == '\ufffd' { // Unicode replacement character + unprintable++ + } + } + ratio := float64(unprintable) / float64(len([]rune(text))) + if ratio > 0.3 { + t.Fatalf("model output is %.0f%% unprintable characters (likely garbled): %q", ratio*100, text) + } + + // Check it contains at least one space — real language has word boundaries + if !strings.Contains(text, " ") { + t.Fatalf("model output contains no spaces (likely garbled): %q", text) + } + + // Check for excessive repetition — a broken model might repeat one token + words := strings.Fields(text) + if len(words) >= 4 { + counts := map[string]int{} + for _, w := range words { + counts[strings.ToLower(w)]++ + } + for w, c := range counts { + if c > len(words)*3/4 { + t.Fatalf("model output is excessively repetitive (%q appears %d/%d times): %q", w, c, len(words), text) + } + } + } +} diff --git a/server/create.go b/server/create.go index 9c48b5186..01fbe5738 100644 --- a/server/create.go +++ b/server/create.go @@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) { ch <- gin.H{"error": err.Error()} } - if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) { + if err == nil && !remote { mf, mErr := manifest.ParseNamedManifest(fromName) if mErr == nil && mf.Config.Digest != "" { configPath, pErr := manifest.BlobsPath(mf.Config.Digest) @@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) { if config.Requires == "" { config.Requires = baseConfig.Requires } + if config.ModelFormat == "" { + config.ModelFormat = baseConfig.ModelFormat + } if len(config.Capabilities) == 0 { config.Capabilities = baseConfig.Capabilities } diff --git a/server/model.go b/server/model.go index 57190ffe0..06bde52b6 100644 --- a/server/model.go +++ b/server/model.go @@ -41,11 +41,12 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - for _, layer := range m.Layers { - layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) + for _, srcLayer := range m.Layers { + layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest()) if err != nil { return nil, err } + layer.Name = srcLayer.Name switch layer.MediaType { case "application/vnd.ollama.image.model", diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 401f98d9d..75bdac73b 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -1024,3 +1024,272 @@ func TestDetectModelTypeFromFiles(t *testing.T) { } }) } + +// createTestBlob creates a blob in the blobs directory and returns its digest. +func createTestBlob(t *testing.T, data []byte) string { + t.Helper() + digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + blobPath, err := manifest.BlobsPath(digest) + if err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Dir(blobPath), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(blobPath, data, 0o644); err != nil { + t.Fatal(err) + } + return digest +} + +// createSafetensorsTestModel creates a minimal safetensors model manifest for testing. +func createSafetensorsTestModel(t *testing.T, modelName string, config model.ConfigV2, extraLayers []manifest.Layer) { + t.Helper() + + // Create a fake tensor blob + tensorData := []byte("fake-tensor-data-for-testing") + tensorDigest := createTestBlob(t, tensorData) + + layers := []manifest.Layer{ + { + MediaType: manifest.MediaTypeImageTensor, + Digest: tensorDigest, + Size: int64(len(tensorData)), + Name: "model.embed_tokens.weight", + }, + } + layers = append(layers, extraLayers...) + + configLayer, err := createConfigLayer(layers, config) + if err != nil { + t.Fatalf("failed to create config layer: %v", err) + } + + name := model.ParseName(modelName) + if err := manifest.WriteManifest(name, *configLayer, layers); err != nil { + t.Fatalf("failed to write manifest: %v", err) + } +} + +func TestCreateFromSafetensorsModel_PreservesConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create a source safetensors model with specific config fields + createSafetensorsTestModel(t, "source-model", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + Requires: "0.14.0", + Renderer: "gemma3", + Parser: "gemma3", + }, nil) + + // Create a derived model FROM the source + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-model", + From: "source-model", + System: "You are a pirate.", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Read the derived model's config + derivedName := model.ParseName("derived-model") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + configBlobPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("failed to get config blob path: %v", err) + } + + configBlob, err := os.ReadFile(configBlobPath) + if err != nil { + t.Fatalf("failed to read config blob: %v", err) + } + + var cfg model.ConfigV2 + if err := json.Unmarshal(configBlob, &cfg); err != nil { + t.Fatalf("failed to unmarshal config: %v", err) + } + + // Verify safetensors-specific config fields are preserved + if cfg.ModelFormat != "safetensors" { + t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors") + } + + if !slices.Contains(cfg.Capabilities, "completion") { + t.Errorf("Capabilities = %v, want to contain %q", cfg.Capabilities, "completion") + } + + if cfg.Requires != "0.14.0" { + t.Errorf("Requires = %q, want %q", cfg.Requires, "0.14.0") + } + + if cfg.Renderer != "gemma3" { + t.Errorf("Renderer = %q, want %q", cfg.Renderer, "gemma3") + } + + if cfg.Parser != "gemma3" { + t.Errorf("Parser = %q, want %q", cfg.Parser, "gemma3") + } + + // Verify system prompt was added + var hasSystem bool + for _, l := range mf.Layers { + if l.MediaType == "application/vnd.ollama.image.system" { + hasSystem = true + break + } + } + if !hasSystem { + t.Error("expected system prompt layer in derived model") + } + + // Verify tensor layers were copied with names preserved + var tensorNames []string + for _, l := range mf.Layers { + if l.MediaType == manifest.MediaTypeImageTensor { + tensorNames = append(tensorNames, l.Name) + } + } + if len(tensorNames) == 0 { + t.Error("expected tensor layers in derived model") + } + for _, name := range tensorNames { + if name == "" { + t.Error("tensor layer has empty name — names must be preserved from source") + } + } +} + +func TestCreateFromSafetensorsModel_OverrideSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create source with a system prompt + createSafetensorsTestModel(t, "source-with-system", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + }, nil) + + // First create with system prompt + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "source-with-system", + From: "source-with-system", + System: "Original system prompt", + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Now create a derived model with a different system prompt + w = createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-new-system", + From: "source-with-system", + System: "New system prompt", + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify ModelFormat is preserved even after override + derivedName := model.ParseName("derived-new-system") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + configBlobPath, _ := manifest.BlobsPath(mf.Config.Digest) + configBlob, _ := os.ReadFile(configBlobPath) + + var cfg model.ConfigV2 + json.Unmarshal(configBlob, &cfg) + + if cfg.ModelFormat != "safetensors" { + t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors") + } +} + +func TestCreateFromSafetensorsModel_PreservesLayerNames(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + // Create JSON config blobs to include as layers + configJSON := []byte(`{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`) + configDigest := createTestBlob(t, configJSON) + tokenizerJSON := []byte(`{"version": "1.0"}`) + tokenizerDigest := createTestBlob(t, tokenizerJSON) + + extraLayers := []manifest.Layer{ + { + MediaType: "application/vnd.ollama.image.json", + Digest: configDigest, + Size: int64(len(configJSON)), + Name: "config.json", + }, + { + MediaType: "application/vnd.ollama.image.json", + Digest: tokenizerDigest, + Size: int64(len(tokenizerJSON)), + Name: "tokenizer.json", + }, + } + + createSafetensorsTestModel(t, "source-named-layers", model.ConfigV2{ + ModelFormat: "safetensors", + Capabilities: []string{"completion"}, + }, extraLayers) + + // Create derived model + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "derived-named-layers", + From: "source-named-layers", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + derivedName := model.ParseName("derived-named-layers") + mf, err := manifest.ParseNamedManifest(derivedName) + if err != nil { + t.Fatalf("failed to parse derived manifest: %v", err) + } + + // Check tensor layer names are preserved + for _, l := range mf.Layers { + if l.MediaType == manifest.MediaTypeImageTensor && l.Name == "" { + t.Error("tensor layer has empty name — names must be preserved from source") + } + } + + // Check JSON layer names are preserved + jsonNames := make(map[string]bool) + for _, l := range mf.Layers { + if l.MediaType == "application/vnd.ollama.image.json" && l.Name != "" { + jsonNames[l.Name] = true + } + } + + if !jsonNames["config.json"] { + t.Error("config.json layer name not preserved in derived model") + } + if !jsonNames["tokenizer.json"] { + t.Error("tokenizer.json layer name not preserved in derived model") + } +} diff --git a/x/create/client/create.go b/x/create/client/create.go index d380ae718..74abb865e 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -22,7 +22,7 @@ import ( "github.com/ollama/ollama/progress" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/create" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // MinOllamaVersion is the minimum Ollama version required for safetensors models. diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index 539a5360b..286ea2208 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -429,3 +429,159 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) { t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8") } } + +func TestSupportsThinking(t *testing.T) { + tests := []struct { + name string + configJSON string + want bool + }{ + { + name: "qwen3 architecture", + configJSON: `{"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"}`, + want: true, + }, + { + name: "deepseek architecture", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: true, + }, + { + name: "glm4moe architecture", + configJSON: `{"architectures": ["GLM4MoeForCausalLM"]}`, + want: true, + }, + { + name: "llama architecture (no thinking)", + configJSON: `{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`, + want: false, + }, + { + name: "gemma architecture (no thinking)", + configJSON: `{"architectures": ["Gemma3ForCausalLM"], "model_type": "gemma3"}`, + want: false, + }, + { + name: "model_type only", + configJSON: `{"model_type": "deepseek"}`, + want: true, + }, + { + name: "empty config", + configJSON: `{}`, + want: false, + }, + { + name: "invalid json", + configJSON: `not json`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := supportsThinking(dir); got != tt.want { + t.Errorf("supportsThinking() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSupportsThinking_NoConfig(t *testing.T) { + if supportsThinking(t.TempDir()) { + t.Error("supportsThinking should return false for missing config.json") + } +} + +func TestGetParserName(t *testing.T) { + tests := []struct { + name string + configJSON string + want string + }{ + { + name: "qwen3 model", + configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`, + want: "qwen3", + }, + { + name: "deepseek model", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: "deepseek3", + }, + { + name: "glm4 model", + configJSON: `{"architectures": ["GLM4ForCausalLM"]}`, + want: "glm-4.7", + }, + { + name: "llama model (no parser)", + configJSON: `{"architectures": ["LlamaForCausalLM"]}`, + want: "", + }, + { + name: "qwen3 via model_type", + configJSON: `{"model_type": "qwen3"}`, + want: "qwen3", + }, + { + name: "no config", + configJSON: `{}`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := getParserName(dir); got != tt.want { + t.Errorf("getParserName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetRendererName(t *testing.T) { + tests := []struct { + name string + configJSON string + want string + }{ + { + name: "qwen3 model", + configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`, + want: "qwen3-coder", + }, + { + name: "deepseek model", + configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`, + want: "deepseek3", + }, + { + name: "glm4 model", + configJSON: `{"architectures": ["GLM4ForCausalLM"]}`, + want: "glm-4.7", + }, + { + name: "llama model (no renderer)", + configJSON: `{"architectures": ["LlamaForCausalLM"]}`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644) + + if got := getRendererName(dir); got != tt.want { + t.Errorf("getRendererName() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/x/create/create.go b/x/create/create.go index 5e234904f..da544e6a4 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -13,7 +13,7 @@ import ( "strings" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // ModelConfig represents the config blob stored with a model. diff --git a/x/create/create_test.go b/x/create/create_test.go index 6fd419cef..f5a68ba1a 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -12,7 +12,7 @@ import ( "testing" "github.com/d4l3k/go-bfloat16" - st "github.com/ollama/ollama/x/imagegen/safetensors" + st "github.com/ollama/ollama/x/safetensors" ) func TestIsTensorModelDir(t *testing.T) { diff --git a/x/create/imagegen.go b/x/create/imagegen.go index 6dbbcbfcc..09c832137 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -9,7 +9,7 @@ import ( "path/filepath" "strings" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) // CreateImageGenModel imports an image generation model from a directory. diff --git a/x/create/qwen35.go b/x/create/qwen35.go index 0431fc452..a8ac98848 100644 --- a/x/create/qwen35.go +++ b/x/create/qwen35.go @@ -7,7 +7,7 @@ import ( "path/filepath" "strings" - "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/safetensors" ) type qwen35ImportTransform struct { diff --git a/x/imagegen/safetensors/extractor.go b/x/safetensors/extractor.go similarity index 99% rename from x/imagegen/safetensors/extractor.go rename to x/safetensors/extractor.go index 549222eab..f4f7e5d87 100644 --- a/x/imagegen/safetensors/extractor.go +++ b/x/safetensors/extractor.go @@ -11,7 +11,6 @@ import ( ) // tensorInfo holds tensor metadata from safetensors headers. -// This avoids depending on safetensors.go which requires the mlx tag. type tensorInfo struct { Dtype string `json:"dtype"` Shape []int32 `json:"shape"` diff --git a/x/safetensors/extractor_test.go b/x/safetensors/extractor_test.go new file mode 100644 index 000000000..0b9e1efe6 --- /dev/null +++ b/x/safetensors/extractor_test.go @@ -0,0 +1,394 @@ +package safetensors + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "os" + "path/filepath" + "slices" + "testing" +) + +// createTestSafetensors creates a minimal valid safetensors file with the given tensors. +func createTestSafetensors(t *testing.T, path string, tensors map[string]struct { + dtype string + shape []int32 + data []byte +}, +) { + t.Helper() + + header := make(map[string]tensorInfo) + var offset int + var allData []byte + + // Sort names for deterministic file layout + names := make([]string, 0, len(tensors)) + for name := range tensors { + names = append(names, name) + } + slices.Sort(names) + + for _, name := range names { + info := tensors[name] + header[name] = tensorInfo{ + Dtype: info.dtype, + Shape: info.shape, + DataOffsets: [2]int{offset, offset + len(info.data)}, + } + allData = append(allData, info.data...) + offset += len(info.data) + } + + headerJSON, err := json.Marshal(header) + if err != nil { + t.Fatalf("failed to marshal header: %v", err) + } + + // Pad to 8-byte alignment + padding := (8 - len(headerJSON)%8) % 8 + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + f, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil { + t.Fatalf("failed to write header size: %v", err) + } + if _, err := f.Write(headerJSON); err != nil { + t.Fatalf("failed to write header: %v", err) + } + if _, err := f.Write(allData); err != nil { + t.Fatalf("failed to write data: %v", err) + } +} + +func TestOpenForExtraction(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + // 4 float32 values = 16 bytes + data := make([]byte, 16) + binary.LittleEndian.PutUint32(data[0:4], 0x3f800000) // 1.0 + binary.LittleEndian.PutUint32(data[4:8], 0x40000000) // 2.0 + binary.LittleEndian.PutUint32(data[8:12], 0x40400000) // 3.0 + binary.LittleEndian.PutUint32(data[12:16], 0x40800000) // 4.0 + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "test_tensor": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + if ext.TensorCount() != 1 { + t.Errorf("TensorCount() = %d, want 1", ext.TensorCount()) + } + + names := ext.ListTensors() + if len(names) != 1 || names[0] != "test_tensor" { + t.Errorf("ListTensors() = %v, want [test_tensor]", names) + } +} + +func TestGetTensor(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + data := make([]byte, 16) + for i := range 4 { + binary.LittleEndian.PutUint32(data[i*4:], uint32(i+1)) + } + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "weight": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + td, err := ext.GetTensor("weight") + if err != nil { + t.Fatalf("GetTensor failed: %v", err) + } + + if td.Name != "weight" { + t.Errorf("Name = %q, want %q", td.Name, "weight") + } + if td.Dtype != "F32" { + t.Errorf("Dtype = %q, want %q", td.Dtype, "F32") + } + if td.Size != 16 { + t.Errorf("Size = %d, want 16", td.Size) + } + if len(td.Shape) != 2 || td.Shape[0] != 2 || td.Shape[1] != 2 { + t.Errorf("Shape = %v, want [2 2]", td.Shape) + } + + // Read the raw data + rawData, err := io.ReadAll(td.Reader()) + if err != nil { + t.Fatalf("Reader() read failed: %v", err) + } + if len(rawData) != 16 { + t.Errorf("raw data length = %d, want 16", len(rawData)) + } +} + +func TestGetTensor_NotFound(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "exists": {dtype: "F32", shape: []int32{1}, data: make([]byte, 4)}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + _, err = ext.GetTensor("missing") + if err == nil { + t.Error("expected error for missing tensor, got nil") + } +} + +func TestSafetensorsReaderRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + data := make([]byte, 16) + for i := range 4 { + binary.LittleEndian.PutUint32(data[i*4:], uint32(0x3f800000+i)) + } + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "tensor_a": {dtype: "F32", shape: []int32{2, 2}, data: data}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + td, err := ext.GetTensor("tensor_a") + if err != nil { + t.Fatalf("GetTensor failed: %v", err) + } + + // Wrap as safetensors and extract back + stReader := td.SafetensorsReader() + stData, err := io.ReadAll(stReader) + if err != nil { + t.Fatalf("SafetensorsReader read failed: %v", err) + } + + // Verify size + if int64(len(stData)) != td.SafetensorsSize() { + t.Errorf("SafetensorsSize() = %d, actual = %d", td.SafetensorsSize(), len(stData)) + } + + // Extract raw data back + raw, err := ExtractRawFromSafetensors(bytes.NewReader(stData)) + if err != nil { + t.Fatalf("ExtractRawFromSafetensors failed: %v", err) + } + + if !bytes.Equal(raw, data) { + t.Errorf("round-trip data mismatch: got %v, want %v", raw, data) + } +} + +func TestNewTensorDataFromBytes(t *testing.T) { + data := []byte{1, 2, 3, 4} + td := NewTensorDataFromBytes("test", "U8", []int32{4}, data) + + if td.Name != "test" { + t.Errorf("Name = %q, want %q", td.Name, "test") + } + if td.Size != 4 { + t.Errorf("Size = %d, want 4", td.Size) + } + + rawData, err := io.ReadAll(td.Reader()) + if err != nil { + t.Fatalf("Reader() failed: %v", err) + } + if !bytes.Equal(rawData, data) { + t.Errorf("data mismatch: got %v, want %v", rawData, data) + } +} + +func TestBuildPackedSafetensorsReader(t *testing.T) { + data1 := []byte{1, 2, 3, 4} + data2 := []byte{5, 6, 7, 8, 9, 10, 11, 12} + + td1 := NewTensorDataFromBytes("a", "U8", []int32{4}, data1) + td2 := NewTensorDataFromBytes("b", "U8", []int32{8}, data2) + + packed := BuildPackedSafetensorsReader([]*TensorData{td1, td2}) + packedBytes, err := io.ReadAll(packed) + if err != nil { + t.Fatalf("BuildPackedSafetensorsReader read failed: %v", err) + } + + // Verify it's a valid safetensors file by parsing the header + var headerSize uint64 + if err := binary.Read(bytes.NewReader(packedBytes), binary.LittleEndian, &headerSize); err != nil { + t.Fatalf("failed to read header size: %v", err) + } + + headerJSON := packedBytes[8 : 8+headerSize] + var header map[string]tensorInfo + if err := json.Unmarshal(headerJSON, &header); err != nil { + t.Fatalf("failed to parse header: %v", err) + } + + if len(header) != 2 { + t.Errorf("header has %d entries, want 2", len(header)) + } + + infoA, ok := header["a"] + if !ok { + t.Fatal("tensor 'a' not found in header") + } + if infoA.Dtype != "U8" { + t.Errorf("tensor 'a' dtype = %q, want %q", infoA.Dtype, "U8") + } + + infoB, ok := header["b"] + if !ok { + t.Fatal("tensor 'b' not found in header") + } + + // Verify data region contains both tensors + dataStart := 8 + int(headerSize) + dataRegion := packedBytes[dataStart:] + if infoA.DataOffsets[0] == 0 { + // a comes first + if !bytes.Equal(dataRegion[:4], data1) { + t.Error("tensor 'a' data mismatch") + } + if !bytes.Equal(dataRegion[infoB.DataOffsets[0]:infoB.DataOffsets[1]], data2) { + t.Error("tensor 'b' data mismatch") + } + } else { + // b comes first + if !bytes.Equal(dataRegion[:8], data2) { + t.Error("tensor 'b' data mismatch") + } + } +} + +func TestExtractAll(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + createTestSafetensors(t, path, map[string]struct { + dtype string + shape []int32 + data []byte + }{ + "alpha": {dtype: "F32", shape: []int32{2}, data: make([]byte, 8)}, + "beta": {dtype: "F16", shape: []int32{4}, data: make([]byte, 8)}, + }) + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + tensors, err := ext.ExtractAll() + if err != nil { + t.Fatalf("ExtractAll failed: %v", err) + } + + if len(tensors) != 2 { + t.Errorf("ExtractAll returned %d tensors, want 2", len(tensors)) + } + + // Verify sorted order + if tensors[0].Name != "alpha" || tensors[1].Name != "beta" { + t.Errorf("tensors not in sorted order: %s, %s", tensors[0].Name, tensors[1].Name) + } +} + +func TestExtractRawFromSafetensors_InvalidInput(t *testing.T) { + // Empty reader + _, err := ExtractRawFromSafetensors(bytes.NewReader(nil)) + if err == nil { + t.Error("expected error for empty reader") + } + + // Truncated header size + _, err = ExtractRawFromSafetensors(bytes.NewReader([]byte{1, 2, 3})) + if err == nil { + t.Error("expected error for truncated header size") + } +} + +func TestOpenForExtraction_MetadataIgnored(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + // Manually create a safetensors file with __metadata__ + header := map[string]any{ + "__metadata__": map[string]string{"format": "pt"}, + "weight": tensorInfo{ + Dtype: "F32", + Shape: []int32{2}, + DataOffsets: [2]int{0, 8}, + }, + } + headerJSON, _ := json.Marshal(header) + padding := (8 - len(headerJSON)%8) % 8 + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + f, _ := os.Create(path) + binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))) + f.Write(headerJSON) + f.Write(make([]byte, 8)) + f.Close() + + ext, err := OpenForExtraction(path) + if err != nil { + t.Fatalf("OpenForExtraction failed: %v", err) + } + defer ext.Close() + + // __metadata__ should be stripped + if ext.TensorCount() != 1 { + t.Errorf("TensorCount() = %d, want 1 (metadata should be stripped)", ext.TensorCount()) + } +}