mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 18:54:15 +02:00
create: Clean up experimental paths, fix create from existing safetensor model (#14679)
* create: Clean up experimental paths
This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions.
* create: preserve config and layer names when creating from safetensors models
When creating a model FROM an existing safetensors model, ModelFormat,
Capabilities, and layer Name fields were lost. ModelFormat stayed empty
because it's only set from GGML layers (which safetensors models lack),
and layer names weren't copied in parseFromModel. This caused derived
models to fail loading ("config.json not found in manifest").
* review comments
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ __debug_bin*
|
||||
llama/build
|
||||
llama/vendor
|
||||
/ollama
|
||||
integration/testdata/models/
|
||||
|
||||
16
cmd/cmd.go
16
cmd/cmd.go
@@ -54,7 +54,6 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
@@ -164,11 +163,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
// This gates both safetensors LLM and imagegen model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
@@ -211,23 +212,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}, p)
|
||||
}
|
||||
|
||||
// Standard Modelfile + API path
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if create.IsTensorModelDir(".") {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
|
||||
107
integration/create_imagegen_test.go
Normal file
107
integration/create_imagegen_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build integration && imagegen
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestCreateImageGen(t *testing.T) {
|
||||
skipIfRemote(t)
|
||||
skipUnderMinVRAM(t, 13)
|
||||
|
||||
// Allow overriding the model directory via env var for local testing,
|
||||
// since the model is ~33GB and may already be downloaded elsewhere.
|
||||
modelDir := os.Getenv("OLLAMA_TEST_IMAGEGEN_MODEL_DIR")
|
||||
if modelDir == "" {
|
||||
modelDir = filepath.Join(testdataModelsDir, "Z-Image-Turbo")
|
||||
downloadHFModel(t, "Tongyi-MAI/Z-Image-Turbo", modelDir)
|
||||
} else {
|
||||
t.Logf("Using existing imagegen model at %s", modelDir)
|
||||
}
|
||||
|
||||
// Verify it looks like a valid imagegen model directory
|
||||
if _, err := os.Stat(filepath.Join(modelDir, "model_index.json")); err != nil {
|
||||
t.Fatalf("model_index.json not found in %s — not a valid imagegen model directory", modelDir)
|
||||
}
|
||||
|
||||
ensureMLXLibraryPath(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "test-z-image-turbo-create"
|
||||
|
||||
absModelDir, err := filepath.Abs(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
// Create a Modelfile pointing to the diffusers model directory
|
||||
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
|
||||
if err := os.WriteFile(tmpModelfile, []byte("FROM "+absModelDir+"\n"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write Modelfile: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Creating imagegen model from %s", absModelDir)
|
||||
runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile)
|
||||
|
||||
// Verify model exists via show
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
showResp, err := client.Show(ctx, showReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Model show failed after create: %v", err)
|
||||
}
|
||||
t.Logf("Created model details: %+v", showResp.Details)
|
||||
|
||||
// Generate an image to verify the model isn't corrupted
|
||||
t.Log("Generating test image...")
|
||||
imageBase64, err := generateImage(ctx, client, modelName, "A red circle on a white background")
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "image generation not available") {
|
||||
t.Skip("Target system does not support image generation")
|
||||
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
|
||||
t.Skip("insufficient memory for image generation")
|
||||
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
|
||||
t.Skip("unsupported architecture")
|
||||
}
|
||||
t.Fatalf("Image generation failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify we got valid image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode base64 image: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated image: %d bytes", len(imageData))
|
||||
|
||||
if len(imageData) < 1000 {
|
||||
t.Fatalf("Generated image suspiciously small (%d bytes), likely corrupted", len(imageData))
|
||||
}
|
||||
|
||||
// Check for PNG or JPEG magic bytes
|
||||
isPNG := len(imageData) >= 4 && imageData[0] == 0x89 && imageData[1] == 'P' && imageData[2] == 'N' && imageData[3] == 'G'
|
||||
isJPEG := len(imageData) >= 2 && imageData[0] == 0xFF && imageData[1] == 0xD8
|
||||
if !isPNG && !isJPEG {
|
||||
t.Fatalf("Generated image is neither PNG nor JPEG (first bytes: %x)", imageData[:min(8, len(imageData))])
|
||||
}
|
||||
t.Logf("Image format validated (PNG=%v, JPEG=%v)", isPNG, isJPEG)
|
||||
|
||||
// Cleanup: delete the model
|
||||
deleteReq := &api.DeleteRequest{Model: modelName}
|
||||
if err := client.Delete(ctx, deleteReq); err != nil {
|
||||
t.Logf("Warning: failed to delete test model: %v", err)
|
||||
}
|
||||
}
|
||||
350
integration/create_test.go
Normal file
350
integration/create_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const testdataModelsDir = "testdata/models"
|
||||
|
||||
// skipIfRemote skips the test if OLLAMA_HOST points to a non-local server.
|
||||
// Safetensors/imagegen creation requires localhost since it reads model files
|
||||
// from disk and uses the --experimental CLI path.
|
||||
func skipIfRemote(t *testing.T) {
|
||||
t.Helper()
|
||||
host := os.Getenv("OLLAMA_HOST")
|
||||
if host == "" {
|
||||
return // default is localhost
|
||||
}
|
||||
// Strip scheme if present
|
||||
_, hostport, ok := strings.Cut(host, "://")
|
||||
if !ok {
|
||||
hostport = host
|
||||
}
|
||||
h, _, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
h = hostport
|
||||
}
|
||||
if h == "" || h == "localhost" {
|
||||
return
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) {
|
||||
return
|
||||
}
|
||||
t.Skipf("safetensors/imagegen creation requires a local server (OLLAMA_HOST=%s)", host)
|
||||
}
|
||||
|
||||
// findHFCLI returns the path to the HuggingFace CLI, or "" if not found.
|
||||
func findHFCLI() string {
|
||||
for _, name := range []string{"huggingface-cli", "hf"} {
|
||||
if p, err := exec.LookPath(name); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// downloadHFModel idempotently downloads a HuggingFace model to destDir.
|
||||
// Skips the test if CLI is missing and model isn't already present.
|
||||
func downloadHFModel(t *testing.T, repo, destDir string, extraArgs ...string) {
|
||||
t.Helper()
|
||||
|
||||
// Check if model already exists
|
||||
if _, err := os.Stat(destDir); err == nil {
|
||||
entries, err := os.ReadDir(destDir)
|
||||
if err == nil && len(entries) > 0 {
|
||||
t.Logf("Model %s already present at %s", repo, destDir)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cli := findHFCLI()
|
||||
if cli == "" {
|
||||
t.Skipf("HuggingFace CLI not found and model %s not present at %s", repo, destDir)
|
||||
}
|
||||
|
||||
t.Logf("Downloading %s to %s", repo, destDir)
|
||||
os.MkdirAll(destDir, 0o755)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
args := []string{"download", repo, "--local-dir", destDir}
|
||||
args = append(args, extraArgs...)
|
||||
cmd := exec.CommandContext(ctx, cli, args...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("Failed to download %s: %v", repo, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ollamaBin returns the path to the ollama binary to use for tests.
|
||||
// Prefers OLLAMA_BIN env, then falls back to the built binary at ../ollama
|
||||
// (same binary the integration test server uses).
|
||||
func ollamaBin() string {
|
||||
if bin := os.Getenv("OLLAMA_BIN"); bin != "" {
|
||||
return bin
|
||||
}
|
||||
if abs, err := filepath.Abs("../ollama"); err == nil {
|
||||
if _, err := os.Stat(abs); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "ollama"
|
||||
}
|
||||
|
||||
// ensureMLXLibraryPath sets OLLAMA_LIBRARY_PATH so the MLX dynamic library
|
||||
// is discoverable. Integration tests run from integration/ dir, so the
|
||||
// default CWD-based search won't find the library at the repo root.
|
||||
func ensureMLXLibraryPath(t *testing.T) {
|
||||
t.Helper()
|
||||
if libPath, err := filepath.Abs("../build/lib/ollama"); err == nil {
|
||||
if _, err := os.Stat(libPath); err == nil {
|
||||
if existing := os.Getenv("OLLAMA_LIBRARY_PATH"); existing != "" {
|
||||
t.Setenv("OLLAMA_LIBRARY_PATH", existing+string(filepath.ListSeparator)+libPath)
|
||||
} else {
|
||||
t.Setenv("OLLAMA_LIBRARY_PATH", libPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runOllamaCreate runs "ollama create" as a subprocess. Skips the test if
|
||||
// the error indicates the server is remote.
|
||||
func runOllamaCreate(ctx context.Context, t *testing.T, args ...string) {
|
||||
t.Helper()
|
||||
createCmd := exec.CommandContext(ctx, ollamaBin(), append([]string{"create"}, args...)...)
|
||||
var createStderr strings.Builder
|
||||
createCmd.Stdout = os.Stdout
|
||||
createCmd.Stderr = io.MultiWriter(os.Stderr, &createStderr)
|
||||
if err := createCmd.Run(); err != nil {
|
||||
if strings.Contains(createStderr.String(), "remote") {
|
||||
t.Skip("safetensors creation requires a local server")
|
||||
}
|
||||
t.Fatalf("ollama create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsLLM(t *testing.T) {
|
||||
skipIfRemote(t)
|
||||
|
||||
modelDir := filepath.Join(testdataModelsDir, "TinyLlama-1.1B")
|
||||
downloadHFModel(t, "TinyLlama/TinyLlama-1.1B-Chat-v1.0", modelDir)
|
||||
|
||||
ensureMLXLibraryPath(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "test-tinyllama-safetensors"
|
||||
|
||||
absModelDir, err := filepath.Abs(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
// Create a Modelfile pointing to the model directory.
|
||||
// Include a chat template since the safetensors importer doesn't extract
|
||||
// chat_template from tokenizer_config.json yet.
|
||||
modelfileContent := "FROM " + absModelDir + "\n" +
|
||||
"TEMPLATE \"{{ if .System }}<|system|>\n{{ .System }}</s>\n{{ end }}" +
|
||||
"{{ if .Prompt }}<|user|>\n{{ .Prompt }}</s>\n{{ end }}" +
|
||||
"<|assistant|>\n{{ .Response }}</s>\n\"\n"
|
||||
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
|
||||
if err := os.WriteFile(tmpModelfile, []byte(modelfileContent), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write Modelfile: %v", err)
|
||||
}
|
||||
|
||||
runOllamaCreate(ctx, t, modelName, "--experimental", "-f", tmpModelfile)
|
||||
|
||||
// Verify model exists via show
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
showResp, err := client.Show(ctx, showReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Model show failed after create: %v", err)
|
||||
}
|
||||
t.Logf("Created model details: %+v", showResp.Details)
|
||||
|
||||
// Use the chat API for proper template application.
|
||||
chatReq := &api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Write a short sentence about the weather."},
|
||||
},
|
||||
Options: map[string]interface{}{
|
||||
"num_predict": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
err = client.Chat(ctx, chatReq, func(resp api.ChatResponse) error {
|
||||
output.WriteString(resp.Message.Content)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat failed: %v", err)
|
||||
}
|
||||
|
||||
text := output.String()
|
||||
t.Logf("Generated output: %q", text)
|
||||
assertCoherentOutput(t, text)
|
||||
|
||||
// Cleanup: delete the model
|
||||
deleteReq := &api.DeleteRequest{Model: modelName}
|
||||
if err := client.Delete(ctx, deleteReq); err != nil {
|
||||
t.Logf("Warning: failed to delete test model: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGGUF(t *testing.T) {
|
||||
modelDir := filepath.Join(testdataModelsDir, "Llama-3.2-1B-GGUF")
|
||||
downloadHFModel(t, "bartowski/Llama-3.2-1B-Instruct-GGUF", modelDir,
|
||||
"--include", "Llama-3.2-1B-Instruct-IQ3_M.gguf")
|
||||
|
||||
// Find the GGUF file
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read model dir: %v", err)
|
||||
}
|
||||
|
||||
var ggufPath string
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".gguf" {
|
||||
ggufPath = filepath.Join(modelDir, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if ggufPath == "" {
|
||||
t.Skip("No GGUF file found in model directory")
|
||||
}
|
||||
|
||||
absGGUF, err := filepath.Abs(ggufPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "test-llama32-gguf"
|
||||
|
||||
// Create a Modelfile and use the CLI
|
||||
tmpModelfile := filepath.Join(t.TempDir(), "Modelfile")
|
||||
if err := os.WriteFile(tmpModelfile, []byte("FROM "+absGGUF+"\n"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write Modelfile: %v", err)
|
||||
}
|
||||
|
||||
createCmd := exec.CommandContext(ctx, ollamaBin(), "create", modelName, "-f", tmpModelfile)
|
||||
createCmd.Stdout = os.Stdout
|
||||
createCmd.Stderr = os.Stderr
|
||||
if err := createCmd.Run(); err != nil {
|
||||
t.Fatalf("ollama create failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
_, err = client.Show(ctx, showReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Model show failed after create: %v", err)
|
||||
}
|
||||
|
||||
// Generate and verify output is coherent
|
||||
genReq := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: "Write a short sentence about the weather.",
|
||||
Options: map[string]interface{}{
|
||||
"num_predict": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
err = client.Generate(ctx, genReq, func(resp api.GenerateResponse) error {
|
||||
output.WriteString(resp.Response)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
text := output.String()
|
||||
t.Logf("Generated output: %q", text)
|
||||
assertCoherentOutput(t, text)
|
||||
|
||||
// Cleanup
|
||||
deleteReq := &api.DeleteRequest{Model: modelName}
|
||||
if err := client.Delete(ctx, deleteReq); err != nil {
|
||||
t.Logf("Warning: failed to delete test model: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// assertCoherentOutput checks that model output looks like real language, not
|
||||
// garbled binary or repeated garbage. This catches corrupted model creation
|
||||
// where inference "works" but produces nonsense.
|
||||
func assertCoherentOutput(t *testing.T, text string) {
|
||||
t.Helper()
|
||||
|
||||
if len(text) == 0 {
|
||||
t.Fatal("model produced empty output")
|
||||
}
|
||||
|
||||
// Check minimum length — 20 tokens should produce at least a few words
|
||||
if len(text) < 5 {
|
||||
t.Fatalf("model output suspiciously short (%d bytes): %q", len(text), text)
|
||||
}
|
||||
|
||||
// Check for mostly-printable ASCII/Unicode — garbled models often emit
|
||||
// high ratios of control characters or replacement characters
|
||||
unprintable := 0
|
||||
for _, r := range text {
|
||||
if r < 0x20 && r != '\n' && r != '\r' && r != '\t' {
|
||||
unprintable++
|
||||
}
|
||||
if r == '\ufffd' { // Unicode replacement character
|
||||
unprintable++
|
||||
}
|
||||
}
|
||||
ratio := float64(unprintable) / float64(len([]rune(text)))
|
||||
if ratio > 0.3 {
|
||||
t.Fatalf("model output is %.0f%% unprintable characters (likely garbled): %q", ratio*100, text)
|
||||
}
|
||||
|
||||
// Check it contains at least one space — real language has word boundaries
|
||||
if !strings.Contains(text, " ") {
|
||||
t.Fatalf("model output contains no spaces (likely garbled): %q", text)
|
||||
}
|
||||
|
||||
// Check for excessive repetition — a broken model might repeat one token
|
||||
words := strings.Fields(text)
|
||||
if len(words) >= 4 {
|
||||
counts := map[string]int{}
|
||||
for _, w := range words {
|
||||
counts[strings.ToLower(w)]++
|
||||
}
|
||||
for w, c := range counts {
|
||||
if c > len(words)*3/4 {
|
||||
t.Fatalf("model output is excessively repetitive (%q appears %d/%d times): %q", w, c, len(words), text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) {
|
||||
if err == nil && !remote {
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
@@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
if config.ModelFormat == "" {
|
||||
config.ModelFormat = baseConfig.ModelFormat
|
||||
}
|
||||
if len(config.Capabilities) == 0 {
|
||||
config.Capabilities = baseConfig.Capabilities
|
||||
}
|
||||
|
||||
@@ -41,11 +41,12 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
for _, srcLayer := range m.Layers {
|
||||
layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer.Name = srcLayer.Name
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model",
|
||||
|
||||
@@ -1024,3 +1024,272 @@ func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createTestBlob creates a blob in the blobs directory and returns its digest.
|
||||
func createTestBlob(t *testing.T, data []byte) string {
|
||||
t.Helper()
|
||||
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(blobPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return digest
|
||||
}
|
||||
|
||||
// createSafetensorsTestModel creates a minimal safetensors model manifest for testing.
|
||||
func createSafetensorsTestModel(t *testing.T, modelName string, config model.ConfigV2, extraLayers []manifest.Layer) {
|
||||
t.Helper()
|
||||
|
||||
// Create a fake tensor blob
|
||||
tensorData := []byte("fake-tensor-data-for-testing")
|
||||
tensorDigest := createTestBlob(t, tensorData)
|
||||
|
||||
layers := []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: tensorDigest,
|
||||
Size: int64(len(tensorData)),
|
||||
Name: "model.embed_tokens.weight",
|
||||
},
|
||||
}
|
||||
layers = append(layers, extraLayers...)
|
||||
|
||||
configLayer, err := createConfigLayer(layers, config)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config layer: %v", err)
|
||||
}
|
||||
|
||||
name := model.ParseName(modelName)
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
t.Fatalf("failed to write manifest: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromSafetensorsModel_PreservesConfig(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
// Create a source safetensors model with specific config fields
|
||||
createSafetensorsTestModel(t, "source-model", model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion"},
|
||||
Requires: "0.14.0",
|
||||
Renderer: "gemma3",
|
||||
Parser: "gemma3",
|
||||
}, nil)
|
||||
|
||||
// Create a derived model FROM the source
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "derived-model",
|
||||
From: "source-model",
|
||||
System: "You are a pirate.",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Read the derived model's config
|
||||
derivedName := model.ParseName("derived-model")
|
||||
mf, err := manifest.ParseNamedManifest(derivedName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse derived manifest: %v", err)
|
||||
}
|
||||
|
||||
configBlobPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get config blob path: %v", err)
|
||||
}
|
||||
|
||||
configBlob, err := os.ReadFile(configBlobPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read config blob: %v", err)
|
||||
}
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.Unmarshal(configBlob, &cfg); err != nil {
|
||||
t.Fatalf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
// Verify safetensors-specific config fields are preserved
|
||||
if cfg.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors")
|
||||
}
|
||||
|
||||
if !slices.Contains(cfg.Capabilities, "completion") {
|
||||
t.Errorf("Capabilities = %v, want to contain %q", cfg.Capabilities, "completion")
|
||||
}
|
||||
|
||||
if cfg.Requires != "0.14.0" {
|
||||
t.Errorf("Requires = %q, want %q", cfg.Requires, "0.14.0")
|
||||
}
|
||||
|
||||
if cfg.Renderer != "gemma3" {
|
||||
t.Errorf("Renderer = %q, want %q", cfg.Renderer, "gemma3")
|
||||
}
|
||||
|
||||
if cfg.Parser != "gemma3" {
|
||||
t.Errorf("Parser = %q, want %q", cfg.Parser, "gemma3")
|
||||
}
|
||||
|
||||
// Verify system prompt was added
|
||||
var hasSystem bool
|
||||
for _, l := range mf.Layers {
|
||||
if l.MediaType == "application/vnd.ollama.image.system" {
|
||||
hasSystem = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasSystem {
|
||||
t.Error("expected system prompt layer in derived model")
|
||||
}
|
||||
|
||||
// Verify tensor layers were copied with names preserved
|
||||
var tensorNames []string
|
||||
for _, l := range mf.Layers {
|
||||
if l.MediaType == manifest.MediaTypeImageTensor {
|
||||
tensorNames = append(tensorNames, l.Name)
|
||||
}
|
||||
}
|
||||
if len(tensorNames) == 0 {
|
||||
t.Error("expected tensor layers in derived model")
|
||||
}
|
||||
for _, name := range tensorNames {
|
||||
if name == "" {
|
||||
t.Error("tensor layer has empty name — names must be preserved from source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromSafetensorsModel_OverrideSystem(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
// Create source with a system prompt
|
||||
createSafetensorsTestModel(t, "source-with-system", model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion"},
|
||||
}, nil)
|
||||
|
||||
// First create with system prompt
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "source-with-system",
|
||||
From: "source-with-system",
|
||||
System: "Original system prompt",
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Now create a derived model with a different system prompt
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "derived-new-system",
|
||||
From: "source-with-system",
|
||||
System: "New system prompt",
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify ModelFormat is preserved even after override
|
||||
derivedName := model.ParseName("derived-new-system")
|
||||
mf, err := manifest.ParseNamedManifest(derivedName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse derived manifest: %v", err)
|
||||
}
|
||||
|
||||
configBlobPath, _ := manifest.BlobsPath(mf.Config.Digest)
|
||||
configBlob, _ := os.ReadFile(configBlobPath)
|
||||
|
||||
var cfg model.ConfigV2
|
||||
json.Unmarshal(configBlob, &cfg)
|
||||
|
||||
if cfg.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", cfg.ModelFormat, "safetensors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromSafetensorsModel_PreservesLayerNames(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
// Create JSON config blobs to include as layers
|
||||
configJSON := []byte(`{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`)
|
||||
configDigest := createTestBlob(t, configJSON)
|
||||
tokenizerJSON := []byte(`{"version": "1.0"}`)
|
||||
tokenizerDigest := createTestBlob(t, tokenizerJSON)
|
||||
|
||||
extraLayers := []manifest.Layer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: configDigest,
|
||||
Size: int64(len(configJSON)),
|
||||
Name: "config.json",
|
||||
},
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: tokenizerDigest,
|
||||
Size: int64(len(tokenizerJSON)),
|
||||
Name: "tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
createSafetensorsTestModel(t, "source-named-layers", model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion"},
|
||||
}, extraLayers)
|
||||
|
||||
// Create derived model
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "derived-named-layers",
|
||||
From: "source-named-layers",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
derivedName := model.ParseName("derived-named-layers")
|
||||
mf, err := manifest.ParseNamedManifest(derivedName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse derived manifest: %v", err)
|
||||
}
|
||||
|
||||
// Check tensor layer names are preserved
|
||||
for _, l := range mf.Layers {
|
||||
if l.MediaType == manifest.MediaTypeImageTensor && l.Name == "" {
|
||||
t.Error("tensor layer has empty name — names must be preserved from source")
|
||||
}
|
||||
}
|
||||
|
||||
// Check JSON layer names are preserved
|
||||
jsonNames := make(map[string]bool)
|
||||
for _, l := range mf.Layers {
|
||||
if l.MediaType == "application/vnd.ollama.image.json" && l.Name != "" {
|
||||
jsonNames[l.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if !jsonNames["config.json"] {
|
||||
t.Error("config.json layer name not preserved in derived model")
|
||||
}
|
||||
if !jsonNames["tokenizer.json"] {
|
||||
t.Error("tokenizer.json layer name not preserved in derived model")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
|
||||
@@ -429,3 +429,159 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
|
||||
t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "qwen3 architecture",
|
||||
configJSON: `{"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "deepseek architecture",
|
||||
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "glm4moe architecture",
|
||||
configJSON: `{"architectures": ["GLM4MoeForCausalLM"]}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "llama architecture (no thinking)",
|
||||
configJSON: `{"architectures": ["LlamaForCausalLM"], "model_type": "llama"}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "gemma architecture (no thinking)",
|
||||
configJSON: `{"architectures": ["Gemma3ForCausalLM"], "model_type": "gemma3"}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "model_type only",
|
||||
configJSON: `{"model_type": "deepseek"}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
configJSON: `{}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
configJSON: `not json`,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
|
||||
|
||||
if got := supportsThinking(dir); got != tt.want {
|
||||
t.Errorf("supportsThinking() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsThinking_NoConfig(t *testing.T) {
|
||||
if supportsThinking(t.TempDir()) {
|
||||
t.Error("supportsThinking should return false for missing config.json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "qwen3 model",
|
||||
configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`,
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "deepseek model",
|
||||
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
|
||||
want: "deepseek3",
|
||||
},
|
||||
{
|
||||
name: "glm4 model",
|
||||
configJSON: `{"architectures": ["GLM4ForCausalLM"]}`,
|
||||
want: "glm-4.7",
|
||||
},
|
||||
{
|
||||
name: "llama model (no parser)",
|
||||
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "qwen3 via model_type",
|
||||
configJSON: `{"model_type": "qwen3"}`,
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "no config",
|
||||
configJSON: `{}`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
|
||||
|
||||
if got := getParserName(dir); got != tt.want {
|
||||
t.Errorf("getParserName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRendererName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "qwen3 model",
|
||||
configJSON: `{"architectures": ["Qwen3ForCausalLM"]}`,
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "deepseek model",
|
||||
configJSON: `{"architectures": ["DeepseekV3ForCausalLM"]}`,
|
||||
want: "deepseek3",
|
||||
},
|
||||
{
|
||||
name: "glm4 model",
|
||||
configJSON: `{"architectures": ["GLM4ForCausalLM"]}`,
|
||||
want: "glm-4.7",
|
||||
},
|
||||
{
|
||||
name: "llama model (no renderer)",
|
||||
configJSON: `{"architectures": ["LlamaForCausalLM"]}`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
|
||||
|
||||
if got := getRendererName(dir); got != tt.want {
|
||||
t.Errorf("getRendererName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
st "github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
st "github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
type qwen35ImportTransform struct {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
)
|
||||
|
||||
// tensorInfo holds tensor metadata from safetensors headers.
|
||||
// This avoids depending on safetensors.go which requires the mlx tag.
|
||||
type tensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
394
x/safetensors/extractor_test.go
Normal file
394
x/safetensors/extractor_test.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createTestSafetensors creates a minimal valid safetensors file with the given tensors.
|
||||
func createTestSafetensors(t *testing.T, path string, tensors map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
},
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
header := make(map[string]tensorInfo)
|
||||
var offset int
|
||||
var allData []byte
|
||||
|
||||
// Sort names for deterministic file layout
|
||||
names := make([]string, 0, len(tensors))
|
||||
for name := range tensors {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
|
||||
for _, name := range names {
|
||||
info := tensors[name]
|
||||
header[name] = tensorInfo{
|
||||
Dtype: info.dtype,
|
||||
Shape: info.shape,
|
||||
DataOffsets: [2]int{offset, offset + len(info.data)},
|
||||
}
|
||||
allData = append(allData, info.data...)
|
||||
offset += len(info.data)
|
||||
}
|
||||
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
if _, err := f.Write(allData); err != nil {
|
||||
t.Fatalf("failed to write data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenForExtraction(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
// 4 float32 values = 16 bytes
|
||||
data := make([]byte, 16)
|
||||
binary.LittleEndian.PutUint32(data[0:4], 0x3f800000) // 1.0
|
||||
binary.LittleEndian.PutUint32(data[4:8], 0x40000000) // 2.0
|
||||
binary.LittleEndian.PutUint32(data[8:12], 0x40400000) // 3.0
|
||||
binary.LittleEndian.PutUint32(data[12:16], 0x40800000) // 4.0
|
||||
|
||||
createTestSafetensors(t, path, map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
}{
|
||||
"test_tensor": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
||||
})
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
if ext.TensorCount() != 1 {
|
||||
t.Errorf("TensorCount() = %d, want 1", ext.TensorCount())
|
||||
}
|
||||
|
||||
names := ext.ListTensors()
|
||||
if len(names) != 1 || names[0] != "test_tensor" {
|
||||
t.Errorf("ListTensors() = %v, want [test_tensor]", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensor(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
data := make([]byte, 16)
|
||||
for i := range 4 {
|
||||
binary.LittleEndian.PutUint32(data[i*4:], uint32(i+1))
|
||||
}
|
||||
|
||||
createTestSafetensors(t, path, map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
}{
|
||||
"weight": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
||||
})
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
td, err := ext.GetTensor("weight")
|
||||
if err != nil {
|
||||
t.Fatalf("GetTensor failed: %v", err)
|
||||
}
|
||||
|
||||
if td.Name != "weight" {
|
||||
t.Errorf("Name = %q, want %q", td.Name, "weight")
|
||||
}
|
||||
if td.Dtype != "F32" {
|
||||
t.Errorf("Dtype = %q, want %q", td.Dtype, "F32")
|
||||
}
|
||||
if td.Size != 16 {
|
||||
t.Errorf("Size = %d, want 16", td.Size)
|
||||
}
|
||||
if len(td.Shape) != 2 || td.Shape[0] != 2 || td.Shape[1] != 2 {
|
||||
t.Errorf("Shape = %v, want [2 2]", td.Shape)
|
||||
}
|
||||
|
||||
// Read the raw data
|
||||
rawData, err := io.ReadAll(td.Reader())
|
||||
if err != nil {
|
||||
t.Fatalf("Reader() read failed: %v", err)
|
||||
}
|
||||
if len(rawData) != 16 {
|
||||
t.Errorf("raw data length = %d, want 16", len(rawData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensor_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
createTestSafetensors(t, path, map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
}{
|
||||
"exists": {dtype: "F32", shape: []int32{1}, data: make([]byte, 4)},
|
||||
})
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
_, err = ext.GetTensor("missing")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing tensor, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorsReaderRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
data := make([]byte, 16)
|
||||
for i := range 4 {
|
||||
binary.LittleEndian.PutUint32(data[i*4:], uint32(0x3f800000+i))
|
||||
}
|
||||
|
||||
createTestSafetensors(t, path, map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
}{
|
||||
"tensor_a": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
||||
})
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
td, err := ext.GetTensor("tensor_a")
|
||||
if err != nil {
|
||||
t.Fatalf("GetTensor failed: %v", err)
|
||||
}
|
||||
|
||||
// Wrap as safetensors and extract back
|
||||
stReader := td.SafetensorsReader()
|
||||
stData, err := io.ReadAll(stReader)
|
||||
if err != nil {
|
||||
t.Fatalf("SafetensorsReader read failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify size
|
||||
if int64(len(stData)) != td.SafetensorsSize() {
|
||||
t.Errorf("SafetensorsSize() = %d, actual = %d", td.SafetensorsSize(), len(stData))
|
||||
}
|
||||
|
||||
// Extract raw data back
|
||||
raw, err := ExtractRawFromSafetensors(bytes.NewReader(stData))
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractRawFromSafetensors failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(raw, data) {
|
||||
t.Errorf("round-trip data mismatch: got %v, want %v", raw, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTensorDataFromBytes(t *testing.T) {
|
||||
data := []byte{1, 2, 3, 4}
|
||||
td := NewTensorDataFromBytes("test", "U8", []int32{4}, data)
|
||||
|
||||
if td.Name != "test" {
|
||||
t.Errorf("Name = %q, want %q", td.Name, "test")
|
||||
}
|
||||
if td.Size != 4 {
|
||||
t.Errorf("Size = %d, want 4", td.Size)
|
||||
}
|
||||
|
||||
rawData, err := io.ReadAll(td.Reader())
|
||||
if err != nil {
|
||||
t.Fatalf("Reader() failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(rawData, data) {
|
||||
t.Errorf("data mismatch: got %v, want %v", rawData, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPackedSafetensorsReader(t *testing.T) {
|
||||
data1 := []byte{1, 2, 3, 4}
|
||||
data2 := []byte{5, 6, 7, 8, 9, 10, 11, 12}
|
||||
|
||||
td1 := NewTensorDataFromBytes("a", "U8", []int32{4}, data1)
|
||||
td2 := NewTensorDataFromBytes("b", "U8", []int32{8}, data2)
|
||||
|
||||
packed := BuildPackedSafetensorsReader([]*TensorData{td1, td2})
|
||||
packedBytes, err := io.ReadAll(packed)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPackedSafetensorsReader read failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's a valid safetensors file by parsing the header
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(packedBytes), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
headerJSON := packedBytes[8 : 8+headerSize]
|
||||
var header map[string]tensorInfo
|
||||
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
|
||||
if len(header) != 2 {
|
||||
t.Errorf("header has %d entries, want 2", len(header))
|
||||
}
|
||||
|
||||
infoA, ok := header["a"]
|
||||
if !ok {
|
||||
t.Fatal("tensor 'a' not found in header")
|
||||
}
|
||||
if infoA.Dtype != "U8" {
|
||||
t.Errorf("tensor 'a' dtype = %q, want %q", infoA.Dtype, "U8")
|
||||
}
|
||||
|
||||
infoB, ok := header["b"]
|
||||
if !ok {
|
||||
t.Fatal("tensor 'b' not found in header")
|
||||
}
|
||||
|
||||
// Verify data region contains both tensors
|
||||
dataStart := 8 + int(headerSize)
|
||||
dataRegion := packedBytes[dataStart:]
|
||||
if infoA.DataOffsets[0] == 0 {
|
||||
// a comes first
|
||||
if !bytes.Equal(dataRegion[:4], data1) {
|
||||
t.Error("tensor 'a' data mismatch")
|
||||
}
|
||||
if !bytes.Equal(dataRegion[infoB.DataOffsets[0]:infoB.DataOffsets[1]], data2) {
|
||||
t.Error("tensor 'b' data mismatch")
|
||||
}
|
||||
} else {
|
||||
// b comes first
|
||||
if !bytes.Equal(dataRegion[:8], data2) {
|
||||
t.Error("tensor 'b' data mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAll(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
createTestSafetensors(t, path, map[string]struct {
|
||||
dtype string
|
||||
shape []int32
|
||||
data []byte
|
||||
}{
|
||||
"alpha": {dtype: "F32", shape: []int32{2}, data: make([]byte, 8)},
|
||||
"beta": {dtype: "F16", shape: []int32{4}, data: make([]byte, 8)},
|
||||
})
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
tensors, err := ext.ExtractAll()
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractAll failed: %v", err)
|
||||
}
|
||||
|
||||
if len(tensors) != 2 {
|
||||
t.Errorf("ExtractAll returned %d tensors, want 2", len(tensors))
|
||||
}
|
||||
|
||||
// Verify sorted order
|
||||
if tensors[0].Name != "alpha" || tensors[1].Name != "beta" {
|
||||
t.Errorf("tensors not in sorted order: %s, %s", tensors[0].Name, tensors[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRawFromSafetensors_InvalidInput(t *testing.T) {
|
||||
// Empty reader
|
||||
_, err := ExtractRawFromSafetensors(bytes.NewReader(nil))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty reader")
|
||||
}
|
||||
|
||||
// Truncated header size
|
||||
_, err = ExtractRawFromSafetensors(bytes.NewReader([]byte{1, 2, 3}))
|
||||
if err == nil {
|
||||
t.Error("expected error for truncated header size")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenForExtraction_MetadataIgnored(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
// Manually create a safetensors file with __metadata__
|
||||
header := map[string]any{
|
||||
"__metadata__": map[string]string{"format": "pt"},
|
||||
"weight": tensorInfo{
|
||||
Dtype: "F32",
|
||||
Shape: []int32{2},
|
||||
DataOffsets: [2]int{0, 8},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
f, _ := os.Create(path)
|
||||
binary.Write(f, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
f.Write(headerJSON)
|
||||
f.Write(make([]byte, 8))
|
||||
f.Close()
|
||||
|
||||
ext, err := OpenForExtraction(path)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenForExtraction failed: %v", err)
|
||||
}
|
||||
defer ext.Close()
|
||||
|
||||
// __metadata__ should be stripped
|
||||
if ext.TensorCount() != 1 {
|
||||
t.Errorf("TensorCount() = %d, want 1 (metadata should be stripped)", ext.TensorCount())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user