mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4 and also importing fp8 and converting directly to mxfp8.
This commit is contained in:
@@ -109,7 +109,7 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
|
||||
if !QuantizeSupported() {
|
||||
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
blobData, err := quantizePackedGroup(tensors)
|
||||
blobData, err := quantizePackedGroup(groupName, tensors)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -7,29 +7,27 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
)
|
||||
|
||||
// quantizeParams maps quantization type names to MLX quantize parameters.
|
||||
var quantizeParams = map[string]struct {
|
||||
groupSize int
|
||||
bits int
|
||||
mode string
|
||||
}{
|
||||
"int4": {64, 4, "affine"},
|
||||
"nvfp4": {16, 4, "nvfp4"},
|
||||
"int8": {64, 8, "affine"},
|
||||
"mxfp8": {32, 8, "mxfp8"},
|
||||
}
|
||||
|
||||
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
||||
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
||||
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
||||
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
||||
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
||||
if quantize != "" {
|
||||
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
|
||||
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
||||
@@ -50,11 +48,16 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
}
|
||||
|
||||
// Find the tensor key (may differ from name for single-tensor blobs)
|
||||
inputKey, err := findSafetensorsKey(tmpPath)
|
||||
header, err := readSafetensorsHeader(tmpPath)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
||||
}
|
||||
inputKey, err := safetensorsKey(name, header)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
|
||||
}
|
||||
|
||||
arr := st.Get(inputKey)
|
||||
if arr == nil {
|
||||
@@ -62,34 +65,46 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
||||
}
|
||||
|
||||
// Decode FP8 source encoding before checking quantize, so that callers
|
||||
// requesting decode-only (quantize="") receive usable float data.
|
||||
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
|
||||
scaleKey := inputKey + ".scale_inv"
|
||||
scaleInv := st.Get(scaleKey)
|
||||
if scaleInv == nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
|
||||
}
|
||||
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
|
||||
}
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
if quantize == "" {
|
||||
arr = mlx.Contiguous(arr)
|
||||
arr = mlx.Contiguous(arr, false)
|
||||
arrays[name] = arr
|
||||
return tmpPath, []*mlx.Array{arr}, st, nil
|
||||
}
|
||||
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
arr = arr.AsType(mlx.DTypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
params, ok := quantizeParams[quantize]
|
||||
if !ok {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
arrays[name] = qweight
|
||||
arrays[name+".scale"] = scales
|
||||
toEval = append(toEval, qweight, scales)
|
||||
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
qbiases = mlx.Contiguous(qbiases, false)
|
||||
arrays[name+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
@@ -101,27 +116,45 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
||||
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
||||
// The blob includes __metadata__ with quant_type and group_size.
|
||||
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
|
||||
// Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
arrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
||||
if tmpPath != "" {
|
||||
defer os.Remove(tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
defer st.Free()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
finalArrays := make([]*mlx.Array, 0, len(arrays))
|
||||
for _, arr := range arrays {
|
||||
if arr != nil {
|
||||
finalArrays = append(finalArrays, arr)
|
||||
}
|
||||
}
|
||||
mlx.Pin(finalArrays...)
|
||||
defer func() {
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
mlx.Unpin(finalArrays...)
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
mlx.Eval(toEval...)
|
||||
mlx.Sweep()
|
||||
// Free early to release mmap; defer guard handles error paths
|
||||
if st != nil {
|
||||
st.Free()
|
||||
st = nil
|
||||
}
|
||||
|
||||
// Build metadata for single-tensor blobs
|
||||
params := quantizeParams[quantize]
|
||||
groupSize, _, _ := model.QuantizationParams(quantize)
|
||||
metadata := map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(params.groupSize),
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
@@ -135,48 +168,81 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
||||
|
||||
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
||||
// combined safetensors blob. Used for packing expert groups.
|
||||
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
|
||||
// they are stacked into 3D switch_mlp tensors before quantization.
|
||||
// Each tensor may have a different quantization type (mixed-precision).
|
||||
// Returns the blob bytes. No __metadata__ is added because different tensors
|
||||
// may use different quantization types.
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Returns the blob bytes.
|
||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||
}
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var allToEval []*mlx.Array
|
||||
var tmpPaths []string
|
||||
var handles []*mlx.SafetensorsFile
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
uniformQuantize := ""
|
||||
hasQuantized := false
|
||||
mixedQuantize := false
|
||||
for _, input := range inputs {
|
||||
if input.Quantize == "" {
|
||||
if hasQuantized {
|
||||
mixedQuantize = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !hasQuantized {
|
||||
hasQuantized = true
|
||||
uniformQuantize = input.Quantize
|
||||
continue
|
||||
}
|
||||
if input.Quantize != uniformQuantize {
|
||||
mixedQuantize = true
|
||||
}
|
||||
}
|
||||
if hasQuantized && !mixedQuantize {
|
||||
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
|
||||
metadata = map[string]string{
|
||||
"quant_type": uniformQuantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
||||
if tmpPath != "" {
|
||||
tmpPaths = append(tmpPaths, tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
handles = append(handles, st)
|
||||
}
|
||||
if err != nil {
|
||||
// Cleanup on error
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
return nil, err
|
||||
}
|
||||
allToEval = append(allToEval, toEval...)
|
||||
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
finalArrays := arraysForPackedInput(allArrays, input)
|
||||
mlx.Pin(finalArrays...)
|
||||
pinned = append(pinned, finalArrays...)
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
if tmpPath != "" {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
defer func() {
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
mlx.Eval(allToEval...)
|
||||
|
||||
// Free native handles after eval
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
|
||||
// Save combined blob (no global metadata for mixed-precision packed blobs)
|
||||
// Save combined blob. Add global metadata only when every packed tensor uses
|
||||
// the same quantization mode and group size.
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
||||
}
|
||||
|
||||
@@ -185,17 +251,193 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
||||
}
|
||||
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
|
||||
keys := []string{input.Name}
|
||||
if input.Quantize != "" {
|
||||
keys = append(keys, input.Name+".scale", input.Name+".bias")
|
||||
}
|
||||
|
||||
out := make([]*mlx.Array, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if arr := allArrays[key]; arr != nil {
|
||||
out = append(out, arr)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
|
||||
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
|
||||
|
||||
type expertTensorInfo struct {
|
||||
index int
|
||||
proj string // e.g., "gate_proj.weight"
|
||||
input create.PackedTensorInput
|
||||
}
|
||||
|
||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||
// and returns the uniform quantization type shared by all inputs.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||
// or if the inputs have mixed quantization types.
|
||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||
if !strings.HasSuffix(groupName, ".experts") {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
quantize := inputs[0].Quantize
|
||||
groups := make(map[string][]expertTensorInfo)
|
||||
for _, input := range inputs {
|
||||
if input.Quantize != quantize {
|
||||
return nil, "" // mixed quantization types
|
||||
}
|
||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||
if m == nil {
|
||||
return nil, "" // not a per-expert pattern
|
||||
}
|
||||
index, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||
index: index,
|
||||
proj: m[2],
|
||||
input: input,
|
||||
})
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
return groups, quantize
|
||||
}
|
||||
|
||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||
metadata = map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Sort projection names for deterministic output
|
||||
projNames := make([]string, 0, len(projGroups))
|
||||
for proj := range projGroups {
|
||||
projNames = append(projNames, proj)
|
||||
}
|
||||
sort.Strings(projNames)
|
||||
|
||||
cleanup := func() {
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
for _, proj := range projNames {
|
||||
experts := projGroups[proj]
|
||||
|
||||
// Sort by expert index
|
||||
sort.Slice(experts, func(i, j int) bool {
|
||||
return experts[i].index < experts[j].index
|
||||
})
|
||||
|
||||
// Load and decode each expert tensor
|
||||
var decoded []*mlx.Array
|
||||
for _, expert := range experts {
|
||||
dummyArrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
arr := dummyArrays[expert.input.Name]
|
||||
mlx.Pin(arr)
|
||||
pinned = append(pinned, arr)
|
||||
decoded = append(decoded, arr)
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
if tmpPath != "" {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
// Stack into 3D along axis 0: [numExperts, rows, cols]
|
||||
stacked := mlx.Stack(decoded, 0)
|
||||
mlx.Eval(stacked)
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
|
||||
// Free individual decoded arrays
|
||||
mlx.Unpin(decoded...)
|
||||
mlx.Sweep()
|
||||
|
||||
stackedName := groupBase + ".switch_mlp." + proj
|
||||
|
||||
// Quantize the stacked tensor
|
||||
if quantize != "" {
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
allArrays[stackedName] = qweight
|
||||
allArrays[stackedName+".scale"] = scales
|
||||
|
||||
toEval := []*mlx.Array{qweight, scales}
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases, false)
|
||||
allArrays[stackedName+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
mlx.Pin(toEval...)
|
||||
pinned = append(pinned, toEval...)
|
||||
|
||||
// Free stacked source array
|
||||
mlx.Unpin(stacked)
|
||||
mlx.Sweep()
|
||||
} else {
|
||||
stacked = mlx.Contiguous(stacked, false)
|
||||
mlx.Eval(stacked)
|
||||
allArrays[stackedName] = stacked
|
||||
}
|
||||
}
|
||||
|
||||
defer cleanup()
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
|
||||
}
|
||||
|
||||
blobData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
|
||||
}
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||
func QuantizeSupported() bool {
|
||||
mlx.InitMLX()
|
||||
return mlx.IsMLXAvailable()
|
||||
return mlx.CheckInit() == nil
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
@@ -205,32 +447,97 @@ func ensureTempDir() string {
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
|
||||
func findSafetensorsKey(path string) (string, error) {
|
||||
type safetensorsHeaderEntry struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
|
||||
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
var header map[string]safetensorsHeaderEntry
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return header, nil
|
||||
}
|
||||
|
||||
for k := range header {
|
||||
if k != "__metadata__" {
|
||||
return k, nil
|
||||
// safetensorsKey resolves the primary tensor key from a header.
|
||||
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
|
||||
if preferred != "" {
|
||||
if _, ok := header[preferred]; ok {
|
||||
return preferred, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
|
||||
keys := make([]string, 0, len(header))
|
||||
for k := range header {
|
||||
if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(keys) == 0 {
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
}
|
||||
return keys[0], nil
|
||||
}
|
||||
|
||||
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
if weight == nil || scaleInv == nil {
|
||||
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scaleInv.Dims()
|
||||
if len(weightShape) != 2 || len(scaleShape) != 2 {
|
||||
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
|
||||
}
|
||||
|
||||
// These must match the block size validated by resolveEffectiveQuantization
|
||||
// in create.go, which rejects any source model with a different block size.
|
||||
const blockRows = 128
|
||||
const blockCols = 128
|
||||
rows, cols := weightShape[0], weightShape[1]
|
||||
expectedScaleRows := (rows + blockRows - 1) / blockRows
|
||||
expectedScaleCols := (cols + blockCols - 1) / blockCols
|
||||
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
|
||||
scaleShape,
|
||||
weightShape,
|
||||
expectedScaleRows,
|
||||
expectedScaleCols,
|
||||
)
|
||||
}
|
||||
|
||||
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
|
||||
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
@@ -267,13 +267,13 @@ func ShouldQuantize(name, component string) bool {
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
|
||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "mxfp4", "int8", "mxfp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
return GetTensorQuantization(name, shape, quantize) != ""
|
||||
}
|
||||
|
||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp4/MXFP4, mxfp8/MXFP8
|
||||
func normalizeQuantType(quantize string) string {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "Q4", "INT4", "FP4":
|
||||
@@ -282,6 +282,8 @@ func normalizeQuantType(quantize string) string {
|
||||
return "int8"
|
||||
case "NVFP4":
|
||||
return "nvfp4"
|
||||
case "MXFP4":
|
||||
return "mxfp4"
|
||||
case "MXFP8":
|
||||
return "mxfp8"
|
||||
default:
|
||||
@@ -335,7 +337,7 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, mxfp8: 32, int4/int8: 64
|
||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
@@ -353,8 +355,8 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
|
||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
@@ -391,23 +393,39 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
||||
// Captures: model.layers.{L}.mlp.experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
||||
var expertLayerPrefixRegexp = regexp.MustCompile(`^(?:model\.language_model\.|language_model(?:\.model)?\.|model\.)?layers\.\d+$`)
|
||||
|
||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||
// For example:
|
||||
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
|
||||
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
|
||||
// - "language_model.model.layers.1.mlp.switch_mlp.down_proj.weight" -> "language_model.model.layers.1.mlp.switch_mlp"
|
||||
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
|
||||
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
|
||||
func ExpertGroupPrefix(tensorName string) string {
|
||||
m := expertGroupRegexp.FindStringSubmatch(tensorName)
|
||||
if m == nil {
|
||||
if !strings.HasSuffix(tensorName, ".weight") {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
|
||||
for _, marker := range []string{
|
||||
".mlp.experts.",
|
||||
".mlp.shared_experts.",
|
||||
".mlp.switch_mlp.",
|
||||
} {
|
||||
idx := strings.Index(tensorName, marker)
|
||||
if idx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
layerPrefix := tensorName[:idx]
|
||||
if !expertLayerPrefixRegexp.MatchString(layerPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
return layerPrefix + strings.TrimSuffix(marker, ".")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
|
||||
@@ -424,9 +442,11 @@ type PackedTensorInput struct {
|
||||
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
|
||||
|
||||
type sourceQuantization struct {
|
||||
Bits int `json:"bits"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Mode string `json:"mode"`
|
||||
Bits int `json:"bits"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Mode string `json:"mode"`
|
||||
QuantMethod string `json:"quant_method"`
|
||||
WeightBlockSize []int32 `json:"weight_block_size"`
|
||||
}
|
||||
|
||||
type sourceModelConfig struct {
|
||||
@@ -493,6 +513,98 @@ func (cfg sourceModelConfig) QuantMetadata() map[string]string {
|
||||
return metadata
|
||||
}
|
||||
|
||||
type sourceQuantizedKind string
|
||||
|
||||
const (
|
||||
sourceQuantizedKindNone sourceQuantizedKind = ""
|
||||
sourceQuantizedKindPrequantized sourceQuantizedKind = "prequantized"
|
||||
sourceQuantizedKindHFFP8 sourceQuantizedKind = "hf_fp8"
|
||||
)
|
||||
|
||||
func (cfg sourceModelConfig) quantizationConfigs() []sourceQuantization {
|
||||
return []sourceQuantization{
|
||||
cfg.Quantization,
|
||||
cfg.QuantizationConfig,
|
||||
cfg.TextConfig.Quantization,
|
||||
cfg.TextConfig.QuantizationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg sourceModelConfig) HFFP8WeightBlockSize() (rows, cols int32, ok bool) {
|
||||
for _, q := range cfg.quantizationConfigs() {
|
||||
if !strings.EqualFold(q.QuantMethod, "fp8") || len(q.WeightBlockSize) != 2 {
|
||||
continue
|
||||
}
|
||||
return q.WeightBlockSize[0], q.WeightBlockSize[1], true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func inspectSourceQuantization(modelDir string, cfg sourceModelConfig) (sourceQuantizedKind, error) {
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return sourceQuantizedKindNone, err
|
||||
}
|
||||
|
||||
hasScaleInv := false
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
|
||||
if err != nil {
|
||||
return sourceQuantizedKindNone, err
|
||||
}
|
||||
|
||||
for _, name := range extractor.ListTensors() {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
extractor.Close()
|
||||
return sourceQuantizedKindPrequantized, nil
|
||||
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||
hasScaleInv = true
|
||||
}
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
if hasScaleInv {
|
||||
if _, _, ok := cfg.HFFP8WeightBlockSize(); ok {
|
||||
return sourceQuantizedKindHFFP8, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sourceQuantizedKindNone, nil
|
||||
}
|
||||
|
||||
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
|
||||
switch sourceKind {
|
||||
case sourceQuantizedKindNone:
|
||||
return requested, nil
|
||||
case sourceQuantizedKindPrequantized:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
|
||||
}
|
||||
return "", nil
|
||||
case sourceQuantizedKindHFFP8:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized fp8 source model with --quantize %q", requested)
|
||||
}
|
||||
rows, cols, ok := cfg.HFFP8WeightBlockSize()
|
||||
if !ok {
|
||||
return "", fmt.Errorf("fp8 source model missing weight_block_size metadata")
|
||||
}
|
||||
if rows != 128 || cols != 128 {
|
||||
return "", fmt.Errorf("unsupported fp8 source block size %dx%d", rows, cols)
|
||||
}
|
||||
return "mxfp8", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported source quantization kind %q", sourceKind)
|
||||
}
|
||||
}
|
||||
|
||||
type tensorImportTransform interface {
|
||||
skipTensor(name string) bool
|
||||
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
|
||||
@@ -546,6 +658,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read source config.json: %w", err)
|
||||
}
|
||||
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to inspect source quantization: %w", err)
|
||||
}
|
||||
effectiveQuantize, err := resolveEffectiveQuantization(sourceConfig, sourceQuantKind, quantize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sourceQuantMetadata := sourceConfig.QuantMetadata()
|
||||
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
@@ -557,7 +677,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if len(createPackedLayer) > 0 {
|
||||
packedCreator = createPackedLayer[0]
|
||||
}
|
||||
|
||||
// Accumulate expert tensors by group prefix for packing.
|
||||
// Readers reference file-backed SectionReaders, so we keep extractors
|
||||
// open until each group is flushed to avoid buffering tensor data in memory.
|
||||
@@ -600,8 +719,8 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
tensorSet[name] = struct{}{}
|
||||
}
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
if effectiveQuantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", effectiveQuantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
@@ -612,9 +731,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if importTransform.skipTensor(tensorName) {
|
||||
continue
|
||||
}
|
||||
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) {
|
||||
if shouldSkipSourceCompanion(tensorName, tensorSet) {
|
||||
continue
|
||||
}
|
||||
sourceFP8ScaleName, hasSourceFP8Scale := sourceFP8Companion(tensorName, tensorSet)
|
||||
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
@@ -623,7 +743,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
if quantize == "" {
|
||||
if effectiveQuantize == "" {
|
||||
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
@@ -647,8 +767,33 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
||||
quantizeType := ""
|
||||
if quantize != "" {
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize)
|
||||
switch {
|
||||
case sourceQuantKind == sourceQuantizedKindHFFP8 && hasSourceFP8Scale:
|
||||
quantizeType = "mxfp8"
|
||||
case sourceQuantKind == sourceQuantizedKindHFFP8:
|
||||
quantizeType = ""
|
||||
case effectiveQuantize != "":
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||
}
|
||||
reader := outTD.SafetensorsReader()
|
||||
if hasSourceFP8Scale {
|
||||
if len(outputTensors) != 1 {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("source fp8 tensor %s rewrote into %d tensors; only 1:1 rewrites are supported", tensorName, len(outputTensors))
|
||||
}
|
||||
if quantizeType == "" {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("source fp8 tensor %s was not scheduled for mxfp8 conversion", tensorName)
|
||||
}
|
||||
scaleTD, err := extractor.GetTensor(sourceFP8ScaleName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to get fp8 scale tensor %s: %w", sourceFP8ScaleName, err)
|
||||
}
|
||||
reader = buildSourceFP8Reader(outTD, scaleTD.WithName(outTD.Name+".scale_inv"))
|
||||
}
|
||||
|
||||
// Check if this tensor belongs to an expert group for packing
|
||||
@@ -670,13 +815,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
Dtype: outTD.Dtype,
|
||||
Shape: outTD.Shape,
|
||||
Quantize: quantizeType,
|
||||
Reader: outTD.SafetensorsReader(),
|
||||
Reader: reader,
|
||||
})
|
||||
} else {
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
newLayers, err := createTensorLayer(reader, outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
@@ -760,7 +905,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool {
|
||||
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
|
||||
@@ -768,11 +913,28 @@ func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{})
|
||||
case strings.HasSuffix(name, ".biases"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
|
||||
return ok
|
||||
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, "_scale_inv")]
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sourceFP8Companion(weightName string, tensorSet map[string]struct{}) (scaleName string, ok bool) {
|
||||
if !strings.HasSuffix(weightName, ".weight") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
scaleName = weightName + "_scale_inv"
|
||||
_, ok = tensorSet[scaleName]
|
||||
return scaleName, ok
|
||||
}
|
||||
|
||||
func buildSourceFP8Reader(weightTD, scaleTD *safetensors.TensorData) io.Reader {
|
||||
return safetensors.BuildPackedSafetensorsReader([]*safetensors.TensorData{weightTD, scaleTD})
|
||||
}
|
||||
|
||||
func createPrequantizedLayer(
|
||||
extractor *safetensors.TensorExtractor,
|
||||
td *safetensors.TensorData,
|
||||
|
||||
@@ -246,6 +246,30 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
|
||||
t.Helper()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -546,6 +570,215 @@ func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
headerNamesByName := make(map[string][]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
|
||||
}
|
||||
|
||||
if got := quantizeByName["norm.weight"]; got != "" {
|
||||
t.Fatalf("norm.weight quantization = %q, want empty", got)
|
||||
}
|
||||
if got := quantizeByName["dense.weight"]; got != "" {
|
||||
t.Fatalf("dense.weight quantization = %q, want empty", got)
|
||||
}
|
||||
|
||||
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
|
||||
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
|
||||
}
|
||||
|
||||
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
|
||||
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
|
||||
}
|
||||
|
||||
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
|
||||
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
|
||||
}
|
||||
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
|
||||
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
tensors []*st.TensorData
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "prequantized affine",
|
||||
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
|
||||
tensors: []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
|
||||
},
|
||||
{
|
||||
name: "hf fp8 source",
|
||||
configJSON: `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`,
|
||||
tensors: []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create 2 experts so stacking produces a [2, 128, 128] tensor
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
})
|
||||
|
||||
var packedLayerNames []string
|
||||
var packedLayerTensors [][]PackedTensorInput
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
packedLayerNames = append(packedLayerNames, groupName)
|
||||
packedLayerTensors = append(packedLayerTensors, tensors)
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(packedLayerNames) != 1 {
|
||||
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
|
||||
}
|
||||
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
|
||||
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
|
||||
}
|
||||
|
||||
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
|
||||
tensors := packedLayerTensors[0]
|
||||
if len(tensors) != 6 {
|
||||
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
|
||||
}
|
||||
|
||||
// All should be marked for mxfp8 quantization
|
||||
for _, tensor := range tensors {
|
||||
if tensor.Quantize != "mxfp8" {
|
||||
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -693,6 +926,113 @@ func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
|
||||
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
|
||||
t.Run(quantize, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"text_config": {"dtype": "bfloat16"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
gateUpValues := make([]float32, 2*128*64)
|
||||
for expert := range 2 {
|
||||
base := expert * 128 * 64
|
||||
for i := range 64 * 64 {
|
||||
gateUpValues[base+i] = 1
|
||||
gateUpValues[base+64*64+i] = 2
|
||||
}
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||
})
|
||||
|
||||
type tensorCall struct {
|
||||
quantize string
|
||||
}
|
||||
type packedTensorCall struct {
|
||||
Name string
|
||||
Quantize string
|
||||
}
|
||||
|
||||
tensorCalls := make(map[string]tensorCall)
|
||||
packedCalls := make(map[string][]packedTensorCall)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
tensorCalls[name] = tensorCall{quantize: quantizeType}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
group := make([]packedTensorCall, 0, len(tensors))
|
||||
for _, tensor := range tensors {
|
||||
group = append(group, packedTensorCall{
|
||||
Name: tensor.Name,
|
||||
Quantize: tensor.Quantize,
|
||||
})
|
||||
}
|
||||
packedCalls[groupName] = group
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
for _, name := range []string{
|
||||
"language_model.model.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
|
||||
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
|
||||
"language_model.model.layers.0.mlp.gate.weight",
|
||||
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
|
||||
} {
|
||||
if got := tensorCalls[name].quantize; got != "" {
|
||||
t.Fatalf("%s quantize = %q, want empty", name, got)
|
||||
}
|
||||
}
|
||||
|
||||
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
|
||||
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
|
||||
}
|
||||
|
||||
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
|
||||
if len(group) != 3 {
|
||||
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
|
||||
}
|
||||
for _, tensor := range group {
|
||||
if tensor.Quantize != quantize {
|
||||
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -865,6 +1205,7 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||
@@ -891,9 +1232,11 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Group size divisibility tests
|
||||
// FP8/FP4 require divisible by 32
|
||||
// FP8/FP4/MXFP4 require divisible by 32
|
||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
|
||||
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
|
||||
// NVFP4 requires divisible by 16
|
||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||
@@ -919,10 +1262,20 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// Expert tensors with language_model prefix should also match
|
||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||
|
||||
// Shared expert tensors should return their own group prefix
|
||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||
|
||||
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
|
||||
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
|
||||
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
|
||||
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
|
||||
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
|
||||
|
||||
// Non-expert tensors should return empty string
|
||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
||||
@@ -978,6 +1331,161 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
if combinedDown != "int8" {
|
||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||
}
|
||||
|
||||
nvfp4GateUp := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
[]int32{64, 11008, 4096},
|
||||
"nvfp4",
|
||||
)
|
||||
if nvfp4GateUp != "nvfp4" {
|
||||
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
|
||||
}
|
||||
|
||||
nvfp4Down := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
[]int32{64, 4096, 11008},
|
||||
"nvfp4",
|
||||
)
|
||||
if nvfp4Down != "nvfp4" {
|
||||
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
|
||||
}
|
||||
|
||||
mxfp4GateUp := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
[]int32{64, 11008, 4096},
|
||||
"mxfp4",
|
||||
)
|
||||
if mxfp4GateUp != "mxfp4" {
|
||||
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
|
||||
}
|
||||
|
||||
mxfp4Down := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
[]int32{64, 4096, 11008},
|
||||
"mxfp4",
|
||||
)
|
||||
if mxfp4Down != "mxfp4" {
|
||||
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"text_config": {"dtype": "bfloat16"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
gateUpValues := make([]float32, 2*128*64)
|
||||
for expert := range 2 {
|
||||
base := expert * 128 * 64
|
||||
for i := range 64 * 64 {
|
||||
gateUpValues[base+i] = 1
|
||||
gateUpValues[base+64*64+i] = 2
|
||||
}
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||
})
|
||||
|
||||
type tensorCall struct {
|
||||
quantize string
|
||||
}
|
||||
type packedTensorCall struct {
|
||||
Name string
|
||||
Dtype string
|
||||
Shape []int32
|
||||
Quantize string
|
||||
}
|
||||
|
||||
tensorCalls := make(map[string]tensorCall)
|
||||
packedCalls := make(map[string][]packedTensorCall)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
tensorCalls[name] = tensorCall{quantize: quantize}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
group := make([]packedTensorCall, 0, len(tensors))
|
||||
for _, tensor := range tensors {
|
||||
group = append(group, packedTensorCall{
|
||||
Name: tensor.Name,
|
||||
Dtype: tensor.Dtype,
|
||||
Shape: append([]int32(nil), tensor.Shape...),
|
||||
Quantize: tensor.Quantize,
|
||||
})
|
||||
}
|
||||
packedCalls[groupName] = group
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
groupName := "language_model.model.layers.0.mlp.switch_mlp"
|
||||
group, ok := packedCalls[groupName]
|
||||
if !ok {
|
||||
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
|
||||
}
|
||||
|
||||
if len(group) != 3 {
|
||||
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
|
||||
}
|
||||
|
||||
gotNames := make([]string, 0, len(group))
|
||||
for _, tensor := range group {
|
||||
gotNames = append(gotNames, tensor.Name)
|
||||
if tensor.Quantize != "nvfp4" {
|
||||
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
|
||||
}
|
||||
if tensor.Dtype != "BF16" {
|
||||
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
|
||||
}
|
||||
}
|
||||
slices.Sort(gotNames)
|
||||
|
||||
wantNames := []string{
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
|
||||
}
|
||||
if !slices.Equal(gotNames, wantNames) {
|
||||
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
|
||||
}
|
||||
|
||||
for _, name := range wantNames {
|
||||
if _, ok := tensorCalls[name]; ok {
|
||||
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
|
||||
}
|
||||
}
|
||||
|
||||
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
|
||||
t.Fatalf("embed_tokens quantize = %q, want empty", got)
|
||||
}
|
||||
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
|
||||
t.Fatalf("mlp.gate quantize = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
|
||||
@@ -87,6 +87,27 @@ func (t qwen35ImportTransform) skipTensor(name string) bool {
|
||||
return strings.Contains(name, "mtp.")
|
||||
}
|
||||
|
||||
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, "embed_tokens.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, "lm_head.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
if strings.HasPrefix(name, "vision_tower.") {
|
||||
return ""
|
||||
@@ -127,6 +148,13 @@ func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quan
|
||||
return ""
|
||||
}
|
||||
|
||||
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
|
||||
// keep embeddings, LM head, low-rank linear_attn projections, and routing
|
||||
// gates in BF16 rather than forcing them into a non-affine quantized format.
|
||||
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,91 @@ package mlx
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SafetensorsFile represents a loaded safetensors file.
|
||||
type SafetensorsFile struct {
|
||||
arrays C.mlx_map_string_to_array
|
||||
metadata C.mlx_map_string_to_string
|
||||
}
|
||||
|
||||
func loadSafetensorsStream() C.mlx_stream {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return C.mlx_default_cpu_stream_new()
|
||||
}
|
||||
return C.mlx_default_gpu_stream_new()
|
||||
}
|
||||
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
||||
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||
var arrays C.mlx_map_string_to_array
|
||||
var metadata C.mlx_map_string_to_string
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
stream := loadSafetensorsStream()
|
||||
defer C.mlx_stream_free(stream)
|
||||
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||
}
|
||||
|
||||
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a tensor by name.
|
||||
func (s *SafetensorsFile) Get(name string) *Array {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
||||
return nil
|
||||
}
|
||||
if value.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
return arr
|
||||
}
|
||||
|
||||
// GetMetadata retrieves a metadata value by key.
|
||||
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||
cKey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cKey))
|
||||
|
||||
var cValue *C.char
|
||||
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(cValue)
|
||||
}
|
||||
|
||||
// Free releases the loaded safetensors maps.
|
||||
func (s *SafetensorsFile) Free() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
C.mlx_map_string_to_array_free(s.arrays)
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu).
|
||||
// macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
|
||||
var stream C.mlx_stream
|
||||
if runtime.GOOS == "darwin" {
|
||||
stream = C.mlx_default_cpu_stream_new()
|
||||
} else {
|
||||
stream = C.mlx_default_gpu_stream_new()
|
||||
sf, err := LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer C.mlx_stream_free(stream)
|
||||
defer sf.Free()
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, stream)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
@@ -51,3 +107,43 @@ func Load(path string) iter.Seq2[string, *Array] {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
||||
}
|
||||
|
||||
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
||||
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
for name, arr := range arrays {
|
||||
if arr == nil {
|
||||
continue
|
||||
}
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
cMetadata := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||
|
||||
for key, value := range metadata {
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||
C.free(unsafe.Pointer(cKey))
|
||||
C.free(unsafe.Pointer(cValue))
|
||||
}
|
||||
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,18 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
return w0, w1, nil
|
||||
}
|
||||
|
||||
func FromFP8(x *Array, dtype DType) *Array {
|
||||
out := New("FROM_FP8")
|
||||
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFP8(x *Array) *Array {
|
||||
out := New("TO_FP8")
|
||||
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
@@ -137,6 +149,40 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Pad(a *Array, paddings []int32) *Array {
|
||||
numAxes := len(paddings) / 2
|
||||
axes := make([]C.int, numAxes)
|
||||
lowPad := make([]C.int, numAxes)
|
||||
highPad := make([]C.int, numAxes)
|
||||
for i := range numAxes {
|
||||
axes[i] = C.int(i)
|
||||
lowPad[i] = C.int(paddings[i*2])
|
||||
highPad[i] = C.int(paddings[i*2+1])
|
||||
}
|
||||
|
||||
padValue := C.mlx_array_new_float(C.float(0))
|
||||
defer C.mlx_array_free(padValue)
|
||||
|
||||
cMode := C.CString("constant")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(axes),
|
||||
C.size_t(len(axes)),
|
||||
unsafe.SliceData(lowPad),
|
||||
C.size_t(len(lowPad)),
|
||||
unsafe.SliceData(highPad),
|
||||
C.size_t(len(highPad)),
|
||||
padValue,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
|
||||
@@ -11,8 +11,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "MXFP4":
|
||||
return 32, 4, "mxfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
return 64, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8":
|
||||
|
||||
@@ -144,3 +144,44 @@ func TestLayerNormDefaultEps(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
weightVals := make([]float32, 3*32)
|
||||
for i := range weightVals {
|
||||
weightVals[i] = float32((i%11)-5) / 7
|
||||
}
|
||||
inputVals := make([]float32, 2*32)
|
||||
for i := range inputVals {
|
||||
inputVals[i] = float32((i%7)-3) / 5
|
||||
}
|
||||
|
||||
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
|
||||
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
|
||||
mlx.Eval(weight, input)
|
||||
|
||||
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
|
||||
if ql.QBiases != nil {
|
||||
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
|
||||
}
|
||||
|
||||
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
||||
mlx.Eval(dequantizedWeight)
|
||||
|
||||
qOut := ql.Forward(input)
|
||||
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
|
||||
mlx.Eval(qOut, dOut)
|
||||
|
||||
got := qOut.Floats()
|
||||
want := dOut.Floats()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("output length = %d, want %d", len(got), len(want))
|
||||
}
|
||||
|
||||
for i := range got {
|
||||
if !approxEqual(got[i], want[i], 1e-3) {
|
||||
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -420,7 +420,16 @@ func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, strin
|
||||
}
|
||||
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
switch mode {
|
||||
case "affine":
|
||||
return bits == 4 || bits == 8
|
||||
case "mxfp8":
|
||||
return bits == 8
|
||||
case "nvfp4", "mxfp4":
|
||||
return bits == 4
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {
|
||||
|
||||
@@ -83,6 +83,28 @@ func TestLayerSelectionHelpers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsGatherQMM(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode string
|
||||
bits int
|
||||
want bool
|
||||
}{
|
||||
{mode: "affine", bits: 4, want: true},
|
||||
{mode: "affine", bits: 8, want: true},
|
||||
{mode: "mxfp8", bits: 8, want: true},
|
||||
{mode: "nvfp4", bits: 4, want: true},
|
||||
{mode: "mxfp4", bits: 4, want: true},
|
||||
{mode: "mxfp8", bits: 4, want: false},
|
||||
{mode: "affine", bits: 3, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
|
||||
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user