Compare commits

...

6 Commits

Author SHA1 Message Date
Patrick Devine
de5cb7311f mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4
and also importing fp8 and converting directly to mxfp8.
2026-03-24 13:45:44 -07:00
Jesse Gross
95ee7fbd29 mlxrunner: panic on double unpin 2026-03-23 17:44:19 -07:00
Jesse Gross
ec55536734 mlxrunner: show time since last used in cache dump tree 2026-03-23 17:44:19 -07:00
Jesse Gross
77491439c2 mlxrunner: support partial match on pure transformer caches
Previously, a partial match within a node's edge would truncate the path
to the parent snapshot - effectively making all cache types behave as
recurrent caches. Caches with only transformer layers can rewind to
arbitrary boundary so this restores this capability to improve cache
hits
2026-03-23 17:44:19 -07:00
Parth Sareen
b166b36cd2 docs: update Claude Code with Telegram guide (#15026) 2026-03-23 16:31:21 -07:00
Daniel Hiltgen
c2b0bb7a52 mlx: update as of 3/23 (#14789)
* mlx: update to HEAD on 3/23

Also fixes a few misc vendoring bugs uncovered with this first update.
This also renames the version files to make them clearer.

* CUDA Fast Gated Delta kernel

* mlx: detect eval errors and panic

On model errors or missing kernels, don't mask the error, bubble it up.
2026-03-23 11:28:44 -07:00
35 changed files with 2007 additions and 295 deletions

View File

@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum . COPY go.mod go.sum .
COPY MLX_VERSION MLX_CORE_VERSION . COPY MLX_VERSION MLX_C_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download

View File

@@ -1 +0,0 @@
v0.30.6

1
MLX_C_VERSION Normal file
View File

@@ -0,0 +1 @@
0726ca922fc902c4c61ef9c27d94132be418e945

View File

@@ -1 +1 @@
v0.5.0 38ad257088fb2193ad47e527cf6534a689f30943

View File

@@ -96,6 +96,18 @@ The `/loop` command runs a prompt or slash command on a recurring schedule insid
/loop 1h Remind me to review the deploy status /loop 1h Remind me to review the deploy status
``` ```
## Telegram
Chat with Claude Code from Telegram by connecting a bot to your session. Install the [Telegram plugin](https://github.com/anthropics/claude-plugins-official), create a bot via [@BotFather](https://t.me/BotFather), then launch with the channel flag:
```shell
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
```
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
## Manual setup ## Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API. Claude Code connects to Ollama using the Anthropic-compatible API.

View File

@@ -109,7 +109,7 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
type CreateOptions struct { type CreateOptions struct {
ModelName string ModelName string
ModelDir string ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
} }
@@ -280,7 +280,7 @@ func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
if !QuantizeSupported() { if !QuantizeSupported() {
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support") return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
} }
blobData, err := quantizePackedGroup(tensors) blobData, err := quantizePackedGroup(groupName, tensors)
if err != nil { if err != nil {
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err) return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
} }

View File

@@ -7,29 +7,27 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort"
"strconv" "strconv"
"strings"
"github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
) )
// quantizeParams maps quantization type names to MLX quantize parameters.
var quantizeParams = map[string]struct {
groupSize int
bits int
mode string
}{
"int4": {64, 4, "affine"},
"nvfp4": {16, 4, "nvfp4"},
"int8": {64, 8, "affine"},
"mxfp8": {32, 8, "mxfp8"},
}
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX, // loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias) // quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
// to the provided maps. If quantize is empty, the tensor is kept as-is. // to the provided maps. If quantize is empty, the tensor is kept as-is.
// Returns any temp file paths created (caller must clean up) and arrays needing eval. // Returns any temp file paths created (caller must clean up) and arrays needing eval.
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) { func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
if quantize != "" {
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
}
tmpDir := ensureTempDir() tmpDir := ensureTempDir()
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors") tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
@@ -50,11 +48,16 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
} }
// Find the tensor key (may differ from name for single-tensor blobs) // Find the tensor key (may differ from name for single-tensor blobs)
inputKey, err := findSafetensorsKey(tmpPath) header, err := readSafetensorsHeader(tmpPath)
if err != nil { if err != nil {
st.Free() st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err) return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
} }
inputKey, err := safetensorsKey(name, header)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
}
arr := st.Get(inputKey) arr := st.Get(inputKey)
if arr == nil { if arr == nil {
@@ -62,34 +65,46 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey) return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
} }
// Decode FP8 source encoding before checking quantize, so that callers
// requesting decode-only (quantize="") receive usable float data.
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
scaleKey := inputKey + ".scale_inv"
scaleInv := st.Get(scaleKey)
if scaleInv == nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
}
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
if err != nil {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
}
mlx.Eval(arr)
}
if quantize == "" { if quantize == "" {
arr = mlx.Contiguous(arr) arr = mlx.Contiguous(arr, false)
arrays[name] = arr arrays[name] = arr
return tmpPath, []*mlx.Array{arr}, st, nil return tmpPath, []*mlx.Array{arr}, st, nil
} }
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
// Convert to float type if needed (quantize expects float) // Convert to float type if needed (quantize expects float)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 { arr = arr.AsType(mlx.DTypeBFloat16)
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr) mlx.Eval(arr)
} }
params, ok := quantizeParams[quantize] groupSize, bits, mode := model.QuantizationParams(quantize)
if !ok { qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
st.Free()
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode) qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
arrays[name] = qweight arrays[name] = qweight
arrays[name+".scale"] = scales arrays[name+".scale"] = scales
toEval = append(toEval, qweight, scales) toEval = append(toEval, qweight, scales)
if qbiases != nil { if qbiases != nil {
qbiases = mlx.Contiguous(qbiases) qbiases = mlx.Contiguous(qbiases, false)
arrays[name+".bias"] = qbiases arrays[name+".bias"] = qbiases
toEval = append(toEval, qbiases) toEval = append(toEval, qbiases)
} }
@@ -101,27 +116,45 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias. // and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
// Tensor keys use the original tensor name: name, name.scale, name.bias. // Tensor keys use the original tensor name: name, name.scale, name.bias.
// The blob includes __metadata__ with quant_type and group_size. // The blob includes __metadata__ with quant_type and group_size.
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8". // Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
arrays := make(map[string]*mlx.Array) arrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays) tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
if tmpPath != "" { if tmpPath != "" {
defer os.Remove(tmpPath) defer os.Remove(tmpPath)
} }
if st != nil {
defer st.Free()
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
finalArrays := make([]*mlx.Array, 0, len(arrays))
for _, arr := range arrays {
if arr != nil {
finalArrays = append(finalArrays, arr)
}
}
mlx.Pin(finalArrays...)
defer func() {
if st != nil {
st.Free()
}
mlx.Unpin(finalArrays...)
mlx.Sweep()
}()
mlx.Eval(toEval...) mlx.Eval(toEval...)
mlx.Sweep()
// Free early to release mmap; defer guard handles error paths
if st != nil {
st.Free()
st = nil
}
// Build metadata for single-tensor blobs // Build metadata for single-tensor blobs
params := quantizeParams[quantize] groupSize, _, _ := model.QuantizationParams(quantize)
metadata := map[string]string{ metadata := map[string]string{
"quant_type": quantize, "quant_type": quantize,
"group_size": strconv.Itoa(params.groupSize), "group_size": strconv.Itoa(groupSize),
} }
tmpDir := ensureTempDir() tmpDir := ensureTempDir()
@@ -135,48 +168,81 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
// quantizePackedGroup quantizes multiple tensors and saves them all into a single // quantizePackedGroup quantizes multiple tensors and saves them all into a single
// combined safetensors blob. Used for packing expert groups. // combined safetensors blob. Used for packing expert groups.
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
// they are stacked into 3D switch_mlp tensors before quantization.
// Each tensor may have a different quantization type (mixed-precision). // Each tensor may have a different quantization type (mixed-precision).
// Returns the blob bytes. No __metadata__ is added because different tensors // Returns the blob bytes.
// may use different quantization types. func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { // Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
}
allArrays := make(map[string]*mlx.Array) allArrays := make(map[string]*mlx.Array)
var allToEval []*mlx.Array var pinned []*mlx.Array
var tmpPaths []string
var handles []*mlx.SafetensorsFile var metadata map[string]string
uniformQuantize := ""
hasQuantized := false
mixedQuantize := false
for _, input := range inputs {
if input.Quantize == "" {
if hasQuantized {
mixedQuantize = true
}
continue
}
if !hasQuantized {
hasQuantized = true
uniformQuantize = input.Quantize
continue
}
if input.Quantize != uniformQuantize {
mixedQuantize = true
}
}
if hasQuantized && !mixedQuantize {
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
metadata = map[string]string{
"quant_type": uniformQuantize,
"group_size": strconv.Itoa(groupSize),
}
}
}
for _, input := range inputs { for _, input := range inputs {
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays) tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
if tmpPath != "" {
tmpPaths = append(tmpPaths, tmpPath)
}
if st != nil {
handles = append(handles, st)
}
if err != nil { if err != nil {
// Cleanup on error mlx.Unpin(pinned...)
for _, h := range handles { mlx.Sweep()
h.Free()
}
for _, p := range tmpPaths {
os.Remove(p)
}
return nil, err return nil, err
} }
allToEval = append(allToEval, toEval...)
mlx.Eval(toEval...)
finalArrays := arraysForPackedInput(allArrays, input)
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
if st != nil {
st.Free()
} }
if tmpPath != "" {
mlx.Eval(allToEval...) os.Remove(tmpPath)
// Free native handles after eval
for _, h := range handles {
h.Free()
} }
mlx.Sweep()
}
defer func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}()
// Save combined blob (no global metadata for mixed-precision packed blobs) // Save combined blob. Add global metadata only when every packed tensor uses
// the same quantization mode and group size.
tmpDir := ensureTempDir() tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "packed-combined.safetensors") outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
defer os.Remove(outPath) defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil { if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save packed blob: %w", err) return nil, fmt.Errorf("failed to save packed blob: %w", err)
} }
@@ -185,17 +251,193 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return nil, fmt.Errorf("failed to read packed blob: %w", err) return nil, fmt.Errorf("failed to read packed blob: %w", err)
} }
for _, p := range tmpPaths { return blobData, nil
os.Remove(p) }
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
keys := []string{input.Name}
if input.Quantize != "" {
keys = append(keys, input.Name+".scale", input.Name+".bias")
} }
out := make([]*mlx.Array, 0, len(keys))
for _, key := range keys {
if arr := allArrays[key]; arr != nil {
out = append(out, arr)
}
}
return out
}
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
type expertTensorInfo struct {
index int
proj string // e.g., "gate_proj.weight"
input create.PackedTensorInput
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
index: index,
proj: m[2],
input: input,
})
}
if len(groups) == 0 {
return nil, ""
}
return groups, quantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
for proj := range projGroups {
projNames = append(projNames, proj)
}
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
mlx.Sweep()
}
for _, proj := range projNames {
experts := projGroups[proj]
// Sort by expert index
sort.Slice(experts, func(i, j int) bool {
return experts[i].index < experts[j].index
})
// Load and decode each expert tensor
var decoded []*mlx.Array
for _, expert := range experts {
dummyArrays := make(map[string]*mlx.Array)
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
if err != nil {
cleanup()
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
}
mlx.Eval(toEval...)
arr := dummyArrays[expert.input.Name]
mlx.Pin(arr)
pinned = append(pinned, arr)
decoded = append(decoded, arr)
if st != nil {
st.Free()
}
if tmpPath != "" {
os.Remove(tmpPath)
}
mlx.Sweep()
}
// Stack into 3D along axis 0: [numExperts, rows, cols]
stacked := mlx.Stack(decoded, 0)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
// Quantize the stacked tensor
if quantize != "" {
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
allArrays[stackedName+".scale"] = scales
toEval := []*mlx.Array{qweight, scales}
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases, false)
allArrays[stackedName+".bias"] = qbiases
toEval = append(toEval, qbiases)
}
mlx.Eval(toEval...)
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
allArrays[stackedName] = stacked
}
}
defer cleanup()
tmpDir := ensureTempDir()
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
defer os.Remove(outPath)
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
}
blobData, err := os.ReadFile(outPath)
if err != nil {
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
}
return blobData, nil return blobData, nil
} }
// QuantizeSupported returns true if quantization is supported (MLX library available) // QuantizeSupported returns true if quantization is supported (MLX library available)
func QuantizeSupported() bool { func QuantizeSupported() bool {
mlx.InitMLX() return mlx.CheckInit() == nil
return mlx.IsMLXAvailable()
} }
// ensureTempDir creates the temp directory for quantization if it doesn't exist // ensureTempDir creates the temp directory for quantization if it doesn't exist
@@ -205,32 +447,97 @@ func ensureTempDir() string {
return tmpDir return tmpDir
} }
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file. type safetensorsHeaderEntry struct {
func findSafetensorsKey(path string) (string, error) { Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
}
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer f.Close() defer f.Close()
var headerSize uint64 var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return "", err return nil, err
} }
headerBytes := make([]byte, headerSize) headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(f, headerBytes); err != nil { if _, err := io.ReadFull(f, headerBytes); err != nil {
return "", err return nil, err
} }
var header map[string]json.RawMessage var header map[string]safetensorsHeaderEntry
if err := json.Unmarshal(headerBytes, &header); err != nil { if err := json.Unmarshal(headerBytes, &header); err != nil {
return "", err return nil, err
}
return header, nil
}
// safetensorsKey resolves the primary tensor key from a header.
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
if preferred != "" {
if _, ok := header[preferred]; ok {
return preferred, nil
}
} }
keys := make([]string, 0, len(header))
for k := range header { for k := range header {
if k != "__metadata__" { if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
return k, nil continue
} }
keys = append(keys, k)
} }
sort.Strings(keys)
if len(keys) == 0 {
return "", fmt.Errorf("no tensor found in safetensors header") return "", fmt.Errorf("no tensor found in safetensors header")
}
return keys[0], nil
}
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
if weight == nil || scaleInv == nil {
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
}
weightShape := weight.Dims()
scaleShape := scaleInv.Dims()
if len(weightShape) != 2 || len(scaleShape) != 2 {
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
}
// These must match the block size validated by resolveEffectiveQuantization
// in create.go, which rejects any source model with a different block size.
const blockRows = 128
const blockCols = 128
rows, cols := weightShape[0], weightShape[1]
expectedScaleRows := (rows + blockRows - 1) / blockRows
expectedScaleCols := (cols + blockCols - 1) / blockCols
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
return nil, fmt.Errorf(
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
scaleShape,
weightShape,
expectedScaleRows,
expectedScaleCols,
)
}
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
if padBottom > 0 || padSide > 0 {
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
}
return decoded, nil
} }

