mlx: mixed-precision quant and capability detection improvements (#15409)

Improve the MLX model creation pipeline with several model-agnostic changes:

- Rewrite supportsVision to use vision_config instead of architecture name
- Add supportsAudio for audio encoder detection
- Add alignment checking (isAligned) for quantization group sizes
- Support per-projection mixed quantization in MoE expert packing
- Record per-tensor quant metadata in safetensors blobs
- Parse per-tensor quant metadata at model load time
- Validate quantize output is non-empty before storing
- Fix pin/unpin cleanup in expert group quantization
- Promote v_proj/k_proj/down_proj to INT8 for INT4 base quant
- Add MetalIsAvailable() utility
- Skip audio encoder tensors from quantization
This commit is contained in:
Daniel Hiltgen
2026-04-13 11:43:07 -07:00
committed by GitHub
parent 1b70bb8a10
commit d3da29cbfc
7 changed files with 368 additions and 87 deletions

View File

@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
name: "qwen3.5 multimodal model",
configJSON: `{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3"
"model_type": "qwen3",
"vision_config": {"hidden_size": 1024}
}`,
want: []string{"completion", "vision", "thinking"},
},
{
name: "model with audio config",
configJSON: `{
"architectures": ["Gemma4ForConditionalGeneration"],
"model_type": "gemma4",
"vision_config": {"hidden_size": 1024},
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "vision", "audio"},
},
{
name: "model with audio but no vision",
configJSON: `{
"architectures": ["SomeAudioModel"],
"model_type": "other",
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "audio"},
},
{
name: "non-qwen conditional generation model",
configJSON: `{
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
}
}
func TestParsePerExpertInputs(t *testing.T) {
makeInput := func(name, quantize string) create.PackedTensorInput {
return create.PackedTensorInput{Name: name, Quantize: quantize}
}
t.Run("uniform quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups")
}
if len(groups) != 2 {
t.Fatalf("expected 2 projection groups, got %d", len(groups))
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int4" {
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
}
})
t.Run("mixed quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups for mixed cross-projection quant")
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int8" {
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
}
})
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups != nil {
t.Fatal("expected nil for mixed quant within same projection")
}
})
t.Run("non-experts group rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.mlp.gate_proj.weight", "int4"),
}
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
if groups != nil {
t.Fatal("expected nil for non-experts group")
}
})
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)