mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* create: Clean up experimental paths
This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions.
* create: preserve config and layer names when creating from safetensors models
When creating a model FROM an existing safetensors model, ModelFormat,
Capabilities, and layer Name fields were lost. ModelFormat stayed empty
because it's only set from GGML layers (which safetensors models lack),
and layer names weren't copied in parseFromModel. This caused derived
models to fail loading ("config.json not found in manifest").
* review comments
395 lines
9.6 KiB
Go
395 lines
9.6 KiB
Go
package safetensors
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"testing"
|
|
)
|
|
|
|
// createTestSafetensors creates a minimal valid safetensors file with the given tensors.
|
|
func createTestSafetensors(t *testing.T, path string, tensors map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
},
|
|
) {
|
|
t.Helper()
|
|
|
|
header := make(map[string]tensorInfo)
|
|
var offset int
|
|
var allData []byte
|
|
|
|
// Sort names for deterministic file layout
|
|
names := make([]string, 0, len(tensors))
|
|
for name := range tensors {
|
|
names = append(names, name)
|
|
}
|
|
slices.Sort(names)
|
|
|
|
for _, name := range names {
|
|
info := tensors[name]
|
|
header[name] = tensorInfo{
|
|
Dtype: info.dtype,
|
|
Shape: info.shape,
|
|
DataOffsets: [2]int{offset, offset + len(info.data)},
|
|
}
|
|
allData = append(allData, info.data...)
|
|
offset += len(info.data)
|
|
}
|
|
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal header: %v", err)
|
|
}
|
|
|
|
// Pad to 8-byte alignment
|
|
padding := (8 - len(headerJSON)%8) % 8
|
|
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
|
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
t.Fatalf("failed to create file: %v", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
|
t.Fatalf("failed to write header size: %v", err)
|
|
}
|
|
if _, err := f.Write(headerJSON); err != nil {
|
|
t.Fatalf("failed to write header: %v", err)
|
|
}
|
|
if _, err := f.Write(allData); err != nil {
|
|
t.Fatalf("failed to write data: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestOpenForExtraction(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
// 4 float32 values = 16 bytes
|
|
data := make([]byte, 16)
|
|
binary.LittleEndian.PutUint32(data[0:4], 0x3f800000) // 1.0
|
|
binary.LittleEndian.PutUint32(data[4:8], 0x40000000) // 2.0
|
|
binary.LittleEndian.PutUint32(data[8:12], 0x40400000) // 3.0
|
|
binary.LittleEndian.PutUint32(data[12:16], 0x40800000) // 4.0
|
|
|
|
createTestSafetensors(t, path, map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
}{
|
|
"test_tensor": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
|
})
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
if ext.TensorCount() != 1 {
|
|
t.Errorf("TensorCount() = %d, want 1", ext.TensorCount())
|
|
}
|
|
|
|
names := ext.ListTensors()
|
|
if len(names) != 1 || names[0] != "test_tensor" {
|
|
t.Errorf("ListTensors() = %v, want [test_tensor]", names)
|
|
}
|
|
}
|
|
|
|
func TestGetTensor(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
data := make([]byte, 16)
|
|
for i := range 4 {
|
|
binary.LittleEndian.PutUint32(data[i*4:], uint32(i+1))
|
|
}
|
|
|
|
createTestSafetensors(t, path, map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
}{
|
|
"weight": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
|
})
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
td, err := ext.GetTensor("weight")
|
|
if err != nil {
|
|
t.Fatalf("GetTensor failed: %v", err)
|
|
}
|
|
|
|
if td.Name != "weight" {
|
|
t.Errorf("Name = %q, want %q", td.Name, "weight")
|
|
}
|
|
if td.Dtype != "F32" {
|
|
t.Errorf("Dtype = %q, want %q", td.Dtype, "F32")
|
|
}
|
|
if td.Size != 16 {
|
|
t.Errorf("Size = %d, want 16", td.Size)
|
|
}
|
|
if len(td.Shape) != 2 || td.Shape[0] != 2 || td.Shape[1] != 2 {
|
|
t.Errorf("Shape = %v, want [2 2]", td.Shape)
|
|
}
|
|
|
|
// Read the raw data
|
|
rawData, err := io.ReadAll(td.Reader())
|
|
if err != nil {
|
|
t.Fatalf("Reader() read failed: %v", err)
|
|
}
|
|
if len(rawData) != 16 {
|
|
t.Errorf("raw data length = %d, want 16", len(rawData))
|
|
}
|
|
}
|
|
|
|
func TestGetTensor_NotFound(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
createTestSafetensors(t, path, map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
}{
|
|
"exists": {dtype: "F32", shape: []int32{1}, data: make([]byte, 4)},
|
|
})
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
_, err = ext.GetTensor("missing")
|
|
if err == nil {
|
|
t.Error("expected error for missing tensor, got nil")
|
|
}
|
|
}
|
|
|
|
func TestSafetensorsReaderRoundTrip(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
data := make([]byte, 16)
|
|
for i := range 4 {
|
|
binary.LittleEndian.PutUint32(data[i*4:], uint32(0x3f800000+i))
|
|
}
|
|
|
|
createTestSafetensors(t, path, map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
}{
|
|
"tensor_a": {dtype: "F32", shape: []int32{2, 2}, data: data},
|
|
})
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
td, err := ext.GetTensor("tensor_a")
|
|
if err != nil {
|
|
t.Fatalf("GetTensor failed: %v", err)
|
|
}
|
|
|
|
// Wrap as safetensors and extract back
|
|
stReader := td.SafetensorsReader()
|
|
stData, err := io.ReadAll(stReader)
|
|
if err != nil {
|
|
t.Fatalf("SafetensorsReader read failed: %v", err)
|
|
}
|
|
|
|
// Verify size
|
|
if int64(len(stData)) != td.SafetensorsSize() {
|
|
t.Errorf("SafetensorsSize() = %d, actual = %d", td.SafetensorsSize(), len(stData))
|
|
}
|
|
|
|
// Extract raw data back
|
|
raw, err := ExtractRawFromSafetensors(bytes.NewReader(stData))
|
|
if err != nil {
|
|
t.Fatalf("ExtractRawFromSafetensors failed: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(raw, data) {
|
|
t.Errorf("round-trip data mismatch: got %v, want %v", raw, data)
|
|
}
|
|
}
|
|
|
|
func TestNewTensorDataFromBytes(t *testing.T) {
|
|
data := []byte{1, 2, 3, 4}
|
|
td := NewTensorDataFromBytes("test", "U8", []int32{4}, data)
|
|
|
|
if td.Name != "test" {
|
|
t.Errorf("Name = %q, want %q", td.Name, "test")
|
|
}
|
|
if td.Size != 4 {
|
|
t.Errorf("Size = %d, want 4", td.Size)
|
|
}
|
|
|
|
rawData, err := io.ReadAll(td.Reader())
|
|
if err != nil {
|
|
t.Fatalf("Reader() failed: %v", err)
|
|
}
|
|
if !bytes.Equal(rawData, data) {
|
|
t.Errorf("data mismatch: got %v, want %v", rawData, data)
|
|
}
|
|
}
|
|
|
|
func TestBuildPackedSafetensorsReader(t *testing.T) {
|
|
data1 := []byte{1, 2, 3, 4}
|
|
data2 := []byte{5, 6, 7, 8, 9, 10, 11, 12}
|
|
|
|
td1 := NewTensorDataFromBytes("a", "U8", []int32{4}, data1)
|
|
td2 := NewTensorDataFromBytes("b", "U8", []int32{8}, data2)
|
|
|
|
packed := BuildPackedSafetensorsReader([]*TensorData{td1, td2})
|
|
packedBytes, err := io.ReadAll(packed)
|
|
if err != nil {
|
|
t.Fatalf("BuildPackedSafetensorsReader read failed: %v", err)
|
|
}
|
|
|
|
// Verify it's a valid safetensors file by parsing the header
|
|
var headerSize uint64
|
|
if err := binary.Read(bytes.NewReader(packedBytes), binary.LittleEndian, &headerSize); err != nil {
|
|
t.Fatalf("failed to read header size: %v", err)
|
|
}
|
|
|
|
headerJSON := packedBytes[8 : 8+headerSize]
|
|
var header map[string]tensorInfo
|
|
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
|
t.Fatalf("failed to parse header: %v", err)
|
|
}
|
|
|
|
if len(header) != 2 {
|
|
t.Errorf("header has %d entries, want 2", len(header))
|
|
}
|
|
|
|
infoA, ok := header["a"]
|
|
if !ok {
|
|
t.Fatal("tensor 'a' not found in header")
|
|
}
|
|
if infoA.Dtype != "U8" {
|
|
t.Errorf("tensor 'a' dtype = %q, want %q", infoA.Dtype, "U8")
|
|
}
|
|
|
|
infoB, ok := header["b"]
|
|
if !ok {
|
|
t.Fatal("tensor 'b' not found in header")
|
|
}
|
|
|
|
// Verify data region contains both tensors
|
|
dataStart := 8 + int(headerSize)
|
|
dataRegion := packedBytes[dataStart:]
|
|
if infoA.DataOffsets[0] == 0 {
|
|
// a comes first
|
|
if !bytes.Equal(dataRegion[:4], data1) {
|
|
t.Error("tensor 'a' data mismatch")
|
|
}
|
|
if !bytes.Equal(dataRegion[infoB.DataOffsets[0]:infoB.DataOffsets[1]], data2) {
|
|
t.Error("tensor 'b' data mismatch")
|
|
}
|
|
} else {
|
|
// b comes first
|
|
if !bytes.Equal(dataRegion[:8], data2) {
|
|
t.Error("tensor 'b' data mismatch")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExtractAll(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
createTestSafetensors(t, path, map[string]struct {
|
|
dtype string
|
|
shape []int32
|
|
data []byte
|
|
}{
|
|
"alpha": {dtype: "F32", shape: []int32{2}, data: make([]byte, 8)},
|
|
"beta": {dtype: "F16", shape: []int32{4}, data: make([]byte, 8)},
|
|
})
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
tensors, err := ext.ExtractAll()
|
|
if err != nil {
|
|
t.Fatalf("ExtractAll failed: %v", err)
|
|
}
|
|
|
|
if len(tensors) != 2 {
|
|
t.Errorf("ExtractAll returned %d tensors, want 2", len(tensors))
|
|
}
|
|
|
|
// Verify sorted order
|
|
if tensors[0].Name != "alpha" || tensors[1].Name != "beta" {
|
|
t.Errorf("tensors not in sorted order: %s, %s", tensors[0].Name, tensors[1].Name)
|
|
}
|
|
}
|
|
|
|
func TestExtractRawFromSafetensors_InvalidInput(t *testing.T) {
|
|
// Empty reader
|
|
_, err := ExtractRawFromSafetensors(bytes.NewReader(nil))
|
|
if err == nil {
|
|
t.Error("expected error for empty reader")
|
|
}
|
|
|
|
// Truncated header size
|
|
_, err = ExtractRawFromSafetensors(bytes.NewReader([]byte{1, 2, 3}))
|
|
if err == nil {
|
|
t.Error("expected error for truncated header size")
|
|
}
|
|
}
|
|
|
|
func TestOpenForExtraction_MetadataIgnored(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
// Manually create a safetensors file with __metadata__
|
|
header := map[string]any{
|
|
"__metadata__": map[string]string{"format": "pt"},
|
|
"weight": tensorInfo{
|
|
Dtype: "F32",
|
|
Shape: []int32{2},
|
|
DataOffsets: [2]int{0, 8},
|
|
},
|
|
}
|
|
headerJSON, _ := json.Marshal(header)
|
|
padding := (8 - len(headerJSON)%8) % 8
|
|
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
|
|
|
f, _ := os.Create(path)
|
|
binary.Write(f, binary.LittleEndian, uint64(len(headerJSON)))
|
|
f.Write(headerJSON)
|
|
f.Write(make([]byte, 8))
|
|
f.Close()
|
|
|
|
ext, err := OpenForExtraction(path)
|
|
if err != nil {
|
|
t.Fatalf("OpenForExtraction failed: %v", err)
|
|
}
|
|
defer ext.Close()
|
|
|
|
// __metadata__ should be stripped
|
|
if ext.TensorCount() != 1 {
|
|
t.Errorf("TensorCount() = %d, want 1 (metadata should be stripped)", ext.TensorCount())
|
|
}
|
|
}
|