View File

@@ -267,13 +267,13 @@ func ShouldQuantize(name, component string) bool {
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type. // ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions. // This is a more detailed check that also considers tensor dimensions.
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8"). // The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "mxfp4", "int8", "mxfp8").
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool { func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
return GetTensorQuantization(name, shape, quantize) != "" return GetTensorQuantization(name, shape, quantize) != ""
} }
// normalizeQuantType converts various quantization type aliases to canonical forms. // normalizeQuantType converts various quantization type aliases to canonical forms.
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8 // Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp4/MXFP4, mxfp8/MXFP8
func normalizeQuantType(quantize string) string { func normalizeQuantType(quantize string) string {
switch strings.ToUpper(quantize) { switch strings.ToUpper(quantize) {
case "Q4", "INT4", "FP4": case "Q4", "INT4", "FP4":
@@ -282,6 +282,8 @@ func normalizeQuantType(quantize string) string {
return "int8" return "int8"
case "NVFP4": case "NVFP4":
return "nvfp4" return "nvfp4"
case "MXFP4":
return "mxfp4"
case "MXFP8": case "MXFP8":
return "mxfp8" return "mxfp8"
default: default:
@@ -335,7 +337,7 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize) quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size // MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, mxfp8: 32, int4/int8: 64 // nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
groupSize := int32(32) groupSize := int32(32)
switch quantNorm { switch quantNorm {
case "nvfp4": case "nvfp4":
@@ -353,8 +355,8 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return "" return ""
} }
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision) // For non-affine modes, use the same quantization for all eligible tensors.
if quantNorm == "nvfp4" || quantNorm == "mxfp8" { if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
return quantNorm return quantNorm
} }
@@ -391,23 +393,39 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
return quantNorm return quantNorm
} }
// expertGroupRegexp matches expert tensor names and captures the group prefix. var expertLayerPrefixRegexp = regexp.MustCompile(`^(?:model\.language_model\.|language_model(?:\.model)?\.|model\.)?layers\.\d+$`)
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
// Captures: model.layers.{L}.mlp.experts
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together. // ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
// For example: // For example:
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts" // - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts" // - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
// - "language_model.model.layers.1.mlp.switch_mlp.down_proj.weight" -> "language_model.model.layers.1.mlp.switch_mlp"
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts) // - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert) // - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
func ExpertGroupPrefix(tensorName string) string { func ExpertGroupPrefix(tensorName string) string {
m := expertGroupRegexp.FindStringSubmatch(tensorName) if !strings.HasSuffix(tensorName, ".weight") {
if m == nil {
return "" return ""
} }
return m[1]
for _, marker := range []string{
".mlp.experts.",
".mlp.shared_experts.",
".mlp.switch_mlp.",
} {
idx := strings.Index(tensorName, marker)
if idx == -1 {
continue
}
layerPrefix := tensorName[:idx]
if !expertLayerPrefixRegexp.MatchString(layerPrefix) {
continue
}
return layerPrefix + strings.TrimSuffix(marker, ".")
}
return ""
} }
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob. // PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
@@ -427,6 +445,8 @@ type sourceQuantization struct {
Bits int `json:"bits"` Bits int `json:"bits"`
GroupSize int `json:"group_size"` GroupSize int `json:"group_size"`
Mode string `json:"mode"` Mode string `json:"mode"`
QuantMethod string `json:"quant_method"`
WeightBlockSize []int32 `json:"weight_block_size"`
} }
type sourceModelConfig struct { type sourceModelConfig struct {
@@ -493,6 +513,98 @@ func (cfg sourceModelConfig) QuantMetadata() map[string]string {
return metadata return metadata
} }
type sourceQuantizedKind string
const (
sourceQuantizedKindNone sourceQuantizedKind = ""
sourceQuantizedKindPrequantized sourceQuantizedKind = "prequantized"
sourceQuantizedKindHFFP8 sourceQuantizedKind = "hf_fp8"
)
func (cfg sourceModelConfig) quantizationConfigs() []sourceQuantization {
return []sourceQuantization{
cfg.Quantization,
cfg.QuantizationConfig,
cfg.TextConfig.Quantization,
cfg.TextConfig.QuantizationConfig,
}
}
func (cfg sourceModelConfig) HFFP8WeightBlockSize() (rows, cols int32, ok bool) {
for _, q := range cfg.quantizationConfigs() {
if !strings.EqualFold(q.QuantMethod, "fp8") || len(q.WeightBlockSize) != 2 {
continue
}
return q.WeightBlockSize[0], q.WeightBlockSize[1], true
}
return 0, 0, false
}
func inspectSourceQuantization(modelDir string, cfg sourceModelConfig) (sourceQuantizedKind, error) {
entries, err := os.ReadDir(modelDir)
if err != nil {
return sourceQuantizedKindNone, err
}
hasScaleInv := false
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
if err != nil {
return sourceQuantizedKindNone, err
}
for _, name := range extractor.ListTensors() {
switch {
case strings.HasSuffix(name, ".scales"):
extractor.Close()
return sourceQuantizedKindPrequantized, nil
case strings.HasSuffix(name, ".weight_scale_inv"):
hasScaleInv = true
}
}
extractor.Close()
}
if hasScaleInv {
if _, _, ok := cfg.HFFP8WeightBlockSize(); ok {
return sourceQuantizedKindHFFP8, nil
}
}
return sourceQuantizedKindNone, nil
}
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
switch sourceKind {
case sourceQuantizedKindNone:
return requested, nil
case sourceQuantizedKindPrequantized:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
}
return "", nil
case sourceQuantizedKindHFFP8:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized fp8 source model with --quantize %q", requested)
}
rows, cols, ok := cfg.HFFP8WeightBlockSize()
if !ok {
return "", fmt.Errorf("fp8 source model missing weight_block_size metadata")
}
if rows != 128 || cols != 128 {
return "", fmt.Errorf("unsupported fp8 source block size %dx%d", rows, cols)
}
return "mxfp8", nil
default:
return "", fmt.Errorf("unsupported source quantization kind %q", sourceKind)
}
}
type tensorImportTransform interface { type tensorImportTransform interface {
skipTensor(name string) bool skipTensor(name string) bool
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
@@ -546,6 +658,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if err != nil { if err != nil {
return fmt.Errorf("failed to read source config.json: %w", err) return fmt.Errorf("failed to read source config.json: %w", err)
} }
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
if err != nil {
return fmt.Errorf("failed to inspect source quantization: %w", err)
}
effectiveQuantize, err := resolveEffectiveQuantization(sourceConfig, sourceQuantKind, quantize)
if err != nil {
return err
}
sourceQuantMetadata := sourceConfig.QuantMetadata() sourceQuantMetadata := sourceConfig.QuantMetadata()
importTransform, err := newTensorImportTransform(modelDir, sourceConfig) importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
if err != nil { if err != nil {
@@ -557,7 +677,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if len(createPackedLayer) > 0 { if len(createPackedLayer) > 0 {
packedCreator = createPackedLayer[0] packedCreator = createPackedLayer[0]
} }
// Accumulate expert tensors by group prefix for packing. // Accumulate expert tensors by group prefix for packing.
// Readers reference file-backed SectionReaders, so we keep extractors // Readers reference file-backed SectionReaders, so we keep extractors
// open until each group is flushed to avoid buffering tensor data in memory. // open until each group is flushed to avoid buffering tensor data in memory.
@@ -600,8 +719,8 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
tensorSet[name] = struct{}{} tensorSet[name] = struct{}{}
} }
quantizeMsg := "" quantizeMsg := ""
if quantize != "" { if effectiveQuantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize) quantizeMsg = fmt.Sprintf(", quantizing to %s", effectiveQuantize)
} }
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg)) fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
@@ -612,9 +731,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
if importTransform.skipTensor(tensorName) { if importTransform.skipTensor(tensorName) {
continue continue
} }
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) { if shouldSkipSourceCompanion(tensorName, tensorSet) {
continue continue
} }
sourceFP8ScaleName, hasSourceFP8Scale := sourceFP8Companion(tensorName, tensorSet)
td, err := extractor.GetTensor(tensorName) td, err := extractor.GetTensor(tensorName)
if err != nil { if err != nil {
@@ -623,7 +743,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
} }
if quantize == "" { if effectiveQuantize == "" {
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer) layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
if err != nil { if err != nil {
extractor.Close() extractor.Close()
@@ -647,8 +767,33 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Determine quantization type for this tensor (empty string if not quantizing) // Determine quantization type for this tensor (empty string if not quantizing)
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := "" quantizeType := ""
if quantize != "" { switch {
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize) case sourceQuantKind == sourceQuantizedKindHFFP8 && hasSourceFP8Scale:
quantizeType = "mxfp8"
case sourceQuantKind == sourceQuantizedKindHFFP8:
quantizeType = ""
case effectiveQuantize != "":
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
}
reader := outTD.SafetensorsReader()
if hasSourceFP8Scale {
if len(outputTensors) != 1 {
extractor.Close()
closeExtractors()
return fmt.Errorf("source fp8 tensor %s rewrote into %d tensors; only 1:1 rewrites are supported", tensorName, len(outputTensors))
}
if quantizeType == "" {
extractor.Close()
closeExtractors()
return fmt.Errorf("source fp8 tensor %s was not scheduled for mxfp8 conversion", tensorName)
}
scaleTD, err := extractor.GetTensor(sourceFP8ScaleName)
if err != nil {
extractor.Close()
closeExtractors()
return fmt.Errorf("failed to get fp8 scale tensor %s: %w", sourceFP8ScaleName, err)
}
reader = buildSourceFP8Reader(outTD, scaleTD.WithName(outTD.Name+".scale_inv"))
} }
// Check if this tensor belongs to an expert group for packing // Check if this tensor belongs to an expert group for packing
@@ -670,13 +815,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
Dtype: outTD.Dtype, Dtype: outTD.Dtype,
Shape: outTD.Shape, Shape: outTD.Shape,
Quantize: quantizeType, Quantize: quantizeType,
Reader: outTD.SafetensorsReader(), Reader: reader,
}) })
} else { } else {
// Store as minimal safetensors format (88 bytes header overhead) // Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors // This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales) // createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType) newLayers, err := createTensorLayer(reader, outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
if err != nil { if err != nil {
extractor.Close() extractor.Close()
closeExtractors() closeExtractors()
@@ -760,7 +905,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return nil return nil
} }
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool { func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}) bool {
switch { switch {
case strings.HasSuffix(name, ".scales"): case strings.HasSuffix(name, ".scales"):
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"] _, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
@@ -768,11 +913,28 @@ func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{})
case strings.HasSuffix(name, ".biases"): case strings.HasSuffix(name, ".biases"):
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"] _, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
return ok return ok
case strings.HasSuffix(name, ".weight_scale_inv"):
_, ok := tensorSet[strings.TrimSuffix(name, "_scale_inv")]
return ok
default: default:
return false return false
} }
} }
func sourceFP8Companion(weightName string, tensorSet map[string]struct{}) (scaleName string, ok bool) {
if !strings.HasSuffix(weightName, ".weight") {
return "", false
}
scaleName = weightName + "_scale_inv"
_, ok = tensorSet[scaleName]
return scaleName, ok
}
func buildSourceFP8Reader(weightTD, scaleTD *safetensors.TensorData) io.Reader {
return safetensors.BuildPackedSafetensorsReader([]*safetensors.TensorData{weightTD, scaleTD})
}
func createPrequantizedLayer( func createPrequantizedLayer(
extractor *safetensors.TensorExtractor, extractor *safetensors.TensorExtractor,
td *safetensors.TensorData, td *safetensors.TensorData,

View File

@@ -246,6 +246,30 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
return nil return nil
} }
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" {
continue
}
names = append(names, name)
}
slices.Sort(names)
return names
}
func TestCreateSafetensorsModel(t *testing.T) { func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
@@ -546,6 +570,215 @@ func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
} }
} }
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
})
quantizeByName := make(map[string]string)
headerNamesByName := make(map[string][]string)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
}
if got := quantizeByName["norm.weight"]; got != "" {
t.Fatalf("norm.weight quantization = %q, want empty", got)
}
if got := quantizeByName["dense.weight"]; got != "" {
t.Fatalf("dense.weight quantization = %q, want empty", got)
}
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
}
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
}
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
}
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
}
}
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
tests := []struct {
name string
configJSON string
tensors []*st.TensorData
wantErr string
}{
{
name: "prequantized affine",
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
},
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
},
{
name: "hf fp8 source",
configJSON: `{
"model_type": "test",
"architectures": ["TestModel"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`,
tensors: []*st.TensorData{
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
},
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return nil, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
}
})
}
}
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create 2 experts so stacking produces a [2, 128, 128] tensor
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
})
var packedLayerNames []string
var packedLayerTensors [][]PackedTensorInput
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return LayerInfo{}, err
}
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
if _, err := io.ReadAll(r); err != nil {
return nil, err
}
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
packedLayerNames = append(packedLayerNames, groupName)
packedLayerTensors = append(packedLayerTensors, tensors)
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
if len(packedLayerNames) != 1 {
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
}
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
}
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
tensors := packedLayerTensors[0]
if len(tensors) != 6 {
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
}
// All should be marked for mxfp8 quantization
for _, tensor := range tensors {
if tensor.Quantize != "mxfp8" {
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
}
}
}
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) { func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
@@ -693,6 +926,113 @@ func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
} }
} }
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
t.Run(quantize, func(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantizeType}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
for _, name := range []string{
"language_model.model.embed_tokens.weight",
"language_model.lm_head.weight",
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
"language_model.model.layers.0.mlp.gate.weight",
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
} {
if got := tensorCalls[name].quantize; got != "" {
t.Fatalf("%s quantize = %q, want empty", name, got)
}
}
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
}
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
if len(group) != 3 {
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
}
for _, tensor := range group {
if tensor.Quantize != quantize {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
}
}
})
}
}
func TestResolveManifestPath(t *testing.T) { func TestResolveManifestPath(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -865,6 +1205,7 @@ func TestShouldQuantizeTensor(t *testing.T) {
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true}, {"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true}, {"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true}, {"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
// Small tensors should not be quantized (< 1024 elements) // Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false}, {"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
@@ -891,9 +1232,11 @@ func TestShouldQuantizeTensor(t *testing.T) {
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false}, {"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
// Group size divisibility tests // Group size divisibility tests
// FP8/FP4 require divisible by 32 // FP8/FP4/MXFP4 require divisible by 32
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false}, {"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true}, {"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
// NVFP4 requires divisible by 16 // NVFP4 requires divisible by 16
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false}, {"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true}, {"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
@@ -919,10 +1262,20 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"}, {"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"}, {"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// Expert tensors with language_model prefix should also match
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
// Shared expert tensors should return their own group prefix // Shared expert tensors should return their own group prefix
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"}, {"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"}, {"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
// Non-expert tensors should return empty string // Non-expert tensors should return empty string
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts {"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert {"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
@@ -978,6 +1331,161 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
if combinedDown != "int8" { if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8") t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
} }
nvfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"nvfp4",
)
if nvfp4GateUp != "nvfp4" {
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
}
nvfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"nvfp4",
)
if nvfp4Down != "nvfp4" {
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
}
mxfp4GateUp := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
[]int32{64, 11008, 4096},
"mxfp4",
)
if mxfp4GateUp != "mxfp4" {
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
}
mxfp4Down := GetTensorQuantization(
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
[]int32{64, 4096, 11008},
"mxfp4",
)
if mxfp4Down != "mxfp4" {
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
}
}
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
dir := t.TempDir()
configJSON := `{
"model_type": "test",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {"dtype": "bfloat16"}
}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
gateUpValues := make([]float32, 2*128*64)
for expert := range 2 {
base := expert * 128 * 64
for i := range 64 * 64 {
gateUpValues[base+i] = 1
gateUpValues[base+64*64+i] = 2
}
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
})
type tensorCall struct {
quantize string
}
type packedTensorCall struct {
Name string
Dtype string
Shape []int32
Quantize string
}
tensorCalls := make(map[string]tensorCall)
packedCalls := make(map[string][]packedTensorCall)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
_, _ = io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
_, _ = io.ReadAll(r)
tensorCalls[name] = tensorCall{quantize: quantize}
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
}
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
group := make([]packedTensorCall, 0, len(tensors))
for _, tensor := range tensors {
group = append(group, packedTensorCall{
Name: tensor.Name,
Dtype: tensor.Dtype,
Shape: append([]int32(nil), tensor.Shape...),
Quantize: tensor.Quantize,
})
}
packedCalls[groupName] = group
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
groupName := "language_model.model.layers.0.mlp.switch_mlp"
group, ok := packedCalls[groupName]
if !ok {
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
}
if len(group) != 3 {
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
}
gotNames := make([]string, 0, len(group))
for _, tensor := range group {
gotNames = append(gotNames, tensor.Name)
if tensor.Quantize != "nvfp4" {
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
}
if tensor.Dtype != "BF16" {
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
}
}
slices.Sort(gotNames)
wantNames := []string{
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
}
if !slices.Equal(gotNames, wantNames) {
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
}
for _, name := range wantNames {
if _, ok := tensorCalls[name]; ok {
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
}
}
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
t.Fatalf("embed_tokens quantize = %q, want empty", got)
}
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
t.Fatalf("mlp.gate quantize = %q, want empty", got)
}
} }
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {

View File

@@ -87,6 +87,27 @@ func (t qwen35ImportTransform) skipTensor(name string) bool {
return strings.Contains(name, "mtp.") return strings.Contains(name, "mtp.")
} }
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
switch {
case strings.HasSuffix(name, "embed_tokens.weight"):
return true
case strings.HasSuffix(name, "lm_head.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
return true
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
return true
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
return true
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
return true
default:
return false
}
}
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string { func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
if strings.HasPrefix(name, "vision_tower.") { if strings.HasPrefix(name, "vision_tower.") {
return "" return ""
@@ -127,6 +148,13 @@ func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quan
return "" return ""
} }
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
// keep embeddings, LM head, low-rank linear_attn projections, and routing
// gates in BF16 rather than forcing them into a non-affine quantized format.
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
return ""
}
return quantNorm return quantNorm
} }

View File

@@ -1,11 +1,11 @@
include(FetchContent) include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile) # Read MLX-C version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
# Read MLX core version from top-level file # Read MLX version from top-level file
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG) file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG) string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF) set(MLX_C_BUILD_EXAMPLES OFF)
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/") file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
find_program(GO_EXECUTABLE go REQUIRED)
message(STATUS "Regenerating MLX Go wrappers")
execute_process(
COMMAND ${GO_EXECUTABLE} generate ./x/...
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
COMMAND_ERROR_IS_FATAL ANY
)
# For local dev builds, override MLX_VERSION with git describe output # For local dev builds, override MLX_VERSION with git describe output
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX) if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
execute_process( execute_process(

View File

@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_ptr)(void) = NULL; bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL; mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL; int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL; int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL; int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL; int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL; int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
return -1; return -1;
} }
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
if (mlx_bartlett_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
return -1;
}
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
if (mlx_bitwise_and_ptr == NULL) { if (mlx_bitwise_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
return -1; return -1;
} }
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
if (mlx_blackman_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
return -1;
}
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
if (mlx_block_masked_mm_ptr == NULL) { if (mlx_block_masked_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
return -1; return -1;
} }
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
if (mlx_hamming_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
return -1;
}
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
if (mlx_hanning_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
return -1;
}
mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
if (mlx_identity_ptr == NULL) { if (mlx_identity_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
return mlx_distributed_group_split_ptr(group, color, key); return mlx_distributed_group_split_ptr(group, color, key);
} }
bool mlx_distributed_is_available(void) { bool mlx_distributed_is_available(const char* bk) {
return mlx_distributed_is_available_ptr(); return mlx_distributed_is_available_ptr(bk);
} }
mlx_distributed_group mlx_distributed_init(bool strict) { mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
return mlx_distributed_init_ptr(strict); return mlx_distributed_init_ptr(strict, bk);
} }
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_ptr(res, a, s); return mlx_atleast_3d_ptr(res, a, s);
} }
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_ptr(res, M, s);
}
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_and_ptr(res, a, b, s); return mlx_bitwise_and_ptr(res, a, b, s);
} }
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
return mlx_bitwise_xor_ptr(res, a, b, s); return mlx_bitwise_xor_ptr(res, a, b, s);
} }
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_ptr(res, M, s);
}
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) { int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
} }
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
return mlx_depends_ptr(res, inputs, dependencies); return mlx_depends_ptr(res, inputs, dependencies);
} }
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) { int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s); return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
} }
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
return mlx_hadamard_transform_ptr(res, a, scale, s); return mlx_hadamard_transform_ptr(res, a, scale, s);
} }
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_ptr(res, M, s);
}
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_ptr(res, M, s);
}
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_ptr(res, n, dtype, s); return mlx_identity_ptr(res, n, dtype, s);
} }
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s); return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
} }
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s); return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
} }
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
return mlx_quantize_ptr(res, w, group_size, bits, mode, s); return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
} }
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {

View File

@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new() res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream()) var globalScale C.mlx_array
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?] // Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases) // mxfp8 mode returns only 2 elements (no biases)
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
} }
res := C.mlx_array_new() res := C.mlx_array_new()
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream()) var globalScale C.mlx_array
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
return newArray(res) return newArray(res)
} }

View File

@@ -309,10 +309,12 @@
#undef mlx_atleast_1d #undef mlx_atleast_1d
#undef mlx_atleast_2d #undef mlx_atleast_2d
#undef mlx_atleast_3d #undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and #undef mlx_bitwise_and
#undef mlx_bitwise_invert #undef mlx_bitwise_invert
#undef mlx_bitwise_or #undef mlx_bitwise_or
#undef mlx_bitwise_xor #undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm #undef mlx_block_masked_mm
#undef mlx_broadcast_arrays #undef mlx_broadcast_arrays
#undef mlx_broadcast_to #undef mlx_broadcast_to
@@ -365,6 +367,8 @@
#undef mlx_greater #undef mlx_greater
#undef mlx_greater_equal #undef mlx_greater_equal
#undef mlx_hadamard_transform #undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity #undef mlx_identity
#undef mlx_imag #undef mlx_imag
#undef mlx_inner #undef mlx_inner
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key); extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_ptr)(void); extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict); extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...); extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key); mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
bool mlx_distributed_is_available(void); bool mlx_distributed_is_available(const char* bk);
mlx_distributed_group mlx_distributed_init(bool strict); mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);

View File

@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
} }
// Check for partial match within a node's edge — truncate path
// to the parent boundary. snapshot() will split the node and
// create the branch point during prefill when caches are ready.
partialMatch := false
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath = matchPath[:len(matchPath)-1]
partialMatch = true
}
}
// Switch to the matched path, paging in/out as needed. // Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath) c.switchToPath(matchPath, matched)
// switchToPath aligns caches to a common offset // switchToPath aligns caches to a common offset
prefix := c.minCacheOffset() prefix := c.minCacheOffset()
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// Schedule a snapshot at the branch point during prefill so future // Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating. // requests diverging here can restore instead of re-evaluating.
var snapshotAt int var snapshotAt int
if partialMatch || (prefix == 0 && matched > 0) { if prefix < matched {
snapshotAt = matched snapshotAt = matched
} }
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// switchToPath transitions from the current active path to a new path, // switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path. // paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode) { func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
defer c.enforceEvictionPolicy() defer c.enforceEvictionPolicy()
// Find common ancestor index. // Find common ancestor index.
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// non-leaf nodes here would produce wrong results for non-rewindable // non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not // caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary. // the intermediate boundary.
if leaf := len(c.activePath) - 1; leaf >= commonLen { leaf := len(c.activePath) - 1
leafDiverges := leaf >= commonLen
leafNeedsRewind := matched < c.activePath[leaf].endOffset
if leafDiverges || leafNeedsRewind {
node := c.activePath[leaf] node := c.activePath[leaf]
if !node.hasAllSnapshots() { if !node.hasAllSnapshots() {
fromOffset := node.startOffset() fromOffset := node.startOffset()
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
} }
} }
// Rewind each cache to the ancestor offset or free it. Freed // Rewind each cache to the target offset or free it. When matched
// caches (e.g. RecurrentCache that can't rewind) will be restored // falls within the ancestor's range (same-path case), we rewind
// from snapshots during page-in. // directly to the match point. Otherwise we rewind to the ancestor
// and let page-in bring us forward to matched.
rewindTarget := min(ancestorOffset, matched)
for _, kv := range c.caches { for _, kv := range c.caches {
if kv == nil { if kv == nil {
continue continue
} }
if !kv.Restore(nil, ancestorOffset) { if !kv.Restore(nil, rewindTarget) {
kv.Free() kv.Free()
} }
} }
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// Page in — walk the full new path, restoring from snapshots. // Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot. // Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check. // Caches already past a node skip it via offset check.
pageIn:
for _, node := range newPath { for _, node := range newPath {
if len(node.snapshots) == 0 { if !node.hasSnapshots() {
continue continue
} }
nodeTarget := min(node.endOffset, matched)
for j, kv := range c.caches { for j, kv := range c.caches {
if kv == nil { if kv == nil {
continue continue
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
if j >= len(node.snapshots) || node.snapshots[j] == nil { if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue continue
} }
if kv.Offset() >= node.endOffset { if kv.Offset() >= nodeTarget {
continue continue
} }
if !kv.Restore(node.snapshots[j], node.endOffset) { if !kv.Restore(node.snapshots[j], nodeTarget) {
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset()) // Restore failed — stop page-in and let alignment
c.freeAll() // bring all caches to a consistent offset.
c.activePath = []*trieNode{c.root} break pageIn
return
} }
} }
if node.endOffset > ancestorOffset { if node.endOffset > ancestorOffset {
pageInCount++ pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset)) logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
} }
} }
@@ -536,6 +529,9 @@ func (c *kvCache) dumpTree() {
if nodeBytes > 0 { if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String() label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
} }
if !n.lastUsed.IsZero() {
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
}
var flags []string var flags []string
if n.user { if n.user {
flags = append(flags, "user") flags = append(flags, "user")

View File

@@ -17,7 +17,8 @@ type Cache interface {
Snapshot(fromOffset int) Snapshot Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds // Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state. // using the cache's own live state. Returns false if the target is
// unreachable (e.g. target > current offset, or negative).
Restore(snapshot Snapshot, target int) bool Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c). // Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
} }
func (c *KVCache) Restore(snapshot Snapshot, target int) bool { func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
// Rewind using live state — just clamp offset. if target > c.offset {
target = max(0, min(target, c.offset)) return false
}
c.offset = target c.offset = target
return true return true
} }
snap := snapshot.(*kvSnapshot) snap := snapshot.(*kvSnapshot)
// Check that the cache has data up to the snapshot's starting point. if target > snap.toOffset || c.offset < snap.fromOffset {
if c.offset < snap.fromOffset {
return false return false
} }
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
} }
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
if target >= c.offset {
return target == c.offset
}
// Live rewind is only safe when the buffer hasn't filled yet // Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding // (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to — // leaves fewer than maxSize trailing tokens to attend to —
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if c.offset > c.maxSize { if c.offset > c.maxSize {
return false return false
} }
target = max(0, min(target, c.offset))
c.offset = target c.offset = target
c.idx = target c.idx = target
return true return true
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*rotatingSnapshot) snap := snapshot.(*rotatingSnapshot)
if target > snap.toOffset {
return false
}
// Reject if clamping would leave an incomplete window. // Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize { if target < snap.toOffset && snap.toOffset > c.maxSize {
return false return false
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
// Clamp to target if needed. // Clamp to target if needed.
if target < c.offset { if target < c.offset {
target = max(0, target)
c.offset = target c.offset = target
c.idx = target c.idx = target
} }

View File

@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
mlx.Pin(v) mlx.Pin(v)
if old != nil && old != v {
mlx.Unpin(old) mlx.Unpin(old)
}
return v return v
} }
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
root := v root := v
if ensureContiguous { if ensureContiguous {
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
detached := root.Clone() detached := root.Clone()
mlx.Pin(detached) mlx.Pin(detached)
if old != nil && old != detached {
mlx.Unpin(old) mlx.Unpin(old)
}
return detached return detached
} }
@@ -150,10 +140,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*recurrentSnapshot) snap := snapshot.(*recurrentSnapshot)
// Recurrent state encodes all tokens up to snap.offset. Restoring // Recurrent snapshots encode cumulative state up to exactly
// to a target before that would leave stale state from tokens // snap.offset. Target must match — rewinding would leave stale
// [target, snap.offset) baked in. Only allow restoring forward. // state, and advancing isn't possible without feeding tokens.
if target < snap.offset { if target != snap.offset {
return false return false
} }

View File

@@ -6,39 +6,35 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only // TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
// allows restoring forward (target >= snapshot offset), never backward. // only succeeds when target exactly matches the snapshot's offset. Recurrent
func TestRecurrentCacheRestoreDirectionality(t *testing.T) { // state is cumulative, so it can't be rewound or fast-forwarded.
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
skipIfNoMLX(t) skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8) c := NewRecurrentCache(3, 12, 4, 8, 8)
_ = c.ConvState(1, mlx.DTypeFloat16) _ = c.ConvState(1, mlx.DTypeFloat16)
_ = c.DeltaState(1, mlx.DTypeFloat16) _ = c.DeltaState(1, mlx.DTypeFloat16)
c.Advance(10) c.Advance(10)
snap := c.Snapshot(0) snap := c.Snapshot(0) // snap.offset == 10
c.Advance(5) // now at 15 c.Advance(5) // cache now at 15
// Restore backward should fail. // target < snap.offset: fails (can't rewind past snapshot)
if c.Restore(snap, 5) { if c.Restore(snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target < snap.offset") t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
} }
// Restore to exact snap offset should succeed. // target > snap.offset: fails (can't advance without feeding tokens)
if c.Restore(snap, 15) {
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
}
// target == snap.offset: succeeds
if !c.Restore(snap, 10) { if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed") t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
} }
if c.Offset() != 10 { if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset()) t.Fatalf("offset = %d, want 10", c.Offset())
} }
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
snap2 := c.Snapshot(0)
if !c.Restore(snap2, 15) {
t.Fatal("Restore(snap, 15) should succeed")
}
// Recurrent state is at snap.offset (10), not target (15).
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
}
} }

View File

@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
} }
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool { func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
// Rewind live state.
if target < 0 { if target < 0 {
target = 0 return false
} }
if snapshot == nil {
if target > len(c.tokens) { if target > len(c.tokens) {
target = len(c.tokens) return false
} }
c.tokens = c.tokens[:target] c.tokens = c.tokens[:target]
return true return true
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if len(c.tokens) < s.from { if target > s.to || len(c.tokens) < s.from {
return false // don't have base data up to snapshot start return false
} }
c.tokens = append(c.tokens[:s.from], s.tokens...) c.tokens = append(c.tokens[:s.from], s.tokens...)
if target < len(c.tokens) { if target < len(c.tokens) {
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
} }
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool { func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
if target == len(c.tokens) { if target >= len(c.tokens) {
return true return target == len(c.tokens)
} }
// Live rewind only works when buffer hasn't filled (offset <= maxSize). // Live rewind only works when buffer hasn't filled (offset <= maxSize).
if len(c.tokens) > c.maxSize { if len(c.tokens) > c.maxSize {
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
return true return true
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if target > s.to {
return false
}
// Reject if clamping would leave an incomplete window
// (matches RotatingKVCache behavior).
if target < s.to && s.to > c.maxSize {
return false
}
c.tokens = slices.Clone(s.tokens) c.tokens = slices.Clone(s.tokens)
if target < len(c.tokens) { if target < len(c.tokens) {
c.tokens = c.tokens[:target] c.tokens = c.tokens[:target]
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
return target == len(c.tokens) // can only no-op return target == len(c.tokens) // can only no-op
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if target < s.to { if target != s.to {
return false // can't go backward return false // cumulative state requires exact match
} }
c.tokens = slices.Clone(s.tokens) c.tokens = slices.Clone(s.tokens)
return true return true
@@ -297,6 +309,7 @@ type testEnv struct {
kvc *kvCache kvc *kvCache
caches []cache.Cache // typed references for assertions caches []cache.Cache // typed references for assertions
tracker *snapshotTracker tracker *snapshotTracker
rewindable bool // true when all caches support arbitrary Restore(nil, target)
} }
// newTransformerEnv creates a test environment with a single rewindable cache // newTransformerEnv creates a test environment with a single rewindable cache
@@ -308,20 +321,25 @@ func newTransformerEnv() *testEnv {
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tracker, tracker: tracker,
rewindable: true,
} }
} }
// newSlidingWindowEnv creates a test environment with one rewindable cache and // newSlidingWindowEnv creates a test environment with one rewindable cache and
// one sliding window cache (Mistral-style architecture). // one sliding window cache (Mistral-style architecture). The sliding window
// maxSize is set small enough that test sequences fill it, making
// Restore(nil, target) fail — the same behavior as production models where
// the window fills after a few turns.
func newSlidingWindowEnv() *testEnv { func newSlidingWindowEnv() *testEnv {
tr := &snapshotTracker{} tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr} rc := &fakeRewindableCache{tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr} sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
caches := []cache.Cache{rc, sw} caches := []cache.Cache{rc, sw}
return &testEnv{ return &testEnv{
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tr, tracker: tr,
rewindable: false,
} }
} }
@@ -336,6 +354,7 @@ func newRecurrentEnv() *testEnv {
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tr, tracker: tr,
rewindable: false,
} }
} }
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
} }
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A. // Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
// Partial match in A's edge triggers snapshotOffset. // For rewindable caches, switchToPath rewinds to the match point
// so only the non-matching suffix needs evaluation. For non-rewindable
// caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if env.rewindable {
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
}
if len(resB.remaining) != 3 {
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
}
} else {
if resB.snapshotOffset != 5 { if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
} }
// Cache was rewound to 0 (partial match truncates path to root),
// so all tokens were re-evaluated.
if len(resB.remaining) != 8 { if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
}
} }
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
@@ -635,15 +663,25 @@ func TestExactMatchSeedBehavior(t *testing.T) {
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: identical prompt. Holdback means matched=4, partial in // Request B: identical prompt. Holdback means matched=4, partial in
// the 5-token edge, so path truncates to root and all tokens are // the 5-token edge. For rewindable caches, switchToPath rewinds to
// re-evaluated. snapshotOffset should be set at the holdback point. // offset 4, so only the held-back token needs re-evaluation. For
// non-rewindable caches, the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21}) resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
if env.rewindable {
if len(resB.remaining) != 1 {
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
}
if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
}
} else {
if len(resB.remaining) != 5 { if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining)) t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
} }
if resB.snapshotOffset != 4 { if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
} }
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
checkTrieInvariants(t, kvc.root) checkTrieInvariants(t, kvc.root)

View File

@@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
resp, err := c.client.Do(httpReq) resp, err := c.client.Do(httpReq)
if err != nil { if err != nil {
if errMsg := c.status.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
} }
} }
return scanner.Err() if err := scanner.Err(); err != nil {
if errMsg := c.status.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err
}
return nil
} }
func (c *Client) ContextLength() int { func (c *Client) ContextLength() int {

View File

@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent) include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "") # Read MLX-C version from top-level file (shared with imagegen CMakeLists)
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
FetchContent_Declare( FetchContent_Declare(
mlx-c mlx-c

View File

@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
for _, t := range s { for _, t := range s {
if t != nil { if t != nil {
t.pinned-- t.pinned--
if t.pinned < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
} }
} }
} }
@@ -261,7 +264,7 @@ func LogArrays() {
for _, t := range arrays { for _, t := range arrays {
nb := t.NumBytes() nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims())) logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
} }
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
} }

View File

@@ -13,6 +13,10 @@ var (
gatedDeltaMetalKernelOnce sync.Once gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled bool gatedDeltaMetalDisabled bool
gatedDeltaCUDAKernelOnce sync.Once
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
gatedDeltaCUDADisabled bool
) )
const gatedDeltaMetalKernelSource = ` const gatedDeltaMetalKernelSource = `
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
} }
` `
const gatedDeltaCUDAKernelSource = `
auto tid_x = threadIdx.x;
auto tid_y = threadIdx.y;
auto grid_y = blockIdx.y * blockDim.y + tid_y;
auto grid_z = blockIdx.z;
int T_val = static_cast<int>(*T);
auto n = grid_z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto dv_idx = grid_y;
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
auto dk_idx = tid_x;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T_val * Hv;
auto beta_ = beta + b_idx * T_val * Hv;
for (int t = 0; t < T_val; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
}
// Warp reduction (full warp, 32 threads in x)
for (int offset = 16; offset > 0; offset >>= 1)
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
out += state[i] * static_cast<float>(q_[s_idx]);
}
// Warp reduction
for (int offset = 16; offset > 0; offset >>= 1)
out += __shfl_down_sync(0xffffffff, out, offset);
if (tid_x == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) { func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new() vec := C.mlx_vector_string_new()
ok := true ok := true
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
return Concatenate(outs, 1), nextState return Concatenate(outs, 1), nextState
} }
func initGatedDeltaCUDAKernel() {
var cudaAvail C.bool
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
gatedDeltaCUDADisabled = true
return
}
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaCUDADisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaCUDADisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaCUDAKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.int(0),
)
}
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaCUDADisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
if gatedDeltaCUDADisabled {
return nil, nil, false
}
cfg := C.mlx_fast_cuda_kernel_config_new()
defer C.mlx_fast_cuda_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_CUDA_Y")
nextState = New("GATED_DELTA_CUDA_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
// GatedDelta runs the recurrent update operation. // GatedDelta runs the recurrent update operation.
// //
// It uses the fused Metal kernel when available and otherwise falls back to a // It tries the fused CUDA kernel first, then Metal, then falls back to a
// backend-agnostic MLX implementation with identical inputs/outputs. // backend-agnostic MLX implementation with identical inputs/outputs.
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) { func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
return y, nextState
}
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok { if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
return y, nextState return y, nextState
} }

