mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 18:55:53 +02:00
Compare commits
6 Commits
pdevine/ml
...
v0.18.3-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de5cb7311f | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 |
@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||
COPY MLX_VERSION MLX_C_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
v0.30.6
|
||||
1
MLX_C_VERSION
Normal file
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||
@@ -1 +1 @@
|
||||
v0.5.0
|
||||
38ad257088fb2193ad47e527cf6534a689f30943
|
||||
|
||||
@@ -96,6 +96,18 @@ The `/loop` command runs a prompt or slash command on a recurring schedule insid
|
||||
/loop 1h Remind me to review the deploy status
|
||||
```
|
||||
|
||||
## Telegram
|
||||
|
||||
Chat with Claude Code from Telegram by connecting a bot to your session. Install the [Telegram plugin](https://github.com/anthropics/claude-plugins-official), create a bot via [@BotFather](https://t.me/BotFather), then launch with the channel flag:
|
||||
|
||||
```shell
|
||||
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
|
||||
```
|
||||
|
||||
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
|
||||
|
||||
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
|
||||
|
||||
## Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
@@ -109,7 +109,7 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
|
||||
if !QuantizeSupported() {
|
||||
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
blobData, err := quantizePackedGroup(tensors)
|
||||
blobData, err := quantizePackedGroup(groupName, tensors)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -7,29 +7,27 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
)
|
||||
|
||||
// quantizeParams maps quantization type names to MLX quantize parameters.
|
||||
var quantizeParams = map[string]struct {
|
||||
groupSize int
|
||||
bits int
|
||||
mode string
|
||||
}{
|
||||
"int4": {64, 4, "affine"},
|
||||
"nvfp4": {16, 4, "nvfp4"},
|
||||
"int8": {64, 8, "affine"},
|
||||
"mxfp8": {32, 8, "mxfp8"},
|
||||
}
|
||||
|
||||
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
||||
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
||||
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
||||
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
||||
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
||||
if quantize != "" {
|
||||
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
|
||||
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
||||
@@ -50,11 +48,16 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
}
|
||||
|
||||
// Find the tensor key (may differ from name for single-tensor blobs)
|
||||
inputKey, err := findSafetensorsKey(tmpPath)
|
||||
header, err := readSafetensorsHeader(tmpPath)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
||||
}
|
||||
inputKey, err := safetensorsKey(name, header)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
|
||||
}
|
||||
|
||||
arr := st.Get(inputKey)
|
||||
if arr == nil {
|
||||
@@ -62,34 +65,46 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
||||
}
|
||||
|
||||
// Decode FP8 source encoding before checking quantize, so that callers
|
||||
// requesting decode-only (quantize="") receive usable float data.
|
||||
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
|
||||
scaleKey := inputKey + ".scale_inv"
|
||||
scaleInv := st.Get(scaleKey)
|
||||
if scaleInv == nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
|
||||
}
|
||||
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
|
||||
}
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
if quantize == "" {
|
||||
arr = mlx.Contiguous(arr)
|
||||
arr = mlx.Contiguous(arr, false)
|
||||
arrays[name] = arr
|
||||
return tmpPath, []*mlx.Array{arr}, st, nil
|
||||
}
|
||||
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
arr = arr.AsType(mlx.DTypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
params, ok := quantizeParams[quantize]
|
||||
if !ok {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
arrays[name] = qweight
|
||||
arrays[name+".scale"] = scales
|
||||
toEval = append(toEval, qweight, scales)
|
||||
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
qbiases = mlx.Contiguous(qbiases, false)
|
||||
arrays[name+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
@@ -101,27 +116,45 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
||||
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
||||
// The blob includes __metadata__ with quant_type and group_size.
|
||||
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
|
||||
// Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
arrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
||||
if tmpPath != "" {
|
||||
defer os.Remove(tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
defer st.Free()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
finalArrays := make([]*mlx.Array, 0, len(arrays))
|
||||
for _, arr := range arrays {
|
||||
if arr != nil {
|
||||
finalArrays = append(finalArrays, arr)
|
||||
}
|
||||
}
|
||||
mlx.Pin(finalArrays...)
|
||||
defer func() {
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
mlx.Unpin(finalArrays...)
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
mlx.Eval(toEval...)
|
||||
mlx.Sweep()
|
||||
// Free early to release mmap; defer guard handles error paths
|
||||
if st != nil {
|
||||
st.Free()
|
||||
st = nil
|
||||
}
|
||||
|
||||
// Build metadata for single-tensor blobs
|
||||
params := quantizeParams[quantize]
|
||||
groupSize, _, _ := model.QuantizationParams(quantize)
|
||||
metadata := map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(params.groupSize),
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
@@ -135,48 +168,81 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
||||
|
||||
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
||||
// combined safetensors blob. Used for packing expert groups.
|
||||
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
|
||||
// they are stacked into 3D switch_mlp tensors before quantization.
|
||||
// Each tensor may have a different quantization type (mixed-precision).
|
||||
// Returns the blob bytes. No __metadata__ is added because different tensors
|
||||
// may use different quantization types.
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Returns the blob bytes.
|
||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||
}
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var allToEval []*mlx.Array
|
||||
var tmpPaths []string
|
||||
var handles []*mlx.SafetensorsFile
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
uniformQuantize := ""
|
||||
hasQuantized := false
|
||||
mixedQuantize := false
|
||||
for _, input := range inputs {
|
||||
if input.Quantize == "" {
|
||||
if hasQuantized {
|
||||
mixedQuantize = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !hasQuantized {
|
||||
hasQuantized = true
|
||||
uniformQuantize = input.Quantize
|
||||
continue
|
||||
}
|
||||
if input.Quantize != uniformQuantize {
|
||||
mixedQuantize = true
|
||||
}
|
||||
}
|
||||
if hasQuantized && !mixedQuantize {
|
||||
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
|
||||
metadata = map[string]string{
|
||||
"quant_type": uniformQuantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
||||
if tmpPath != "" {
|
||||
tmpPaths = append(tmpPaths, tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
handles = append(handles, st)
|
||||
}
|
||||
if err != nil {
|
||||
// Cleanup on error
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
return nil, err
|
||||
}
|
||||
allToEval = append(allToEval, toEval...)
|
||||
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
finalArrays := arraysForPackedInput(allArrays, input)
|
||||
mlx.Pin(finalArrays...)
|
||||
pinned = append(pinned, finalArrays...)
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
if tmpPath != "" {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
defer func() {
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
mlx.Eval(allToEval...)
|
||||
|
||||
// Free native handles after eval
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
|
||||
// Save combined blob (no global metadata for mixed-precision packed blobs)
|
||||
// Save combined blob. Add global metadata only when every packed tensor uses
|
||||
// the same quantization mode and group size.
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
||||
}
|
||||
|
||||
@@ -185,17 +251,193 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
||||
}
|
||||
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
|
||||
keys := []string{input.Name}
|
||||
if input.Quantize != "" {
|
||||
keys = append(keys, input.Name+".scale", input.Name+".bias")
|
||||
}
|
||||
|
||||
out := make([]*mlx.Array, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if arr := allArrays[key]; arr != nil {
|
||||
out = append(out, arr)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
|
||||
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
|
||||
|
||||
type expertTensorInfo struct {
|
||||
index int
|
||||
proj string // e.g., "gate_proj.weight"
|
||||
input create.PackedTensorInput
|
||||
}
|
||||
|
||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||
// and returns the uniform quantization type shared by all inputs.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||
// or if the inputs have mixed quantization types.
|
||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||
if !strings.HasSuffix(groupName, ".experts") {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
quantize := inputs[0].Quantize
|
||||
groups := make(map[string][]expertTensorInfo)
|
||||
for _, input := range inputs {
|
||||
if input.Quantize != quantize {
|
||||
return nil, "" // mixed quantization types
|
||||
}
|
||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||
if m == nil {
|
||||
return nil, "" // not a per-expert pattern
|
||||
}
|
||||
index, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||
index: index,
|
||||
proj: m[2],
|
||||
input: input,
|
||||
})
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
return groups, quantize
|
||||
}
|
||||
|
||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||
metadata = map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Sort projection names for deterministic output
|
||||
projNames := make([]string, 0, len(projGroups))
|
||||
for proj := range projGroups {
|
||||
projNames = append(projNames, proj)
|
||||
}
|
||||
sort.Strings(projNames)
|
||||
|
||||
cleanup := func() {
|
||||
mlx.Unpin(pinned...)
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
for _, proj := range projNames {
|
||||
experts := projGroups[proj]
|
||||
|
||||
// Sort by expert index
|
||||
sort.Slice(experts, func(i, j int) bool {
|
||||
return experts[i].index < experts[j].index
|
||||
})
|
||||
|
||||
// Load and decode each expert tensor
|
||||
var decoded []*mlx.Array
|
||||
for _, expert := range experts {
|
||||
dummyArrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
arr := dummyArrays[expert.input.Name]
|
||||
mlx.Pin(arr)
|
||||
pinned = append(pinned, arr)
|
||||
decoded = append(decoded, arr)
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
if tmpPath != "" {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
// Stack into 3D along axis 0: [numExperts, rows, cols]
|
||||
stacked := mlx.Stack(decoded, 0)
|
||||
mlx.Eval(stacked)
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
|
||||
// Free individual decoded arrays
|
||||
mlx.Unpin(decoded...)
|
||||
mlx.Sweep()
|
||||
|
||||
stackedName := groupBase + ".switch_mlp." + proj
|
||||
|
||||
// Quantize the stacked tensor
|
||||
if quantize != "" {
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
allArrays[stackedName] = qweight
|
||||
allArrays[stackedName+".scale"] = scales
|
||||
|
||||
toEval := []*mlx.Array{qweight, scales}
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases, false)
|
||||
allArrays[stackedName+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
mlx.Pin(toEval...)
|
||||
pinned = append(pinned, toEval...)
|
||||
|
||||
// Free stacked source array
|
||||
mlx.Unpin(stacked)
|
||||
mlx.Sweep()
|
||||
} else {
|
||||
stacked = mlx.Contiguous(stacked, false)
|
||||
mlx.Eval(stacked)
|
||||
allArrays[stackedName] = stacked
|
||||
}
|
||||
}
|
||||
|
||||
defer cleanup()
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
|
||||
}
|
||||
|
||||
blobData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
|
||||
}
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||
func QuantizeSupported() bool {
|
||||
mlx.InitMLX()
|
||||
return mlx.IsMLXAvailable()
|
||||
return mlx.CheckInit() == nil
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
@@ -205,32 +447,97 @@ func ensureTempDir() string {
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
|
||||
func findSafetensorsKey(path string) (string, error) {
|
||||
type safetensorsHeaderEntry struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
|
||||
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
var header map[string]safetensorsHeaderEntry
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return header, nil
|
||||
}
|
||||
|
||||
for k := range header {
|
||||
if k != "__metadata__" {
|
||||
return k, nil
|
||||
// safetensorsKey resolves the primary tensor key from a header.
|
||||
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
|
||||
if preferred != "" {
|
||||
if _, ok := header[preferred]; ok {
|
||||
return preferred, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
|
||||
keys := make([]string, 0, len(header))
|
||||
for k := range header {
|
||||
if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(keys) == 0 {
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
}
|
||||
return keys[0], nil
|
||||
}
|
||||
|
||||
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
if weight == nil || scaleInv == nil {
|
||||
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scaleInv.Dims()
|
||||
if len(weightShape) != 2 || len(scaleShape) != 2 {
|
||||
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
|
||||
}
|
||||
|
||||
// These must match the block size validated by resolveEffectiveQuantization
|
||||
// in create.go, which rejects any source model with a different block size.
|
||||
const blockRows = 128
|
||||
const blockCols = 128
|
||||
rows, cols := weightShape[0], weightShape[1]
|
||||
expectedScaleRows := (rows + blockRows - 1) / blockRows
|
||||
expectedScaleCols := (cols + blockCols - 1) / blockCols
|
||||
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
|
||||
scaleShape,
|
||||
weightShape,
|
||||
expectedScaleRows,
|
||||
expectedScaleCols,
|
||||
)
|
||||
}
|
||||
|
||||
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
|
||||
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
@@ -267,13 +267,13 @@ func ShouldQuantize(name, component string) bool {
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
|
||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "mxfp4", "int8", "mxfp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
return GetTensorQuantization(name, shape, quantize) != ""
|
||||
}
|
||||
|
||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp4/MXFP4, mxfp8/MXFP8
|
||||
func normalizeQuantType(quantize string) string {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "Q4", "INT4", "FP4":
|
||||
@@ -282,6 +282,8 @@ func normalizeQuantType(quantize string) string {
|
||||
return "int8"
|
||||
case "NVFP4":
|
||||
return "nvfp4"
|
||||
case "MXFP4":
|
||||
return "mxfp4"
|
||||
case "MXFP8":
|
||||
return "mxfp8"
|
||||
default:
|
||||
@@ -335,7 +337,7 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, mxfp8: 32, int4/int8: 64
|
||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
@@ -353,8 +355,8 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
|
||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
@@ -391,23 +393,39 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
||||
// Captures: model.layers.{L}.mlp.experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
||||
var expertLayerPrefixRegexp = regexp.MustCompile(`^(?:model\.language_model\.|language_model(?:\.model)?\.|model\.)?layers\.\d+$`)
|
||||
|
||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||
// For example:
|
||||
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
|
||||
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
|
||||
// - "language_model.model.layers.1.mlp.switch_mlp.down_proj.weight" -> "language_model.model.layers.1.mlp.switch_mlp"
|
||||
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
|
||||
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
|
||||
func ExpertGroupPrefix(tensorName string) string {
|
||||
m := expertGroupRegexp.FindStringSubmatch(tensorName)
|
||||
if m == nil {
|
||||
if !strings.HasSuffix(tensorName, ".weight") {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
|
||||
for _, marker := range []string{
|
||||
".mlp.experts.",
|
||||
".mlp.shared_experts.",
|
||||
".mlp.switch_mlp.",
|
||||
} {
|
||||
idx := strings.Index(tensorName, marker)
|
||||
if idx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
layerPrefix := tensorName[:idx]
|
||||
if !expertLayerPrefixRegexp.MatchString(layerPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
return layerPrefix + strings.TrimSuffix(marker, ".")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
|
||||
@@ -424,9 +442,11 @@ type PackedTensorInput struct {
|
||||
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
|
||||
|
||||
type sourceQuantization struct {
|
||||
Bits int `json:"bits"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Mode string `json:"mode"`
|
||||
Bits int `json:"bits"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Mode string `json:"mode"`
|
||||
QuantMethod string `json:"quant_method"`
|
||||
WeightBlockSize []int32 `json:"weight_block_size"`
|
||||
}
|
||||
|
||||
type sourceModelConfig struct {
|
||||
@@ -493,6 +513,98 @@ func (cfg sourceModelConfig) QuantMetadata() map[string]string {
|
||||
return metadata
|
||||
}
|
||||
|
||||
type sourceQuantizedKind string
|
||||
|
||||
const (
|
||||
sourceQuantizedKindNone sourceQuantizedKind = ""
|
||||
sourceQuantizedKindPrequantized sourceQuantizedKind = "prequantized"
|
||||
sourceQuantizedKindHFFP8 sourceQuantizedKind = "hf_fp8"
|
||||
)
|
||||
|
||||
func (cfg sourceModelConfig) quantizationConfigs() []sourceQuantization {
|
||||
return []sourceQuantization{
|
||||
cfg.Quantization,
|
||||
cfg.QuantizationConfig,
|
||||
cfg.TextConfig.Quantization,
|
||||
cfg.TextConfig.QuantizationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg sourceModelConfig) HFFP8WeightBlockSize() (rows, cols int32, ok bool) {
|
||||
for _, q := range cfg.quantizationConfigs() {
|
||||
if !strings.EqualFold(q.QuantMethod, "fp8") || len(q.WeightBlockSize) != 2 {
|
||||
continue
|
||||
}
|
||||
return q.WeightBlockSize[0], q.WeightBlockSize[1], true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func inspectSourceQuantization(modelDir string, cfg sourceModelConfig) (sourceQuantizedKind, error) {
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return sourceQuantizedKindNone, err
|
||||
}
|
||||
|
||||
hasScaleInv := false
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
|
||||
if err != nil {
|
||||
return sourceQuantizedKindNone, err
|
||||
}
|
||||
|
||||
for _, name := range extractor.ListTensors() {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
extractor.Close()
|
||||
return sourceQuantizedKindPrequantized, nil
|
||||
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||
hasScaleInv = true
|
||||
}
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
if hasScaleInv {
|
||||
if _, _, ok := cfg.HFFP8WeightBlockSize(); ok {
|
||||
return sourceQuantizedKindHFFP8, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sourceQuantizedKindNone, nil
|
||||
}
|
||||
|
||||
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
|
||||
switch sourceKind {
|
||||
case sourceQuantizedKindNone:
|
||||
return requested, nil
|
||||
case sourceQuantizedKindPrequantized:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
|
||||
}
|
||||
return "", nil
|
||||
case sourceQuantizedKindHFFP8:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized fp8 source model with --quantize %q", requested)
|
||||
}
|
||||
rows, cols, ok := cfg.HFFP8WeightBlockSize()
|
||||
if !ok {
|
||||
return "", fmt.Errorf("fp8 source model missing weight_block_size metadata")
|
||||
}
|
||||
if rows != 128 || cols != 128 {
|
||||
return "", fmt.Errorf("unsupported fp8 source block size %dx%d", rows, cols)
|
||||
}
|
||||
return "mxfp8", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported source quantization kind %q", sourceKind)
|
||||
}
|
||||
}
|
||||
|
||||
type tensorImportTransform interface {
|
||||
skipTensor(name string) bool
|
||||
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
|
||||
@@ -546,6 +658,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read source config.json: %w", err)
|
||||
}
|
||||
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to inspect source quantization: %w", err)
|
||||
}
|
||||
effectiveQuantize, err := resolveEffectiveQuantization(sourceConfig, sourceQuantKind, quantize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sourceQuantMetadata := sourceConfig.QuantMetadata()
|
||||
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
@@ -557,7 +677,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if len(createPackedLayer) > 0 {
|
||||
packedCreator = createPackedLayer[0]
|
||||
}
|
||||
|
||||
// Accumulate expert tensors by group prefix for packing.
|
||||
// Readers reference file-backed SectionReaders, so we keep extractors
|
||||
// open until each group is flushed to avoid buffering tensor data in memory.
|
||||
@@ -600,8 +719,8 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
tensorSet[name] = struct{}{}
|
||||
}
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
if effectiveQuantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", effectiveQuantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
@@ -612,9 +731,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
if importTransform.skipTensor(tensorName) {
|
||||
continue
|
||||
}
|
||||
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) {
|
||||
if shouldSkipSourceCompanion(tensorName, tensorSet) {
|
||||
continue
|
||||
}
|
||||
sourceFP8ScaleName, hasSourceFP8Scale := sourceFP8Companion(tensorName, tensorSet)
|
||||
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
@@ -623,7 +743,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
if quantize == "" {
|
||||
if effectiveQuantize == "" {
|
||||
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
@@ -647,8 +767,33 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
||||
quantizeType := ""
|
||||
if quantize != "" {
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize)
|
||||
switch {
|
||||
case sourceQuantKind == sourceQuantizedKindHFFP8 && hasSourceFP8Scale:
|
||||
quantizeType = "mxfp8"
|
||||
case sourceQuantKind == sourceQuantizedKindHFFP8:
|
||||
quantizeType = ""
|
||||
case effectiveQuantize != "":
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||
}
|
||||
reader := outTD.SafetensorsReader()
|
||||
if hasSourceFP8Scale {
|
||||
if len(outputTensors) != 1 {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("source fp8 tensor %s rewrote into %d tensors; only 1:1 rewrites are supported", tensorName, len(outputTensors))
|
||||
}
|
||||
if quantizeType == "" {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("source fp8 tensor %s was not scheduled for mxfp8 conversion", tensorName)
|
||||
}
|
||||
scaleTD, err := extractor.GetTensor(sourceFP8ScaleName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to get fp8 scale tensor %s: %w", sourceFP8ScaleName, err)
|
||||
}
|
||||
reader = buildSourceFP8Reader(outTD, scaleTD.WithName(outTD.Name+".scale_inv"))
|
||||
}
|
||||
|
||||
// Check if this tensor belongs to an expert group for packing
|
||||
@@ -670,13 +815,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
Dtype: outTD.Dtype,
|
||||
Shape: outTD.Shape,
|
||||
Quantize: quantizeType,
|
||||
Reader: outTD.SafetensorsReader(),
|
||||
Reader: reader,
|
||||
})
|
||||
} else {
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
newLayers, err := createTensorLayer(reader, outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
@@ -760,7 +905,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool {
|
||||
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
|
||||
@@ -768,11 +913,28 @@ func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{})
|
||||
case strings.HasSuffix(name, ".biases"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
|
||||
return ok
|
||||
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||
_, ok := tensorSet[strings.TrimSuffix(name, "_scale_inv")]
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sourceFP8Companion(weightName string, tensorSet map[string]struct{}) (scaleName string, ok bool) {
|
||||
if !strings.HasSuffix(weightName, ".weight") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
scaleName = weightName + "_scale_inv"
|
||||
_, ok = tensorSet[scaleName]
|
||||
return scaleName, ok
|
||||
}
|
||||
|
||||
func buildSourceFP8Reader(weightTD, scaleTD *safetensors.TensorData) io.Reader {
|
||||
return safetensors.BuildPackedSafetensorsReader([]*safetensors.TensorData{weightTD, scaleTD})
|
||||
}
|
||||
|
||||
func createPrequantizedLayer(
|
||||
extractor *safetensors.TensorExtractor,
|
||||
td *safetensors.TensorData,
|
||||
|
||||
@@ -246,6 +246,30 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
|
||||
t.Helper()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -546,6 +570,215 @@ func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
|
||||
})
|
||||
|
||||
quantizeByName := make(map[string]string)
|
||||
headerNamesByName := make(map[string][]string)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
|
||||
}
|
||||
|
||||
if got := quantizeByName["norm.weight"]; got != "" {
|
||||
t.Fatalf("norm.weight quantization = %q, want empty", got)
|
||||
}
|
||||
if got := quantizeByName["dense.weight"]; got != "" {
|
||||
t.Fatalf("dense.weight quantization = %q, want empty", got)
|
||||
}
|
||||
|
||||
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
|
||||
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
|
||||
}
|
||||
|
||||
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
|
||||
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
|
||||
}
|
||||
|
||||
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
|
||||
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
|
||||
}
|
||||
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
|
||||
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
tensors []*st.TensorData
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "prequantized affine",
|
||||
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
|
||||
tensors: []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
|
||||
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
|
||||
},
|
||||
{
|
||||
name: "hf fp8 source",
|
||||
configJSON: `{
|
||||
"model_type": "test",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`,
|
||||
tensors: []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
},
|
||||
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create 2 experts so stacking produces a [2, 128, 128] tensor
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||
})
|
||||
|
||||
var packedLayerNames []string
|
||||
var packedLayerTensors [][]PackedTensorInput
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
packedLayerNames = append(packedLayerNames, groupName)
|
||||
packedLayerTensors = append(packedLayerTensors, tensors)
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(packedLayerNames) != 1 {
|
||||
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
|
||||
}
|
||||
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
|
||||
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
|
||||
}
|
||||
|
||||
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
|
||||
tensors := packedLayerTensors[0]
|
||||
if len(tensors) != 6 {
|
||||
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
|
||||
}
|
||||
|
||||
// All should be marked for mxfp8 quantization
|
||||
for _, tensor := range tensors {
|
||||
if tensor.Quantize != "mxfp8" {
|
||||
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -693,6 +926,113 @@ func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
|
||||
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
|
||||
t.Run(quantize, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"text_config": {"dtype": "bfloat16"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
gateUpValues := make([]float32, 2*128*64)
|
||||
for expert := range 2 {
|
||||
base := expert * 128 * 64
|
||||
for i := range 64 * 64 {
|
||||
gateUpValues[base+i] = 1
|
||||
gateUpValues[base+64*64+i] = 2
|
||||
}
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||
})
|
||||
|
||||
type tensorCall struct {
|
||||
quantize string
|
||||
}
|
||||
type packedTensorCall struct {
|
||||
Name string
|
||||
Quantize string
|
||||
}
|
||||
|
||||
tensorCalls := make(map[string]tensorCall)
|
||||
packedCalls := make(map[string][]packedTensorCall)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
tensorCalls[name] = tensorCall{quantize: quantizeType}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
group := make([]packedTensorCall, 0, len(tensors))
|
||||
for _, tensor := range tensors {
|
||||
group = append(group, packedTensorCall{
|
||||
Name: tensor.Name,
|
||||
Quantize: tensor.Quantize,
|
||||
})
|
||||
}
|
||||
packedCalls[groupName] = group
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
for _, name := range []string{
|
||||
"language_model.model.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
|
||||
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
|
||||
"language_model.model.layers.0.mlp.gate.weight",
|
||||
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
|
||||
} {
|
||||
if got := tensorCalls[name].quantize; got != "" {
|
||||
t.Fatalf("%s quantize = %q, want empty", name, got)
|
||||
}
|
||||
}
|
||||
|
||||
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
|
||||
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
|
||||
}
|
||||
|
||||
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
|
||||
if len(group) != 3 {
|
||||
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
|
||||
}
|
||||
for _, tensor := range group {
|
||||
if tensor.Quantize != quantize {
|
||||
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -865,6 +1205,7 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||
@@ -891,9 +1232,11 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Group size divisibility tests
|
||||
// FP8/FP4 require divisible by 32
|
||||
// FP8/FP4/MXFP4 require divisible by 32
|
||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
|
||||
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
|
||||
// NVFP4 requires divisible by 16
|
||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||
@@ -919,10 +1262,20 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// Expert tensors with language_model prefix should also match
|
||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||
|
||||
// Shared expert tensors should return their own group prefix
|
||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||
|
||||
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
|
||||
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
|
||||
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
|
||||
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
|
||||
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
|
||||
|
||||
// Non-expert tensors should return empty string
|
||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
||||
@@ -978,6 +1331,161 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
if combinedDown != "int8" {
|
||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||
}
|
||||
|
||||
nvfp4GateUp := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
[]int32{64, 11008, 4096},
|
||||
"nvfp4",
|
||||
)
|
||||
if nvfp4GateUp != "nvfp4" {
|
||||
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
|
||||
}
|
||||
|
||||
nvfp4Down := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
[]int32{64, 4096, 11008},
|
||||
"nvfp4",
|
||||
)
|
||||
if nvfp4Down != "nvfp4" {
|
||||
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
|
||||
}
|
||||
|
||||
mxfp4GateUp := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
[]int32{64, 11008, 4096},
|
||||
"mxfp4",
|
||||
)
|
||||
if mxfp4GateUp != "mxfp4" {
|
||||
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
|
||||
}
|
||||
|
||||
mxfp4Down := GetTensorQuantization(
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
[]int32{64, 4096, 11008},
|
||||
"mxfp4",
|
||||
)
|
||||
if mxfp4Down != "mxfp4" {
|
||||
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
configJSON := `{
|
||||
"model_type": "test",
|
||||
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||
"text_config": {"dtype": "bfloat16"}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
gateUpValues := make([]float32, 2*128*64)
|
||||
for expert := range 2 {
|
||||
base := expert * 128 * 64
|
||||
for i := range 64 * 64 {
|
||||
gateUpValues[base+i] = 1
|
||||
gateUpValues[base+64*64+i] = 2
|
||||
}
|
||||
}
|
||||
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||
})
|
||||
|
||||
type tensorCall struct {
|
||||
quantize string
|
||||
}
|
||||
type packedTensorCall struct {
|
||||
Name string
|
||||
Dtype string
|
||||
Shape []int32
|
||||
Quantize string
|
||||
}
|
||||
|
||||
tensorCalls := make(map[string]tensorCall)
|
||||
packedCalls := make(map[string][]packedTensorCall)
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
_, _ = io.ReadAll(r)
|
||||
tensorCalls[name] = tensorCall{quantize: quantize}
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||
}
|
||||
|
||||
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||
group := make([]packedTensorCall, 0, len(tensors))
|
||||
for _, tensor := range tensors {
|
||||
group = append(group, packedTensorCall{
|
||||
Name: tensor.Name,
|
||||
Dtype: tensor.Dtype,
|
||||
Shape: append([]int32(nil), tensor.Shape...),
|
||||
Quantize: tensor.Quantize,
|
||||
})
|
||||
}
|
||||
packedCalls[groupName] = group
|
||||
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
groupName := "language_model.model.layers.0.mlp.switch_mlp"
|
||||
group, ok := packedCalls[groupName]
|
||||
if !ok {
|
||||
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
|
||||
}
|
||||
|
||||
if len(group) != 3 {
|
||||
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
|
||||
}
|
||||
|
||||
gotNames := make([]string, 0, len(group))
|
||||
for _, tensor := range group {
|
||||
gotNames = append(gotNames, tensor.Name)
|
||||
if tensor.Quantize != "nvfp4" {
|
||||
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
|
||||
}
|
||||
if tensor.Dtype != "BF16" {
|
||||
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
|
||||
}
|
||||
}
|
||||
slices.Sort(gotNames)
|
||||
|
||||
wantNames := []string{
|
||||
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
|
||||
}
|
||||
if !slices.Equal(gotNames, wantNames) {
|
||||
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
|
||||
}
|
||||
|
||||
for _, name := range wantNames {
|
||||
if _, ok := tensorCalls[name]; ok {
|
||||
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
|
||||
}
|
||||
}
|
||||
|
||||
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
|
||||
t.Fatalf("embed_tokens quantize = %q, want empty", got)
|
||||
}
|
||||
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
|
||||
t.Fatalf("mlp.gate quantize = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
|
||||
@@ -87,6 +87,27 @@ func (t qwen35ImportTransform) skipTensor(name string) bool {
|
||||
return strings.Contains(name, "mtp.")
|
||||
}
|
||||
|
||||
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, "embed_tokens.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, "lm_head.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
|
||||
return true
|
||||
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
if strings.HasPrefix(name, "vision_tower.") {
|
||||
return ""
|
||||
@@ -127,6 +148,13 @@ func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quan
|
||||
return ""
|
||||
}
|
||||
|
||||
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
|
||||
// keep embeddings, LM head, low-rank linear_attn projections, and routing
|
||||
// gates in BF16 rather than forcing them into a non-affine quantized format.
|
||||
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
# Read MLX-C version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
# Read MLX core version from top-level file
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
|
||||
# Read MLX version from top-level file
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
|
||||
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
|
||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
||||
|
||||
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
|
||||
find_program(GO_EXECUTABLE go REQUIRED)
|
||||
message(STATUS "Regenerating MLX Go wrappers")
|
||||
execute_process(
|
||||
COMMAND ${GO_EXECUTABLE} generate ./x/...
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
COMMAND_ERROR_IS_FATAL ANY
|
||||
)
|
||||
|
||||
# For local dev builds, override MLX_VERSION with git describe output
|
||||
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
||||
execute_process(
|
||||
|
||||
@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
|
||||
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
||||
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
|
||||
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
||||
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
||||
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
||||
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
|
||||
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
||||
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
||||
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
||||
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
|
||||
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
||||
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
||||
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
|
||||
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
||||
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
|
||||
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
||||
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
||||
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
|
||||
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
|
||||
if (mlx_bartlett_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
||||
if (mlx_bitwise_and_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
||||
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
|
||||
if (mlx_blackman_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
||||
if (mlx_block_masked_mm_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
||||
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
|
||||
if (mlx_hamming_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
|
||||
if (mlx_hanning_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
||||
if (mlx_identity_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
||||
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
|
||||
return mlx_distributed_group_split_ptr(group, color, key);
|
||||
}
|
||||
|
||||
bool mlx_distributed_is_available(void) {
|
||||
return mlx_distributed_is_available_ptr();
|
||||
bool mlx_distributed_is_available(const char* bk) {
|
||||
return mlx_distributed_is_available_ptr(bk);
|
||||
}
|
||||
|
||||
mlx_distributed_group mlx_distributed_init(bool strict) {
|
||||
return mlx_distributed_init_ptr(strict);
|
||||
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
|
||||
return mlx_distributed_init_ptr(strict, bk);
|
||||
}
|
||||
|
||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
||||
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||
return mlx_atleast_3d_ptr(res, a, s);
|
||||
}
|
||||
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_bartlett_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
||||
return mlx_bitwise_and_ptr(res, a, b, s);
|
||||
}
|
||||
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
|
||||
return mlx_bitwise_xor_ptr(res, a, b, s);
|
||||
}
|
||||
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_blackman_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
||||
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
||||
}
|
||||
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
|
||||
return mlx_depends_ptr(res, inputs, dependencies);
|
||||
}
|
||||
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
|
||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||
}
|
||||
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
|
||||
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
||||
}
|
||||
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hamming_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hanning_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||
return mlx_identity_ptr(res, n, dtype, s);
|
||||
}
|
||||
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
|
||||
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
||||
}
|
||||
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
|
||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||
}
|
||||
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
|
||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
|
||||
}
|
||||
|
||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
|
||||
@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
res := C.mlx_vector_array_new()
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
|
||||
|
||||
// Result is a vector of arrays: [weights, scales, biases?]
|
||||
// mxfp8 mode returns only 2 elements (no biases)
|
||||
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
}
|
||||
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
|
||||
@@ -309,10 +309,12 @@
|
||||
#undef mlx_atleast_1d
|
||||
#undef mlx_atleast_2d
|
||||
#undef mlx_atleast_3d
|
||||
#undef mlx_bartlett
|
||||
#undef mlx_bitwise_and
|
||||
#undef mlx_bitwise_invert
|
||||
#undef mlx_bitwise_or
|
||||
#undef mlx_bitwise_xor
|
||||
#undef mlx_blackman
|
||||
#undef mlx_block_masked_mm
|
||||
#undef mlx_broadcast_arrays
|
||||
#undef mlx_broadcast_to
|
||||
@@ -365,6 +367,8 @@
|
||||
#undef mlx_greater
|
||||
#undef mlx_greater_equal
|
||||
#undef mlx_hadamard_transform
|
||||
#undef mlx_hamming
|
||||
#undef mlx_hanning
|
||||
#undef mlx_identity
|
||||
#undef mlx_imag
|
||||
#undef mlx_inner
|
||||
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
|
||||
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
||||
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
||||
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
||||
extern bool (*mlx_distributed_is_available_ptr)(void);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
|
||||
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
|
||||
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
||||
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
||||
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
|
||||
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
||||
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
|
||||
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
||||
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
||||
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
|
||||
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
|
||||
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
||||
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
||||
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
|
||||
|
||||
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||
|
||||
bool mlx_distributed_is_available(void);
|
||||
bool mlx_distributed_is_available(const char* bk);
|
||||
|
||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
||||
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
|
||||
|
||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||
|
||||
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
|
||||
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
|
||||
|
||||
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||
|
||||
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
|
||||
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
|
||||
|
||||
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
|
||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
|
||||
|
||||
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||
|
||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
|
||||
|
||||
@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||
}
|
||||
|
||||
// Check for partial match within a node's edge — truncate path
|
||||
// to the parent boundary. snapshot() will split the node and
|
||||
// create the branch point during prefill when caches are ready.
|
||||
partialMatch := false
|
||||
if len(matchPath) > 1 {
|
||||
lastNode := matchPath[len(matchPath)-1]
|
||||
matchedInEdge := matched - lastNode.startOffset()
|
||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||
matchPath = matchPath[:len(matchPath)-1]
|
||||
partialMatch = true
|
||||
}
|
||||
}
|
||||
|
||||
// Switch to the matched path, paging in/out as needed.
|
||||
c.switchToPath(matchPath)
|
||||
c.switchToPath(matchPath, matched)
|
||||
|
||||
// switchToPath aligns caches to a common offset
|
||||
prefix := c.minCacheOffset()
|
||||
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
// Schedule a snapshot at the branch point during prefill so future
|
||||
// requests diverging here can restore instead of re-evaluating.
|
||||
var snapshotAt int
|
||||
if partialMatch || (prefix == 0 && matched > 0) {
|
||||
if prefix < matched {
|
||||
snapshotAt = matched
|
||||
}
|
||||
|
||||
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
|
||||
// switchToPath transitions from the current active path to a new path,
|
||||
// paging out diverging segments and paging in the new path.
|
||||
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||
defer c.enforceEvictionPolicy()
|
||||
|
||||
// Find common ancestor index.
|
||||
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||
// the intermediate boundary.
|
||||
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
||||
leaf := len(c.activePath) - 1
|
||||
leafDiverges := leaf >= commonLen
|
||||
leafNeedsRewind := matched < c.activePath[leaf].endOffset
|
||||
if leafDiverges || leafNeedsRewind {
|
||||
node := c.activePath[leaf]
|
||||
if !node.hasAllSnapshots() {
|
||||
fromOffset := node.startOffset()
|
||||
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
}
|
||||
}
|
||||
|
||||
// Rewind each cache to the ancestor offset or free it. Freed
|
||||
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
||||
// from snapshots during page-in.
|
||||
// Rewind each cache to the target offset or free it. When matched
|
||||
// falls within the ancestor's range (same-path case), we rewind
|
||||
// directly to the match point. Otherwise we rewind to the ancestor
|
||||
// and let page-in bring us forward to matched.
|
||||
rewindTarget := min(ancestorOffset, matched)
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(nil, ancestorOffset) {
|
||||
if !kv.Restore(nil, rewindTarget) {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
// Page in — walk the full new path, restoring from snapshots.
|
||||
// Freed caches naturally pick up the first available snapshot.
|
||||
// Caches already past a node skip it via offset check.
|
||||
pageIn:
|
||||
for _, node := range newPath {
|
||||
if len(node.snapshots) == 0 {
|
||||
if !node.hasSnapshots() {
|
||||
continue
|
||||
}
|
||||
nodeTarget := min(node.endOffset, matched)
|
||||
for j, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||
continue
|
||||
}
|
||||
if kv.Offset() >= node.endOffset {
|
||||
if kv.Offset() >= nodeTarget {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
||||
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
||||
c.freeAll()
|
||||
c.activePath = []*trieNode{c.root}
|
||||
return
|
||||
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||
// Restore failed — stop page-in and let alignment
|
||||
// bring all caches to a consistent offset.
|
||||
break pageIn
|
||||
}
|
||||
}
|
||||
if node.endOffset > ancestorOffset {
|
||||
pageInCount++
|
||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,6 +529,9 @@ func (c *kvCache) dumpTree() {
|
||||
if nodeBytes > 0 {
|
||||
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||
}
|
||||
if !n.lastUsed.IsZero() {
|
||||
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
|
||||
}
|
||||
var flags []string
|
||||
if n.user {
|
||||
flags = append(flags, "user")
|
||||
|
||||
28
x/mlxrunner/cache/cache.go
vendored
28
x/mlxrunner/cache/cache.go
vendored
@@ -17,7 +17,8 @@ type Cache interface {
|
||||
Snapshot(fromOffset int) Snapshot
|
||||
|
||||
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||
// using the cache's own live state.
|
||||
// using the cache's own live state. Returns false if the target is
|
||||
// unreachable (e.g. target > current offset, or negative).
|
||||
Restore(snapshot Snapshot, target int) bool
|
||||
|
||||
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
||||
}
|
||||
|
||||
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
// Rewind using live state — just clamp offset.
|
||||
target = max(0, min(target, c.offset))
|
||||
if target > c.offset {
|
||||
return false
|
||||
}
|
||||
c.offset = target
|
||||
return true
|
||||
}
|
||||
|
||||
snap := snapshot.(*kvSnapshot)
|
||||
|
||||
// Check that the cache has data up to the snapshot's starting point.
|
||||
if c.offset < snap.fromOffset {
|
||||
if target > snap.toOffset || c.offset < snap.fromOffset {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target >= c.offset {
|
||||
return target == c.offset
|
||||
}
|
||||
// Live rewind is only safe when the buffer hasn't filled yet
|
||||
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||
// leaves fewer than maxSize trailing tokens to attend to —
|
||||
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if c.offset > c.maxSize {
|
||||
return false
|
||||
}
|
||||
target = max(0, min(target, c.offset))
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
return true
|
||||
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
|
||||
snap := snapshot.(*rotatingSnapshot)
|
||||
|
||||
if target > snap.toOffset {
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject if clamping would leave an incomplete window.
|
||||
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||
return false
|
||||
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
|
||||
// Clamp to target if needed.
|
||||
if target < c.offset {
|
||||
target = max(0, target)
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
}
|
||||
|
||||
22
x/mlxrunner/cache/recurrent.go
vendored
22
x/mlxrunner/cache/recurrent.go
vendored
@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
mlx.Unpin(old)
|
||||
|
||||
return v
|
||||
}
|
||||
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
||||
detached := root.Clone()
|
||||
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
mlx.Unpin(old)
|
||||
|
||||
return detached
|
||||
}
|
||||
@@ -150,10 +140,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
||||
|
||||
snap := snapshot.(*recurrentSnapshot)
|
||||
|
||||
// Recurrent state encodes all tokens up to snap.offset. Restoring
|
||||
// to a target before that would leave stale state from tokens
|
||||
// [target, snap.offset) baked in. Only allow restoring forward.
|
||||
if target < snap.offset {
|
||||
// Recurrent snapshots encode cumulative state up to exactly
|
||||
// snap.offset. Target must match — rewinding would leave stale
|
||||
// state, and advancing isn't possible without feeding tokens.
|
||||
if target != snap.offset {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
34
x/mlxrunner/cache/recurrent_test.go
vendored
34
x/mlxrunner/cache/recurrent_test.go
vendored
@@ -6,39 +6,35 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
|
||||
// allows restoring forward (target >= snapshot offset), never backward.
|
||||
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
|
||||
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
|
||||
// only succeeds when target exactly matches the snapshot's offset. Recurrent
|
||||
// state is cumulative, so it can't be rewound or fast-forwarded.
|
||||
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||
_ = c.ConvState(1, mlx.DTypeFloat16)
|
||||
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
||||
c.Advance(10)
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
snap := c.Snapshot(0) // snap.offset == 10
|
||||
|
||||
c.Advance(5) // now at 15
|
||||
c.Advance(5) // cache now at 15
|
||||
|
||||
// Restore backward should fail.
|
||||
// target < snap.offset: fails (can't rewind past snapshot)
|
||||
if c.Restore(snap, 5) {
|
||||
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
|
||||
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// Restore to exact snap offset should succeed.
|
||||
// target > snap.offset: fails (can't advance without feeding tokens)
|
||||
if c.Restore(snap, 15) {
|
||||
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// target == snap.offset: succeeds
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore(snap, 10) should succeed")
|
||||
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
}
|
||||
|
||||
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
|
||||
snap2 := c.Snapshot(0)
|
||||
if !c.Restore(snap2, 15) {
|
||||
t.Fatal("Restore(snap, 15) should succeed")
|
||||
}
|
||||
// Recurrent state is at snap.offset (10), not target (15).
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
// Rewind live state.
|
||||
if target < 0 {
|
||||
target = 0
|
||||
}
|
||||
if target > len(c.tokens) {
|
||||
target = len(c.tokens)
|
||||
return false
|
||||
}
|
||||
c.tokens = c.tokens[:target]
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if len(c.tokens) < s.from {
|
||||
return false // don't have base data up to snapshot start
|
||||
if target > s.to || len(c.tokens) < s.from {
|
||||
return false
|
||||
}
|
||||
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||
if target < len(c.tokens) {
|
||||
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target == len(c.tokens) {
|
||||
return true
|
||||
if target >= len(c.tokens) {
|
||||
return target == len(c.tokens)
|
||||
}
|
||||
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||
if len(c.tokens) > c.maxSize {
|
||||
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target > s.to {
|
||||
return false
|
||||
}
|
||||
// Reject if clamping would leave an incomplete window
|
||||
// (matches RotatingKVCache behavior).
|
||||
if target < s.to && s.to > c.maxSize {
|
||||
return false
|
||||
}
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
if target < len(c.tokens) {
|
||||
c.tokens = c.tokens[:target]
|
||||
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
return target == len(c.tokens) // can only no-op
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target < s.to {
|
||||
return false // can't go backward
|
||||
if target != s.to {
|
||||
return false // cumulative state requires exact match
|
||||
}
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
return true
|
||||
@@ -294,9 +306,10 @@ type feedableCache interface {
|
||||
|
||||
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
||||
type testEnv struct {
|
||||
kvc *kvCache
|
||||
caches []cache.Cache // typed references for assertions
|
||||
tracker *snapshotTracker
|
||||
kvc *kvCache
|
||||
caches []cache.Cache // typed references for assertions
|
||||
tracker *snapshotTracker
|
||||
rewindable bool // true when all caches support arbitrary Restore(nil, target)
|
||||
}
|
||||
|
||||
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||
@@ -305,23 +318,28 @@ func newTransformerEnv() *testEnv {
|
||||
tracker := &snapshotTracker{}
|
||||
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tracker,
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tracker,
|
||||
rewindable: true,
|
||||
}
|
||||
}
|
||||
|
||||
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||
// one sliding window cache (Mistral-style architecture).
|
||||
// one sliding window cache (Mistral-style architecture). The sliding window
|
||||
// maxSize is set small enough that test sequences fill it, making
|
||||
// Restore(nil, target) fail — the same behavior as production models where
|
||||
// the window fills after a few turns.
|
||||
func newSlidingWindowEnv() *testEnv {
|
||||
tr := &snapshotTracker{}
|
||||
rc := &fakeRewindableCache{tracker: tr}
|
||||
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
|
||||
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
|
||||
caches := []cache.Cache{rc, sw}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
rewindable: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,9 +351,10 @@ func newRecurrentEnv() *testEnv {
|
||||
nrc := &fakeRecurrentCache{tracker: tr}
|
||||
caches := []cache.Cache{rc, nrc}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
rewindable: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
|
||||
}
|
||||
|
||||
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||
// Partial match in A's edge triggers snapshotOffset.
|
||||
// For rewindable caches, switchToPath rewinds to the match point
|
||||
// so only the non-matching suffix needs evaluation. For non-rewindable
|
||||
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||
if resB.snapshotOffset != 5 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||
}
|
||||
// Cache was rewound to 0 (partial match truncates path to root),
|
||||
// so all tokens were re-evaluated.
|
||||
if len(resB.remaining) != 8 {
|
||||
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
|
||||
if env.rewindable {
|
||||
if resB.snapshotOffset != 0 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||
}
|
||||
if len(resB.remaining) != 3 {
|
||||
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||
}
|
||||
} else {
|
||||
if resB.snapshotOffset != 5 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||
}
|
||||
if len(resB.remaining) != 8 {
|
||||
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||
}
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||
|
||||
@@ -635,14 +663,24 @@ func TestExactMatchSeedBehavior(t *testing.T) {
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
|
||||
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||
// the 5-token edge, so path truncates to root and all tokens are
|
||||
// re-evaluated. snapshotOffset should be set at the holdback point.
|
||||
// the 5-token edge. For rewindable caches, switchToPath rewinds to
|
||||
// offset 4, so only the held-back token needs re-evaluation. For
|
||||
// non-rewindable caches, the rewind fails and freeAll fires.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||
if len(resB.remaining) != 5 {
|
||||
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 4 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||
if env.rewindable {
|
||||
if len(resB.remaining) != 1 {
|
||||
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 0 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||
}
|
||||
} else {
|
||||
if len(resB.remaining) != 5 {
|
||||
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 4 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||
}
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||
|
||||
|
||||
@@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
|
||||
@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned--
|
||||
if t.pinned < 0 {
|
||||
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -261,7 +264,7 @@ func LogArrays() {
|
||||
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
@@ -13,6 +13,10 @@ var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled bool
|
||||
|
||||
gatedDeltaCUDAKernelOnce sync.Once
|
||||
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||
gatedDeltaCUDADisabled bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
|
||||
}
|
||||
`
|
||||
|
||||
const gatedDeltaCUDAKernelSource = `
|
||||
auto tid_x = threadIdx.x;
|
||||
auto tid_y = threadIdx.y;
|
||||
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||
auto grid_z = blockIdx.z;
|
||||
|
||||
int T_val = static_cast<int>(*T);
|
||||
|
||||
auto n = grid_z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto dv_idx = grid_y;
|
||||
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = tid_x;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T_val * Hv;
|
||||
auto beta_ = beta + b_idx * T_val * Hv;
|
||||
|
||||
for (int t = 0; t < T_val; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||
}
|
||||
// Warp reduction (full warp, 32 threads in x)
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||
|
||||
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||
}
|
||||
// Warp reduction
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||
if (tid_x == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
return Concatenate(outs, 1), nextState
|
||||
}
|
||||
|
||||
func initGatedDeltaCUDAKernel() {
|
||||
var cudaAvail C.bool
|
||||
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return
|
||||
}
|
||||
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.int(0),
|
||||
)
|
||||
}
|
||||
|
||||
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
dtype := q.DType()
|
||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_CUDA_Y")
|
||||
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
// GatedDelta runs the recurrent update operation.
|
||||
//
|
||||
// It uses the fused Metal kernel when available and otherwise falls back to a
|
||||
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
|
||||
@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(
|
||||
bool strict,
|
||||
const char* bk /* may be null */) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
|
||||
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_and_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_block_masked_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_inner_)(
|
||||
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantized_matmul_)(
|
||||
mlx_array* res,
|
||||
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_atleast_1d);
|
||||
CHECK_LOAD(handle, mlx_atleast_2d);
|
||||
CHECK_LOAD(handle, mlx_atleast_3d);
|
||||
CHECK_LOAD(handle, mlx_bartlett);
|
||||
CHECK_LOAD(handle, mlx_bitwise_and);
|
||||
CHECK_LOAD(handle, mlx_bitwise_invert);
|
||||
CHECK_LOAD(handle, mlx_bitwise_or);
|
||||
CHECK_LOAD(handle, mlx_bitwise_xor);
|
||||
CHECK_LOAD(handle, mlx_blackman);
|
||||
CHECK_LOAD(handle, mlx_block_masked_mm);
|
||||
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
||||
CHECK_LOAD(handle, mlx_broadcast_to);
|
||||
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
CHECK_LOAD(handle, mlx_greater_equal);
|
||||
CHECK_LOAD(handle, mlx_hadamard_transform);
|
||||
CHECK_LOAD(handle, mlx_hamming);
|
||||
CHECK_LOAD(handle, mlx_hanning);
|
||||
CHECK_LOAD(handle, mlx_identity);
|
||||
CHECK_LOAD(handle, mlx_imag);
|
||||
CHECK_LOAD(handle, mlx_inner);
|
||||
|
||||
@@ -300,10 +300,12 @@
|
||||
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
||||
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
||||
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
||||
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
|
||||
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
||||
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
||||
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
||||
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
||||
#define mlx_blackman mlx_blackman_mlx_gen_orig_
|
||||
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
||||
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
||||
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
||||
@@ -356,6 +358,8 @@
|
||||
#define mlx_greater mlx_greater_mlx_gen_orig_
|
||||
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
||||
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
||||
#define mlx_hamming mlx_hamming_mlx_gen_orig_
|
||||
#define mlx_hanning mlx_hanning_mlx_gen_orig_
|
||||
#define mlx_identity mlx_identity_mlx_gen_orig_
|
||||
#define mlx_imag mlx_imag_mlx_gen_orig_
|
||||
#define mlx_inner mlx_inner_mlx_gen_orig_
|
||||
@@ -889,10 +893,12 @@
|
||||
#undef mlx_atleast_1d
|
||||
#undef mlx_atleast_2d
|
||||
#undef mlx_atleast_3d
|
||||
#undef mlx_bartlett
|
||||
#undef mlx_bitwise_and
|
||||
#undef mlx_bitwise_invert
|
||||
#undef mlx_bitwise_or
|
||||
#undef mlx_bitwise_xor
|
||||
#undef mlx_blackman
|
||||
#undef mlx_block_masked_mm
|
||||
#undef mlx_broadcast_arrays
|
||||
#undef mlx_broadcast_to
|
||||
@@ -945,6 +951,8 @@
|
||||
#undef mlx_greater
|
||||
#undef mlx_greater_equal
|
||||
#undef mlx_hadamard_transform
|
||||
#undef mlx_hamming
|
||||
#undef mlx_hanning
|
||||
#undef mlx_identity
|
||||
#undef mlx_imag
|
||||
#undef mlx_inner
|
||||
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
|
||||
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
||||
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
||||
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
||||
extern bool (*mlx_distributed_is_available_)(void);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
|
||||
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_)(
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
extern void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
|
||||
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_and_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_block_masked_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_inner_)(
|
||||
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_quantized_matmul_)(
|
||||
mlx_array* res,
|
||||
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
|
||||
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
||||
return mlx_distributed_group_split_(group, color, key);
|
||||
}
|
||||
static inline bool mlx_distributed_is_available(void) {
|
||||
return mlx_distributed_is_available_();
|
||||
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
|
||||
return mlx_distributed_is_available_(bk);
|
||||
}
|
||||
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
|
||||
return mlx_distributed_init_(strict);
|
||||
static inline mlx_distributed_group mlx_distributed_init(
|
||||
bool strict,
|
||||
const char* bk /* may be null */) {
|
||||
return mlx_distributed_init_(strict, bk);
|
||||
}
|
||||
static inline void mlx_set_error_handler(
|
||||
mlx_error_handler_func handler,
|
||||
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
|
||||
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||
return mlx_atleast_3d_(res, a, s);
|
||||
}
|
||||
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_bartlett_(res, M, s);
|
||||
}
|
||||
static inline int mlx_bitwise_and(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
|
||||
const mlx_stream s) {
|
||||
return mlx_bitwise_xor_(res, a, b, s);
|
||||
}
|
||||
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_blackman_(res, M, s);
|
||||
}
|
||||
static inline int mlx_block_masked_mm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s) {
|
||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||
}
|
||||
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||
return mlx_diag_(res, a, k, s);
|
||||
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
|
||||
const mlx_stream s) {
|
||||
return mlx_hadamard_transform_(res, a, scale, s);
|
||||
}
|
||||
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hamming_(res, M, s);
|
||||
}
|
||||
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hanning_(res, M, s);
|
||||
}
|
||||
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||
return mlx_identity_(res, n, dtype, s);
|
||||
}
|
||||
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s) {
|
||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
|
||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||
}
|
||||
static inline int mlx_quantize(
|
||||
mlx_vector_array* res,
|
||||
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s) {
|
||||
return mlx_quantize_(res, w, group_size, bits, mode, s);
|
||||
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
|
||||
}
|
||||
static inline int mlx_quantized_matmul(
|
||||
mlx_array* res,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Vendored MLX-C Headers
|
||||
|
||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||
The pinned version is in `MLX_VERSION` at the repo root.
|
||||
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||
|
||||
Headers are automatically refreshed when you run a CMake build:
|
||||
|
||||
|
||||
@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||
/**
|
||||
* Check if distributed is available.
|
||||
*/
|
||||
bool mlx_distributed_is_available(void);
|
||||
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||
|
||||
/**
|
||||
* Initialize distributed.
|
||||
*/
|
||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
||||
mlx_distributed_group mlx_distributed_init(
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
|
||||
/**@}*/
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ int mlx_astype(
|
||||
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_bitwise_and(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_block_masked_mm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -362,6 +364,7 @@ int mlx_dequantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s);
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s);
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_inner(
|
||||
@@ -790,6 +795,8 @@ int mlx_qqmm(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_quantize(
|
||||
mlx_vector_array* res,
|
||||
@@ -797,6 +804,7 @@ int mlx_quantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_quantized_matmul(
|
||||
mlx_array* res,
|
||||
|
||||
@@ -4,35 +4,91 @@ package mlx
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SafetensorsFile represents a loaded safetensors file.
|
||||
type SafetensorsFile struct {
|
||||
arrays C.mlx_map_string_to_array
|
||||
metadata C.mlx_map_string_to_string
|
||||
}
|
||||
|
||||
func loadSafetensorsStream() C.mlx_stream {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return C.mlx_default_cpu_stream_new()
|
||||
}
|
||||
return C.mlx_default_gpu_stream_new()
|
||||
}
|
||||
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
||||
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||
var arrays C.mlx_map_string_to_array
|
||||
var metadata C.mlx_map_string_to_string
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
stream := loadSafetensorsStream()
|
||||
defer C.mlx_stream_free(stream)
|
||||
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||
}
|
||||
|
||||
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a tensor by name.
|
||||
func (s *SafetensorsFile) Get(name string) *Array {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
||||
return nil
|
||||
}
|
||||
if value.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
return arr
|
||||
}
|
||||
|
||||
// GetMetadata retrieves a metadata value by key.
|
||||
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||
cKey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cKey))
|
||||
|
||||
var cValue *C.char
|
||||
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(cValue)
|
||||
}
|
||||
|
||||
// Free releases the loaded safetensors maps.
|
||||
func (s *SafetensorsFile) Free() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
C.mlx_map_string_to_array_free(s.arrays)
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu).
|
||||
// macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
|
||||
var stream C.mlx_stream
|
||||
if runtime.GOOS == "darwin" {
|
||||
stream = C.mlx_default_cpu_stream_new()
|
||||
} else {
|
||||
stream = C.mlx_default_gpu_stream_new()
|
||||
sf, err := LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer C.mlx_stream_free(stream)
|
||||
defer sf.Free()
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, stream)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
@@ -51,3 +107,43 @@ func Load(path string) iter.Seq2[string, *Array] {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
||||
}
|
||||
|
||||
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
||||
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
for name, arr := range arrays {
|
||||
if arr == nil {
|
||||
continue
|
||||
}
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
cMetadata := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||
|
||||
for key, value := range metadata {
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||
C.free(unsafe.Pointer(cKey))
|
||||
C.free(unsafe.Pointer(cValue))
|
||||
}
|
||||
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,8 +7,44 @@ package mlx
|
||||
// #cgo LDFLAGS: -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
// #include <string.h>
|
||||
//
|
||||
// static char _mlx_last_error_msg[1024] = {0};
|
||||
// static int _mlx_last_error_flag = 0;
|
||||
//
|
||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||
// (void)data;
|
||||
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||
// _mlx_last_error_flag = 1;
|
||||
// }
|
||||
//
|
||||
// static void mlx_install_capture_handler(void) {
|
||||
// if (mlx_set_error_handler_) {
|
||||
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// static void mlx_clear_last_error(void) {
|
||||
// _mlx_last_error_flag = 0;
|
||||
// _mlx_last_error_msg[0] = '\0';
|
||||
// }
|
||||
//
|
||||
// static int mlx_had_last_error(void) {
|
||||
// return _mlx_last_error_flag;
|
||||
// }
|
||||
//
|
||||
// static const char* mlx_get_last_error(void) {
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||
// }
|
||||
import "C"
|
||||
|
||||
func init() {
|
||||
// Replace the default exit(-1) error handler with one that captures
|
||||
// the error message so we can surface it in Go.
|
||||
C.mlx_install_capture_handler()
|
||||
}
|
||||
|
||||
// Version returns the MLX core library version string.
|
||||
func Version() string {
|
||||
str := C.mlx_string_new()
|
||||
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
|
||||
}
|
||||
}
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
var rc C.int
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
rc = C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
rc = C.mlx_eval(vector)
|
||||
}
|
||||
if rc != 0 {
|
||||
msg := "mlx eval failed"
|
||||
if C.mlx_had_last_error() != 0 {
|
||||
msg = C.GoString(C.mlx_get_last_error())
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
res := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(res)
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
w0 := New("QUANTIZE_W")
|
||||
@@ -32,6 +33,18 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
return w0, w1, nil
|
||||
}
|
||||
|
||||
func FromFP8(x *Array, dtype DType) *Array {
|
||||
out := New("FROM_FP8")
|
||||
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFP8(x *Array) *Array {
|
||||
out := New("TO_FP8")
|
||||
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
@@ -45,7 +58,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE")
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -135,6 +149,40 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Pad(a *Array, paddings []int32) *Array {
|
||||
numAxes := len(paddings) / 2
|
||||
axes := make([]C.int, numAxes)
|
||||
lowPad := make([]C.int, numAxes)
|
||||
highPad := make([]C.int, numAxes)
|
||||
for i := range numAxes {
|
||||
axes[i] = C.int(i)
|
||||
lowPad[i] = C.int(paddings[i*2])
|
||||
highPad[i] = C.int(paddings[i*2+1])
|
||||
}
|
||||
|
||||
padValue := C.mlx_array_new_float(C.float(0))
|
||||
defer C.mlx_array_free(padValue)
|
||||
|
||||
cMode := C.CString("constant")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(axes),
|
||||
C.size_t(len(axes)),
|
||||
unsafe.SliceData(lowPad),
|
||||
C.size_t(len(lowPad)),
|
||||
unsafe.SliceData(highPad),
|
||||
C.size_t(len(highPad)),
|
||||
padValue,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
|
||||
@@ -11,8 +11,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "MXFP4":
|
||||
return 32, 4, "mxfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
return 64, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8":
|
||||
|
||||
@@ -144,3 +144,44 @@ func TestLayerNormDefaultEps(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
weightVals := make([]float32, 3*32)
|
||||
for i := range weightVals {
|
||||
weightVals[i] = float32((i%11)-5) / 7
|
||||
}
|
||||
inputVals := make([]float32, 2*32)
|
||||
for i := range inputVals {
|
||||
inputVals[i] = float32((i%7)-3) / 5
|
||||
}
|
||||
|
||||
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
|
||||
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
|
||||
mlx.Eval(weight, input)
|
||||
|
||||
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
|
||||
if ql.QBiases != nil {
|
||||
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
|
||||
}
|
||||
|
||||
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
||||
mlx.Eval(dequantizedWeight)
|
||||
|
||||
qOut := ql.Forward(input)
|
||||
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
|
||||
mlx.Eval(qOut, dOut)
|
||||
|
||||
got := qOut.Floats()
|
||||
want := dOut.Floats()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("output length = %d, want %d", len(got), len(want))
|
||||
}
|
||||
|
||||
for i := range got {
|
||||
if !approxEqual(got[i], want[i], 1e-3) {
|
||||
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -420,7 +420,16 @@ func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, strin
|
||||
}
|
||||
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
switch mode {
|
||||
case "affine":
|
||||
return bits == 4 || bits == 8
|
||||
case "mxfp8":
|
||||
return bits == 8
|
||||
case "nvfp4", "mxfp4":
|
||||
return bits == 4
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {
|
||||
|
||||
@@ -83,6 +83,28 @@ func TestLayerSelectionHelpers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsGatherQMM(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode string
|
||||
bits int
|
||||
want bool
|
||||
}{
|
||||
{mode: "affine", bits: 4, want: true},
|
||||
{mode: "affine", bits: 8, want: true},
|
||||
{mode: "mxfp8", bits: 8, want: true},
|
||||
{mode: "nvfp4", bits: 4, want: true},
|
||||
{mode: "mxfp4", bits: 4, want: true},
|
||||
{mode: "mxfp8", bits: 4, want: false},
|
||||
{mode: "affine", bits: 3, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
|
||||
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user