create: Clean up experimental paths, fix create from existing safetensor model (#14679)

* create:  Clean up experimental paths

This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions.

* create: preserve config and layer names when creating from safetensors models

When creating a model FROM an existing safetensors model, ModelFormat,
Capabilities, and layer Name fields were lost. ModelFormat stayed empty
because it's only set from GGML layers (which safetensors models lack),
and layer names weren't copied in parseFromModel. This caused derived
models to fail loading ("config.json not found in manifest").

* review comments
This commit is contained in:
Daniel Hiltgen
2026-04-07 08:12:57 -07:00
committed by GitHub
parent e823bff873
commit 30fdd229a4
15 changed files with 1292 additions and 22 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ __debug_bin*
llama/build llama/build
llama/vendor llama/vendor
/ollama /ollama
integration/testdata/models/

View File

@@ -54,7 +54,6 @@ import (
"github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd" xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client" xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
) )
@@ -164,11 +163,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
// Check for --experimental flag for safetensors model creation // Check for --experimental flag for safetensors model creation
// This gates both safetensors LLM and imagegen model creation
experimental, _ := cmd.Flags().GetBool("experimental") experimental, _ := cmd.Flags().GetBool("experimental")
if experimental { if experimental {
if !isLocalhost() { if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported") return errors.New("remote safetensor model creation not yet supported")
} }
// Get Modelfile content - either from -f flag or default to "FROM ." // Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
@@ -211,23 +212,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}, p) }, p)
} }
// Standard Modelfile + API path
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
if os.IsNotExist(err) { if os.IsNotExist(err) {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n") reader = strings.NewReader("FROM .\n")
} else { } else {
return errModelfileNotFound return errModelfileNotFound

View File

@@ -0,0 +1,107 @@
//go:build integration && imagegen
package integration
import (
"context"
"encoding/base64"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestCreateImageGen(t *testing.T) {
skipIfRemote(t)
skipUnderMinVRAM(t, 13)
// Allow overriding the model directory via env var for local testing,
// since the model is ~33GB and may already be downloaded elsewhere.
modelDir := os.Getenv("OLLAMA_TEST_IMAGEGEN_MODEL_DIR")
if modelDir == "" {
modelDir = filepath.Join(testdataModelsDir, "Z-Image-Turbo")
downloadHFModel(t, "Tongyi-MAI/Z-Image-Turbo", modelDir)
} else {
t.Logf("Using existing imagegen model at %s", modelDir)
}
// Verify it looks like a valid imagegen model directory
if _, err := os.Stat(filepath.Join(modelDir, "model_index.json")); err != nil {
t.Fatalf("model_index.json not found in %s — not a valid imagegen model directory", modelDir)
}
ensureMLXLibraryPath(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "test-z-image-turbo-create"
absModelDir, err := filepath.Abs(modelDir)
if err != nil {
t.Fatalf("Failed to get absolute path: %v", err)
}
// Create a Modelfile pointing to the diffusers model directory
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
if err := os.WriteFile(tmpModelfile, []byte("FROM "+absModelDir+"\n"), 0o644); err != nil {
t.Fatalf("Failed to write Modelfile: %v", err)
}
t.Logf("Creating imagegen model from %s", absModelDir)
runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile)
// Verify model exists via show
showReq := &api.ShowRequest{Name: modelName}
showResp, err := client.Show(ctx, showReq)
if err != nil {
t.Fatalf("Model show failed after create: %v", err)
}
t.Logf("Created model details: %+v", showResp.Details)
// Generate an image to verify the model isn't corrupted
t.Log("Generating test image...")
imageBase64, err := generateImage(ctx, client, modelName, "A red circle on a white background")
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
t.Skip("unsupported architecture")
}
t.Fatalf("Image generation failed: %v", err)
}
// Verify we got valid image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("Failed to decode base64 image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
if len(imageData) < 1000 {
t.Fatalf("Generated image suspiciously small (%d bytes), likely corrupted", len(imageData))
}
// Check for PNG or JPEG magic bytes
isPNG := len(imageData) >= 4 && imageData[0] == 0x89 && imageData[1] == 'P' && imageData[2] == 'N' && imageData[3] == 'G'
isJPEG := len(imageData) >= 2 && imageData[0] == 0xFF && imageData[1] == 0xD8
if !isPNG && !isJPEG {
t.Fatalf("Generated image is neither PNG nor JPEG (first bytes: %x)", imageData[:min(8, len(imageData))])
}
t.Logf("Image format validated (PNG=%v, JPEG=%v)", isPNG, isJPEG)
// Cleanup: delete the model
deleteReq := &api.DeleteRequest{Model: modelName}
if err := client.Delete(ctx, deleteReq); err != nil {
t.Logf("Warning: failed to delete test model: %v", err)
}
}

350
integration/create_test.go Normal file
View File

@@ -0,0 +1,350 @@
//go:build integration
package integration
import (
"context"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
const testdataModelsDir = "testdata/models"
// skipIfRemote skips the test if OLLAMA_HOST points to a non-local server.
// Safetensors/imagegen creation requires localhost since it reads model files
// from disk and uses the --experimental CLI path.
func skipIfRemote(t *testing.T) {
t.Helper()
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return // default is localhost
}
// Strip scheme if present
_, hostport, ok := strings.Cut(host, "://")
if !ok {
hostport = host
}
h, _, err := net.SplitHostPort(hostport)
if err != nil {
h = hostport
}
if h == "" || h == "localhost" {
return
}
ip := net.ParseIP(h)
if ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) {
return
}
t.Skipf("safetensors/imagegen creation requires a local server (OLLAMA_HOST=%s)", host)
}
// findHFCLI returns the path to the HuggingFace CLI, or "" if not found.
func findHFCLI() string {
for _, name := range []string{"huggingface-cli", "hf"} {
if p, err := exec.LookPath(name); err == nil {
return p
}
}
return ""
}
// downloadHFModel idempotently downloads a HuggingFace model to destDir.
// Skips the test if CLI is missing and model isn't already present.
func downloadHFModel(t *testing.T, repo, destDir string, extraArgs ...string) {
t.Helper()
// Check if model already exists
if _, err := os.Stat(destDir); err == nil {
entries, err := os.ReadDir(destDir)
if err == nil && len(entries) > 0 {
t.Logf("Model %s already present at %s", repo, destDir)
return
}
}
cli := findHFCLI()
if cli == "" {
t.Skipf("HuggingFace CLI not found and model %s not present at %s", repo, destDir)
}
t.Logf("Downloading %s to %s", repo, destDir)
os.MkdirAll(destDir, 0o755)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
args := []string{"download", repo, "--local-dir", destDir}
args = append(args, extraArgs...)
cmd := exec.CommandContext(ctx, cli, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
t.Fatalf("Failed to download %s: %v", repo, err)
}
}
// ollamaBin returns the path to the ollama binary to use for tests.
// Prefers OLLAMA_BIN env, then falls back to the built binary at ../ollama
// (same binary the integration test server uses).
func ollamaBin() string {
if bin := os.Getenv("OLLAMA_BIN"); bin != "" {
return bin
}
if abs, err := filepath.Abs("../ollama"); err == nil {
if _, err := os.Stat(abs); err == nil {
return abs
}
}
return "ollama"
}
// ensureMLXLibraryPath sets OLLAMA_LIBRARY_PATH so the MLX dynamic library
// is discoverable. Integration tests run from integration/ dir, so the
// default CWD-based search won't find the library at the repo root.
func ensureMLXLibraryPath(t *testing.T) {
t.Helper()
if libPath, err := filepath.Abs("../build/lib/ollama"); err == nil {
if _, err := os.Stat(libPath); err == nil {
if existing := os.Getenv("OLLAMA_LIBRARY_PATH"); existing != "" {
t.Setenv("OLLAMA_LIBRARY_PATH", existing+string(filepath.ListSeparator)+libPath)
} else {
t.Setenv("OLLAMA_LIBRARY_PATH", libPath)
}
}
}
}
// runOllamaCreate runs "ollama create" as a subprocess. Skips the test if
// the error indicates the server is remote.
func runOllamaCreate(ctx context.Context, t *testing.T, args ...string) {
t.Helper()
createCmd := exec.CommandContext(ctx, ollamaBin(), append([]string{"create"}, args...)...)
var createStderr strings.Builder
createCmd.Stdout = os.Stdout
createCmd.Stderr = io.MultiWriter(os.Stderr, &createStderr)
if err := createCmd.Run(); err != nil {
if strings.Contains(createStderr.String(), "remote") {
t.Skip("safetensors creation requires a local server")
}
t.Fatalf("ollama create failed: %v", err)
}
}
func TestCreateSafetensorsLLM(t *testing.T) {
skipIfRemote(t)
modelDir := filepath.Join(testdataModelsDir, "TinyLlama-1.1B")
downloadHFModel(t, "TinyLlama/TinyLlama-1.1B-Chat-v1.0", modelDir)
ensureMLXLibraryPath(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "test-tinyllama-safetensors"
absModelDir, err := filepath.Abs(modelDir)
if err != nil {
t.Fatalf("Failed to get absolute path: %v", err)
}
// Create a Modelfile pointing to the model directory.
// Include a chat template since the safetensors importer doesn't extract
// chat_template from tokenizer_config.json yet.
modelfileContent := "FROM " + absModelDir + "\n" +
"TEMPLATE \"{{ if .System }}<|system|>\n{{ .System }}</s>\n{{ end }}" +
"{{ if .Prompt }}<|user|>\n{{ .Prompt }}</s>\n{{ end }}" +
"<|assistant|>\n{{ .Response }}</s>\n\"\n"
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
if err := os.WriteFile(tmpModelfile, []byte(modelfileContent), 0o644); err != nil {
t.Fatalf("Failed to write Modelfile: %v", err)
}
runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile)
// Verify model exists via show
showReq := &api.ShowRequest{Name: modelName}
showResp, err := client.Show(ctx, showReq)
if err != nil {
t.Fatalf("Model show failed after create: %v", err)
}
t.Logf("Created model details: %+v", showResp.Details)
// Use the chat API for proper template application.
chatReq := &api.ChatRequest{
Model: modelName,
Messages: []api.Message{
{Role: "user", Content: "Write a short sentence about the weather."},
},
Options: map[string]interface{}{
"num_predict": 20,
"temperature": 0.0,
},
}
var output strings.Builder
err = client.Chat(ctx, chatReq, func(resp api.ChatResponse) error {
output.WriteString(resp.Message.Content)
return nil
})
if err != nil {
t.Fatalf("Chat failed: %v", err)
}
text := output.String()
t.Logf("Generated output: %q", text)
assertCoherentOutput(t, text)
// Cleanup: delete the model
deleteReq := &api.DeleteRequest{Model: modelName}
if err := client.Delete(ctx, deleteReq); err != nil {
t.Logf("Warning: failed to delete test model: %v", err)
}
}
func TestCreateGGUF(t *testing.T) {
modelDir := filepath.Join(testdataModelsDir, "Llama-3.2-1B-GGUF")
downloadHFModel(t, "bartowski/Llama-3.2-1B-Instruct-GGUF", modelDir,
"--include", "Llama-3.2-1B-Instruct-IQ3_M.gguf")
// Find the GGUF file
entries, err := os.ReadDir(modelDir)
if err != nil {
t.Fatalf("Failed to read model dir: %v", err)
}
var ggufPath string
for _, e := range entries {
if filepath.Ext(e.Name()) == ".gguf" {
ggufPath = filepath.Join(modelDir, e.Name())
break
}
}
if ggufPath == "" {
t.Skip("No GGUF file found in model directory")
}
absGGUF, err := filepath.Abs(ggufPath)
if err != nil {
t.Fatalf("Failed to get absolute path: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "test-llama32-gguf"
// Create a Modelfile and use the CLI
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
if err := os.WriteFile(tmpModelfile, []byte("FROM "+absGGUF+"\n"), 0o644); err != nil {
t.Fatalf("Failed to write Modelfile: %v", err)
}
createCmd := exec.CommandContext(ctx, ollamaBin(), "create", modelName, "-f", tmpModelfile)
createCmd.Stdout = os.Stdout
createCmd.Stderr = os.Stderr
if err := createCmd.Run(); err != nil {
t.Fatalf("ollama create failed: %v", err)
}
// Verify model exists
showReq := &api.ShowRequest{Name: modelName}
_, err = client.Show(ctx, showReq)
if err != nil {
t.Fatalf("Model show failed after create: %v", err)
}
// Generate and verify output is coherent
genReq := &api.GenerateRequest{
Model: modelName,
Prompt: "Write a short sentence about the weather.",
Options: map[string]interface{}{
"num_predict": 20,
"temperature": 0.0,
},
}
var output strings.Builder
err = client.Generate(ctx, genReq, func(resp api.GenerateResponse) error {
output.WriteString(resp.Response)
return nil
})
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
text := output.String()
t.Logf("Generated output: %q", text)
assertCoherentOutput(t, text)
// Cleanup
deleteReq := &api.DeleteRequest{Model: modelName}
if err := client.Delete(ctx, deleteReq); err != nil {
t.Logf("Warning: failed to delete test model: %v", err)
}
}
// assertCoherentOutput checks that model output looks like real language, not
// garbled binary or repeated garbage. This catches corrupted model creation
// where inference "works" but produces nonsense.
func assertCoherentOutput(t *testing.T, text string) {
t.Helper()
if len(text) == 0 {
t.Fatal("model produced empty output")
}
// Check minimum length — 20 tokens should produce at least a few words
if len(text) < 5 {
t.Fatalf("model output suspiciously short (%d bytes): %q", len(text), text)
}
// Check for mostly-printable ASCII/Unicode — garbled models often emit
// high ratios of control characters or replacement characters
unprintable := 0
for _, r := range text {
if r < 0x20 && r != '\n' && r != '\r' && r != '\t' {
unprintable++
}
if r == '\ufffd' { // Unicode replacement character
unprintable++
}
}
ratio := float64(unprintable) / float64(len([]rune(text)))
if ratio > 0.3 {
t.Fatalf("model output is %.0f%% unprintable characters (likely garbled): %q", ratio*100, text)
}
// Check it contains at least one space — real language has word boundaries
if !strings.Contains(text, " ") {
t.Fatalf("model output contains no spaces (likely garbled): %q", text)
}
// Check for excessive repetition — a broken model might repeat one token
words := strings.Fields(text)
if len(words) >= 4 {
counts := map[string]int{}
for _, w := range words {
counts[strings.ToLower(w)]++
}
for w, c := range counts {
if c > len(words)*3/4 {
t.Fatalf("model output is excessively repetitive (%q appears %d/%d times): %q", w, c, len(words), text)
}
}
}
}

View File

@@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) { if err == nil && !remote {
mf, mErr := manifest.ParseNamedManifest(fromName) mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" { if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest) configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
@@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
if config.Requires == "" { if config.Requires == "" {
config.Requires = baseConfig.Requires config.Requires = baseConfig.Requires
} }
if config.ModelFormat == "" {
config.ModelFormat = baseConfig.ModelFormat
}
if len(config.Capabilities) == 0 { if len(config.Capabilities) == 0 {
config.Capabilities = baseConfig.Capabilities config.Capabilities = baseConfig.Capabilities
} }

View File

@@ -41,11 +41,12 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err return nil, err
} }
for _, layer := range m.Layers { for _, srcLayer := range m.Layers {
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest())
if err != nil { if err != nil {
return nil, err return nil, err
} }
layer.Name = srcLayer.Name
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model", case "application/vnd.ollama.image.model",

View File

@@ -1024,3 +1024,272 @@ func TestDetectModelTypeFromFiles(t *testing.T) {
} }
}) })
} }
// createTestBlob creates a blob in the blobs directory and returns its digest.
func createTestBlob(t *testing.T, data []byte) string {
t.Helper()
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(blobPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(blobPath, data, 0o644); err != nil {
t.Fatal(err)
}
return digest
}
// createSafetensorsTestModel creates a minimal safetensors model manifest for testing.
func createSafetensorsTestModel(t *testing.T, modelName string, config model.ConfigV2, extraLayers []manifest.Layer) {
t.Helper()
// Create a fake tensor blob
tensorData := []byte("fake-tensor-data-for-testing")
tensorDigest := createTestBlob(t, tensorData)
layers := []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: tensorDigest,
Size: int64(len(tensorData)),
Name: "model.embed_tokens.weight",
},
}
layers = append(layers, extraLayers...)
configLayer, err := createConfigLayer(layers, config)
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName(modelName)
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
}
func TestCreateFromSafetensorsModel_PreservesConfig(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
// Create a source safetensors model with specific config fields
createSafetensorsTestModel(t, "source-model", model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"completion"},
Requires: "0.14.0",
Renderer: "gemma3",
Parser: "gemma3",
}, nil)
// Create a derived model FROM the source
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "derived-model",
From: "source-model",
System: "You are a pirate.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Read the derived model's config
derivedName := model.ParseName("derived-model")
mf, err := manifest.ParseNamedManifest(derivedName)
if err != nil {
t.Fatalf("failed to parse derived manifest: %v", err)
}
configBlobPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("failed to get config blob path: %v", err)
}
configBlob, err := os.ReadFile(configBlobPath)
if err != nil {
t.Fatalf("failed to read config blob: %v", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(configBlob, &cfg); err != nil {
t.Fatalf("failed to unmarshal config: %v", err)
}
// Verify safetensors-specific config fields are preserved
if cfg.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors")
}
if !slices.Contains(cfg.Capabilities, "completion") {
t.Errorf("Capabilities = %v, want to contain %q", cfg.Capabilities, "completion")
}
if cfg.Requires != "0.14.0" {
t.Errorf("Requires = %q, want %q", cfg.Requires, "0.14.0")
}
if cfg.Renderer != "gemma3" {
t.Errorf("Renderer = %q, want %q", cfg.Renderer, "gemma3")
}
if cfg.Parser != "gemma3" {
t.Errorf("Parser = %q, want %q", cfg.Parser, "gemma3")
}
// Verify system prompt was added
var hasSystem bool
for _, l := range mf.Layers {
if l.MediaType == "application/vnd.ollama.image.system" {
hasSystem = true
break
}
}
if !hasSystem {
t.Error("expected system prompt layer in derived model")
}
// Verify tensor layers were copied with names preserved
var tensorNames []string
for _, l := range mf.Layers {
if l.MediaType == manifest.MediaTypeImageTensor {
tensorNames = append(tensorNames, l.Name)
}
}
if len(tensorNames) == 0 {
t.Error("expected tensor layers in derived model")
}
for _, name := range tensorNames {
if name == "" {
t.Error("tensor layer has empty name — names must be preserved from source")
}
}
}
func TestCreateFromSafetensorsModel_OverrideSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
// Create source with a system prompt
createSafetensorsTestModel(t, "source-with-system", model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"completion"},
}, nil)
// First create with system prompt
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "source-with-system",
From: "source-with-system",
System: "Original system prompt",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Now create a derived model with a different system prompt
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "derived-new-system",
From: "source-with-system",
System: "New system prompt",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify ModelFormat is preserved even after override
derivedName := model.ParseName("derived-new-system")
mf, err := manifest.ParseNamedManifest(derivedName)
if err != nil {
t.Fatalf("failed to parse derived manifest: %v", err)
}
configBlobPath, _ := manifest.BlobsPath(mf.Config.Digest)
configBlob, _ := os.ReadFile(configBlobPath)
var cfg model.ConfigV2
json.Unmarshal(configBlob, &cfg)
if cfg.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors")
}
}
func TestCreateFromSafetensorsModel_PreservesLayerNames(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
// Create JSON config blobs to include as layers
configJSON := []byte(`{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`)
configDigest := createTestBlob(t, configJSON)
tokenizerJSON := []byte(`{"version": "1.0"}`)
tokenizerDigest := createTestBlob(t, tokenizerJSON)
extraLayers := []manifest.Layer{
{
MediaType: "application/vnd.ollama.image.json",
Digest: configDigest,
Size: int64(len(configJSON)),
Name: "config.json",
},
{
MediaType: "application/vnd.ollama.image.json",
Digest: tokenizerDigest,
Size: int64(len(tokenizerJSON)),
Name: "tokenizer.json",
},
}
createSafetensorsTestModel(t, "source-named-layers", model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"completion"},
}, extraLayers)
// Create derived model
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "derived-named-layers",
From: "source-named-layers",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
derivedName := model.ParseName("derived-named-layers")
mf, err := manifest.ParseNamedManifest(derivedName)
if err != nil {
t.Fatalf("failed to parse derived manifest: %v", err)
}
// Check tensor layer names are preserved
for _, l := range mf.Layers {
if l.MediaType == manifest.MediaTypeImageTensor && l.Name == "" {
t.Error("tensor layer has empty name — names must be preserved from source")
}
}
// Check JSON layer names are preserved
jsonNames := make(map[string]bool)
for _, l := range mf.Layers {
if l.MediaType == "application/vnd.ollama.image.json" && l.Name != "" {
jsonNames[l.Name] = true
}
}
if !jsonNames["config.json"] {
t.Error("config.json layer name not preserved in derived model")
}
if !jsonNames["tokenizer.json"] {
t.Error("tokenizer.json layer name not preserved in derived model")
}
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/safetensors" "github.com/ollama/ollama/x/safetensors"
) )
// MinOllamaVersion is the minimum Ollama version required for safetensors models. // MinOllamaVersion is the minimum Ollama version required for safetensors models.

