mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
mlx: mixed-precision quant and capability detection improvements (#15409)
Improve the MLX model creation pipeline with several model-agnostic changes: - Rewrite supportsVision to use vision_config instead of architecture name - Add supportsAudio for audio encoder detection - Add alignment checking (isAligned) for quantization group sizes - Support per-projection mixed quantization in MoE expert packing - Record per-tensor quant metadata in safetensors blobs - Parse per-tensor quant metadata at model load time - Validate quantize output is non-empty before storing - Fix pin/unpin cleanup in expert group quantization - Promote v_proj/k_proj/down_proj to INT8 for INT4 base quant - Add MetalIsAvailable() utility - Skip audio encoder tensors from quantization
This commit is contained in:
@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
|
||||
capabilities = append(capabilities, "vision")
|
||||
}
|
||||
|
||||
if supportsAudio(modelDir) {
|
||||
capabilities = append(capabilities, "audio")
|
||||
}
|
||||
|
||||
if supportsThinking(modelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// supportsVision checks if the model supports image input based on its architecture.
|
||||
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
|
||||
// supportsVision checks if the model has a vision encoder by looking for
|
||||
// vision_config in config.json.
|
||||
func supportsVision(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
VisionConfig *map[string]any `json:"vision_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
|
||||
return true
|
||||
}
|
||||
return cfg.VisionConfig != nil
|
||||
}
|
||||
|
||||
func supportsAudio(modelDir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
|
||||
var cfg struct {
|
||||
AudioConfig *map[string]any `json:"audio_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.AudioConfig != nil
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
|
||||
@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
name: "qwen3.5 multimodal model",
|
||||
configJSON: `{
|
||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||
"model_type": "qwen3"
|
||||
"model_type": "qwen3",
|
||||
"vision_config": {"hidden_size": 1024}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "model with audio config",
|
||||
configJSON: `{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"model_type": "gemma4",
|
||||
"vision_config": {"hidden_size": 1024},
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "audio"},
|
||||
},
|
||||
{
|
||||
name: "model with audio but no vision",
|
||||
configJSON: `{
|
||||
"architectures": ["SomeAudioModel"],
|
||||
"model_type": "other",
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "audio"},
|
||||
},
|
||||
{
|
||||
name: "non-qwen conditional generation model",
|
||||
configJSON: `{
|
||||
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePerExpertInputs(t *testing.T) {
|
||||
makeInput := func(name, quantize string) create.PackedTensorInput {
|
||||
return create.PackedTensorInput{Name: name, Quantize: quantize}
|
||||
}
|
||||
|
||||
t.Run("uniform quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups")
|
||||
}
|
||||
if len(groups) != 2 {
|
||||
t.Fatalf("expected 2 projection groups, got %d", len(groups))
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int4" {
|
||||
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups for mixed cross-projection quant")
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int8" {
|
||||
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for mixed quant within same projection")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-experts group rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.mlp.gate_proj.weight", "int4"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for non-experts group")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
|
||||
@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output. MLX quantize may return
|
||||
// empty arrays for unsupported mode/bits combinations without raising an error.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
arrays[name] = qweight
|
||||
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
||||
// Returns the blob bytes.
|
||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
|
||||
}
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
|
||||
mlx.Pin(finalArrays...)
|
||||
pinned = append(pinned, finalArrays...)
|
||||
|
||||
// Record per-tensor quant type so the model can resolve params at load time.
|
||||
if input.Quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]string)
|
||||
}
|
||||
metadata[input.Name+".quant_type"] = input.Quantize
|
||||
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
|
||||
}
|
||||
|
||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||
// and returns the uniform quantization type shared by all inputs.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||
// or if the inputs have mixed quantization types.
|
||||
// and returns per-projection quantization types. Different projections may use
|
||||
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
|
||||
// a projection must share the same type.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
|
||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
|
||||
if !strings.HasSuffix(groupName, ".experts") {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
quantize := inputs[0].Quantize
|
||||
groups := make(map[string][]expertTensorInfo)
|
||||
projQuantize := make(map[string]string) // projection -> quant type
|
||||
for _, input := range inputs {
|
||||
if input.Quantize != quantize {
|
||||
return nil, "" // mixed quantization types
|
||||
}
|
||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||
if m == nil {
|
||||
return nil, "" // not a per-expert pattern
|
||||
return nil, nil // not a per-expert pattern
|
||||
}
|
||||
index, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||
proj := m[2]
|
||||
if existing, ok := projQuantize[proj]; ok {
|
||||
if input.Quantize != existing {
|
||||
return nil, nil // mixed quant within same projection
|
||||
}
|
||||
} else {
|
||||
projQuantize[proj] = input.Quantize
|
||||
}
|
||||
groups[proj] = append(groups[proj], expertTensorInfo{
|
||||
index: index,
|
||||
proj: m[2],
|
||||
proj: proj,
|
||||
input: input,
|
||||
})
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
return groups, quantize
|
||||
return groups, projQuantize
|
||||
}
|
||||
|
||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||
// projQuantize maps projection name to its quantization type (may differ per projection).
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
|
||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||
metadata = map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
// Build metadata: if all projections use the same quant type, set global metadata.
|
||||
// Otherwise record per-tensor quant info.
|
||||
metadata := make(map[string]string)
|
||||
|
||||
// Sort projection names for deterministic output
|
||||
projNames := make([]string, 0, len(projGroups))
|
||||
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
sort.Strings(projNames)
|
||||
|
||||
cleanup := func() {
|
||||
mlx.Unpin(pinned...)
|
||||
for _, p := range pinned {
|
||||
if p != nil {
|
||||
mlx.Unpin(p)
|
||||
}
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
|
||||
// Free individual decoded arrays
|
||||
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
for _, d := range decoded {
|
||||
if p == d {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
mlx.Unpin(decoded...)
|
||||
mlx.Sweep()
|
||||
|
||||
stackedName := groupBase + ".switch_mlp." + proj
|
||||
quantize := projQuantize[proj]
|
||||
|
||||
// Record per-tensor quant metadata so the model can resolve params at load time.
|
||||
if quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
|
||||
metadata[stackedName+".quant_type"] = quantize
|
||||
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize the stacked tensor
|
||||
if quantize != "" {
|
||||
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
stackedName, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
allArrays[stackedName] = qweight
|
||||
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(toEval...)
|
||||
pinned = append(pinned, toEval...)
|
||||
|
||||
// Free stacked source array
|
||||
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
if p == stacked {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
mlx.Unpin(stacked)
|
||||
mlx.Sweep()
|
||||
} else {
|
||||
stacked = mlx.Contiguous(stacked, false)
|
||||
mlx.Eval(stacked)
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
allArrays[stackedName] = stacked
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip audio encoder tensors (highly sensitive to quantization)
|
||||
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isAligned checks if a tensor's last dimension is divisible by the
|
||||
// group size required for the given quantization type.
|
||||
func isAligned(shape []int32, quantType string) bool {
|
||||
if len(shape) == 0 {
|
||||
return false
|
||||
}
|
||||
groupSize := int32(32)
|
||||
switch normalizeQuantType(quantType) {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
strings.Contains(name, ".mlp.shared_experts.") ||
|
||||
strings.Contains(name, ".moe.experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: int4 (less sensitive)
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
// - All other eligible weights: use requested quantization type
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Normalize quantization type to canonical form
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip routing gate weights (should stay high precision)
|
||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Attention MLA weights - keep unquantized (bf16)
|
||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
||||
if strings.Contains(name, "q_a_proj") ||
|
||||
strings.Contains(name, "q_b_proj") ||
|
||||
strings.Contains(name, "kv_a_proj") ||
|
||||
strings.Contains(name, "kv_b_proj") {
|
||||
return "" // No quantization - keep bf16
|
||||
// Value projection weights directly determine attention output quality.
|
||||
// Down projection weights feed directly into the residual stream where
|
||||
// errors accumulate across layers. Both benefit from higher precision.
|
||||
// Promote to INT8 when base is INT4 (same affine mode, compatible with
|
||||
// GatherQMM for MoE expert tensors).
|
||||
if quantNorm == "int4" {
|
||||
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "int8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (INT4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
strings.Contains(name, "up_proj") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// LM head - use requested quantization
|
||||
if strings.Contains(name, "lm_head") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Default to requested quantization for other weights
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
|
||||
".mlp.experts.",
|
||||
".mlp.shared_experts.",
|
||||
".mlp.switch_mlp.",
|
||||
".moe.experts.",
|
||||
} {
|
||||
idx := strings.Index(tensorName, marker)
|
||||
if idx == -1 {
|
||||
|
||||
@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Audio encoder tensors should not be quantized
|
||||
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
|
||||
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
|
||||
{"embed audio weight", "embed_audio.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// MoE expert tensors (Gemma-style .moe.experts.)
|
||||
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
|
||||
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
|
||||
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
|
||||
|
||||
// Expert tensors with language_model prefix should also match
|
||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAligned(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shape []int32
|
||||
quantType string
|
||||
want bool
|
||||
}{
|
||||
// int4/int8: group_size=64
|
||||
{"int4 aligned", []int32{1024, 4096}, "int4", true},
|
||||
{"int4 unaligned", []int32{1024, 48}, "int4", false},
|
||||
{"int8 aligned", []int32{1024, 128}, "int8", true},
|
||||
{"int8 unaligned", []int32{1024, 32}, "int8", false},
|
||||
|
||||
// nvfp4: group_size=16
|
||||
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
|
||||
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
|
||||
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
|
||||
|
||||
// mxfp4/mxfp8: group_size=32
|
||||
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
|
||||
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
|
||||
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
|
||||
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
|
||||
|
||||
// Edge cases
|
||||
{"empty shape", []int32{}, "int4", false},
|
||||
{"1D tensor", []int32{4096}, "int4", true},
|
||||
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isAligned(tt.shape, tt.quantType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
|
||||
aligned := []int32{4096, 4096} // divisible by 64
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// int4 → int8 promotion for sensitive tensors
|
||||
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
|
||||
// Non-sensitive int4 tensors stay int4
|
||||
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
|
||||
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// int8: already 8-bit, no promotion
|
||||
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
|
||||
// Expert tensors: down_proj also promoted for int4
|
||||
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
|
||||
// Unaligned: falls back to bf16 (empty string)
|
||||
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
@@ -90,3 +90,10 @@ func AsyncEval(outputs ...*Array) {
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||
func MetalIsAvailable() bool {
|
||||
var available C._Bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
|
||||
@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
// Parse full metadata for per-tensor quant info
|
||||
var metaMap map[string]string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
json.Unmarshal(metaRaw, &metaMap)
|
||||
}
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||
if metaMap != nil {
|
||||
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||
if v, err := strconv.Atoi(gs); err == nil {
|
||||
groupSize = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
|
||||
Reference in New Issue
Block a user