diff --git a/x/create/client/create.go b/x/create/client/create.go index 74abb865e..5ada0c23b 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string { capabilities = append(capabilities, "vision") } + if supportsAudio(modelDir) { + capabilities = append(capabilities, "audio") + } + if supportsThinking(modelDir) { capabilities = append(capabilities, "thinking") } @@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool { return false } -// supportsVision checks if the model supports image input based on its architecture. -// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures. +// supportsVision checks if the model has a vision encoder by looking for +// vision_config in config.json. func supportsVision(modelDir string) bool { - configPath := filepath.Join(modelDir, "config.json") - data, err := os.ReadFile(configPath) + data, err := os.ReadFile(filepath.Join(modelDir, "config.json")) if err != nil { return false } var cfg struct { - Architectures []string `json:"architectures"` - ModelType string `json:"model_type"` + VisionConfig *map[string]any `json:"vision_config"` } if err := json.Unmarshal(data, &cfg); err != nil { return false } - for _, arch := range cfg.Architectures { - archLower := strings.ToLower(arch) - if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") { - return true - } + return cfg.VisionConfig != nil +} + +func supportsAudio(modelDir string) bool { + data, err := os.ReadFile(filepath.Join(modelDir, "config.json")) + if err != nil { + return false } - typeLower := strings.ToLower(cfg.ModelType) - return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration") + var cfg struct { + AudioConfig *map[string]any `json:"audio_config"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return false + } + + return cfg.AudioConfig != nil } // getParserName returns the parser name for a model based on its architecture. diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index 286ea2208..a8a6fc4a8 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) { name: "qwen3.5 multimodal model", configJSON: `{ "architectures": ["Qwen3_5ForConditionalGeneration"], - "model_type": "qwen3" + "model_type": "qwen3", + "vision_config": {"hidden_size": 1024} }`, want: []string{"completion", "vision", "thinking"}, }, + { + name: "model with audio config", + configJSON: `{ + "architectures": ["Gemma4ForConditionalGeneration"], + "model_type": "gemma4", + "vision_config": {"hidden_size": 1024}, + "audio_config": {"num_mel_bins": 128} + }`, + want: []string{"completion", "vision", "audio"}, + }, + { + name: "model with audio but no vision", + configJSON: `{ + "architectures": ["SomeAudioModel"], + "model_type": "other", + "audio_config": {"num_mel_bins": 128} + }`, + want: []string{"completion", "audio"}, + }, { name: "non-qwen conditional generation model", configJSON: `{ @@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) { } } +func TestParsePerExpertInputs(t *testing.T) { + makeInput := func(name, quantize string) create.PackedTensorInput { + return create.PackedTensorInput{Name: name, Quantize: quantize} + } + + t.Run("uniform quant across projections", func(t *testing.T) { + inputs := []create.PackedTensorInput{ + makeInput("layer.moe.experts.0.gate_proj.weight", "int4"), + makeInput("layer.moe.experts.1.gate_proj.weight", "int4"), + makeInput("layer.moe.experts.0.down_proj.weight", "int4"), + makeInput("layer.moe.experts.1.down_proj.weight", "int4"), + } + groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs) + if groups == nil { + t.Fatal("expected non-nil groups") + } + if len(groups) != 2 { + t.Fatalf("expected 2 projection groups, got %d", len(groups)) + } + if projQ["gate_proj.weight"] != "int4" { + t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"]) + } + if projQ["down_proj.weight"] != "int4" { + t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"]) + } + }) + + t.Run("mixed quant across projections", func(t *testing.T) { + inputs := []create.PackedTensorInput{ + makeInput("layer.moe.experts.0.gate_proj.weight", "int4"), + makeInput("layer.moe.experts.1.gate_proj.weight", "int4"), + makeInput("layer.moe.experts.0.down_proj.weight", "int8"), + makeInput("layer.moe.experts.1.down_proj.weight", "int8"), + } + groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs) + if groups == nil { + t.Fatal("expected non-nil groups for mixed cross-projection quant") + } + if projQ["gate_proj.weight"] != "int4" { + t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"]) + } + if projQ["down_proj.weight"] != "int8" { + t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"]) + } + }) + + t.Run("mixed quant within same projection rejected", func(t *testing.T) { + inputs := []create.PackedTensorInput{ + makeInput("layer.moe.experts.0.down_proj.weight", "int4"), + makeInput("layer.moe.experts.1.down_proj.weight", "int8"), + } + groups, _ := parsePerExpertInputs("layer.moe.experts", inputs) + if groups != nil { + t.Fatal("expected nil for mixed quant within same projection") + } + }) + + t.Run("non-experts group rejected", func(t *testing.T) { + inputs := []create.PackedTensorInput{ + makeInput("layer.mlp.gate_proj.weight", "int4"), + } + groups, _ := parsePerExpertInputs("layer.mlp", inputs) + if groups != nil { + t.Fatal("expected nil for non-experts group") + } + }) +} + func TestQuantizeSupported(t *testing.T) { // This just verifies the function exists and returns a boolean // The actual value depends on build tags (mlx vs non-mlx) diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index c5dcb9a95..893252b7e 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string] groupSize, bits, mode := model.QuantizationParams(quantize) qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode) + // Validate quantization produced non-empty output. MLX quantize may return + // empty arrays for unsupported mode/bits combinations without raising an error. + mlx.Eval(qweight, scales) + if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 { + st.Free() + return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)", + name, quantize, groupSize, bits, mode) + } + if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 { + st.Free() + return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)", + name, quantize, groupSize, bits, mode) + } + qweight = mlx.Contiguous(qweight, false) scales = mlx.Contiguous(scales, false) arrays[name] = qweight @@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti // Returns the blob bytes. func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) { // Check if inputs are per-expert tensors that should be stacked into 3D - if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil { - return stackAndQuantizeExpertGroup(groupName, projGroups, quantize) + if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil { + return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize) } allArrays := make(map[string]*mlx.Array) @@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([ mlx.Pin(finalArrays...) pinned = append(pinned, finalArrays...) + // Record per-tensor quant type so the model can resolve params at load time. + if input.Quantize != "" { + if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 { + if metadata == nil { + metadata = make(map[string]string) + } + metadata[input.Name+".quant_type"] = input.Quantize + metadata[input.Name+".group_size"] = strconv.Itoa(groupSize) + } + } + if st != nil { st.Free() } @@ -279,57 +304,60 @@ type expertTensorInfo struct { } // parsePerExpertInputs groups per-expert 2D tensor inputs by projection type -// and returns the uniform quantization type shared by all inputs. -// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D) -// or if the inputs have mixed quantization types. +// and returns per-projection quantization types. Different projections may use +// different quant types (e.g., gate_up=int4, down=int8) but all experts within +// a projection must share the same type. +// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D). // Only handles ".experts" groups; ".shared_experts" groups are left unpacked. -func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) { +func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) { if !strings.HasSuffix(groupName, ".experts") { - return nil, "" + return nil, nil } - quantize := inputs[0].Quantize groups := make(map[string][]expertTensorInfo) + projQuantize := make(map[string]string) // projection -> quant type for _, input := range inputs { - if input.Quantize != quantize { - return nil, "" // mixed quantization types - } suffix := strings.TrimPrefix(input.Name, groupName) m := perExpertSuffix.FindStringSubmatch(suffix) if m == nil { - return nil, "" // not a per-expert pattern + return nil, nil // not a per-expert pattern } index, err := strconv.Atoi(m[1]) if err != nil { - return nil, "" + return nil, nil } - groups[m[2]] = append(groups[m[2]], expertTensorInfo{ + proj := m[2] + if existing, ok := projQuantize[proj]; ok { + if input.Quantize != existing { + return nil, nil // mixed quant within same projection + } + } else { + projQuantize[proj] = input.Quantize + } + groups[proj] = append(groups[proj], expertTensorInfo{ index: index, - proj: m[2], + proj: proj, input: input, }) } if len(groups) == 0 { - return nil, "" + return nil, nil } - return groups, quantize + return groups, projQuantize } // stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D // switch_mlp tensors, quantizes, and returns the combined safetensors blob. -func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) { +// projQuantize maps projection name to its quantization type (may differ per projection). +func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) { groupBase := strings.TrimSuffix(groupName, ".experts") allArrays := make(map[string]*mlx.Array) var pinned []*mlx.Array - var metadata map[string]string - if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" { - metadata = map[string]string{ - "quant_type": quantize, - "group_size": strconv.Itoa(groupSize), - } - } + // Build metadata: if all projections use the same quant type, set global metadata. + // Otherwise record per-tensor quant info. + metadata := make(map[string]string) // Sort projection names for deterministic output projNames := make([]string, 0, len(projGroups)) @@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper sort.Strings(projNames) cleanup := func() { - mlx.Unpin(pinned...) + for _, p := range pinned { + if p != nil { + mlx.Unpin(p) + } + } mlx.Sweep() } @@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper mlx.Pin(stacked) pinned = append(pinned, stacked) - // Free individual decoded arrays + // Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup) + for i, p := range pinned { + for _, d := range decoded { + if p == d { + pinned[i] = nil + } + } + } mlx.Unpin(decoded...) mlx.Sweep() stackedName := groupBase + ".switch_mlp." + proj + quantize := projQuantize[proj] + + // Record per-tensor quant metadata so the model can resolve params at load time. + if quantize != "" { + if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 { + metadata[stackedName+".quant_type"] = quantize + metadata[stackedName+".group_size"] = strconv.Itoa(groupSize) + } + } // Quantize the stacked tensor if quantize != "" { @@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode) + // Validate quantization produced non-empty output. + mlx.Eval(qweight, scales) + if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 { + cleanup() + return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)", + stackedName, quantize, groupSize, bits, mode) + } + qweight = mlx.Contiguous(qweight, false) scales = mlx.Contiguous(scales, false) allArrays[stackedName] = qweight @@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper mlx.Pin(toEval...) pinned = append(pinned, toEval...) - // Free stacked source array + // Free stacked source array (remove from pinned to avoid double-unpin in cleanup) + for i, p := range pinned { + if p == stacked { + pinned[i] = nil + } + } mlx.Unpin(stacked) mlx.Sweep() } else { stacked = mlx.Contiguous(stacked, false) mlx.Eval(stacked) + mlx.Pin(stacked) + pinned = append(pinned, stacked) allArrays[stackedName] = stacked } } diff --git a/x/create/create.go b/x/create/create.go index da544e6a4..88747ebed 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool { return false } + // Skip audio encoder tensors (highly sensitive to quantization) + if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") { + return false + } + // Skip embeddings if strings.Contains(name, "embed") { return false @@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string { } } +// isAligned checks if a tensor's last dimension is divisible by the +// group size required for the given quantization type. +func isAligned(shape []int32, quantType string) bool { + if len(shape) == 0 { + return false + } + groupSize := int32(32) + switch normalizeQuantType(quantType) { + case "nvfp4": + groupSize = 16 + case "int4", "int8": + groupSize = 64 + } + return shape[len(shape)-1]%groupSize == 0 +} + func isStackedExpertWeight(name string) bool { // Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert) // or "...proj" (pre-stacked packed tensor). @@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool { return strings.Contains(name, ".mlp.switch_mlp.") || strings.Contains(name, ".mlp.experts.") || - strings.Contains(name, ".mlp.shared_experts.") + strings.Contains(name, ".mlp.shared_experts.") || + strings.Contains(name, ".moe.experts.") } // GetTensorQuantization returns the appropriate quantization type for a tensor. // Returns "" if the tensor should not be quantized. // This implements mixed-precision quantization: -// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive) -// - Output projection, gate/up weights: int4 (less sensitive) -// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel) +// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4 // - Norms, embeddings, biases, routing gates: no quantization +// - All other eligible weights: use requested quantization type func GetTensorQuantization(name string, shape []int32, quantize string) string { stackedExpert := isStackedExpertWeight(name) @@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string { // Normalize quantization type to canonical form quantNorm := normalizeQuantType(quantize) - // MLX quantization requires last dimension to be divisible by group size - // nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64 - groupSize := int32(32) - switch quantNorm { - case "nvfp4": - groupSize = 16 - case "int4", "int8": - groupSize = 64 - } - if shape[len(shape)-1]%groupSize != 0 { - return "" - } - // Skip routing gate weights (should stay high precision) // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight) if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") { return "" } + // MLX quantization requires last dimension to be divisible by group size. + if !isAligned(shape, quantNorm) { + return "" + } + // For non-affine modes, use the same quantization for all eligible tensors. if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" { return quantNorm } - // Attention MLA weights - keep unquantized (bf16) - // These are highly sensitive: errors accumulate in the KV cache over time - // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj - if strings.Contains(name, "q_a_proj") || - strings.Contains(name, "q_b_proj") || - strings.Contains(name, "kv_a_proj") || - strings.Contains(name, "kv_b_proj") { - return "" // No quantization - keep bf16 + // Value projection weights directly determine attention output quality. + // Down projection weights feed directly into the residual stream where + // errors accumulate across layers. Both benefit from higher precision. + // Promote to INT8 when base is INT4 (same affine mode, compatible with + // GatherQMM for MoE expert tensors). + if quantNorm == "int4" { + if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") { + if isAligned(shape, "int8") { + return "int8" + } + } } - // Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel) - // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj - if strings.Contains(name, "down_proj") { - return "int8" - } - - // Output projection, gate/up weights - use requested quantization (INT4) - // o_proj, gate_proj, up_proj - if strings.Contains(name, "o_proj") || - strings.Contains(name, "gate_proj") || - strings.Contains(name, "up_proj") { - return quantNorm - } - - // LM head - use requested quantization - if strings.Contains(name, "lm_head") { - return quantNorm - } - - // Default to requested quantization for other weights return quantNorm } @@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string { ".mlp.experts.", ".mlp.shared_experts.", ".mlp.switch_mlp.", + ".moe.experts.", } { idx := strings.Index(tensorName, marker) if idx == -1 { diff --git a/x/create/create_test.go b/x/create/create_test.go index f5a68ba1a..0acbb6613 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) { {"ln prefix", "ln_1.weight", "", false}, {"layernorm in name", "input_layernorm.weight", "", false}, + // Audio encoder tensors should not be quantized + {"audio tower weight", "model.audio_tower.layers.0.weight", "", false}, + {"audio tower norm", "model.audio_tower.norm.weight", "", false}, + {"embed audio weight", "embed_audio.weight", "", false}, + // Biases should not be quantized {"bias tensor", "attention.bias", "", false}, {"proj bias", "o_proj.bias", "", false}, @@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) { {"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"}, {"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"}, + // MoE expert tensors (Gemma-style .moe.experts.) + {"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"}, + {"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"}, + {"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"}, + // Expert tensors with language_model prefix should also match {"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"}, {"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"}, @@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) { } } +func TestIsAligned(t *testing.T) { + tests := []struct { + name string + shape []int32 + quantType string + want bool + }{ + // int4/int8: group_size=64 + {"int4 aligned", []int32{1024, 4096}, "int4", true}, + {"int4 unaligned", []int32{1024, 48}, "int4", false}, + {"int8 aligned", []int32{1024, 128}, "int8", true}, + {"int8 unaligned", []int32{1024, 32}, "int8", false}, + + // nvfp4: group_size=16 + {"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true}, + {"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false}, + {"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true}, + + // mxfp4/mxfp8: group_size=32 + {"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true}, + {"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false}, + {"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true}, + {"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false}, + + // Edge cases + {"empty shape", []int32{}, "int4", false}, + {"1D tensor", []int32{4096}, "int4", true}, + {"3D stacked expert", []int32{128, 4096, 2816}, "int4", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isAligned(tt.shape, tt.quantType) + if got != tt.want { + t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want) + } + }) + } +} + +func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) { + aligned := []int32{4096, 4096} // divisible by 64 + + tests := []struct { + name string + tensor string + shape []int32 + quantize string + want string + }{ + // int4 → int8 promotion for sensitive tensors + {"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"}, + {"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"}, + {"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"}, + + // Non-sensitive int4 tensors stay int4 + {"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"}, + {"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"}, + {"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"}, + {"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"}, + + // nvfp4/mxfp4/mxfp8: no promotion (uniform quantization) + {"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"}, + {"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"}, + {"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"}, + + // int8: already 8-bit, no promotion + {"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"}, + + // Expert tensors: down_proj also promoted for int4 + {"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"}, + {"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"}, + + // Unaligned: falls back to bf16 (empty string) + {"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize) + if got != tt.want { + t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q", + tt.tensor, tt.shape, tt.quantize, got, tt.want) + } + }) + } +} + func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) { dir := t.TempDir() diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 5ec3fc850..a03488158 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -90,3 +90,10 @@ func AsyncEval(outputs ...*Array) { func Eval(outputs ...*Array) { doEval(outputs, false) } + +// MetalIsAvailable returns true if a Metal GPU is available. +func MetalIsAvailable() bool { + var available C._Bool + C.mlx_metal_is_available(&available) + return bool(available) +} diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go index 1c05ee6a8..43f6426fb 100644 --- a/x/mlxrunner/model/root.go +++ b/x/mlxrunner/model/root.go @@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header) globalQuantType = strings.ToUpper(globalQuantType) + // Parse full metadata for per-tensor quant info + var metaMap map[string]string + if metaRaw, ok := header["__metadata__"]; ok { + json.Unmarshal(metaRaw, &metaMap) + } + mainNames := mainTensorNames(header) infos := make(map[string]*TensorQuantInfo) for _, name := range mainNames { @@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, quantType := globalQuantType groupSize := globalGroupSize + // Check per-tensor metadata (e.g. from packed expert blobs with mixed precision) + if metaMap != nil { + if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" { + quantType = strings.ToUpper(qt) + } + if gs, ok := metaMap[name+".group_size"]; ok && gs != "" { + if v, err := strconv.Atoi(gs); err == nil { + groupSize = v + } + } + } + inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType) if quantType == "" { quantType = inferredType