Files
ollama/x/create/client/quantize.go
Patrick Devine de5cb7311f mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4
and also importing fp8 and converting directly to mxfp8.
2026-03-24 13:45:44 -07:00

544 lines
16 KiB
Go

package client
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
)
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
// to the provided maps. If quantize is empty, the tensor is kept as-is.
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
if quantize != "" {
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
}
tmpDir := ensureTempDir()
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
if err != nil {
return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath = tmpFile.Name()
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
}
tmpFile.Close()
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
}
// Find the tensor key (may differ from name for single-tensor blobs)
header, err := readSafetensorsHeader(tmpPath)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
}
inputKey, err := safetensorsKey(name, header)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
}
arr := st.Get(inputKey)
if arr == nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
}
// Decode FP8 source encoding before checking quantize, so that callers
// requesting decode-only (quantize="") receive usable float data.
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
scaleKey := inputKey + ".scale_inv"
scaleInv := st.Get(scaleKey)
if scaleInv == nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
}
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
}
mlx.Eval(arr)
}
if quantize == "" {
arr = mlx.Contiguous(arr, false)
arrays[name] = arr
return tmpPath, []*mlx.Array{arr}, st, nil
}
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
// Convert to float type if needed (quantize expects float)
arr = arr.AsType(mlx.DTypeBFloat16)
mlx.Eval(arr)
}
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
arrays[name] = qweight
arrays[name+".scale"] = scales
toEval = append(toEval, qweight, scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases, false)
arrays[name+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
return tmpPath, toEval, st, nil
}
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
// Tensor keys use the original tensor name: name, name.scale, name.bias.
// The blob includes __metadata__ with quant_type and group_size.
// Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
arrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
if tmpPath != "" {
defer os.Remove(tmpPath)
}
if err != nil {
return nil, err
}
finalArrays := make([]*mlx.Array, 0, len(arrays))
for _, arr := range arrays {
if arr != nil {
finalArrays = append(finalArrays, arr)
}
}
mlx.Pin(finalArrays...)
defer func() {
if st != nil {
st.Free()
}
mlx.Unpin(finalArrays...)
mlx.Sweep()
}()
mlx.Eval(toEval...)
mlx.Sweep()
// Free early to release mmap; defer guard handles error paths
if st != nil {
st.Free()
st = nil
}
// Build metadata for single-tensor blobs
groupSize, _, _ := model.QuantizationParams(quantize)
metadata := map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save combined blob: %w", err)
}
return os.ReadFile(outPath)
}
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
// combined safetensors blob. Used for packing expert groups.
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
// they are stacked into 3D switch_mlp tensors before quantization.
// Each tensor may have a different quantization type (mixed-precision).
// Returns the blob bytes.
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
// Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
}
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
uniformQuantize := ""
hasQuantized := false
mixedQuantize := false
for _, input := range inputs {
if input.Quantize == "" {
if hasQuantized {
mixedQuantize = true
}
continue
}
if !hasQuantized {
hasQuantized = true
uniformQuantize = input.Quantize
continue
}
if input.Quantize != uniformQuantize {
mixedQuantize = true
}
}
if hasQuantized && !mixedQuantize {
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
metadata = map[string]string{
"quant_type": uniformQuantize,
"group_size": strconv.Itoa(groupSize),
}
}
}
for _, input := range inputs {
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
if err != nil {
mlx.Unpin(pinned...)
mlx.Sweep()
return nil, err
}
mlx.Eval(toEval...)
finalArrays := arraysForPackedInput(allArrays, input)
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
if st != nil {
st.Free()
}
if tmpPath != "" {
os.Remove(tmpPath)
}
mlx.Sweep()
}
defer func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}()
// Save combined blob. Add global metadata only when every packed tensor uses
// the same quantization mode and group size.
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save packed blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read packed blob: %w", err)
}
return blobData, nil
}
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
keys := []string{input.Name}
if input.Quantize != "" {
keys = append(keys, input.Name+".scale", input.Name+".bias")
}
out := make([]*mlx.Array, 0, len(keys))
for _, key := range keys {
if arr := allArrays[key]; arr != nil {
out = append(out, arr)
}
}
return out
}
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
type expertTensorInfo struct {
index int
proj string // e.g., "gate_proj.weight"
input create.PackedTensorInput
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
index: index,
proj: m[2],
input: input,
})
}
if len(groups) == 0 {
return nil, ""
}
return groups, quantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
for proj := range projGroups {
projNames = append(projNames, proj)
}
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}
for _, proj := range projNames {
experts := projGroups[proj]
// Sort by expert index
sort.Slice(experts, func(i, j int) bool {
return experts[i].index < experts[j].index
})
// Load and decode each expert tensor
var decoded []*mlx.Array
for _, expert := range experts {
dummyArrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
if err != nil {
cleanup()
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
}
mlx.Eval(toEval...)
arr := dummyArrays[expert.input.Name]
mlx.Pin(arr)
pinned = append(pinned, arr)
decoded = append(decoded, arr)
if st != nil {
st.Free()
}
if tmpPath != "" {
os.Remove(tmpPath)
}
mlx.Sweep()
}
// Stack into 3D along axis 0: [numExperts, rows, cols]
stacked := mlx.Stack(decoded, 0)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
// Quantize the stacked tensor
if quantize != "" {
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
allArrays[stackedName+".scale"] = scales
toEval := []*mlx.Array{qweight, scales}
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases, false)
allArrays[stackedName+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
mlx.Eval(toEval...)
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
allArrays[stackedName] = stacked
}
}
defer cleanup()
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
}
return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX library available)
func QuantizeSupported() bool {
return mlx.CheckInit() == nil
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}
type safetensorsHeaderEntry struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
}
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(f, headerBytes); err != nil {
return nil, err
}
var header map[string]safetensorsHeaderEntry
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, err
}
return header, nil
}
// safetensorsKey resolves the primary tensor key from a header.
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
if preferred != "" {
if _, ok := header[preferred]; ok {
return preferred, nil
}
}
keys := make([]string, 0, len(header))
for k := range header {
if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
if len(keys) == 0 {
return "", fmt.Errorf("no tensor found in safetensors header")
}
return keys[0], nil
}
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
if weight == nil || scaleInv == nil {
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
}
weightShape := weight.Dims()
scaleShape := scaleInv.Dims()
if len(weightShape) != 2 || len(scaleShape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
}
// These must match the block size validated by resolveEffectiveQuantization
// in create.go, which rejects any source model with a different block size.
const blockRows = 128
const blockCols = 128
rows, cols := weightShape[0], weightShape[1]
expectedScaleRows := (rows + blockRows - 1) / blockRows
expectedScaleCols := (cols + blockCols - 1) / blockCols
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
return nil, fmt.Errorf(
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
scaleShape,
weightShape,
expectedScaleRows,
expectedScaleCols,
)
}
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
if padBottom > 0 || padSide > 0 {
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
}
return decoded, nil
}