View File

@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL; bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */) = NULL;
void (*mlx_set_error_handler_)( void (*mlx_set_error_handler_)(
mlx_error_handler_func handler, mlx_error_handler_func handler,
void* data, void* data,
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_)( int (*mlx_bitwise_and_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_)( int (*mlx_block_masked_mm_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_)( int (*mlx_inner_)(
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_quantize_)( int (*mlx_quantize_)(
mlx_vector_array* res, mlx_vector_array* res,
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_)( int (*mlx_quantized_matmul_)(
mlx_array* res, mlx_array* res,
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_atleast_1d); CHECK_LOAD(handle, mlx_atleast_1d);
CHECK_LOAD(handle, mlx_atleast_2d); CHECK_LOAD(handle, mlx_atleast_2d);
CHECK_LOAD(handle, mlx_atleast_3d); CHECK_LOAD(handle, mlx_atleast_3d);
CHECK_LOAD(handle, mlx_bartlett);
CHECK_LOAD(handle, mlx_bitwise_and); CHECK_LOAD(handle, mlx_bitwise_and);
CHECK_LOAD(handle, mlx_bitwise_invert); CHECK_LOAD(handle, mlx_bitwise_invert);
CHECK_LOAD(handle, mlx_bitwise_or); CHECK_LOAD(handle, mlx_bitwise_or);
CHECK_LOAD(handle, mlx_bitwise_xor); CHECK_LOAD(handle, mlx_bitwise_xor);
CHECK_LOAD(handle, mlx_blackman);
CHECK_LOAD(handle, mlx_block_masked_mm); CHECK_LOAD(handle, mlx_block_masked_mm);
CHECK_LOAD(handle, mlx_broadcast_arrays); CHECK_LOAD(handle, mlx_broadcast_arrays);
CHECK_LOAD(handle, mlx_broadcast_to); CHECK_LOAD(handle, mlx_broadcast_to);
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_greater); CHECK_LOAD(handle, mlx_greater);
CHECK_LOAD(handle, mlx_greater_equal); CHECK_LOAD(handle, mlx_greater_equal);
CHECK_LOAD(handle, mlx_hadamard_transform); CHECK_LOAD(handle, mlx_hadamard_transform);
CHECK_LOAD(handle, mlx_hamming);
CHECK_LOAD(handle, mlx_hanning);
CHECK_LOAD(handle, mlx_identity); CHECK_LOAD(handle, mlx_identity);
CHECK_LOAD(handle, mlx_imag); CHECK_LOAD(handle, mlx_imag);
CHECK_LOAD(handle, mlx_inner); CHECK_LOAD(handle, mlx_inner);

View File

@@ -300,10 +300,12 @@
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_ #define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_ #define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_ #define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_ #define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_ #define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_ #define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_ #define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
#define mlx_blackman mlx_blackman_mlx_gen_orig_
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_ #define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_ #define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_ #define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
@@ -356,6 +358,8 @@
#define mlx_greater mlx_greater_mlx_gen_orig_ #define mlx_greater mlx_greater_mlx_gen_orig_
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_ #define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_ #define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
#define mlx_hamming mlx_hamming_mlx_gen_orig_
#define mlx_hanning mlx_hanning_mlx_gen_orig_
#define mlx_identity mlx_identity_mlx_gen_orig_ #define mlx_identity mlx_identity_mlx_gen_orig_
#define mlx_imag mlx_imag_mlx_gen_orig_ #define mlx_imag mlx_imag_mlx_gen_orig_
#define mlx_inner mlx_inner_mlx_gen_orig_ #define mlx_inner mlx_inner_mlx_gen_orig_
@@ -889,10 +893,12 @@
#undef mlx_atleast_1d #undef mlx_atleast_1d
#undef mlx_atleast_2d #undef mlx_atleast_2d
#undef mlx_atleast_3d #undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and #undef mlx_bitwise_and
#undef mlx_bitwise_invert #undef mlx_bitwise_invert
#undef mlx_bitwise_or #undef mlx_bitwise_or
#undef mlx_bitwise_xor #undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm #undef mlx_block_masked_mm
#undef mlx_broadcast_arrays #undef mlx_broadcast_arrays
#undef mlx_broadcast_to #undef mlx_broadcast_to
@@ -945,6 +951,8 @@
#undef mlx_greater #undef mlx_greater
#undef mlx_greater_equal #undef mlx_greater_equal
#undef mlx_hadamard_transform #undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity #undef mlx_identity
#undef mlx_imag #undef mlx_imag
#undef mlx_inner #undef mlx_inner
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_)(void); extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); extern mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */);
extern void (*mlx_set_error_handler_)( extern void (*mlx_set_error_handler_)(
mlx_error_handler_func handler, mlx_error_handler_func handler,
void* data, void* data,
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_)( extern int (*mlx_bitwise_and_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_)( extern int (*mlx_block_masked_mm_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_)( extern int (*mlx_inner_)(
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_quantize_)( extern int (*mlx_quantize_)(
mlx_vector_array* res, mlx_vector_array* res,
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_quantized_matmul_)( extern int (*mlx_quantized_matmul_)(
mlx_array* res, mlx_array* res,
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
return mlx_distributed_group_split_(group, color, key); return mlx_distributed_group_split_(group, color, key);
} }
static inline bool mlx_distributed_is_available(void) { static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
return mlx_distributed_is_available_(); return mlx_distributed_is_available_(bk);
} }
static inline mlx_distributed_group mlx_distributed_init(bool strict) { static inline mlx_distributed_group mlx_distributed_init(
return mlx_distributed_init_(strict); bool strict,
const char* bk /* may be null */) {
return mlx_distributed_init_(strict, bk);
} }
static inline void mlx_set_error_handler( static inline void mlx_set_error_handler(
mlx_error_handler_func handler, mlx_error_handler_func handler,
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_(res, a, s); return mlx_atleast_3d_(res, a, s);
} }
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_(res, M, s);
}
static inline int mlx_bitwise_and( static inline int mlx_bitwise_and(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
const mlx_stream s) { const mlx_stream s) {
return mlx_bitwise_xor_(res, a, b, s); return mlx_bitwise_xor_(res, a, b, s);
} }
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_(res, M, s);
}
static inline int mlx_block_masked_mm( static inline int mlx_block_masked_mm(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s) { const mlx_stream s) {
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
} }
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_diag_(res, a, k, s); return mlx_diag_(res, a, k, s);
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
const mlx_stream s) { const mlx_stream s) {
return mlx_hadamard_transform_(res, a, scale, s); return mlx_hadamard_transform_(res, a, scale, s);
} }
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_(res, M, s);
}
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_(res, M, s);
}
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_(res, n, dtype, s); return mlx_identity_(res, n, dtype, s);
} }
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) { const mlx_stream s) {
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s); return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
} }
static inline int mlx_quantize( static inline int mlx_quantize(
mlx_vector_array* res, mlx_vector_array* res,
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) { const mlx_stream s) {
return mlx_quantize_(res, w, group_size, bits, mode, s); return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
} }
static inline int mlx_quantized_matmul( static inline int mlx_quantized_matmul(
mlx_array* res, mlx_array* res,

View File

@@ -1,7 +1,7 @@
# Vendored MLX-C Headers # Vendored MLX-C Headers
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c). These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
The pinned version is in `MLX_VERSION` at the repo root. The pinned version is in `MLX_C_VERSION` at the repo root.
Headers are automatically refreshed when you run a CMake build: Headers are automatically refreshed when you run a CMake build:

View File

@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
/** /**
* Check if distributed is available. * Check if distributed is available.
*/ */
bool mlx_distributed_is_available(void); bool mlx_distributed_is_available(const char* bk /* may be null */);
/** /**
* Initialize distributed. * Initialize distributed.
*/ */
mlx_distributed_group mlx_distributed_init(bool strict); mlx_distributed_group mlx_distributed_init(
bool strict,
const char* bk /* may be null */);
/**@}*/ /**@}*/

View File

@@ -166,6 +166,7 @@ int mlx_astype(
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and( int mlx_bitwise_and(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s); const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm( int mlx_block_masked_mm(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -362,6 +364,7 @@ int mlx_dequantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s); const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s); const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_inner( int mlx_inner(
@@ -790,6 +795,8 @@ int mlx_qqmm(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s); const mlx_stream s);
int mlx_quantize( int mlx_quantize(
mlx_vector_array* res, mlx_vector_array* res,
@@ -797,6 +804,7 @@ int mlx_quantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s); const mlx_stream s);
int mlx_quantized_matmul( int mlx_quantized_matmul(
mlx_array* res, mlx_array* res,

View File

@@ -4,35 +4,91 @@ package mlx
import "C" import "C"
import ( import (
"fmt"
"iter" "iter"
"runtime" "runtime"
"unsafe" "unsafe"
) )
func Load(path string) iter.Seq2[string, *Array] { // SafetensorsFile represents a loaded safetensors file.
return func(yield func(string, *Array) bool) { type SafetensorsFile struct {
string2array := C.mlx_map_string_to_array_new() arrays C.mlx_map_string_to_array
defer C.mlx_map_string_to_array_free(string2array) metadata C.mlx_map_string_to_string
}
string2string := C.mlx_map_string_to_string_new() func loadSafetensorsStream() C.mlx_stream {
defer C.mlx_map_string_to_string_free(string2string) if runtime.GOOS == "darwin" {
return C.mlx_default_cpu_stream_new()
}
return C.mlx_default_gpu_stream_new()
}
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
cPath := C.CString(path) cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath)) defer C.free(unsafe.Pointer(cPath))
// Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu). stream := loadSafetensorsStream()
// macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
var stream C.mlx_stream
if runtime.GOOS == "darwin" {
stream = C.mlx_default_cpu_stream_new()
} else {
stream = C.mlx_default_gpu_stream_new()
}
defer C.mlx_stream_free(stream) defer C.mlx_stream_free(stream)
C.mlx_load_safetensors(&string2array, &string2string, cPath, stream) if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
it := C.mlx_map_string_to_array_iterator_new(string2array) return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name.
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
value := C.mlx_array_new()
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
return nil
}
if value.ctx == nil {
return nil
}
arr := New(name)
arr.ctx = value
return arr
}
// GetMetadata retrieves a metadata value by key.
func (s *SafetensorsFile) GetMetadata(key string) string {
cKey := C.CString(key)
defer C.free(unsafe.Pointer(cKey))
var cValue *C.char
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
return ""
}
return C.GoString(cValue)
}
// Free releases the loaded safetensors maps.
func (s *SafetensorsFile) Free() {
if s == nil {
return
}
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
sf, err := LoadSafetensorsNative(path)
if err != nil {
return
}
defer sf.Free()
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
defer C.mlx_map_string_to_array_iterator_free(it) defer C.mlx_map_string_to_array_iterator_free(it)
for { for {
@@ -51,3 +107,43 @@ func Load(path string) iter.Seq2[string, *Array] {
} }
} }
} }
// SaveSafetensors saves arrays to a safetensors file without metadata.
func SaveSafetensors(path string, arrays map[string]*Array) error {
return SaveSafetensorsWithMetadata(path, arrays, nil)
}
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
for name, arr := range arrays {
if arr == nil {
continue
}
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
cMetadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMetadata)
for key, value := range metadata {
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
C.free(unsafe.Pointer(cKey))
C.free(unsafe.Pointer(cValue))
}
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}

