mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
* create: Clean up experimental paths
This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions.
* create: preserve config and layer names when creating from safetensors models
When creating a model FROM an existing safetensors model, ModelFormat,
Capabilities, and layer Name fields were lost. ModelFormat stayed empty
because it's only set from GGML layers (which safetensors models lack),
and layer names weren't copied in parseFromModel. This caused derived
models to fail loading ("config.json not found in manifest").
* review comments
231 lines
6.8 KiB
Go
231 lines
6.8 KiB
Go
package create
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/safetensors"
|
|
)
|
|
|
|
// CreateImageGenModel imports an image generation model from a directory.
|
|
// Stores each tensor as a separate blob for fine-grained deduplication.
|
|
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
|
// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
|
|
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
|
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
|
// Validate quantization type
|
|
switch quantize {
|
|
case "", "int4", "int8", "nvfp4", "mxfp8":
|
|
// valid
|
|
default:
|
|
return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
|
|
}
|
|
|
|
var layers []LayerInfo
|
|
var configLayer LayerInfo
|
|
var totalParams int64 // Count parameters from original tensor shapes
|
|
var torchDtype string // Read from component config for quantization display
|
|
|
|
// Components to process - extract individual tensors from each
|
|
components := []string{"text_encoder", "transformer", "vae"}
|
|
|
|
for _, component := range components {
|
|
componentDir := filepath.Join(modelDir, component)
|
|
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
|
|
// Find all safetensors files in this component
|
|
entries, err := os.ReadDir(componentDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read %s: %w", component, err)
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if !strings.HasSuffix(entry.Name(), ".safetensors") {
|
|
continue
|
|
}
|
|
|
|
stPath := filepath.Join(componentDir, entry.Name())
|
|
|
|
// Extract individual tensors from safetensors file
|
|
extractor, err := safetensors.OpenForExtraction(stPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
|
}
|
|
|
|
tensorNames := extractor.ListTensors()
|
|
quantizeMsg := ""
|
|
if quantize != "" && component != "vae" {
|
|
quantizeMsg = ", quantizing to " + quantize
|
|
}
|
|
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
|
|
|
for _, tensorName := range tensorNames {
|
|
td, err := extractor.GetTensor(tensorName)
|
|
if err != nil {
|
|
extractor.Close()
|
|
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
|
}
|
|
|
|
// Count parameters from original tensor shape
|
|
if len(td.Shape) > 0 {
|
|
numElements := int64(1)
|
|
for _, dim := range td.Shape {
|
|
numElements *= int64(dim)
|
|
}
|
|
totalParams += numElements
|
|
}
|
|
|
|
// Store as minimal safetensors format (88 bytes header overhead)
|
|
// This enables native mmap loading via mlx_load_safetensors
|
|
// Use path-style name: "component/tensor_name"
|
|
fullName := component + "/" + tensorName
|
|
|
|
// Determine quantization type for this tensor (empty string if not quantizing)
|
|
quantizeType := ""
|
|
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
|
|
quantizeType = quantize
|
|
}
|
|
|
|
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
|
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
|
if err != nil {
|
|
extractor.Close()
|
|
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
|
}
|
|
layers = append(layers, newLayers...)
|
|
}
|
|
|
|
extractor.Close()
|
|
}
|
|
}
|
|
|
|
// Read torch_dtype from text_encoder config for quantization display
|
|
if torchDtype == "" {
|
|
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
|
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
|
var cfg struct {
|
|
TorchDtype string `json:"torch_dtype"`
|
|
}
|
|
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
|
torchDtype = cfg.TorchDtype
|
|
}
|
|
}
|
|
}
|
|
|
|
// Import config files
|
|
configFiles := []string{
|
|
"model_index.json",
|
|
"text_encoder/config.json",
|
|
"text_encoder/generation_config.json",
|
|
"transformer/config.json",
|
|
"vae/config.json",
|
|
"scheduler/scheduler_config.json",
|
|
"tokenizer/tokenizer.json",
|
|
"tokenizer/tokenizer_config.json",
|
|
"tokenizer/vocab.json",
|
|
}
|
|
|
|
for _, cfgPath := range configFiles {
|
|
fullPath := filepath.Join(modelDir, cfgPath)
|
|
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
|
|
fn(fmt.Sprintf("importing config %s", cfgPath))
|
|
|
|
var r io.Reader
|
|
|
|
// For model_index.json, normalize to Ollama format and add metadata
|
|
if cfgPath == "model_index.json" {
|
|
data, err := os.ReadFile(fullPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
|
|
}
|
|
|
|
var cfg map[string]any
|
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
|
|
}
|
|
|
|
// Rename _class_name to architecture, remove diffusers-specific fields
|
|
if className, ok := cfg["_class_name"]; ok {
|
|
cfg["architecture"] = className
|
|
delete(cfg, "_class_name")
|
|
}
|
|
delete(cfg, "_diffusers_version")
|
|
|
|
// Add parameter count (counted from tensor shapes during import)
|
|
cfg["parameter_count"] = totalParams
|
|
|
|
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
|
if quantize != "" {
|
|
cfg["quantization"] = strings.ToUpper(quantize)
|
|
} else {
|
|
cfg["quantization"] = torchDtype
|
|
}
|
|
|
|
data, err = json.MarshalIndent(cfg, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
|
}
|
|
r = bytes.NewReader(data)
|
|
} else {
|
|
f, err := os.Open(fullPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
|
}
|
|
defer f.Close()
|
|
r = f
|
|
}
|
|
|
|
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
|
}
|
|
|
|
// Use model_index.json as the config layer
|
|
if cfgPath == "model_index.json" {
|
|
configLayer = layer
|
|
}
|
|
|
|
layers = append(layers, layer)
|
|
}
|
|
|
|
if configLayer.Digest == "" {
|
|
return fmt.Errorf("model_index.json not found in %s", modelDir)
|
|
}
|
|
|
|
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
|
|
|
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
|
return fmt.Errorf("failed to write manifest: %w", err)
|
|
}
|
|
|
|
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
|
return nil
|
|
}
|
|
|
|
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
|
// MLX requires the last dimension to be divisible by the group size.
|
|
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
|
func canQuantizeShape(shape []int32, quantize string) bool {
|
|
if len(shape) < 2 {
|
|
return false
|
|
}
|
|
groupSize := int32(32)
|
|
switch strings.ToUpper(quantize) {
|
|
case "NVFP4":
|
|
groupSize = 16
|
|
case "INT8":
|
|
groupSize = 64
|
|
}
|
|
return shape[len(shape)-1]%groupSize == 0
|
|
}
|