mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 01:05:47 +02:00
252 lines
7.3 KiB
Go
252 lines
7.3 KiB
Go
//go:build mlx
|
|
|
|
package manifest
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// ManifestWeights provides fast weight loading from tensor blobs.
|
|
// Uses native mmap loading with synthetic safetensors headers for zero-copy.
|
|
type ManifestWeights struct {
|
|
manifest *ModelManifest
|
|
component string
|
|
tensors map[string]ManifestLayer // name -> layer
|
|
cache map[string]*mlx.Array // name -> loaded array
|
|
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
|
}
|
|
|
|
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
|
// If component is empty, loads all tensors (for LLM models).
|
|
// If component is specified, loads only tensors for that component and strips the prefix.
|
|
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
|
|
layers := manifest.GetTensorLayers(component)
|
|
if len(layers) == 0 {
|
|
if component == "" {
|
|
return nil, fmt.Errorf("no tensor layers found in manifest")
|
|
}
|
|
return nil, fmt.Errorf("no tensor layers found for component %q", component)
|
|
}
|
|
|
|
// Strip component prefix from tensor names for model loading
|
|
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
|
|
tensors := make(map[string]ManifestLayer, len(layers))
|
|
for _, layer := range layers {
|
|
if component == "" {
|
|
tensors[layer.Name] = layer
|
|
} else {
|
|
tensorName := strings.TrimPrefix(layer.Name, component+"/")
|
|
tensors[tensorName] = layer
|
|
}
|
|
}
|
|
|
|
return &ManifestWeights{
|
|
manifest: manifest,
|
|
component: component,
|
|
tensors: tensors,
|
|
cache: make(map[string]*mlx.Array),
|
|
}, nil
|
|
}
|
|
|
|
// Load loads all tensor blobs using native mmap (zero-copy).
|
|
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
|
// If dtype is non-zero, tensors are converted to the specified dtype.
|
|
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
|
// Track native handles to free after batch eval
|
|
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
|
|
arrays := make([]*mlx.Array, 0, len(mw.tensors))
|
|
|
|
for name, layer := range mw.tensors {
|
|
path := mw.manifest.BlobPath(layer.Digest)
|
|
|
|
// Load blob as safetensors (native mmap, zero-copy)
|
|
sf, err := mlx.LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
// Free any handles we've accumulated
|
|
for _, h := range nativeHandles {
|
|
h.Free()
|
|
}
|
|
return fmt.Errorf("load %s: %w", name, err)
|
|
}
|
|
nativeHandles = append(nativeHandles, sf)
|
|
|
|
// Blob contains single tensor named "data"
|
|
arr := sf.Get("data")
|
|
if arr == nil {
|
|
for _, h := range nativeHandles {
|
|
h.Free()
|
|
}
|
|
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
|
}
|
|
|
|
// Convert dtype if needed
|
|
if dtype != 0 && arr.Dtype() != dtype {
|
|
arr = mlx.AsType(arr, dtype)
|
|
}
|
|
// Make contiguous copy to ensure independence from mmap
|
|
arr = mlx.Contiguous(arr)
|
|
mw.cache[name] = arr
|
|
arrays = append(arrays, arr)
|
|
}
|
|
|
|
// Batch evaluate all tensors at once (much faster than one at a time)
|
|
mlx.Eval(arrays...)
|
|
|
|
// Now safe to free all native handles
|
|
for _, sf := range nativeHandles {
|
|
sf.Free()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTensor returns a tensor from cache. Call Load() first.
|
|
func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
|
|
if mw.cache == nil {
|
|
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
|
}
|
|
arr, ok := mw.cache[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("tensor %q not found", name)
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
// ListTensors returns all tensor names in sorted order.
|
|
func (mw *ManifestWeights) ListTensors() []string {
|
|
names := make([]string, 0, len(mw.tensors))
|
|
for name := range mw.tensors {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
return names
|
|
}
|
|
|
|
// HasTensor checks if a tensor exists.
|
|
func (mw *ManifestWeights) HasTensor(name string) bool {
|
|
_, ok := mw.tensors[name]
|
|
return ok
|
|
}
|
|
|
|
// Quantization returns the model's quantization type from model_index.json.
|
|
// Returns empty string if not quantized.
|
|
// Falls back to detecting from tensor names and shapes if not in config.
|
|
func (mw *ManifestWeights) Quantization() string {
|
|
if mw.manifest == nil {
|
|
return ""
|
|
}
|
|
|
|
// Try to read from model_index.json first
|
|
var index struct {
|
|
Quantization string `json:"quantization"`
|
|
}
|
|
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
|
|
return index.Quantization
|
|
}
|
|
|
|
// Fallback: detect from tensor names
|
|
// Check if any tensors have _scale suffix (indicates quantization)
|
|
hasScales := false
|
|
hasQBias := false
|
|
for name := range mw.tensors {
|
|
if strings.HasSuffix(name, ".weight_scale") {
|
|
hasScales = true
|
|
}
|
|
if strings.HasSuffix(name, ".weight_qbias") {
|
|
hasQBias = true
|
|
}
|
|
}
|
|
|
|
if !hasScales {
|
|
// No scales = not quantized
|
|
return ""
|
|
}
|
|
|
|
// Has scales but no qbias = NVFP4 (or other non-affine mode)
|
|
if !hasQBias {
|
|
return "NVFP4"
|
|
}
|
|
|
|
// Has both scales and qbias = affine mode
|
|
// Need to determine FP4 vs FP8 from tensor shapes
|
|
// FP4: weight last dim is 1/8 of scales last dim * group_size
|
|
// FP8: weight last dim is 1/4 of scales last dim * group_size
|
|
//
|
|
// For affine mode with group_size=32:
|
|
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
|
|
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
|
|
// scales_dim = orig_dim / group_size
|
|
// So: weight_dim / scales_dim = group_size / pack_factor
|
|
// FP4: ratio = 32/8 = 4
|
|
// FP8: ratio = 32/4 = 8
|
|
|
|
// Find a weight/scale pair to check the ratio
|
|
for name := range mw.tensors {
|
|
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
|
|
continue
|
|
}
|
|
scaleName := name + "_scale"
|
|
if _, ok := mw.tensors[scaleName]; !ok {
|
|
continue
|
|
}
|
|
|
|
// Load both tensors to check shapes
|
|
weightLayer := mw.tensors[name]
|
|
scaleLayer := mw.tensors[scaleName]
|
|
|
|
// Get shapes from manifest layer metadata if available
|
|
// For now, default to FP4 since it's more common
|
|
// The actual shape check would require loading the tensor
|
|
|
|
// Simple heuristic: check if scale tensor is ~4x smaller than weight
|
|
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
|
|
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
|
|
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
|
|
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
|
|
|
|
// Rough size heuristic (assuming float16 scales)
|
|
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
|
|
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
|
|
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
|
|
if ratio < 12 {
|
|
// Closer to 8 = Q4
|
|
return "Q4"
|
|
}
|
|
// Closer to 16 = Q8
|
|
return "Q8"
|
|
}
|
|
|
|
// Default to Q4 for affine mode (most common)
|
|
return "Q4"
|
|
}
|
|
|
|
// GroupSize returns the quantization group size from model_index.json.
|
|
// Returns 0 if not specified (caller should use default based on quantization type).
|
|
func (mw *ManifestWeights) GroupSize() int {
|
|
if mw.manifest == nil {
|
|
return 0
|
|
}
|
|
|
|
var index struct {
|
|
GroupSize int `json:"group_size"`
|
|
}
|
|
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
|
|
return index.GroupSize
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// ReleaseAll frees all native handles and clears the tensor cache.
|
|
func (mw *ManifestWeights) ReleaseAll() {
|
|
for _, sf := range mw.nativeCache {
|
|
sf.Free()
|
|
}
|
|
mw.nativeCache = nil
|
|
mw.cache = nil
|
|
}
|