View File

@@ -7,8 +7,44 @@ package mlx
// #cgo LDFLAGS: -lstdc++ // #cgo LDFLAGS: -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate // #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h" // #include "generated.h"
// #include <string.h>
//
// static char _mlx_last_error_msg[1024] = {0};
// static int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
// _mlx_last_error_flag = 1;
// }
//
// static void mlx_install_capture_handler(void) {
// if (mlx_set_error_handler_) {
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
// }
// }
//
// static void mlx_clear_last_error(void) {
// _mlx_last_error_flag = 0;
// _mlx_last_error_msg[0] = '\0';
// }
//
// static int mlx_had_last_error(void) {
// return _mlx_last_error_flag;
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
// }
import "C" import "C"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
C.mlx_install_capture_handler()
}
// Version returns the MLX core library version string. // Version returns the MLX core library version string.
func Version() string { func Version() string {
str := C.mlx_string_new() str := C.mlx_string_new()
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
} }
} }
C.mlx_clear_last_error()
var rc C.int
if async { if async {
C.mlx_async_eval(vector) rc = C.mlx_async_eval(vector)
} else { } else {
C.mlx_eval(vector) rc = C.mlx_eval(vector)
}
if rc != 0 {
msg := "mlx eval failed"
if C.mlx_had_last_error() != 0 {
msg = C.GoString(C.mlx_get_last_error())
}
panic("mlx: " + msg)
} }
} }

