mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 14:54:11 +02:00
This change adds a tensorImportTransform interface for model-specific tensor transformations during safetensors import. This allows importing and modifying the standard HF based weights as well as the mlx-community derived pre-quantized safetensors repos to be directly imported into `ollama create`. Right now this only works with Qwen3.5 importing which does tensor renaming, norm weight shifting (it adds +1 to each value of the norm vectors), conv1d transposition, and casts to BF16s for F32 based vectors.
1157 lines
37 KiB
Go
1157 lines
37 KiB
Go
package create
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/d4l3k/go-bfloat16"
|
|
st "github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
func TestIsTensorModelDir(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setup func(dir string) error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "valid diffusers model with model_index.json",
|
|
setup: func(dir string) error {
|
|
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "empty directory",
|
|
setup: func(dir string) error {
|
|
return nil
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "directory with other files but no model_index.json",
|
|
setup: func(dir string) error {
|
|
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
|
},
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := tt.setup(dir); err != nil {
|
|
t.Fatalf("setup failed: %v", err)
|
|
}
|
|
|
|
got := IsTensorModelDir(dir)
|
|
if got != tt.expected {
|
|
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsSafetensorsModelDir(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setup func(dir string) error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "valid safetensors model with config.json and .safetensors file",
|
|
setup: func(dir string) error {
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "config.json only, no safetensors files",
|
|
setup: func(dir string) error {
|
|
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "safetensors file only, no config.json",
|
|
setup: func(dir string) error {
|
|
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "empty directory",
|
|
setup: func(dir string) error {
|
|
return nil
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "multiple safetensors files with config.json",
|
|
setup: func(dir string) error {
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
|
return err
|
|
}
|
|
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
|
},
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := tt.setup(dir); err != nil {
|
|
t.Fatalf("setup failed: %v", err)
|
|
}
|
|
|
|
got := IsSafetensorsModelDir(dir)
|
|
if got != tt.expected {
|
|
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
|
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
|
if got != false {
|
|
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
|
}
|
|
}
|
|
|
|
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
|
func createMinimalSafetensors(t *testing.T, path string) {
|
|
t.Helper()
|
|
|
|
// Create a minimal safetensors file with a single float32 tensor
|
|
header := map[string]interface{}{
|
|
"test_tensor": map[string]interface{}{
|
|
"dtype": "F32",
|
|
"shape": []int{2, 2},
|
|
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
|
},
|
|
}
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal header: %v", err)
|
|
}
|
|
|
|
// Pad header to 8-byte alignment
|
|
padding := (8 - len(headerJSON)%8) % 8
|
|
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
|
|
|
// Write file
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
t.Fatalf("failed to create file: %v", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
// Write header size (8 bytes, little endian)
|
|
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
|
t.Fatalf("failed to write header size: %v", err)
|
|
}
|
|
|
|
// Write header
|
|
if _, err := f.Write(headerJSON); err != nil {
|
|
t.Fatalf("failed to write header: %v", err)
|
|
}
|
|
|
|
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
|
if _, err := f.Write(make([]byte, 16)); err != nil {
|
|
t.Fatalf("failed to write tensor data: %v", err)
|
|
}
|
|
}
|
|
|
|
func createTestSafetensors(t *testing.T, path string, tensors []*st.TensorData) {
|
|
t.Helper()
|
|
|
|
data, err := io.ReadAll(st.BuildPackedSafetensorsReader(tensors))
|
|
if err != nil {
|
|
t.Fatalf("failed to build packed safetensors: %v", err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
t.Fatalf("failed to write safetensors: %v", err)
|
|
}
|
|
}
|
|
|
|
func readSingleTensorHeader(t *testing.T, data []byte) (string, []int32) {
|
|
t.Helper()
|
|
|
|
var headerSize uint64
|
|
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
|
t.Fatalf("failed to read header size: %v", err)
|
|
}
|
|
|
|
var header map[string]struct {
|
|
Dtype string `json:"dtype"`
|
|
Shape []int32 `json:"shape"`
|
|
}
|
|
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
|
t.Fatalf("failed to parse header: %v", err)
|
|
}
|
|
|
|
for name, info := range header {
|
|
if name == "__metadata__" {
|
|
continue
|
|
}
|
|
return info.Dtype, info.Shape
|
|
}
|
|
|
|
t.Fatal("no tensor entry found in header")
|
|
return "", nil
|
|
}
|
|
|
|
func readSingleTensorRaw(t *testing.T, data []byte) []byte {
|
|
t.Helper()
|
|
|
|
var headerSize uint64
|
|
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
|
t.Fatalf("failed to read header size: %v", err)
|
|
}
|
|
|
|
var header map[string]struct {
|
|
Dtype string `json:"dtype"`
|
|
Shape []int32 `json:"shape"`
|
|
DataOffsets [2]int `json:"data_offsets"`
|
|
}
|
|
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
|
t.Fatalf("failed to parse header: %v", err)
|
|
}
|
|
|
|
for name, info := range header {
|
|
if name == "__metadata__" {
|
|
continue
|
|
}
|
|
start := 8 + int(headerSize) + info.DataOffsets[0]
|
|
end := 8 + int(headerSize) + info.DataOffsets[1]
|
|
return data[start:end]
|
|
}
|
|
|
|
t.Fatal("no tensor entry found in header")
|
|
return nil
|
|
}
|
|
|
|
func TestCreateSafetensorsModel(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create config.json
|
|
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
|
t.Fatalf("failed to write config.json: %v", err)
|
|
}
|
|
|
|
// Create a minimal safetensors file
|
|
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
|
|
|
// Track what was created
|
|
var createdLayers []LayerInfo
|
|
var manifestWritten bool
|
|
var manifestModelName string
|
|
var manifestConfigLayer LayerInfo
|
|
var manifestLayers []LayerInfo
|
|
var statusMessages []string
|
|
|
|
// Mock callbacks
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
data, err := io.ReadAll(r)
|
|
if err != nil {
|
|
return LayerInfo{}, err
|
|
}
|
|
layer := LayerInfo{
|
|
Digest: "sha256:test",
|
|
Size: int64(len(data)),
|
|
MediaType: mediaType,
|
|
Name: name,
|
|
}
|
|
createdLayers = append(createdLayers, layer)
|
|
return layer, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
data, err := io.ReadAll(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
layer := LayerInfo{
|
|
Digest: "sha256:tensor_" + name,
|
|
Size: int64(len(data)),
|
|
MediaType: "application/vnd.ollama.image.tensor",
|
|
Name: name,
|
|
}
|
|
createdLayers = append(createdLayers, layer)
|
|
return []LayerInfo{layer}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
manifestWritten = true
|
|
manifestModelName = modelName
|
|
manifestConfigLayer = config
|
|
manifestLayers = layers
|
|
return nil
|
|
}
|
|
|
|
progressFn := func(status string) {
|
|
statusMessages = append(statusMessages, status)
|
|
}
|
|
|
|
// Run CreateSafetensorsModel
|
|
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err != nil {
|
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
|
}
|
|
|
|
// Verify manifest was written
|
|
if !manifestWritten {
|
|
t.Error("manifest was not written")
|
|
}
|
|
|
|
if manifestModelName != "test-model" {
|
|
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
|
}
|
|
|
|
// Verify config layer was set
|
|
if manifestConfigLayer.Name != "config.json" {
|
|
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
|
}
|
|
|
|
// Verify we have at least one tensor and one config layer
|
|
hasTensor := false
|
|
hasConfig := false
|
|
for _, layer := range manifestLayers {
|
|
if layer.Name == "test_tensor" {
|
|
hasTensor = true
|
|
}
|
|
if layer.Name == "config.json" {
|
|
hasConfig = true
|
|
}
|
|
}
|
|
|
|
if !hasTensor {
|
|
t.Error("no tensor layer found in manifest")
|
|
}
|
|
if !hasConfig {
|
|
t.Error("no config layer found in manifest")
|
|
}
|
|
|
|
// Verify status messages were sent
|
|
if len(statusMessages) == 0 {
|
|
t.Error("no status messages received")
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create only a safetensors file, no config.json
|
|
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
|
|
|
// Mock callbacks (minimal)
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return LayerInfo{Name: name}, nil
|
|
}
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return []LayerInfo{{Name: name}}, nil
|
|
}
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
progressFn := func(status string) {}
|
|
|
|
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err == nil {
|
|
t.Error("expected error for missing config.json, got nil")
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Mock callbacks
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
return LayerInfo{}, nil
|
|
}
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
return []LayerInfo{{}}, nil
|
|
}
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
progressFn := func(status string) {}
|
|
|
|
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err == nil {
|
|
t.Error("expected error for empty directory, got nil")
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create config.json
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
|
t.Fatalf("failed to write config.json: %v", err)
|
|
}
|
|
|
|
// Create model.safetensors.index.json (should be skipped)
|
|
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
|
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
|
t.Fatalf("failed to write index.json: %v", err)
|
|
}
|
|
|
|
// Create a minimal safetensors file
|
|
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
|
|
|
var configNames []string
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
configNames = append(configNames, name)
|
|
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
|
}
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return []LayerInfo{{Name: name}}, nil
|
|
}
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
progressFn := func(status string) {}
|
|
|
|
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err != nil {
|
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
|
}
|
|
|
|
// Verify model.safetensors.index.json was not included
|
|
for _, name := range configNames {
|
|
if name == "model.safetensors.index.json" {
|
|
t.Error("model.safetensors.index.json should have been skipped")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
configJSON := `{
|
|
"model_type": "test",
|
|
"architectures": ["TestModel"],
|
|
"quantization": {"group_size": 64, "bits": 4, "mode": "affine"}
|
|
}`
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
|
t.Fatalf("failed to write config.json: %v", err)
|
|
}
|
|
|
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
|
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
|
|
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
|
|
st.NewTensorDataFromBytes("linear.biases", "BF16", []int32{4, 1}, make([]byte, 8)),
|
|
st.NewTensorDataFromBytes("plain.weight", "F32", []int32{2, 2}, make([]byte, 16)),
|
|
})
|
|
|
|
var packedHeader map[string]json.RawMessage
|
|
var tensorLayerNames []string
|
|
var createTensorLayerNames []string
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
data, err := io.ReadAll(r)
|
|
if err != nil {
|
|
return LayerInfo{}, err
|
|
}
|
|
if mediaType == "application/vnd.ollama.image.tensor" && name == "linear.weight" {
|
|
var headerSize uint64
|
|
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
|
return LayerInfo{}, err
|
|
}
|
|
if err := json.Unmarshal(data[8:8+headerSize], &packedHeader); err != nil {
|
|
return LayerInfo{}, err
|
|
}
|
|
}
|
|
tensorLayerNames = append(tensorLayerNames, name)
|
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType, Size: int64(len(data))}, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
if _, err := io.ReadAll(r); err != nil {
|
|
return nil, err
|
|
}
|
|
createTensorLayerNames = append(createTensorLayerNames, name)
|
|
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
|
|
progressFn := func(status string) {}
|
|
|
|
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn); err != nil {
|
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
|
}
|
|
|
|
if packedHeader == nil {
|
|
t.Fatal("expected packed quantized header for linear.weight")
|
|
}
|
|
if _, ok := packedHeader["linear.weight"]; !ok {
|
|
t.Fatalf("packed header missing linear.weight: %v", packedHeader)
|
|
}
|
|
if _, ok := packedHeader["linear.weight.scale"]; !ok {
|
|
t.Fatalf("packed header missing linear.weight.scale: %v", packedHeader)
|
|
}
|
|
if _, ok := packedHeader["linear.weight.bias"]; !ok {
|
|
t.Fatalf("packed header missing linear.weight.bias: %v", packedHeader)
|
|
}
|
|
|
|
var metadata map[string]string
|
|
if metaRaw, ok := packedHeader["__metadata__"]; ok {
|
|
if err := json.Unmarshal(metaRaw, &metadata); err != nil {
|
|
t.Fatalf("failed to parse packed metadata: %v", err)
|
|
}
|
|
}
|
|
if metadata["quant_type"] != "int4" {
|
|
t.Fatalf("quant_type = %q, want %q", metadata["quant_type"], "int4")
|
|
}
|
|
if metadata["group_size"] != "64" {
|
|
t.Fatalf("group_size = %q, want %q", metadata["group_size"], "64")
|
|
}
|
|
|
|
if slices.Contains(createTensorLayerNames, "linear.weight") {
|
|
t.Fatalf("linear.weight unexpectedly handled by createTensorLayer: %v", createTensorLayerNames)
|
|
}
|
|
if slices.Contains(createTensorLayerNames, "linear.scales") || slices.Contains(createTensorLayerNames, "linear.biases") {
|
|
t.Fatalf("quantized companions unexpectedly handled separately: %v", createTensorLayerNames)
|
|
}
|
|
if !slices.Contains(createTensorLayerNames, "plain.weight") {
|
|
t.Fatalf("plain.weight missing from createTensorLayer calls: %v", createTensorLayerNames)
|
|
}
|
|
if slices.Contains(tensorLayerNames, "linear.scales") || slices.Contains(tensorLayerNames, "linear.biases") {
|
|
t.Fatalf("quantized companions unexpectedly emitted as layers: %v", tensorLayerNames)
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
configJSON := `{
|
|
"model_type": "test",
|
|
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
|
"text_config": {"dtype": "bfloat16"}
|
|
}`
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
|
t.Fatalf("failed to write config.json: %v", err)
|
|
}
|
|
|
|
gateUpValues := make([]float32, 2*128*64)
|
|
for expert := range 2 {
|
|
base := expert * 128 * 64
|
|
for i := range 64 * 64 {
|
|
gateUpValues[base+i] = 1
|
|
gateUpValues[base+64*64+i] = 2
|
|
}
|
|
}
|
|
|
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
|
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.input_layernorm.weight", "F32", []int32{64}, make([]byte, 64*4)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.A_log", "F32", []int32{32}, make([]byte, 32*4)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.conv1d.weight", "BF16", []int32{64, 1, 4}, make([]byte, 64*1*4*2)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert.down_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
|
st.NewTensorDataFromBytes("model.visual.blocks.0.attn.proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
|
st.NewTensorDataFromBytes("mtp.layers.0.foo.weight", "F32", []int32{64, 64}, make([]byte, 64*64*4)),
|
|
})
|
|
|
|
type tensorCall struct {
|
|
dtype string
|
|
shape []int32
|
|
quantize string
|
|
raw []byte
|
|
}
|
|
calls := make(map[string]tensorCall)
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
_, _ = io.ReadAll(r)
|
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
data, err := io.ReadAll(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
headerDType, headerShape := readSingleTensorHeader(t, data)
|
|
calls[name] = tensorCall{
|
|
dtype: headerDType,
|
|
shape: headerShape,
|
|
quantize: quantize,
|
|
raw: readSingleTensorRaw(t, data),
|
|
}
|
|
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
|
|
if err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
|
}
|
|
|
|
if _, ok := calls["mtp.layers.0.foo.weight"]; ok {
|
|
t.Fatal("mtp tensor should have been dropped")
|
|
}
|
|
|
|
layerNorm := calls["language_model.model.layers.0.input_layernorm.weight"]
|
|
if layerNorm.dtype != "BF16" {
|
|
t.Fatalf("input_layernorm dtype = %q, want %q", layerNorm.dtype, "BF16")
|
|
}
|
|
if layerNorm.quantize != "" {
|
|
t.Fatalf("input_layernorm quantize = %q, want empty", layerNorm.quantize)
|
|
}
|
|
layerNormValues := bfloat16.DecodeFloat32(layerNorm.raw)
|
|
if len(layerNormValues) == 0 || layerNormValues[0] != 1.0 {
|
|
t.Fatalf("input_layernorm first value = %v, want 1.0 after +1 shift", layerNormValues[0])
|
|
}
|
|
|
|
alog := calls["language_model.model.layers.0.linear_attn.A_log"]
|
|
if alog.dtype != "F32" {
|
|
t.Fatalf("A_log dtype = %q, want %q", alog.dtype, "F32")
|
|
}
|
|
|
|
conv := calls["language_model.model.layers.0.linear_attn.conv1d.weight"]
|
|
if !slices.Equal(conv.shape, []int32{64, 4, 1}) {
|
|
t.Fatalf("conv1d shape = %v, want %v", conv.shape, []int32{64, 4, 1})
|
|
}
|
|
|
|
if got := calls["language_model.model.embed_tokens.weight"].quantize; got != "int4" {
|
|
t.Fatalf("embed_tokens quantize = %q, want %q", got, "int4")
|
|
}
|
|
if got := calls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "int4" {
|
|
t.Fatalf("mlp.gate quantize = %q, want %q", got, "int4")
|
|
}
|
|
if got := calls["language_model.model.layers.0.mlp.shared_expert.down_proj.weight"].quantize; got != "int4" {
|
|
t.Fatalf("down_proj quantize = %q, want %q", got, "int4")
|
|
}
|
|
|
|
if _, ok := calls["language_model.model.layers.0.mlp.experts.gate_up_proj"]; ok {
|
|
t.Fatal("combined gate_up_proj tensor should have been rewritten")
|
|
}
|
|
if _, ok := calls["language_model.model.layers.0.mlp.experts.down_proj"]; ok {
|
|
t.Fatal("combined down_proj tensor should have been rewritten")
|
|
}
|
|
|
|
gateProj := calls["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"]
|
|
if !slices.Equal(gateProj.shape, []int32{2, 64, 64}) {
|
|
t.Fatalf("gate_proj shape = %v, want %v", gateProj.shape, []int32{2, 64, 64})
|
|
}
|
|
gateProjValues := bfloat16.DecodeFloat32(gateProj.raw)
|
|
if len(gateProjValues) == 0 || gateProjValues[0] != 1.0 {
|
|
t.Fatalf("gate_proj first value = %v, want 1.0", gateProjValues[0])
|
|
}
|
|
|
|
upProj := calls["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"]
|
|
if !slices.Equal(upProj.shape, []int32{2, 64, 64}) {
|
|
t.Fatalf("up_proj shape = %v, want %v", upProj.shape, []int32{2, 64, 64})
|
|
}
|
|
upProjValues := bfloat16.DecodeFloat32(upProj.raw)
|
|
if len(upProjValues) == 0 || upProjValues[0] != 2.0 {
|
|
t.Fatalf("up_proj first value = %v, want 2.0", upProjValues[0])
|
|
}
|
|
|
|
if got := calls["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"].quantize; got != "int4" {
|
|
t.Fatalf("switch_mlp down_proj quantize = %q, want %q", got, "int4")
|
|
}
|
|
|
|
vision := calls["vision_tower.blocks.0.attn.proj.weight"]
|
|
if vision.dtype != "BF16" {
|
|
t.Fatalf("vision weight dtype = %q, want %q", vision.dtype, "BF16")
|
|
}
|
|
if vision.quantize != "" {
|
|
t.Fatalf("vision weight quantize = %q, want empty", vision.quantize)
|
|
}
|
|
if _, ok := calls["language_model.model.visual.blocks.0.attn.proj.weight"]; ok {
|
|
t.Fatal("vision tensor should have been rewritten to vision_tower.*")
|
|
}
|
|
}
|
|
|
|
func TestResolveManifestPath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
modelName string
|
|
wantParts []string // Parts that should appear in the path
|
|
}{
|
|
{
|
|
name: "simple model name",
|
|
modelName: "llama2",
|
|
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
|
},
|
|
{
|
|
name: "model name with tag",
|
|
modelName: "llama2:7b",
|
|
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
|
},
|
|
{
|
|
name: "model name with namespace",
|
|
modelName: "myuser/mymodel",
|
|
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
|
},
|
|
{
|
|
name: "model name with namespace and tag",
|
|
modelName: "myuser/mymodel:v1",
|
|
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
|
},
|
|
{
|
|
name: "fully qualified model name",
|
|
modelName: "registry.example.com/namespace/model:tag",
|
|
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := resolveManifestPath(tt.modelName)
|
|
|
|
for _, part := range tt.wantParts {
|
|
if !strings.Contains(got, part) {
|
|
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLayerInfo(t *testing.T) {
|
|
layer := LayerInfo{
|
|
Digest: "sha256:abc123",
|
|
Size: 1024,
|
|
MediaType: "application/vnd.ollama.image.tensor",
|
|
Name: "model.weight",
|
|
}
|
|
|
|
if layer.Digest != "sha256:abc123" {
|
|
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
|
}
|
|
if layer.Size != 1024 {
|
|
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
|
}
|
|
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
|
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
|
}
|
|
if layer.Name != "model.weight" {
|
|
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
|
}
|
|
}
|
|
|
|
func TestModelConfig(t *testing.T) {
|
|
config := ModelConfig{
|
|
ModelFormat: "safetensors",
|
|
Capabilities: []string{"completion", "chat"},
|
|
}
|
|
|
|
if config.ModelFormat != "safetensors" {
|
|
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
|
}
|
|
if len(config.Capabilities) != 2 {
|
|
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
|
}
|
|
}
|
|
|
|
func TestManifest(t *testing.T) {
|
|
manifest := Manifest{
|
|
SchemaVersion: 2,
|
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
|
Config: ManifestLayer{
|
|
MediaType: "application/vnd.docker.container.image.v1+json",
|
|
Digest: "sha256:config",
|
|
Size: 100,
|
|
},
|
|
Layers: []ManifestLayer{
|
|
{
|
|
MediaType: "application/vnd.ollama.image.tensor",
|
|
Digest: "sha256:layer1",
|
|
Size: 1000,
|
|
Name: "weight.bin",
|
|
},
|
|
},
|
|
}
|
|
|
|
if manifest.SchemaVersion != 2 {
|
|
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
|
}
|
|
if manifest.Config.Digest != "sha256:config" {
|
|
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
|
}
|
|
if len(manifest.Layers) != 1 {
|
|
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
|
}
|
|
if manifest.Layers[0].Name != "weight.bin" {
|
|
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
|
}
|
|
}
|
|
|
|
func TestShouldQuantize(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tensor string
|
|
component string
|
|
want bool
|
|
}{
|
|
// VAE component should never be quantized
|
|
{"vae weight", "decoder.weight", "vae", false},
|
|
{"vae bias", "decoder.bias", "vae", false},
|
|
|
|
// Embeddings should not be quantized
|
|
{"embedding weight", "embed_tokens.weight", "", false},
|
|
{"embedding in name", "token_embedding.weight", "", false},
|
|
|
|
// Norms should not be quantized
|
|
{"layer norm", "layer_norm.weight", "", false},
|
|
{"rms norm", "rms_norm.weight", "", false},
|
|
{"ln prefix", "ln_1.weight", "", false},
|
|
{"layernorm in name", "input_layernorm.weight", "", false},
|
|
|
|
// Biases should not be quantized
|
|
{"bias tensor", "attention.bias", "", false},
|
|
{"proj bias", "o_proj.bias", "", false},
|
|
|
|
// Linear weights should be quantized
|
|
{"linear weight", "q_proj.weight", "", true},
|
|
{"attention weight", "self_attn.weight", "", true},
|
|
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
|
|
|
// Transformer component weights should be quantized
|
|
{"transformer weight", "layers.0.weight", "transformer", true},
|
|
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ShouldQuantize(tt.tensor, tt.component)
|
|
if got != tt.want {
|
|
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestShouldQuantizeTensor(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tensor string
|
|
shape []int32
|
|
quantize string
|
|
want bool
|
|
}{
|
|
// 2D tensors with sufficient size should be quantized
|
|
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
|
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
|
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
|
|
|
// Small tensors should not be quantized (< 1024 elements)
|
|
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
|
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
|
|
|
|
// 1D tensors should not be quantized
|
|
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
|
|
|
|
// 3D+ tensors should not be quantized
|
|
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
|
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
|
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
|
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
|
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
|
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
|
|
|
// Embeddings should not be quantized regardless of shape
|
|
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
|
|
|
// Norms should not be quantized regardless of shape
|
|
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
|
|
|
|
// Biases should not be quantized
|
|
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
|
|
|
// Group size divisibility tests
|
|
// FP8/FP4 require divisible by 32
|
|
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
|
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
|
// NVFP4 requires divisible by 16
|
|
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
|
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
|
|
if got != tt.want {
|
|
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExpertGroupPrefix(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
want string
|
|
}{
|
|
// Expert tensors should return the group prefix
|
|
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
|
|
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
|
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
|
|
|
// Shared expert tensors should return their own group prefix
|
|
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
|
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
|
|
|
// Non-expert tensors should return empty string
|
|
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
|
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
|
{"model.embed_tokens.weight", ""}, // embedding
|
|
{"model.layers.0.self_attn.q_proj.weight", ""}, // attention
|
|
{"model.norm.weight", ""}, // norm
|
|
{"lm_head.weight", ""}, // output head
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ExpertGroupPrefix(tt.name)
|
|
if got != tt.want {
|
|
t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
|
gateUp := GetTensorQuantization(
|
|
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
|
[]int32{64, 22016, 4096},
|
|
"int4",
|
|
)
|
|
if gateUp != "int4" {
|
|
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
|
}
|
|
|
|
down := GetTensorQuantization(
|
|
"model.layers.1.mlp.experts.down_proj.weight",
|
|
[]int32{64, 4096, 14336},
|
|
"int4",
|
|
)
|
|
if down != "int8" {
|
|
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
|
}
|
|
|
|
combinedGateUp := GetTensorQuantization(
|
|
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
|
[]int32{256, 1024, 2048},
|
|
"int8",
|
|
)
|
|
if combinedGateUp != "int8" {
|
|
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
|
}
|
|
|
|
combinedDown := GetTensorQuantization(
|
|
"model.language_model.layers.0.mlp.experts.down_proj",
|
|
[]int32{256, 2048, 512},
|
|
"int4",
|
|
)
|
|
if combinedDown != "int8" {
|
|
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
|
}
|
|
}
|
|
|
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create config.json
|
|
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
|
t.Fatalf("failed to write config.json: %v", err)
|
|
}
|
|
|
|
// Create a minimal safetensors file
|
|
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
|
|
|
var quantizeRequested []string
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
quantizeRequested = append(quantizeRequested, quantize)
|
|
return []LayerInfo{{Name: name}}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
|
|
progressFn := func(status string) {}
|
|
|
|
// Run with quantize enabled
|
|
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err != nil {
|
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
|
}
|
|
|
|
// Verify quantize was passed to callback (will be false for small test tensor)
|
|
if len(quantizeRequested) == 0 {
|
|
t.Error("no tensors processed")
|
|
}
|
|
}
|
|
|
|
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
|
func createMinimalImageGenModel(t *testing.T, dir string) {
|
|
t.Helper()
|
|
|
|
// Create model_index.json
|
|
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
|
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
|
t.Fatalf("failed to write model_index.json: %v", err)
|
|
}
|
|
|
|
// Create transformer directory with a safetensors file
|
|
transformerDir := filepath.Join(dir, "transformer")
|
|
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
|
t.Fatalf("failed to create transformer dir: %v", err)
|
|
}
|
|
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
|
|
|
// Create transformer config
|
|
transformerConfig := `{"hidden_size": 3072}`
|
|
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
|
t.Fatalf("failed to write transformer config: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCreateImageGenModel(t *testing.T) {
|
|
dir := t.TempDir()
|
|
createMinimalImageGenModel(t, dir)
|
|
|
|
var manifestWritten bool
|
|
var manifestModelName string
|
|
var statusMessages []string
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
manifestWritten = true
|
|
manifestModelName = modelName
|
|
return nil
|
|
}
|
|
|
|
progressFn := func(status string) {
|
|
statusMessages = append(statusMessages, status)
|
|
}
|
|
|
|
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err != nil {
|
|
t.Fatalf("CreateImageGenModel failed: %v", err)
|
|
}
|
|
|
|
if !manifestWritten {
|
|
t.Error("manifest was not written")
|
|
}
|
|
|
|
if manifestModelName != "test-imagegen" {
|
|
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
|
}
|
|
|
|
if len(statusMessages) == 0 {
|
|
t.Error("no status messages received")
|
|
}
|
|
}
|
|
|
|
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create only transformer without model_index.json
|
|
transformerDir := filepath.Join(dir, "transformer")
|
|
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
|
t.Fatalf("failed to create transformer dir: %v", err)
|
|
}
|
|
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return LayerInfo{Name: name}, nil
|
|
}
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return []LayerInfo{{Name: name}}, nil
|
|
}
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
progressFn := func(status string) {}
|
|
|
|
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err == nil {
|
|
t.Error("expected error for missing model_index.json, got nil")
|
|
}
|
|
}
|
|
|
|
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
|
dir := t.TempDir()
|
|
createMinimalImageGenModel(t, dir)
|
|
|
|
var quantizeRequested []string
|
|
|
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
|
}
|
|
|
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
|
io.ReadAll(r)
|
|
quantizeRequested = append(quantizeRequested, quantize)
|
|
return []LayerInfo{{Name: name}}, nil
|
|
}
|
|
|
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
|
return nil
|
|
}
|
|
|
|
progressFn := func(status string) {}
|
|
|
|
err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
|
|
if err != nil {
|
|
t.Fatalf("CreateImageGenModel failed: %v", err)
|
|
}
|
|
|
|
if len(quantizeRequested) == 0 {
|
|
t.Error("no tensors processed")
|
|
}
|
|
}
|