package create import ( "encoding/json" "fmt" "io" "os" "path/filepath" "slices" "strings" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/x/imagegen/safetensors" ) // ModelConfig represents the config blob stored with a model. type ModelConfig struct { ModelFormat string `json:"model_format"` Capabilities []string `json:"capabilities"` } // Manifest represents the manifest JSON structure. type Manifest struct { SchemaVersion int `json:"schemaVersion"` MediaType string `json:"mediaType"` Config ManifestLayer `json:"config"` Layers []ManifestLayer `json:"layers"` } // ManifestLayer represents a layer in the manifest. type ManifestLayer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int64 `json:"size"` Name string `json:"name,omitempty"` } // defaultManifestDir returns the manifest storage directory. func defaultManifestDir() string { return filepath.Join(envconfig.Models(), "manifests") } // defaultBlobDir returns the blob storage directory. func defaultBlobDir() string { return filepath.Join(envconfig.Models(), "blobs") } // resolveManifestPath converts a model name to a manifest file path. func resolveManifestPath(modelName string) string { host := "registry.ollama.ai" namespace := "library" name := modelName tag := "latest" if idx := strings.LastIndex(name, ":"); idx != -1 { tag = name[idx+1:] name = name[:idx] } parts := strings.Split(name, "/") switch len(parts) { case 3: host = parts[0] namespace = parts[1] name = parts[2] case 2: namespace = parts[0] name = parts[1] } return filepath.Join(defaultManifestDir(), host, namespace, name, tag) } // loadManifest loads a manifest for the given model name. func loadManifest(modelName string) (*Manifest, error) { manifestPath := resolveManifestPath(modelName) data, err := os.ReadFile(manifestPath) if err != nil { return nil, err } var manifest Manifest if err := json.Unmarshal(data, &manifest); err != nil { return nil, err } return &manifest, nil } // loadModelConfig loads the config blob for a model. func loadModelConfig(modelName string) (*ModelConfig, error) { manifest, err := loadManifest(modelName) if err != nil { return nil, err } // Read the config blob blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1) blobPath := filepath.Join(defaultBlobDir(), blobName) data, err := os.ReadFile(blobPath) if err != nil { return nil, err } var config ModelConfig if err := json.Unmarshal(data, &config); err != nil { return nil, err } return &config, nil } // IsSafetensorsModel checks if a model was created with the experimental // safetensors builder by checking the model format in the config. func IsSafetensorsModel(modelName string) bool { config, err := loadModelConfig(modelName) if err != nil { return false } return config.ModelFormat == "safetensors" } // IsSafetensorsLLMModel checks if a model is a safetensors LLM model // (has completion capability, not image generation). func IsSafetensorsLLMModel(modelName string) bool { config, err := loadModelConfig(modelName) if err != nil { return false } return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion") } // IsImageGenModel checks if a model is an image generation model // (has image capability). func IsImageGenModel(modelName string) bool { config, err := loadModelConfig(modelName) if err != nil { return false } return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image") } // GetModelArchitecture returns the architecture from the model's config.json layer. func GetModelArchitecture(modelName string) (string, error) { manifest, err := loadManifest(modelName) if err != nil { return "", err } // Find the config.json layer for _, layer := range manifest.Layers { if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" { blobName := strings.Replace(layer.Digest, ":", "-", 1) blobPath := filepath.Join(defaultBlobDir(), blobName) data, err := os.ReadFile(blobPath) if err != nil { return "", err } var cfg struct { Architectures []string `json:"architectures"` ModelType string `json:"model_type"` } if err := json.Unmarshal(data, &cfg); err != nil { return "", err } // Prefer model_type, fall back to first architecture if cfg.ModelType != "" { return cfg.ModelType, nil } if len(cfg.Architectures) > 0 { return cfg.Architectures[0], nil } } } return "", fmt.Errorf("architecture not found in model config") } // IsTensorModelDir checks if the directory contains a diffusers-style tensor model // by looking for model_index.json, which is the standard diffusers pipeline config. func IsTensorModelDir(dir string) bool { _, err := os.Stat(filepath.Join(dir, "model_index.json")) return err == nil } // IsSafetensorsModelDir checks if the directory contains a standard safetensors model // by looking for config.json and at least one .safetensors file. func IsSafetensorsModelDir(dir string) bool { // Must have config.json if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil { return false } // Must have at least one .safetensors file entries, err := os.ReadDir(dir) if err != nil { return false } for _, entry := range entries { if strings.HasSuffix(entry.Name(), ".safetensors") { return true } } return false } // LayerInfo holds metadata for a created layer. type LayerInfo struct { Digest string Size int64 MediaType string Name string // Path-style name: "component/tensor" or "path/to/config.json" } // LayerCreator is called to create a blob layer. // name is the path-style name (e.g., "tokenizer/tokenizer.json") type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) // TensorLayerCreator creates a tensor blob layer with metadata. // name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight") type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) // QuantizingTensorLayerCreator creates tensor layers with optional quantization. // When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases). type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) // ManifestWriter writes the manifest file. type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error // ShouldQuantize returns true if a tensor should be quantized. // For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms. // For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors. func ShouldQuantize(name, component string) bool { // Image gen specific: skip VAE entirely if component == "vae" { return false } // Skip embeddings if strings.Contains(name, "embed") { return false } // Skip layer norms and RMS norms if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") { return false } // Skip biases if strings.HasSuffix(name, ".bias") { return false } // Only quantize weights return strings.HasSuffix(name, ".weight") } // ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type. // This is a more detailed check that also considers tensor dimensions. // The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8"). func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool { return GetTensorQuantization(name, shape, quantize) != "" } // normalizeQuantType converts various quantization type aliases to canonical forms. // Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8 func normalizeQuantType(quantize string) string { switch strings.ToUpper(quantize) { case "Q4", "INT4", "FP4": return "q4" case "Q8", "INT8", "FP8": return "q8" case "NVFP4": return "nvfp4" case "MXFP8": return "mxfp8" default: return quantize } } // getQuantGroupSize returns the group size for a given quantization type. // These must match the values used in quantize.go when creating quantized models. func getQuantGroupSize(quantize string) int { switch normalizeQuantType(quantize) { case "nvfp4": return 16 case "q4": return 32 case "mxfp8": return 32 case "q8": return 64 default: return 32 } } // GetTensorQuantization returns the appropriate quantization type for a tensor. // Returns "" if the tensor should not be quantized. // This implements mixed-precision quantization: // - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive) // - Output projection, gate/up weights: q4 (less sensitive) // - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel) // - Norms, embeddings, biases, routing gates: no quantization func GetTensorQuantization(name string, shape []int32, quantize string) string { // Use basic name-based check first if !ShouldQuantize(name, "") { return "" } // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) if len(shape) != 2 { return "" } // Skip small tensors (less than 1024 elements) - not worth quantizing if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { return "" } // Normalize quantization type to canonical form quantNorm := normalizeQuantType(quantize) // MLX quantization requires last dimension to be divisible by group size // nvfp4: 16, q4/mxfp8: 32, q8: 64 groupSize := int32(32) switch quantNorm { case "nvfp4": groupSize = 16 case "q8": groupSize = 64 } if shape[len(shape)-1]%groupSize != 0 { return "" } // Skip routing gate weights (should stay high precision) // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight) if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") { return "" } // For NVFP4 or MXFP8, use the same quantization for all (no mixed precision) if quantNorm == "nvfp4" || quantNorm == "mxfp8" { return quantNorm } // Attention MLA weights - keep unquantized (bf16) // These are highly sensitive: errors accumulate in the KV cache over time // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj if strings.Contains(name, "q_a_proj") || strings.Contains(name, "q_b_proj") || strings.Contains(name, "kv_a_proj") || strings.Contains(name, "kv_b_proj") { return "" // No quantization - keep bf16 } // Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel) // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj if strings.Contains(name, "down_proj") { return "q8" } // Output projection, gate/up weights - use requested quantization (Q4) // o_proj, gate_proj, up_proj if strings.Contains(name, "o_proj") || strings.Contains(name, "gate_proj") || strings.Contains(name, "up_proj") { return quantNorm } // LM head - use requested quantization if strings.Contains(name, "lm_head") { return quantNorm } // Default to requested quantization for other weights return quantNorm } // CreateSafetensorsModel imports a standard safetensors model from a directory. // This handles Hugging Face style models with config.json and *.safetensors files. // Stores each tensor as a separate blob for fine-grained deduplication. // If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized. func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { var layers []LayerInfo var configLayer LayerInfo entries, err := os.ReadDir(modelDir) if err != nil { return fmt.Errorf("failed to read directory: %w", err) } // Process all safetensors files for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") { continue } stPath := filepath.Join(modelDir, entry.Name()) // Extract individual tensors from safetensors file extractor, err := safetensors.OpenForExtraction(stPath) if err != nil { return fmt.Errorf("failed to open %s: %w", stPath, err) } tensorNames := extractor.ListTensors() quantizeMsg := "" if quantize != "" { quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize) } fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg)) for _, tensorName := range tensorNames { td, err := extractor.GetTensor(tensorName) if err != nil { extractor.Close() return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) } // Determine quantization type for this tensor (empty string if not quantizing) // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) quantizeType := "" if quantize != "" { quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize) } // Store as minimal safetensors format (88 bytes header overhead) // This enables native mmap loading via mlx_load_safetensors // createTensorLayer returns multiple layers if quantizing (weight + scales) newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType) if err != nil { extractor.Close() return fmt.Errorf("failed to create layer for %s: %w", tensorName, err) } layers = append(layers, newLayers...) } extractor.Close() } // Process all JSON config files for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { continue } // Skip the index file as we don't need it after extraction if entry.Name() == "model.safetensors.index.json" { continue } cfgPath := entry.Name() fullPath := filepath.Join(modelDir, cfgPath) fn(fmt.Sprintf("importing config %s", cfgPath)) f, err := os.Open(fullPath) if err != nil { return fmt.Errorf("failed to open %s: %w", cfgPath, err) } layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath) f.Close() if err != nil { return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err) } // Use config.json as the config layer if cfgPath == "config.json" { configLayer = layer } layers = append(layers, layer) } if configLayer.Digest == "" { return fmt.Errorf("config.json not found in %s", modelDir) } // Create model_index.json with quantization info if quantizing if quantize != "" { modelIndex := map[string]any{ "quantization": strings.ToUpper(quantize), "group_size": getQuantGroupSize(quantize), } indexData, err := json.MarshalIndent(modelIndex, "", " ") if err != nil { return fmt.Errorf("failed to marshal model_index.json: %w", err) } indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json") if err != nil { return fmt.Errorf("failed to create model_index.json layer: %w", err) } layers = append(layers, indexLayer) } fn(fmt.Sprintf("writing manifest for %s", modelName)) if err := writeManifest(modelName, configLayer, layers); err != nil { return fmt.Errorf("failed to write manifest: %w", err) } fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers))) return nil }