View File

@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new() res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res) defer C.mlx_vector_array_free(res)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx) var globalScale C.mlx_array
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res)) vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W") w0 := New("QUANTIZE_W")
@@ -32,6 +33,18 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
return w0, w1, nil return w0, w1, nil
} }
func FromFP8(x *Array, dtype DType) *Array {
out := New("FROM_FP8")
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func ToFP8(x *Array) *Array {
out := New("TO_FP8")
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
return out
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array { func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode) cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode)) defer C.free(unsafe.Pointer(cMode))
@@ -45,7 +58,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
} }
out := New("DEQUANTIZE") out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx) var globalScale C.mlx_array
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
return out return out
} }
@@ -135,6 +149,40 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
return out return out
} }
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
axes := make([]C.int, numAxes)
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := range numAxes {
axes[i] = C.int(i)
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
}
padValue := C.mlx_array_new_float(C.float(0))
defer C.mlx_array_free(padValue)
cMode := C.CString("constant")
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(axes),
C.size_t(len(axes)),
unsafe.SliceData(lowPad),
C.size_t(len(lowPad)),
unsafe.SliceData(highPad),
C.size_t(len(highPad)),
padValue,
cMode,
DefaultStream().ctx,
)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array { func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1)) groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups) return Conv1d(x, weight, bias, 1, 0, 1, groups)

