mlx: mixed-precision quant and capability detection improvements (#15409)

Improve the MLX model creation pipeline with several model-agnostic changes:

- Rewrite supportsVision to use vision_config instead of architecture name
- Add supportsAudio for audio encoder detection
- Add alignment checking (isAligned) for quantization group sizes
- Support per-projection mixed quantization in MoE expert packing
- Record per-tensor quant metadata in safetensors blobs
- Parse per-tensor quant metadata at model load time
- Validate quantize output is non-empty before storing
- Fix pin/unpin cleanup in expert group quantization
- Promote v_proj/k_proj/down_proj to INT8 for INT4 base quant
- Add MetalIsAvailable() utility
- Skip audio encoder tensors from quantization
This commit is contained in:
Daniel Hiltgen
2026-04-13 11:43:07 -07:00
committed by GitHub
parent 1b70bb8a10
commit d3da29cbfc
7 changed files with 368 additions and 87 deletions

View File

@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
// Validate quantization produced non-empty output. MLX quantize may return
// empty arrays for unsupported mode/bits combinations without raising an error.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
arrays[name] = qweight
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
// Returns the blob bytes.
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
// Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
}
allArrays := make(map[string]*mlx.Array)
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
// Record per-tensor quant type so the model can resolve params at load time.
if input.Quantize != "" {
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
if metadata == nil {
metadata = make(map[string]string)
}
metadata[input.Name+".quant_type"] = input.Quantize
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
}
}
if st != nil {
st.Free()
}
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// and returns per-projection quantization types. Different projections may use
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
// a projection must share the same type.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
return nil, nil
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
projQuantize := make(map[string]string) // projection -> quant type
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
return nil, nil // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
return nil, nil
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
proj := m[2]
if existing, ok := projQuantize[proj]; ok {
if input.Quantize != existing {
return nil, nil // mixed quant within same projection
}
} else {
projQuantize[proj] = input.Quantize
}
groups[proj] = append(groups[proj], expertTensorInfo{
index: index,
proj: m[2],
proj: proj,
input: input,
})
}
if len(groups) == 0 {
return nil, ""
return nil, nil
}
return groups, quantize
return groups, projQuantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
// projQuantize maps projection name to its quantization type (may differ per projection).
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Build metadata: if all projections use the same quant type, set global metadata.
// Otherwise record per-tensor quant info.
metadata := make(map[string]string)
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
for _, p := range pinned {
if p != nil {
mlx.Unpin(p)
}
}
mlx.Sweep()
}
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
for _, d := range decoded {
if p == d {
pinned[i] = nil
}
}
}
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
quantize := projQuantize[proj]
// Record per-tensor quant metadata so the model can resolve params at load time.
if quantize != "" {
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
metadata[stackedName+".quant_type"] = quantize
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
}
}
// Quantize the stacked tensor
if quantize != "" {
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
// Validate quantization produced non-empty output.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
cleanup()
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
stackedName, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
if p == stacked {
pinned[i] = nil
}
}
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
allArrays[stackedName] = stacked
}
}