package model import ( "strings" "github.com/ollama/ollama/x/mlxrunner/mlx" ) // QuantizationParams returns default groupSize, bits, and mode for a quantization type. func QuantizationParams(quantization string) (groupSize, bits int, mode string) { switch strings.ToUpper(quantization) { case "NVFP4": return 16, 4, "nvfp4" case "MXFP4": return 32, 4, "mxfp4" case "FP4", "Q4", "INT4": return 64, 4, "affine" case "MXFP8": return 32, 8, "mxfp8" case "FP8", "Q8", "INT8": return 64, 8, "affine" case "": return 0, 0, "" default: return 32, 8, "affine" } } // TensorQuantParams resolves quant params for a tensor using per-tensor metadata // when available, otherwise falling back to the provided model defaults. func TensorQuantParams( defaultGroupSize, defaultBits int, defaultMode string, tensorQuant map[string]*TensorQuantInfo, tensorName string, ) (groupSize, bits int, mode string, fromTensor bool) { if tensorQuant != nil { if tq := tensorQuant[tensorName]; tq != nil { groupSize, bits, mode = QuantizationParams(tq.QuantType) if tq.GroupSize > 0 { groupSize = tq.GroupSize } return groupSize, bits, mode, true } } return defaultGroupSize, defaultBits, defaultMode, false } // ResolveLinearQuantParams resolves quantization params for a quantized linear // tensor, preferring per-tensor metadata and falling back to shape-based // inference for affine packed tensors. func ResolveLinearQuantParams( defaultGroupSize, defaultBits int, defaultMode string, tensorQuant map[string]*TensorQuantInfo, tensorName string, weight, scales *mlx.Array, ) (groupSize, bits int, mode string) { groupSize, bits, mode, fromTensor := TensorQuantParams( defaultGroupSize, defaultBits, defaultMode, tensorQuant, tensorName, ) if mode == "affine" { if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok { if !fromTensor || groupSize == 0 || bits == 0 { groupSize = inferredGroupSize bits = inferredBits } } } return groupSize, bits, mode } // InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized // tensors from packed weight and scale shapes. func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) { if weight == nil || scales == nil { return 0, 0, false } weightShape := weight.Dims() scaleShape := scales.Dims() if len(weightShape) == 0 || len(scaleShape) == 0 { return 0, 0, false } weightCols := weightShape[len(weightShape)-1] scalesCols := scaleShape[len(scaleShape)-1] if weightCols <= 0 || scalesCols <= 0 { return 0, 0, false } groupSize4 := weightCols * 8 / scalesCols groupSize8 := weightCols * 4 / scalesCols switch { case groupSize4 == 32: return 32, 4, true case groupSize8 == 64: return 64, 8, true case groupSize4 == 64 && groupSize8 == 32: if hintBits == 8 { return 32, 8, true } if hintBits == 4 { return 64, 4, true } } if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) { return groupSize4, 4, true } if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) { return groupSize8, 8, true } return 0, 0, false } func isCommonGroupSize(v int) bool { switch v { case 16, 32, 64, 128: return true default: return false } }