View File

@@ -429,3 +429,159 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8") t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8")
} }
} }
func TestSupportsThinking(t *testing.T) {
tests := []struct {
name string
configJSON string
want bool
}{
{
name: "qwen3 architecture",
configJSON: `{"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"}`,
want: true,
},
{
name: "deepseek architecture",
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
want: true,
},
{
name: "glm4moe architecture",
configJSON: `{"architectures": ["GLM4MoeForCausalLM"]}`,
want: true,
},
{
name: "llama architecture (no thinking)",
configJSON: `{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`,
want: false,
},
{
name: "gemma architecture (no thinking)",
configJSON: `{"architectures": ["Gemma3ForCausalLM"], "model_type": "gemma3"}`,
want: false,
},
{
name: "model_type only",
configJSON: `{"model_type": "deepseek"}`,
want: true,
},
{
name: "empty config",
configJSON: `{}`,
want: false,
},
{
name: "invalid json",
configJSON: `not json`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := supportsThinking(dir); got != tt.want {
t.Errorf("supportsThinking() = %v, want %v", got, tt.want)
}
})
}
}
func TestSupportsThinking_NoConfig(t *testing.T) {
if supportsThinking(t.TempDir()) {
t.Error("supportsThinking should return false for missing config.json")
}
}
func TestGetParserName(t *testing.T) {
tests := []struct {
name string
configJSON string
want string
}{
{
name: "qwen3 model",
configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`,
want: "qwen3",
},
{
name: "deepseek model",
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
want: "deepseek3",
},
{
name: "glm4 model",
configJSON: `{"architectures": ["GLM4ForCausalLM"]}`,
want: "glm-4.7",
},
{
name: "llama model (no parser)",
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
want: "",
},
{
name: "qwen3 via model_type",
configJSON: `{"model_type": "qwen3"}`,
want: "qwen3",
},
{
name: "no config",
configJSON: `{}`,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := getParserName(dir); got != tt.want {
t.Errorf("getParserName() = %q, want %q", got, tt.want)
}
})
}
}
func TestGetRendererName(t *testing.T) {
tests := []struct {
name string
configJSON string
want string
}{
{
name: "qwen3 model",
configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`,
want: "qwen3-coder",
},
{
name: "deepseek model",
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
want: "deepseek3",
},
{
name: "glm4 model",
configJSON: `{"architectures": ["GLM4ForCausalLM"]}`,
want: "glm-4.7",
},
{
name: "llama model (no renderer)",
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := getRendererName(dir); got != tt.want {
t.Errorf("getRendererName() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -13,7 +13,7 @@ import (
"strings" "strings"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors" "github.com/ollama/ollama/x/safetensors"
) )
// ModelConfig represents the config blob stored with a model. // ModelConfig represents the config blob stored with a model.

View File

@@ -12,7 +12,7 @@ import (
"testing" "testing"
"github.com/d4l3k/go-bfloat16" "github.com/d4l3k/go-bfloat16"
st "github.com/ollama/ollama/x/imagegen/safetensors" st "github.com/ollama/ollama/x/safetensors"
) )
func TestIsTensorModelDir(t *testing.T) { func TestIsTensorModelDir(t *testing.T) {

View File

@@ -9,7 +9,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/ollama/ollama/x/imagegen/safetensors" "github.com/ollama/ollama/x/safetensors"
) )
// CreateImageGenModel imports an image generation model from a directory. // CreateImageGenModel imports an image generation model from a directory.

View File

@@ -7,7 +7,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/ollama/ollama/x/imagegen/safetensors" "github.com/ollama/ollama/x/safetensors"
) )
type qwen35ImportTransform struct { type qwen35ImportTransform struct {

View File

@@ -11,7 +11,6 @@ import (
) )
// tensorInfo holds tensor metadata from safetensors headers. // tensorInfo holds tensor metadata from safetensors headers.
// This avoids depending on safetensors.go which requires the mlx tag.
type tensorInfo struct { type tensorInfo struct {
Dtype string `json:"dtype"` Dtype string `json:"dtype"`
Shape []int32 `json:"shape"` Shape []int32 `json:"shape"`

View File

@@ -0,0 +1,394 @@
package safetensors
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"testing"
)
// createTestSafetensors creates a minimal valid safetensors file with the given tensors.
func createTestSafetensors(t *testing.T, path string, tensors map[string]struct {
dtype string
shape []int32
data []byte
},
) {
t.Helper()
header := make(map[string]tensorInfo)
var offset int
var allData []byte
// Sort names for deterministic file layout
names := make([]string, 0, len(tensors))
for name := range tensors {
names = append(names, name)
}
slices.Sort(names)
for _, name := range names {
info := tensors[name]
header[name] = tensorInfo{
Dtype: info.dtype,
Shape: info.shape,
DataOffsets: [2]int{offset, offset + len(info.data)},
}
allData = append(allData, info.data...)
offset += len(info.data)
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
if _, err := f.Write(allData); err != nil {
t.Fatalf("failed to write data: %v", err)
}
}
func TestOpenForExtraction(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
// 4 float32 values = 16 bytes
data := make([]byte, 16)
binary.LittleEndian.PutUint32(data[0:4], 0x3f800000) // 1.0
binary.LittleEndian.PutUint32(data[4:8], 0x40000000) // 2.0
binary.LittleEndian.PutUint32(data[8:12], 0x40400000) // 3.0
binary.LittleEndian.PutUint32(data[12:16], 0x40800000) // 4.0
createTestSafetensors(t, path, map[string]struct {
dtype string
shape []int32
data []byte
}{
"test_tensor": {dtype: "F32", shape: []int32{2, 2}, data: data},
})
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
if ext.TensorCount() != 1 {
t.Errorf("TensorCount() = %d, want 1", ext.TensorCount())
}
names := ext.ListTensors()
if len(names) != 1 || names[0] != "test_tensor" {
t.Errorf("ListTensors() = %v, want [test_tensor]", names)
}
}
func TestGetTensor(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
data := make([]byte, 16)
for i := range 4 {
binary.LittleEndian.PutUint32(data[i*4:], uint32(i+1))
}
createTestSafetensors(t, path, map[string]struct {
dtype string
shape []int32
data []byte
}{
"weight": {dtype: "F32", shape: []int32{2, 2}, data: data},
})
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
td, err := ext.GetTensor("weight")
if err != nil {
t.Fatalf("GetTensor failed: %v", err)
}
if td.Name != "weight" {
t.Errorf("Name = %q, want %q", td.Name, "weight")
}
if td.Dtype != "F32" {
t.Errorf("Dtype = %q, want %q", td.Dtype, "F32")
}
if td.Size != 16 {
t.Errorf("Size = %d, want 16", td.Size)
}
if len(td.Shape) != 2 || td.Shape[0] != 2 || td.Shape[1] != 2 {
t.Errorf("Shape = %v, want [2 2]", td.Shape)
}
// Read the raw data
rawData, err := io.ReadAll(td.Reader())
if err != nil {
t.Fatalf("Reader() read failed: %v", err)
}
if len(rawData) != 16 {
t.Errorf("raw data length = %d, want 16", len(rawData))
}
}
func TestGetTensor_NotFound(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
createTestSafetensors(t, path, map[string]struct {
dtype string
shape []int32
data []byte
}{
"exists": {dtype: "F32", shape: []int32{1}, data: make([]byte, 4)},
})
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
_, err = ext.GetTensor("missing")
if err == nil {
t.Error("expected error for missing tensor, got nil")
}
}
func TestSafetensorsReaderRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
data := make([]byte, 16)
for i := range 4 {
binary.LittleEndian.PutUint32(data[i*4:], uint32(0x3f800000+i))
}
createTestSafetensors(t, path, map[string]struct {
dtype string
shape []int32
data []byte
}{
"tensor_a": {dtype: "F32", shape: []int32{2, 2}, data: data},
})
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
td, err := ext.GetTensor("tensor_a")
if err != nil {
t.Fatalf("GetTensor failed: %v", err)
}
// Wrap as safetensors and extract back
stReader := td.SafetensorsReader()
stData, err := io.ReadAll(stReader)
if err != nil {
t.Fatalf("SafetensorsReader read failed: %v", err)
}
// Verify size
if int64(len(stData)) != td.SafetensorsSize() {
t.Errorf("SafetensorsSize() = %d, actual = %d", td.SafetensorsSize(), len(stData))
}
// Extract raw data back
raw, err := ExtractRawFromSafetensors(bytes.NewReader(stData))
if err != nil {
t.Fatalf("ExtractRawFromSafetensors failed: %v", err)
}
if !bytes.Equal(raw, data) {
t.Errorf("round-trip data mismatch: got %v, want %v", raw, data)
}
}
func TestNewTensorDataFromBytes(t *testing.T) {
data := []byte{1, 2, 3, 4}
td := NewTensorDataFromBytes("test", "U8", []int32{4}, data)
if td.Name != "test" {
t.Errorf("Name = %q, want %q", td.Name, "test")
}
if td.Size != 4 {
t.Errorf("Size = %d, want 4", td.Size)
}
rawData, err := io.ReadAll(td.Reader())
if err != nil {
t.Fatalf("Reader() failed: %v", err)
}
if !bytes.Equal(rawData, data) {
t.Errorf("data mismatch: got %v, want %v", rawData, data)
}
}
func TestBuildPackedSafetensorsReader(t *testing.T) {
data1 := []byte{1, 2, 3, 4}
data2 := []byte{5, 6, 7, 8, 9, 10, 11, 12}
td1 := NewTensorDataFromBytes("a", "U8", []int32{4}, data1)
td2 := NewTensorDataFromBytes("b", "U8", []int32{8}, data2)
packed := BuildPackedSafetensorsReader([]*TensorData{td1, td2})
packedBytes, err := io.ReadAll(packed)
if err != nil {
t.Fatalf("BuildPackedSafetensorsReader read failed: %v", err)
}
// Verify it's a valid safetensors file by parsing the header
var headerSize uint64
if err := binary.Read(bytes.NewReader(packedBytes), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
headerJSON := packedBytes[8 : 8+headerSize]
var header map[string]tensorInfo
if err := json.Unmarshal(headerJSON, &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
if len(header) != 2 {
t.Errorf("header has %d entries, want 2", len(header))
}
infoA, ok := header["a"]
if !ok {
t.Fatal("tensor 'a' not found in header")
}
if infoA.Dtype != "U8" {
t.Errorf("tensor 'a' dtype = %q, want %q", infoA.Dtype, "U8")
}
infoB, ok := header["b"]
if !ok {
t.Fatal("tensor 'b' not found in header")
}
// Verify data region contains both tensors
dataStart := 8 + int(headerSize)
dataRegion := packedBytes[dataStart:]
if infoA.DataOffsets[0] == 0 {
// a comes first
if !bytes.Equal(dataRegion[:4], data1) {
t.Error("tensor 'a' data mismatch")
}
if !bytes.Equal(dataRegion[infoB.DataOffsets[0]:infoB.DataOffsets[1]], data2) {
t.Error("tensor 'b' data mismatch")
}
} else {
// b comes first
if !bytes.Equal(dataRegion[:8], data2) {
t.Error("tensor 'b' data mismatch")
}
}
}
func TestExtractAll(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
createTestSafetensors(t, path, map[string]struct {
dtype string
shape []int32
data []byte
}{
"alpha": {dtype: "F32", shape: []int32{2}, data: make([]byte, 8)},
"beta": {dtype: "F16", shape: []int32{4}, data: make([]byte, 8)},
})
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
tensors, err := ext.ExtractAll()
if err != nil {
t.Fatalf("ExtractAll failed: %v", err)
}
if len(tensors) != 2 {
t.Errorf("ExtractAll returned %d tensors, want 2", len(tensors))
}
// Verify sorted order
if tensors[0].Name != "alpha" || tensors[1].Name != "beta" {
t.Errorf("tensors not in sorted order: %s, %s", tensors[0].Name, tensors[1].Name)
}
}
func TestExtractRawFromSafetensors_InvalidInput(t *testing.T) {
// Empty reader
_, err := ExtractRawFromSafetensors(bytes.NewReader(nil))
if err == nil {
t.Error("expected error for empty reader")
}
// Truncated header size
_, err = ExtractRawFromSafetensors(bytes.NewReader([]byte{1, 2, 3}))
if err == nil {
t.Error("expected error for truncated header size")
}
}
func TestOpenForExtraction_MetadataIgnored(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
// Manually create a safetensors file with __metadata__
header := map[string]any{
"__metadata__": map[string]string{"format": "pt"},
"weight": tensorInfo{
Dtype: "F32",
Shape: []int32{2},
DataOffsets: [2]int{0, 8},
},
}
headerJSON, _ := json.Marshal(header)
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
f, _ := os.Create(path)
binary.Write(f, binary.LittleEndian, uint64(len(headerJSON)))
f.Write(headerJSON)
f.Write(make([]byte, 8))
f.Close()
ext, err := OpenForExtraction(path)
if err != nil {
t.Fatalf("OpenForExtraction failed: %v", err)
}
defer ext.Close()
// __metadata__ should be stripped
if ext.TensorCount() != 1 {
t.Errorf("TensorCount() = %d, want 1 (metadata should be stripped)", ext.TensorCount())
}
}