package client import ( "encoding/binary" "encoding/json" "fmt" "io" "os" "path/filepath" "regexp" "sort" "strconv" "strings" "github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" ) // loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX, // quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias) // to the provided maps. If quantize is empty, the tensor is kept as-is. // Returns any temp file paths created (caller must clean up) and arrays needing eval. func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) { if quantize != "" { if gs, _, _ := model.QuantizationParams(quantize); gs == 0 { return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize) } } tmpDir := ensureTempDir() tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors") if err != nil { return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err) } tmpPath = tmpFile.Name() if _, err := io.Copy(tmpFile, r); err != nil { tmpFile.Close() return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err) } tmpFile.Close() st, err := mlx.LoadSafetensorsNative(tmpPath) if err != nil { return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err) } // Find the tensor key (may differ from name for single-tensor blobs) header, err := readSafetensorsHeader(tmpPath) if err != nil { st.Free() return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err) } inputKey, err := safetensorsKey(name, header) if err != nil { st.Free() return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err) } arr := st.Get(inputKey) if arr == nil { st.Free() return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey) } // Decode FP8 source encoding before checking quantize, so that callers // requesting decode-only (quantize="") receive usable float data. if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" { scaleKey := inputKey + ".scale_inv" scaleInv := st.Get(scaleKey) if scaleInv == nil { st.Free() return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey) } arr, err = decodeSourceFP8Tensor(arr, scaleInv) if err != nil { st.Free() return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err) } mlx.Eval(arr) } if quantize == "" { arr = mlx.Contiguous(arr, false) arrays[name] = arr return tmpPath, []*mlx.Array{arr}, st, nil } if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 { // Convert to float type if needed (quantize expects float) arr = arr.AsType(mlx.DTypeBFloat16) mlx.Eval(arr) } groupSize, bits, mode := model.QuantizationParams(quantize) qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode) qweight = mlx.Contiguous(qweight, false) scales = mlx.Contiguous(scales, false) arrays[name] = qweight arrays[name+".scale"] = scales toEval = append(toEval, qweight, scales) if qbiases != nil { qbiases = mlx.Contiguous(qbiases, false) arrays[name+".bias"] = qbiases toEval = append(toEval, qbiases) } return tmpPath, toEval, st, nil } // quantizeTensor loads a tensor from safetensors format, quantizes it, // and returns a single combined safetensors blob with the quantized weight, scale, and optional bias. // Tensor keys use the original tensor name: name, name.scale, name.bias. // The blob includes __metadata__ with quant_type and group_size. // Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8". func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { arrays := make(map[string]*mlx.Array) tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays) if tmpPath != "" { defer os.Remove(tmpPath) } if err != nil { return nil, err } finalArrays := make([]*mlx.Array, 0, len(arrays)) for _, arr := range arrays { if arr != nil { finalArrays = append(finalArrays, arr) } } mlx.Pin(finalArrays...) defer func() { if st != nil { st.Free() } mlx.Unpin(finalArrays...) mlx.Sweep() }() mlx.Eval(toEval...) mlx.Sweep() // Free early to release mmap; defer guard handles error paths if st != nil { st.Free() st = nil } // Build metadata for single-tensor blobs groupSize, _, _ := model.QuantizationParams(quantize) metadata := map[string]string{ "quant_type": quantize, "group_size": strconv.Itoa(groupSize), } tmpDir := ensureTempDir() outPath := filepath.Join(tmpDir, "combined.safetensors") defer os.Remove(outPath) if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil { return nil, fmt.Errorf("failed to save combined blob: %w", err) } return os.ReadFile(outPath) } // quantizePackedGroup quantizes multiple tensors and saves them all into a single // combined safetensors blob. Used for packing expert groups. // When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight), // they are stacked into 3D switch_mlp tensors before quantization. // Each tensor may have a different quantization type (mixed-precision). // Returns the blob bytes. func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) { // Check if inputs are per-expert tensors that should be stacked into 3D if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil { return stackAndQuantizeExpertGroup(groupName, projGroups, quantize) } allArrays := make(map[string]*mlx.Array) var pinned []*mlx.Array var metadata map[string]string uniformQuantize := "" hasQuantized := false mixedQuantize := false for _, input := range inputs { if input.Quantize == "" { if hasQuantized { mixedQuantize = true } continue } if !hasQuantized { hasQuantized = true uniformQuantize = input.Quantize continue } if input.Quantize != uniformQuantize { mixedQuantize = true } } if hasQuantized && !mixedQuantize { if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 { metadata = map[string]string{ "quant_type": uniformQuantize, "group_size": strconv.Itoa(groupSize), } } } for _, input := range inputs { tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays) if err != nil { mlx.Unpin(pinned...) mlx.Sweep() return nil, err } mlx.Eval(toEval...) finalArrays := arraysForPackedInput(allArrays, input) mlx.Pin(finalArrays...) pinned = append(pinned, finalArrays...) if st != nil { st.Free() } if tmpPath != "" { os.Remove(tmpPath) } mlx.Sweep() } defer func() { mlx.Unpin(pinned...) mlx.Sweep() }() // Save combined blob. Add global metadata only when every packed tensor uses // the same quantization mode and group size. tmpDir := ensureTempDir() outPath := filepath.Join(tmpDir, "packed-combined.safetensors") defer os.Remove(outPath) if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil { return nil, fmt.Errorf("failed to save packed blob: %w", err) } blobData, err := os.ReadFile(outPath) if err != nil { return nil, fmt.Errorf("failed to read packed blob: %w", err) } return blobData, nil } func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array { keys := []string{input.Name} if input.Quantize != "" { keys = append(keys, input.Name+".scale", input.Name+".bias") } out := make([]*mlx.Array, 0, len(keys)) for _, key := range keys { if arr := allArrays[key]; arr != nil { out = append(out, arr) } } return out } // perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix. var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`) type expertTensorInfo struct { index int proj string // e.g., "gate_proj.weight" input create.PackedTensorInput } // parsePerExpertInputs groups per-expert 2D tensor inputs by projection type // and returns the uniform quantization type shared by all inputs. // Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D) // or if the inputs have mixed quantization types. // Only handles ".experts" groups; ".shared_experts" groups are left unpacked. func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) { if !strings.HasSuffix(groupName, ".experts") { return nil, "" } quantize := inputs[0].Quantize groups := make(map[string][]expertTensorInfo) for _, input := range inputs { if input.Quantize != quantize { return nil, "" // mixed quantization types } suffix := strings.TrimPrefix(input.Name, groupName) m := perExpertSuffix.FindStringSubmatch(suffix) if m == nil { return nil, "" // not a per-expert pattern } index, err := strconv.Atoi(m[1]) if err != nil { return nil, "" } groups[m[2]] = append(groups[m[2]], expertTensorInfo{ index: index, proj: m[2], input: input, }) } if len(groups) == 0 { return nil, "" } return groups, quantize } // stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D // switch_mlp tensors, quantizes, and returns the combined safetensors blob. func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) { groupBase := strings.TrimSuffix(groupName, ".experts") allArrays := make(map[string]*mlx.Array) var pinned []*mlx.Array var metadata map[string]string if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" { metadata = map[string]string{ "quant_type": quantize, "group_size": strconv.Itoa(groupSize), } } // Sort projection names for deterministic output projNames := make([]string, 0, len(projGroups)) for proj := range projGroups { projNames = append(projNames, proj) } sort.Strings(projNames) cleanup := func() { mlx.Unpin(pinned...) mlx.Sweep() } for _, proj := range projNames { experts := projGroups[proj] // Sort by expert index sort.Slice(experts, func(i, j int) bool { return experts[i].index < experts[j].index }) // Load and decode each expert tensor var decoded []*mlx.Array for _, expert := range experts { dummyArrays := make(map[string]*mlx.Array) tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays) if err != nil { cleanup() return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err) } mlx.Eval(toEval...) arr := dummyArrays[expert.input.Name] mlx.Pin(arr) pinned = append(pinned, arr) decoded = append(decoded, arr) if st != nil { st.Free() } if tmpPath != "" { os.Remove(tmpPath) } mlx.Sweep() } // Stack into 3D along axis 0: [numExperts, rows, cols] stacked := mlx.Stack(decoded, 0) mlx.Eval(stacked) mlx.Pin(stacked) pinned = append(pinned, stacked) // Free individual decoded arrays mlx.Unpin(decoded...) mlx.Sweep() stackedName := groupBase + ".switch_mlp." + proj // Quantize the stacked tensor if quantize != "" { groupSize, bits, mode := model.QuantizationParams(quantize) qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode) qweight = mlx.Contiguous(qweight, false) scales = mlx.Contiguous(scales, false) allArrays[stackedName] = qweight allArrays[stackedName+".scale"] = scales toEval := []*mlx.Array{qweight, scales} if qbiases != nil { qbiases = mlx.Contiguous(qbiases, false) allArrays[stackedName+".bias"] = qbiases toEval = append(toEval, qbiases) } mlx.Eval(toEval...) mlx.Pin(toEval...) pinned = append(pinned, toEval...) // Free stacked source array mlx.Unpin(stacked) mlx.Sweep() } else { stacked = mlx.Contiguous(stacked, false) mlx.Eval(stacked) allArrays[stackedName] = stacked } } defer cleanup() tmpDir := ensureTempDir() outPath := filepath.Join(tmpDir, "stacked-combined.safetensors") defer os.Remove(outPath) if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil { return nil, fmt.Errorf("failed to save stacked blob: %w", err) } blobData, err := os.ReadFile(outPath) if err != nil { return nil, fmt.Errorf("failed to read stacked blob: %w", err) } return blobData, nil } // QuantizeSupported returns true if quantization is supported (MLX library available) func QuantizeSupported() bool { return mlx.CheckInit() == nil } // ensureTempDir creates the temp directory for quantization if it doesn't exist func ensureTempDir() string { tmpDir := filepath.Join(os.TempDir(), "ollama-quantize") os.MkdirAll(tmpDir, 0755) return tmpDir } type safetensorsHeaderEntry struct { Dtype string `json:"dtype"` Shape []int32 `json:"shape"` } func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) { f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() var headerSize uint64 if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { return nil, err } headerBytes := make([]byte, headerSize) if _, err := io.ReadFull(f, headerBytes); err != nil { return nil, err } var header map[string]safetensorsHeaderEntry if err := json.Unmarshal(headerBytes, &header); err != nil { return nil, err } return header, nil } // safetensorsKey resolves the primary tensor key from a header. func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) { if preferred != "" { if _, ok := header[preferred]; ok { return preferred, nil } } keys := make([]string, 0, len(header)) for k := range header { if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") { continue } keys = append(keys, k) } sort.Strings(keys) if len(keys) == 0 { return "", fmt.Errorf("no tensor found in safetensors header") } return keys[0], nil } func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) { if weight == nil || scaleInv == nil { return nil, fmt.Errorf("fp8 weight and scale tensors are required") } weightShape := weight.Dims() scaleShape := scaleInv.Dims() if len(weightShape) != 2 || len(scaleShape) != 2 { return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape) } // These must match the block size validated by resolveEffectiveQuantization // in create.go, which rejects any source model with a different block size. const blockRows = 128 const blockCols = 128 rows, cols := weightShape[0], weightShape[1] expectedScaleRows := (rows + blockRows - 1) / blockRows expectedScaleCols := (cols + blockCols - 1) / blockCols if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols { return nil, fmt.Errorf( "unexpected fp8 scale shape %v for weight shape %v; want [%d %d]", scaleShape, weightShape, expectedScaleRows, expectedScaleCols, ) } decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16) padBottom := blockRows*scaleShape[0] - rows padSide := blockCols*scaleShape[1] - cols if padBottom > 0 || padSide > 0 { decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)}) } decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols)) decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3)) decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide)) if padBottom > 0 || padSide > 0 { decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)}) } return decoded, nil }