View File

@@ -11,8 +11,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
switch strings.ToUpper(quantization) { switch strings.ToUpper(quantization) {
case "NVFP4": case "NVFP4":
return 16, 4, "nvfp4" return 16, 4, "nvfp4"
case "MXFP4":
return 32, 4, "mxfp4"
case "FP4", "Q4", "INT4": case "FP4", "Q4", "INT4":
return 32, 4, "affine" return 64, 4, "affine"
case "MXFP8": case "MXFP8":
return 32, 8, "mxfp8" return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8": case "FP8", "Q8", "INT8":

View File

@@ -144,3 +144,44 @@ func TestLayerNormDefaultEps(t *testing.T) {
} }
} }
} }
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
skipIfNoMLX(t)
weightVals := make([]float32, 3*32)
for i := range weightVals {
weightVals[i] = float32((i%11)-5) / 7
}
inputVals := make([]float32, 2*32)
for i := range inputVals {
inputVals[i] = float32((i%7)-3) / 5
}
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
mlx.Eval(weight, input)
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
if ql.QBiases != nil {
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
}
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input)
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
mlx.Eval(qOut, dOut)
got := qOut.Floats()
want := dOut.Floats()
if len(got) != len(want) {
t.Fatalf("output length = %d, want %d", len(got), len(want))
}
for i := range got {
if !approxEqual(got[i], want[i], 1e-3) {
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
}
}
}

View File

@@ -420,7 +420,16 @@ func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, strin
} }
func supportsGatherQMM(mode string, bits int) bool { func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8) switch mode {
case "affine":
return bits == 4 || bits == 8
case "mxfp8":
return bits == 8
case "nvfp4", "mxfp4":
return bits == 4
default:
return false
}
} }
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) { func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {

View File

@@ -83,6 +83,28 @@ func TestLayerSelectionHelpers(t *testing.T) {
} }
} }
func TestSupportsGatherQMM(t *testing.T) {
tests := []struct {
mode string
bits int
want bool
}{
{mode: "affine", bits: 4, want: true},
{mode: "affine", bits: 8, want: true},
{mode: "mxfp8", bits: 8, want: true},
{mode: "nvfp4", bits: 4, want: true},
{mode: "mxfp4", bits: 4, want: true},
{mode: "mxfp8", bits: 4, want: false},
{mode: "affine", bits: 3, want: false},
}
for _, tt := range tests {
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
}
}
}
func TestResolveTensorPathLayout(t *testing.T) { func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy") dummy := mlx.New("dummy")