mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)

This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4
and also importing fp8 and converting directly to mxfp8.
This commit is contained in:
Patrick Devine
2026-03-24 13:45:44 -07:00
committed by GitHub
parent 95ee7fbd29
commit de5cb7311f
11 changed files with 1349 additions and 128 deletions

View File

@@ -109,7 +109,7 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
}
@@ -280,7 +280,7 @@ func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
if !QuantizeSupported() {
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
}
blobData, err := quantizePackedGroup(tensors)
blobData, err := quantizePackedGroup(groupName, tensors)
if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
}

View File

@@ -7,29 +7,27 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
)
// quantizeParams maps quantization type names to MLX quantize parameters.
var quantizeParams = map[string]struct {
groupSize int
bits int
mode string
}{
"int4": {64, 4, "affine"},
"nvfp4": {16, 4, "nvfp4"},
"int8": {64, 8, "affine"},
"mxfp8": {32, 8, "mxfp8"},
}
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
// to the provided maps. If quantize is empty, the tensor is kept as-is.
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
if quantize != "" {
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
}
tmpDir := ensureTempDir()
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
@@ -50,11 +48,16 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
}
// Find the tensor key (may differ from name for single-tensor blobs)
inputKey, err := findSafetensorsKey(tmpPath)
header, err := readSafetensorsHeader(tmpPath)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
}
inputKey, err := safetensorsKey(name, header)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
}
arr := st.Get(inputKey)
if arr == nil {
@@ -62,34 +65,46 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
}
// Decode FP8 source encoding before checking quantize, so that callers
// requesting decode-only (quantize="") receive usable float data.
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
scaleKey := inputKey + ".scale_inv"
scaleInv := st.Get(scaleKey)
if scaleInv == nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
}
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
}
mlx.Eval(arr)
}
if quantize == "" {
arr = mlx.Contiguous(arr)
arr = mlx.Contiguous(arr, false)
arrays[name] = arr
return tmpPath, []*mlx.Array{arr}, st, nil
}
// Convert to float type if needed (quantize expects float)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
// Convert to float type if needed (quantize expects float)
arr = arr.AsType(mlx.DTypeBFloat16)
mlx.Eval(arr)
}
params, ok := quantizeParams[quantize]
if !ok {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
arrays[name] = qweight
arrays[name+".scale"] = scales
toEval = append(toEval, qweight, scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
qbiases = mlx.Contiguous(qbiases, false)
arrays[name+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
@@ -101,27 +116,45 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
// Tensor keys use the original tensor name: name, name.scale, name.bias.
// The blob includes __metadata__ with quant_type and group_size.
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
// Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
arrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
if tmpPath != "" {
defer os.Remove(tmpPath)
}
if st != nil {
defer st.Free()
}
if err != nil {
return nil, err
}
finalArrays := make([]*mlx.Array, 0, len(arrays))
for _, arr := range arrays {
if arr != nil {
finalArrays = append(finalArrays, arr)
}
}
mlx.Pin(finalArrays...)
defer func() {
if st != nil {
st.Free()
}
mlx.Unpin(finalArrays...)
mlx.Sweep()
}()
mlx.Eval(toEval...)
mlx.Sweep()
// Free early to release mmap; defer guard handles error paths
if st != nil {
st.Free()
st = nil
}
// Build metadata for single-tensor blobs
params := quantizeParams[quantize]
groupSize, _, _ := model.QuantizationParams(quantize)
metadata := map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(params.groupSize),
"group_size": strconv.Itoa(groupSize),
}
tmpDir := ensureTempDir()
@@ -135,48 +168,81 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
// combined safetensors blob. Used for packing expert groups.
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
// they are stacked into 3D switch_mlp tensors before quantization.
// Each tensor may have a different quantization type (mixed-precision).
// Returns the blob bytes. No __metadata__ is added because different tensors
// may use different quantization types.
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
// Returns the blob bytes.
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
// Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
}
allArrays := make(map[string]*mlx.Array)
var allToEval []*mlx.Array
var tmpPaths []string
var handles []*mlx.SafetensorsFile
var pinned []*mlx.Array
var metadata map[string]string
uniformQuantize := ""
hasQuantized := false
mixedQuantize := false
for _, input := range inputs {
if input.Quantize == "" {
if hasQuantized {
mixedQuantize = true
}
continue
}
if !hasQuantized {
hasQuantized = true
uniformQuantize = input.Quantize
continue
}
if input.Quantize != uniformQuantize {
mixedQuantize = true
}
}
if hasQuantized && !mixedQuantize {
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
metadata = map[string]string{
"quant_type": uniformQuantize,
"group_size": strconv.Itoa(groupSize),
}
}
}
for _, input := range inputs {
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
if tmpPath != "" {
tmpPaths = append(tmpPaths, tmpPath)
}
if st != nil {
handles = append(handles, st)
}
if err != nil {
// Cleanup on error
for _, h := range handles {
h.Free()
}
for _, p := range tmpPaths {
os.Remove(p)
}
mlx.Unpin(pinned...)
mlx.Sweep()
return nil, err
}
allToEval = append(allToEval, toEval...)
mlx.Eval(toEval...)
finalArrays := arraysForPackedInput(allArrays, input)
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
if st != nil {
st.Free()
}
if tmpPath != "" {
os.Remove(tmpPath)
}
mlx.Sweep()
}
defer func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}()
mlx.Eval(allToEval...)
// Free native handles after eval
for _, h := range handles {
h.Free()
}
// Save combined blob (no global metadata for mixed-precision packed blobs)
// Save combined blob. Add global metadata only when every packed tensor uses
// the same quantization mode and group size.
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save packed blob: %w", err)
}
@@ -185,17 +251,193 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return nil, fmt.Errorf("failed to read packed blob: %w", err)
}
for _, p := range tmpPaths {
os.Remove(p)
return blobData, nil
}
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
keys := []string{input.Name}
if input.Quantize != "" {
keys = append(keys, input.Name+".scale", input.Name+".bias")
}
out := make([]*mlx.Array, 0, len(keys))
for _, key := range keys {
if arr := allArrays[key]; arr != nil {
out = append(out, arr)
}
}
return out
}
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
type expertTensorInfo struct {
index int
proj string // e.g., "gate_proj.weight"
input create.PackedTensorInput
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
index: index,
proj: m[2],
input: input,
})
}
if len(groups) == 0 {
return nil, ""
}
return groups, quantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
for proj := range projGroups {
projNames = append(projNames, proj)
}
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}
for _, proj := range projNames {
experts := projGroups[proj]
// Sort by expert index
sort.Slice(experts, func(i, j int) bool {
return experts[i].index < experts[j].index
})
// Load and decode each expert tensor
var decoded []*mlx.Array
for _, expert := range experts {
dummyArrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
if err != nil {
cleanup()
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
}
mlx.Eval(toEval...)
arr := dummyArrays[expert.input.Name]
mlx.Pin(arr)
pinned = append(pinned, arr)
decoded = append(decoded, arr)
if st != nil {
st.Free()
}
if tmpPath != "" {
os.Remove(tmpPath)
}
mlx.Sweep()
}
// Stack into 3D along axis 0: [numExperts, rows, cols]
stacked := mlx.Stack(decoded, 0)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
// Quantize the stacked tensor
if quantize != "" {
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
allArrays[stackedName+".scale"] = scales
toEval := []*mlx.Array{qweight, scales}
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases, false)
allArrays[stackedName+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
mlx.Eval(toEval...)
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
allArrays[stackedName] = stacked
}
}
defer cleanup()
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
}
return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX library available)
func QuantizeSupported() bool {
mlx.InitMLX()
return mlx.IsMLXAvailable()
return mlx.CheckInit() == nil
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
@@ -205,32 +447,97 @@ func ensureTempDir() string {
return tmpDir
}
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
func findSafetensorsKey(path string) (string, error) {
type safetensorsHeaderEntry struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
f, err := os.Open(path)
if err != nil {
return "", err
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return "", err
return nil, err
}
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(f, headerBytes); err != nil {
return "", err
return nil, err
}
var header map[string]json.RawMessage
var header map[string]safetensorsHeaderEntry
if err := json.Unmarshal(headerBytes, &header); err != nil {
return "", err
return nil, err
}
return header, nil
}
for k := range header {
if k != "__metadata__" {
return k, nil
// safetensorsKey resolves the primary tensor key from a header.
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
if preferred != "" {
if _, ok := header[preferred]; ok {
return preferred, nil
}
}
return "", fmt.Errorf("no tensor found in safetensors header")
keys := make([]string, 0, len(header))
for k := range header {
if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
if len(keys) == 0 {
return "", fmt.Errorf("no tensor found in safetensors header")
}
return keys[0], nil
}
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
if weight == nil || scaleInv == nil {
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
}
weightShape := weight.Dims()
scaleShape := scaleInv.Dims()
if len(weightShape) != 2 || len(scaleShape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
}
// These must match the block size validated by resolveEffectiveQuantization
// in create.go, which rejects any source model with a different block size.
const blockRows = 128
const blockCols = 128
rows, cols := weightShape[0], weightShape[1]
expectedScaleRows := (rows + blockRows - 1) / blockRows
expectedScaleCols := (cols + blockCols - 1) / blockCols
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
return nil, fmt.Errorf(
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
scaleShape,
weightShape,
expectedScaleRows,
expectedScaleCols,
)
}
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
if padBottom > 0 || padSide > 0 {
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
}
return decoded, nil
}

View File

@@ -267,13 +267,13 @@ func ShouldQuantize(name, component string) bool {
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "mxfp4", "int8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != ""
}
// normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp4/MXFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4":
@@ -282,6 +282,8 @@ func normalizeQuantType(quantize string) string {
return "int8"
case "NVFP4":
return "nvfp4"
case "MXFP4":
return "mxfp4"
case "MXFP8":
return "mxfp8"
default:
@@ -335,7 +337,7 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, mxfp8: 32, int4/int8: 64
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
@@ -353,8 +355,8 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return ""
}
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
// For non-affine modes, use the same quantization for all eligible tensors.
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
return quantNorm
}
@@ -391,23 +393,39 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return quantNorm
}
// expertGroupRegexp matches expert tensor names and captures the group prefix.
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
var expertLayerPrefixRegexp = regexp.MustCompile(`^(?:model\.language_model\.|language_model(?:\.model)?\.|model\.)?layers\.\d+$`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example:
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
// - "language_model.model.layers.1.mlp.switch_mlp.down_proj.weight" -> "language_model.model.layers.1.mlp.switch_mlp"
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
func ExpertGroupPrefix(tensorName string) string {
m := expertGroupRegexp.FindStringSubmatch(tensorName)
if m == nil {
if !strings.HasSuffix(tensorName, ".weight") {
return ""
}
return m[1]
for _, marker := range []string{
".mlp.experts.",
".mlp.shared_experts.",
".mlp.switch_mlp.",
} {
idx := strings.Index(tensorName, marker)
if idx == -1 {
continue
}
layerPrefix := tensorName[:idx]
if !expertLayerPrefixRegexp.MatchString(layerPrefix) {
continue
}
return layerPrefix + strings.TrimSuffix(marker, ".")
}
return ""
}
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
@@ -424,9 +442,11 @@ type PackedTensorInput struct {
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
type sourceQuantization struct {
Bits int `json:"bits"`
GroupSize int `json:"group_size"`
Mode string `json:"mode"`
Bits int `json:"bits"`
GroupSize int `json:"group_size"`
Mode string `json:"mode"`
QuantMethod string `json:"quant_method"`
WeightBlockSize []int32 `json:"weight_block_size"`
}
type sourceModelConfig struct {
@@ -493,6 +513,98 @@ func (cfg sourceModelConfig) QuantMetadata() map[string]string {
return metadata
}
type sourceQuantizedKind string
const (
sourceQuantizedKindNone sourceQuantizedKind = ""
sourceQuantizedKindPrequantized sourceQuantizedKind = "prequantized"
sourceQuantizedKindHFFP8 sourceQuantizedKind = "hf_fp8"
)
func (cfg sourceModelConfig) quantizationConfigs() []sourceQuantization {
return []sourceQuantization{
cfg.Quantization,
cfg.QuantizationConfig,
cfg.TextConfig.Quantization,
cfg.TextConfig.QuantizationConfig,
}
}
func (cfg sourceModelConfig) HFFP8WeightBlockSize() (rows, cols int32, ok bool) {
for _, q := range cfg.quantizationConfigs() {
if !strings.EqualFold(q.QuantMethod, "fp8") || len(q.WeightBlockSize) != 2 {
continue
}
return q.WeightBlockSize[0], q.WeightBlockSize[1], true
}
return 0, 0, false
}
func inspectSourceQuantization(modelDir string, cfg sourceModelConfig) (sourceQuantizedKind, error) {
entries, err := os.ReadDir(modelDir)
if err != nil {
return sourceQuantizedKindNone, err
}
hasScaleInv := false
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
if err != nil {
return sourceQuantizedKindNone, err
}
for _, name := range extractor.ListTensors() {
switch {
case strings.HasSuffix(name, ".scales"):
extractor.Close()
return sourceQuantizedKindPrequantized, nil
case strings.HasSuffix(name, ".weight_scale_inv"):
hasScaleInv = true
}
}
extractor.Close()
}
if hasScaleInv {
if _, _, ok := cfg.HFFP8WeightBlockSize(); ok {
return sourceQuantizedKindHFFP8, nil
}
}
return sourceQuantizedKindNone, nil
}
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
switch sourceKind {
case sourceQuantizedKindNone:
return requested, nil
case sourceQuantizedKindPrequantized:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
}
return "", nil
case sourceQuantizedKindHFFP8:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized fp8 source model with --quantize %q", requested)
}
rows, cols, ok := cfg.HFFP8WeightBlockSize()
if !ok {
return "", fmt.Errorf("fp8 source model missing weight_block_size metadata")
}
if rows != 128 || cols != 128 {
return "", fmt.Errorf("unsupported fp8 source block size %dx%d", rows, cols)
}
return "mxfp8", nil
default:
return "", fmt.Errorf("unsupported source quantization kind %q", sourceKind)
}
}
type tensorImportTransform interface {
skipTensor(name string) bool
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
@@ -546,6 +658,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if err != nil {
return fmt.Errorf("failed to read source config.json: %w", err)
}
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
if err != nil {
return fmt.Errorf("failed to inspect source quantization: %w", err)
}
effectiveQuantize, err := resolveEffectiveQuantization(sourceConfig, sourceQuantKind, quantize)
if err != nil {
return err
}
sourceQuantMetadata := sourceConfig.QuantMetadata()
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
if err != nil {
@@ -557,7 +677,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if len(createPackedLayer) > 0 {
packedCreator = createPackedLayer[0]
}
// Accumulate expert tensors by group prefix for packing.
// Readers reference file-backed SectionReaders, so we keep extractors
// open until each group is flushed to avoid buffering tensor data in memory.
@@ -600,8 +719,8 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
tensorSet[name] = struct{}{}
}
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
if effectiveQuantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", effectiveQuantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
@@ -612,9 +731,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if importTransform.skipTensor(tensorName) {
continue
}
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) {
if shouldSkipSourceCompanion(tensorName, tensorSet) {
continue
}
sourceFP8ScaleName, hasSourceFP8Scale := sourceFP8Companion(tensorName, tensorSet)
td, err := extractor.GetTensor(tensorName)
if err != nil {
@@ -623,7 +743,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
if quantize == "" {
if effectiveQuantize == "" {
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
if err != nil {
extractor.Close()
@@ -647,8 +767,33 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
if quantize != "" {
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize)
switch {
case sourceQuantKind == sourceQuantizedKindHFFP8 && hasSourceFP8Scale:
quantizeType = "mxfp8"
case sourceQuantKind == sourceQuantizedKindHFFP8:
quantizeType = ""
case effectiveQuantize != "":
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
}
reader := outTD.SafetensorsReader()
if hasSourceFP8Scale {
if len(outputTensors) != 1 {
extractor.Close()
closeExtractors()
return fmt.Errorf("source fp8 tensor %s rewrote into %d tensors; only 1:1 rewrites are supported", tensorName, len(outputTensors))
}
if quantizeType == "" {
extractor.Close()
closeExtractors()
return fmt.Errorf("source fp8 tensor %s was not scheduled for mxfp8 conversion", tensorName)
}
scaleTD, err := extractor.GetTensor(sourceFP8ScaleName)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to get fp8 scale tensor %s: %w", sourceFP8ScaleName, err)
}
reader = buildSourceFP8Reader(outTD, scaleTD.WithName(outTD.Name+".scale_inv"))
}
// Check if this tensor belongs to an expert group for packing
@@ -670,13 +815,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
Dtype: outTD.Dtype,
Shape: outTD.Shape,
Quantize: quantizeType,
Reader: outTD.SafetensorsReader(),
Reader: reader,
})
} else {
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
newLayers, err := createTensorLayer(reader, outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
if err != nil {
extractor.Close()
closeExtractors()
@@ -760,7 +905,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return nil
}
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool {
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}) bool {
switch {
case strings.HasSuffix(name, ".scales"):
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
@@ -768,11 +913,28 @@ func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{})
case strings.HasSuffix(name, ".biases"):
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
return ok
case strings.HasSuffix(name, ".weight_scale_inv"):
_, ok := tensorSet[strings.TrimSuffix(name, "_scale_inv")]
return ok
default:
return false
}
}
func sourceFP8Companion(weightName string, tensorSet map[string]struct{}) (scaleName string, ok bool) {
if !strings.HasSuffix(weightName, ".weight") {
return "", false
}
scaleName = weightName + "_scale_inv"
_, ok = tensorSet[scaleName]
return scaleName, ok
}
func buildSourceFP8Reader(weightTD, scaleTD *safetensors.TensorData) io.Reader {
return safetensors.BuildPackedSafetensorsReader([]*safetensors.TensorData{weightTD, scaleTD})
}
func createPrequantizedLayer(
extractor *safetensors.TensorExtractor,
td *safetensors.TensorData,

View File

@@ -246,6 +246,30 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
return nil
}
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" {
continue
}
names = append(names, name)
}
slices.Sort(names)
return names
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
@@ -546,6 +570,215 @@ func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
}
}
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
})
quantizeByName := make(map[string]string)
headerNamesByName := make(map[string][]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
}
if got := quantizeByName["norm.weight"]; got != "" {
t.Fatalf("norm.weight quantization = %q, want empty", got)
}
if got := quantizeByName["dense.weight"]; got != "" {
t.Fatalf("dense.weight quantization = %q, want empty", got)
}
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
}
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
}
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
}
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
}
}
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
tests := []struct {
name string
configJSON string
tensors []*st.TensorData
wantErr string
}{
{
name: "prequantized affine",
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
},
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
},
{
name: "hf fp8 source",
configJSON: `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
},
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return nil, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
}
})
}
}
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create 2 experts so stacking produces a [2, 128, 128] tensor
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
})
var packedLayerNames []string
var packedLayerTensors [][]PackedTensorInput
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
packedLayerNames = append(packedLayerNames, groupName)
packedLayerTensors = append(packedLayerTensors, tensors)
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(packedLayerNames) != 1 {
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
}
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
}
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
tensors := packedLayerTensors[0]
if len(tensors) != 6 {
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
}
// All should be marked for mxfp8 quantization
for _, tensor := range tensors {
if tensor.Quantize != "mxfp8" {
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
}
}
}
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
dir := t.TempDir()
@@ -693,6 +926,113 @@ func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
}
}
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
t.Run(quantize, func(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantizeType}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
for _, name := range []string{
"language_model.model.embed_tokens.weight",
"language_model.lm_head.weight",
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
"language_model.model.layers.0.mlp.gate.weight",
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
} {
if got := tensorCalls[name].quantize; got != "" {
t.Fatalf("%s quantize = %q, want empty", name, got)
}
}
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
}
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
if len(group) != 3 {
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
}
for _, tensor := range group {
if tensor.Quantize != quantize {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
}
}
})
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
@@ -865,6 +1205,7 @@ func TestShouldQuantizeTensor(t *testing.T) {
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
@@ -891,9 +1232,11 @@ func TestShouldQuantizeTensor(t *testing.T) {
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests
// FP8/FP4 require divisible by 32
// FP8/FP4/MXFP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
// NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
@@ -919,10 +1262,20 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// Expert tensors with language_model prefix should also match
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
// Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
// Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
@@ -978,6 +1331,161 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
nvfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"nvfp4",
)
if nvfp4GateUp != "nvfp4" {
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
}
nvfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"nvfp4",
)
if nvfp4Down != "nvfp4" {
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
}
mxfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"mxfp4",
)
if mxfp4GateUp != "mxfp4" {
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
}
mxfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"mxfp4",
)
if mxfp4Down != "mxfp4" {
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
}
}
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Dtype string
Shape []int32
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantize}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Dtype: tensor.Dtype,
Shape: append([]int32(nil), tensor.Shape...),
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
groupName := "language_model.model.layers.0.mlp.switch_mlp"
group, ok := packedCalls[groupName]
if !ok {
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
}
if len(group) != 3 {
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
}
gotNames := make([]string, 0, len(group))
for _, tensor := range group {
gotNames = append(gotNames, tensor.Name)
if tensor.Quantize != "nvfp4" {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
}
if tensor.Dtype != "BF16" {
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
}
}
slices.Sort(gotNames)
wantNames := []string{
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
}
if !slices.Equal(gotNames, wantNames) {
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
}
for _, name := range wantNames {
if _, ok := tensorCalls[name]; ok {
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
}
}
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
t.Fatalf("embed_tokens quantize = %q, want empty", got)
}
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
t.Fatalf("mlp.gate quantize = %q, want empty", got)
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {

View File

@@ -87,6 +87,27 @@ func (t qwen35ImportTransform) skipTensor(name string) bool {
return strings.Contains(name, "mtp.")
}
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
switch {
case strings.HasSuffix(name, "embed_tokens.weight"):
return true
case strings.HasSuffix(name, "lm_head.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
return true
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
return true
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
return true
default:
return false
}
}
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
if strings.HasPrefix(name, "vision_tower.") {
return ""
@@ -127,6 +148,13 @@ func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quan
return ""
}
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
// keep embeddings, LM head, low-rank linear_attn projections, and routing
// gates in BF16 rather than forcing them into a non-affine quantized format.
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
return ""
}
return quantNorm
}

View File

@@ -4,35 +4,91 @@ package mlx
import "C"
import (
"fmt"
"iter"
"runtime"
"unsafe"
)
// SafetensorsFile represents a loaded safetensors file.
type SafetensorsFile struct {
arrays C.mlx_map_string_to_array
metadata C.mlx_map_string_to_string
}
func loadSafetensorsStream() C.mlx_stream {
if runtime.GOOS == "darwin" {
return C.mlx_default_cpu_stream_new()
}
return C.mlx_default_gpu_stream_new()
}
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
stream := loadSafetensorsStream()
defer C.mlx_stream_free(stream)
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name.
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
value := C.mlx_array_new()
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
return nil
}
if value.ctx == nil {
return nil
}
arr := New(name)
arr.ctx = value
return arr
}
// GetMetadata retrieves a metadata value by key.
func (s *SafetensorsFile) GetMetadata(key string) string {
cKey := C.CString(key)
defer C.free(unsafe.Pointer(cKey))
var cValue *C.char
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
return ""
}
return C.GoString(cValue)
}
// Free releases the loaded safetensors maps.
func (s *SafetensorsFile) Free() {
if s == nil {
return
}
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu).
// macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
var stream C.mlx_stream
if runtime.GOOS == "darwin" {
stream = C.mlx_default_cpu_stream_new()
} else {
stream = C.mlx_default_gpu_stream_new()
sf, err := LoadSafetensorsNative(path)
if err != nil {
return
}
defer C.mlx_stream_free(stream)
defer sf.Free()
C.mlx_load_safetensors(&string2array, &string2string, cPath, stream)
it := C.mlx_map_string_to_array_iterator_new(string2array)
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
@@ -51,3 +107,43 @@ func Load(path string) iter.Seq2[string, *Array] {
}
}
}
// SaveSafetensors saves arrays to a safetensors file without metadata.
func SaveSafetensors(path string, arrays map[string]*Array) error {
return SaveSafetensorsWithMetadata(path, arrays, nil)
}
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
for name, arr := range arrays {
if arr == nil {
continue
}
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
cMetadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMetadata)
for key, value := range metadata {
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
C.free(unsafe.Pointer(cKey))
C.free(unsafe.Pointer(cValue))
}
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}

View File

@@ -33,6 +33,18 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
return w0, w1, nil
}
func FromFP8(x *Array, dtype DType) *Array {
out := New("FROM_FP8")
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func ToFP8(x *Array) *Array {
out := New("TO_FP8")
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
return out
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -137,6 +149,40 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
return out
}
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
axes := make([]C.int, numAxes)
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := range numAxes {
axes[i] = C.int(i)
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
}
padValue := C.mlx_array_new_float(C.float(0))
defer C.mlx_array_free(padValue)
cMode := C.CString("constant")
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(axes),
C.size_t(len(axes)),
unsafe.SliceData(lowPad),
C.size_t(len(lowPad)),
unsafe.SliceData(highPad),
C.size_t(len(highPad)),
padValue,
cMode,
DefaultStream().ctx,
)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)

View File

@@ -11,8 +11,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "MXFP4":
return 32, 4, "mxfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
return 64, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8":

View File

@@ -144,3 +144,44 @@ func TestLayerNormDefaultEps(t *testing.T) {
}
}
}
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
skipIfNoMLX(t)
weightVals := make([]float32, 3*32)
for i := range weightVals {
weightVals[i] = float32((i%11)-5) / 7
}
inputVals := make([]float32, 2*32)
for i := range inputVals {
inputVals[i] = float32((i%7)-3) / 5
}
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
mlx.Eval(weight, input)
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
if ql.QBiases != nil {
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
}
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input)
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
mlx.Eval(qOut, dOut)
got := qOut.Floats()
want := dOut.Floats()
if len(got) != len(want) {
t.Fatalf("output length = %d, want %d", len(got), len(want))
}
for i := range got {
if !approxEqual(got[i], want[i], 1e-3) {
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
}
}
}

View File

@@ -420,7 +420,16 @@ func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, strin
}
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
switch mode {
case "affine":
return bits == 4 || bits == 8
case "mxfp8":
return bits == 8
case "nvfp4", "mxfp4":
return bits == 4
default:
return false
}
}
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {

View File

@@ -83,6 +83,28 @@ func TestLayerSelectionHelpers(t *testing.T) {
}
}
func TestSupportsGatherQMM(t *testing.T) {
tests := []struct {
mode string
bits int
want bool
}{
{mode: "affine", bits: 4, want: true},
{mode: "affine", bits: 8, want: true},
{mode: "mxfp8", bits: 8, want: true},
{mode: "nvfp4", bits: 4, want: true},
{mode: "mxfp4", bits: 4, want: true},
{mode: "mxfp8", bits: 4, want: false},
{mode: "affine", bits: 3, want: false},
}
for _, tt := range tests {
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
}
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")