mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 18:55:53 +02:00
Compare commits
6 Commits
pdevine/ml
...
v0.18.3-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de5cb7311f | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 |
@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
|||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
COPY MLX_VERSION MLX_C_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
v0.30.6
|
|
||||||
1
MLX_C_VERSION
Normal file
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||||
@@ -1 +1 @@
|
|||||||
v0.5.0
|
38ad257088fb2193ad47e527cf6534a689f30943
|
||||||
|
|||||||
@@ -96,6 +96,18 @@ The `/loop` command runs a prompt or slash command on a recurring schedule insid
|
|||||||
/loop 1h Remind me to review the deploy status
|
/loop 1h Remind me to review the deploy status
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Telegram
|
||||||
|
|
||||||
|
Chat with Claude Code from Telegram by connecting a bot to your session. Install the [Telegram plugin](https://github.com/anthropics/claude-plugins-official), create a bot via [@BotFather](https://t.me/BotFather), then launch with the channel flag:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
|
||||||
|
```
|
||||||
|
|
||||||
|
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
|
||||||
|
|
||||||
|
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
|
||||||
|
|
||||||
## Manual setup
|
## Manual setup
|
||||||
|
|
||||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
|||||||
type CreateOptions struct {
|
type CreateOptions struct {
|
||||||
ModelName string
|
ModelName string
|
||||||
ModelDir string
|
ModelDir string
|
||||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
|
|||||||
if !QuantizeSupported() {
|
if !QuantizeSupported() {
|
||||||
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
||||||
}
|
}
|
||||||
blobData, err := quantizePackedGroup(tensors)
|
blobData, err := quantizePackedGroup(groupName, tensors)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,29 +7,27 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/create"
|
"github.com/ollama/ollama/x/create"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// quantizeParams maps quantization type names to MLX quantize parameters.
|
|
||||||
var quantizeParams = map[string]struct {
|
|
||||||
groupSize int
|
|
||||||
bits int
|
|
||||||
mode string
|
|
||||||
}{
|
|
||||||
"int4": {64, 4, "affine"},
|
|
||||||
"nvfp4": {16, 4, "nvfp4"},
|
|
||||||
"int8": {64, 8, "affine"},
|
|
||||||
"mxfp8": {32, 8, "mxfp8"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
||||||
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
||||||
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
||||||
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
||||||
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
||||||
|
if quantize != "" {
|
||||||
|
if gs, _, _ := model.QuantizationParams(quantize); gs == 0 {
|
||||||
|
return "", nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tmpDir := ensureTempDir()
|
tmpDir := ensureTempDir()
|
||||||
|
|
||||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
||||||
@@ -50,11 +48,16 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find the tensor key (may differ from name for single-tensor blobs)
|
// Find the tensor key (may differ from name for single-tensor blobs)
|
||||||
inputKey, err := findSafetensorsKey(tmpPath)
|
header, err := readSafetensorsHeader(tmpPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.Free()
|
st.Free()
|
||||||
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
inputKey, err := safetensorsKey(name, header)
|
||||||
|
if err != nil {
|
||||||
|
st.Free()
|
||||||
|
return tmpPath, nil, nil, fmt.Errorf("failed to resolve tensor key for %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
arr := st.Get(inputKey)
|
arr := st.Get(inputKey)
|
||||||
if arr == nil {
|
if arr == nil {
|
||||||
@@ -62,34 +65,46 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
|||||||
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decode FP8 source encoding before checking quantize, so that callers
|
||||||
|
// requesting decode-only (quantize="") receive usable float data.
|
||||||
|
if info, ok := header[inputKey]; ok && info.Dtype == "F8_E4M3" {
|
||||||
|
scaleKey := inputKey + ".scale_inv"
|
||||||
|
scaleInv := st.Get(scaleKey)
|
||||||
|
if scaleInv == nil {
|
||||||
|
st.Free()
|
||||||
|
return tmpPath, nil, nil, fmt.Errorf("missing companion tensor %q for fp8 source tensor %q", scaleKey, inputKey)
|
||||||
|
}
|
||||||
|
arr, err = decodeSourceFP8Tensor(arr, scaleInv)
|
||||||
|
if err != nil {
|
||||||
|
st.Free()
|
||||||
|
return tmpPath, nil, nil, fmt.Errorf("failed to decode fp8 tensor %s: %w", inputKey, err)
|
||||||
|
}
|
||||||
|
mlx.Eval(arr)
|
||||||
|
}
|
||||||
|
|
||||||
if quantize == "" {
|
if quantize == "" {
|
||||||
arr = mlx.Contiguous(arr)
|
arr = mlx.Contiguous(arr, false)
|
||||||
arrays[name] = arr
|
arrays[name] = arr
|
||||||
return tmpPath, []*mlx.Array{arr}, st, nil
|
return tmpPath, []*mlx.Array{arr}, st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if arr.DType() != mlx.DTypeBFloat16 && arr.DType() != mlx.DTypeFloat32 && arr.DType() != mlx.DTypeFloat16 {
|
||||||
// Convert to float type if needed (quantize expects float)
|
// Convert to float type if needed (quantize expects float)
|
||||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
arr = arr.AsType(mlx.DTypeBFloat16)
|
||||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
|
||||||
mlx.Eval(arr)
|
mlx.Eval(arr)
|
||||||
}
|
}
|
||||||
|
|
||||||
params, ok := quantizeParams[quantize]
|
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||||
if !ok {
|
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||||
st.Free()
|
|
||||||
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
|
||||||
}
|
|
||||||
|
|
||||||
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
|
qweight = mlx.Contiguous(qweight, false)
|
||||||
|
scales = mlx.Contiguous(scales, false)
|
||||||
qweight = mlx.Contiguous(qweight)
|
|
||||||
scales = mlx.Contiguous(scales)
|
|
||||||
arrays[name] = qweight
|
arrays[name] = qweight
|
||||||
arrays[name+".scale"] = scales
|
arrays[name+".scale"] = scales
|
||||||
toEval = append(toEval, qweight, scales)
|
toEval = append(toEval, qweight, scales)
|
||||||
|
|
||||||
if qbiases != nil {
|
if qbiases != nil {
|
||||||
qbiases = mlx.Contiguous(qbiases)
|
qbiases = mlx.Contiguous(qbiases, false)
|
||||||
arrays[name+".bias"] = qbiases
|
arrays[name+".bias"] = qbiases
|
||||||
toEval = append(toEval, qbiases)
|
toEval = append(toEval, qbiases)
|
||||||
}
|
}
|
||||||
@@ -101,27 +116,45 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
|||||||
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
||||||
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
||||||
// The blob includes __metadata__ with quant_type and group_size.
|
// The blob includes __metadata__ with quant_type and group_size.
|
||||||
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
|
// Supported quantization types: "int4", "nvfp4", "mxfp4", "int8", "mxfp8".
|
||||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||||
arrays := make(map[string]*mlx.Array)
|
arrays := make(map[string]*mlx.Array)
|
||||||
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
||||||
if tmpPath != "" {
|
if tmpPath != "" {
|
||||||
defer os.Remove(tmpPath)
|
defer os.Remove(tmpPath)
|
||||||
}
|
}
|
||||||
if st != nil {
|
|
||||||
defer st.Free()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
finalArrays := make([]*mlx.Array, 0, len(arrays))
|
||||||
|
for _, arr := range arrays {
|
||||||
|
if arr != nil {
|
||||||
|
finalArrays = append(finalArrays, arr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mlx.Pin(finalArrays...)
|
||||||
|
defer func() {
|
||||||
|
if st != nil {
|
||||||
|
st.Free()
|
||||||
|
}
|
||||||
|
mlx.Unpin(finalArrays...)
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
mlx.Eval(toEval...)
|
mlx.Eval(toEval...)
|
||||||
|
mlx.Sweep()
|
||||||
|
// Free early to release mmap; defer guard handles error paths
|
||||||
|
if st != nil {
|
||||||
|
st.Free()
|
||||||
|
st = nil
|
||||||
|
}
|
||||||
|
|
||||||
// Build metadata for single-tensor blobs
|
// Build metadata for single-tensor blobs
|
||||||
params := quantizeParams[quantize]
|
groupSize, _, _ := model.QuantizationParams(quantize)
|
||||||
metadata := map[string]string{
|
metadata := map[string]string{
|
||||||
"quant_type": quantize,
|
"quant_type": quantize,
|
||||||
"group_size": strconv.Itoa(params.groupSize),
|
"group_size": strconv.Itoa(groupSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := ensureTempDir()
|
tmpDir := ensureTempDir()
|
||||||
@@ -135,48 +168,81 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
|||||||
|
|
||||||
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
||||||
// combined safetensors blob. Used for packing expert groups.
|
// combined safetensors blob. Used for packing expert groups.
|
||||||
|
// When the inputs are per-expert 2D tensors (e.g., experts.0.gate_proj.weight),
|
||||||
|
// they are stacked into 3D switch_mlp tensors before quantization.
|
||||||
// Each tensor may have a different quantization type (mixed-precision).
|
// Each tensor may have a different quantization type (mixed-precision).
|
||||||
// Returns the blob bytes. No __metadata__ is added because different tensors
|
// Returns the blob bytes.
|
||||||
// may use different quantization types.
|
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||||
|
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||||
|
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||||
|
}
|
||||||
|
|
||||||
allArrays := make(map[string]*mlx.Array)
|
allArrays := make(map[string]*mlx.Array)
|
||||||
var allToEval []*mlx.Array
|
var pinned []*mlx.Array
|
||||||
var tmpPaths []string
|
|
||||||
var handles []*mlx.SafetensorsFile
|
var metadata map[string]string
|
||||||
|
uniformQuantize := ""
|
||||||
|
hasQuantized := false
|
||||||
|
mixedQuantize := false
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input.Quantize == "" {
|
||||||
|
if hasQuantized {
|
||||||
|
mixedQuantize = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !hasQuantized {
|
||||||
|
hasQuantized = true
|
||||||
|
uniformQuantize = input.Quantize
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if input.Quantize != uniformQuantize {
|
||||||
|
mixedQuantize = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasQuantized && !mixedQuantize {
|
||||||
|
if groupSize, _, _ := model.QuantizationParams(uniformQuantize); groupSize > 0 {
|
||||||
|
metadata = map[string]string{
|
||||||
|
"quant_type": uniformQuantize,
|
||||||
|
"group_size": strconv.Itoa(groupSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, input := range inputs {
|
for _, input := range inputs {
|
||||||
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
||||||
if tmpPath != "" {
|
|
||||||
tmpPaths = append(tmpPaths, tmpPath)
|
|
||||||
}
|
|
||||||
if st != nil {
|
|
||||||
handles = append(handles, st)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Cleanup on error
|
mlx.Unpin(pinned...)
|
||||||
for _, h := range handles {
|
mlx.Sweep()
|
||||||
h.Free()
|
|
||||||
}
|
|
||||||
for _, p := range tmpPaths {
|
|
||||||
os.Remove(p)
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allToEval = append(allToEval, toEval...)
|
|
||||||
|
mlx.Eval(toEval...)
|
||||||
|
|
||||||
|
finalArrays := arraysForPackedInput(allArrays, input)
|
||||||
|
mlx.Pin(finalArrays...)
|
||||||
|
pinned = append(pinned, finalArrays...)
|
||||||
|
|
||||||
|
if st != nil {
|
||||||
|
st.Free()
|
||||||
}
|
}
|
||||||
|
if tmpPath != "" {
|
||||||
mlx.Eval(allToEval...)
|
os.Remove(tmpPath)
|
||||||
|
|
||||||
// Free native handles after eval
|
|
||||||
for _, h := range handles {
|
|
||||||
h.Free()
|
|
||||||
}
|
}
|
||||||
|
mlx.Sweep()
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
mlx.Unpin(pinned...)
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
// Save combined blob (no global metadata for mixed-precision packed blobs)
|
// Save combined blob. Add global metadata only when every packed tensor uses
|
||||||
|
// the same quantization mode and group size.
|
||||||
tmpDir := ensureTempDir()
|
tmpDir := ensureTempDir()
|
||||||
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
||||||
defer os.Remove(outPath)
|
defer os.Remove(outPath)
|
||||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
|
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||||
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,17 +251,193 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range tmpPaths {
|
return blobData, nil
|
||||||
os.Remove(p)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func arraysForPackedInput(allArrays map[string]*mlx.Array, input create.PackedTensorInput) []*mlx.Array {
|
||||||
|
keys := []string{input.Name}
|
||||||
|
if input.Quantize != "" {
|
||||||
|
keys = append(keys, input.Name+".scale", input.Name+".bias")
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*mlx.Array, 0, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if arr := allArrays[key]; arr != nil {
|
||||||
|
out = append(out, arr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// perExpertSuffix matches ".{index}.{proj_and_suffix}" after the group prefix.
|
||||||
|
var perExpertSuffix = regexp.MustCompile(`^\.(\d+)\.(.+)$`)
|
||||||
|
|
||||||
|
type expertTensorInfo struct {
|
||||||
|
index int
|
||||||
|
proj string // e.g., "gate_proj.weight"
|
||||||
|
input create.PackedTensorInput
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||||
|
// and returns the uniform quantization type shared by all inputs.
|
||||||
|
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||||
|
// or if the inputs have mixed quantization types.
|
||||||
|
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||||
|
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||||
|
if !strings.HasSuffix(groupName, ".experts") {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
quantize := inputs[0].Quantize
|
||||||
|
groups := make(map[string][]expertTensorInfo)
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input.Quantize != quantize {
|
||||||
|
return nil, "" // mixed quantization types
|
||||||
|
}
|
||||||
|
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||||
|
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||||
|
if m == nil {
|
||||||
|
return nil, "" // not a per-expert pattern
|
||||||
|
}
|
||||||
|
index, err := strconv.Atoi(m[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||||
|
index: index,
|
||||||
|
proj: m[2],
|
||||||
|
input: input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
return groups, quantize
|
||||||
|
}
|
||||||
|
|
||||||
|
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||||
|
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||||
|
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||||
|
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||||
|
|
||||||
|
allArrays := make(map[string]*mlx.Array)
|
||||||
|
var pinned []*mlx.Array
|
||||||
|
|
||||||
|
var metadata map[string]string
|
||||||
|
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||||
|
metadata = map[string]string{
|
||||||
|
"quant_type": quantize,
|
||||||
|
"group_size": strconv.Itoa(groupSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort projection names for deterministic output
|
||||||
|
projNames := make([]string, 0, len(projGroups))
|
||||||
|
for proj := range projGroups {
|
||||||
|
projNames = append(projNames, proj)
|
||||||
|
}
|
||||||
|
sort.Strings(projNames)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
mlx.Unpin(pinned...)
|
||||||
|
mlx.Sweep()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, proj := range projNames {
|
||||||
|
experts := projGroups[proj]
|
||||||
|
|
||||||
|
// Sort by expert index
|
||||||
|
sort.Slice(experts, func(i, j int) bool {
|
||||||
|
return experts[i].index < experts[j].index
|
||||||
|
})
|
||||||
|
|
||||||
|
// Load and decode each expert tensor
|
||||||
|
var decoded []*mlx.Array
|
||||||
|
for _, expert := range experts {
|
||||||
|
dummyArrays := make(map[string]*mlx.Array)
|
||||||
|
tmpPath, toEval, st, err := loadAndQuantizeArray(expert.input.Reader, expert.input.Name, "", dummyArrays)
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
return nil, fmt.Errorf("failed to decode expert tensor %s: %w", expert.input.Name, err)
|
||||||
|
}
|
||||||
|
mlx.Eval(toEval...)
|
||||||
|
|
||||||
|
arr := dummyArrays[expert.input.Name]
|
||||||
|
mlx.Pin(arr)
|
||||||
|
pinned = append(pinned, arr)
|
||||||
|
decoded = append(decoded, arr)
|
||||||
|
|
||||||
|
if st != nil {
|
||||||
|
st.Free()
|
||||||
|
}
|
||||||
|
if tmpPath != "" {
|
||||||
|
os.Remove(tmpPath)
|
||||||
|
}
|
||||||
|
mlx.Sweep()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stack into 3D along axis 0: [numExperts, rows, cols]
|
||||||
|
stacked := mlx.Stack(decoded, 0)
|
||||||
|
mlx.Eval(stacked)
|
||||||
|
mlx.Pin(stacked)
|
||||||
|
pinned = append(pinned, stacked)
|
||||||
|
|
||||||
|
// Free individual decoded arrays
|
||||||
|
mlx.Unpin(decoded...)
|
||||||
|
mlx.Sweep()
|
||||||
|
|
||||||
|
stackedName := groupBase + ".switch_mlp." + proj
|
||||||
|
|
||||||
|
// Quantize the stacked tensor
|
||||||
|
if quantize != "" {
|
||||||
|
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||||
|
|
||||||
|
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||||
|
|
||||||
|
qweight = mlx.Contiguous(qweight, false)
|
||||||
|
scales = mlx.Contiguous(scales, false)
|
||||||
|
allArrays[stackedName] = qweight
|
||||||
|
allArrays[stackedName+".scale"] = scales
|
||||||
|
|
||||||
|
toEval := []*mlx.Array{qweight, scales}
|
||||||
|
if qbiases != nil {
|
||||||
|
qbiases = mlx.Contiguous(qbiases, false)
|
||||||
|
allArrays[stackedName+".bias"] = qbiases
|
||||||
|
toEval = append(toEval, qbiases)
|
||||||
|
}
|
||||||
|
mlx.Eval(toEval...)
|
||||||
|
mlx.Pin(toEval...)
|
||||||
|
pinned = append(pinned, toEval...)
|
||||||
|
|
||||||
|
// Free stacked source array
|
||||||
|
mlx.Unpin(stacked)
|
||||||
|
mlx.Sweep()
|
||||||
|
} else {
|
||||||
|
stacked = mlx.Contiguous(stacked, false)
|
||||||
|
mlx.Eval(stacked)
|
||||||
|
allArrays[stackedName] = stacked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tmpDir := ensureTempDir()
|
||||||
|
outPath := filepath.Join(tmpDir, "stacked-combined.safetensors")
|
||||||
|
defer os.Remove(outPath)
|
||||||
|
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, metadata); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save stacked blob: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blobData, err := os.ReadFile(outPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read stacked blob: %w", err)
|
||||||
|
}
|
||||||
return blobData, nil
|
return blobData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||||
func QuantizeSupported() bool {
|
func QuantizeSupported() bool {
|
||||||
mlx.InitMLX()
|
return mlx.CheckInit() == nil
|
||||||
return mlx.IsMLXAvailable()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||||
@@ -205,32 +447,97 @@ func ensureTempDir() string {
|
|||||||
return tmpDir
|
return tmpDir
|
||||||
}
|
}
|
||||||
|
|
||||||
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
|
type safetensorsHeaderEntry struct {
|
||||||
func findSafetensorsKey(path string) (string, error) {
|
Dtype string `json:"dtype"`
|
||||||
|
Shape []int32 `json:"shape"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSafetensorsHeader(path string) (map[string]safetensorsHeaderEntry, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var headerSize uint64
|
var headerSize uint64
|
||||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
headerBytes := make([]byte, headerSize)
|
headerBytes := make([]byte, headerSize)
|
||||||
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var header map[string]json.RawMessage
|
var header map[string]safetensorsHeaderEntry
|
||||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
|
}
|
||||||
|
return header, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// safetensorsKey resolves the primary tensor key from a header.
|
||||||
|
func safetensorsKey(preferred string, header map[string]safetensorsHeaderEntry) (string, error) {
|
||||||
|
if preferred != "" {
|
||||||
|
if _, ok := header[preferred]; ok {
|
||||||
|
return preferred, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(header))
|
||||||
for k := range header {
|
for k := range header {
|
||||||
if k != "__metadata__" {
|
if k == "__metadata__" || strings.HasSuffix(k, ".scale_inv") {
|
||||||
return k, nil
|
continue
|
||||||
}
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
if len(keys) == 0 {
|
||||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||||
}
|
}
|
||||||
|
return keys[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||||
|
if weight == nil || scaleInv == nil {
|
||||||
|
return nil, fmt.Errorf("fp8 weight and scale tensors are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
weightShape := weight.Dims()
|
||||||
|
scaleShape := scaleInv.Dims()
|
||||||
|
if len(weightShape) != 2 || len(scaleShape) != 2 {
|
||||||
|
return nil, fmt.Errorf("expected 2D fp8 weight and scale tensors, got %v and %v", weightShape, scaleShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// These must match the block size validated by resolveEffectiveQuantization
|
||||||
|
// in create.go, which rejects any source model with a different block size.
|
||||||
|
const blockRows = 128
|
||||||
|
const blockCols = 128
|
||||||
|
rows, cols := weightShape[0], weightShape[1]
|
||||||
|
expectedScaleRows := (rows + blockRows - 1) / blockRows
|
||||||
|
expectedScaleCols := (cols + blockCols - 1) / blockCols
|
||||||
|
if scaleShape[0] != expectedScaleRows || scaleShape[1] != expectedScaleCols {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"unexpected fp8 scale shape %v for weight shape %v; want [%d %d]",
|
||||||
|
scaleShape,
|
||||||
|
weightShape,
|
||||||
|
expectedScaleRows,
|
||||||
|
expectedScaleCols,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded := mlx.FromFP8(weight, mlx.DTypeBFloat16)
|
||||||
|
padBottom := blockRows*scaleShape[0] - rows
|
||||||
|
padSide := blockCols*scaleShape[1] - cols
|
||||||
|
if padBottom > 0 || padSide > 0 {
|
||||||
|
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||||
|
decoded = mlx.Mul(decoded, mlx.ExpandDims(mlx.ExpandDims(scaleInv, 1), 3))
|
||||||
|
decoded = mlx.Reshape(decoded, int32(rows+padBottom), int32(cols+padSide))
|
||||||
|
if padBottom > 0 || padSide > 0 {
|
||||||
|
decoded = mlx.SliceStartStop(decoded, []int32{0, 0}, []int32{int32(rows), int32(cols)})
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -267,13 +267,13 @@ func ShouldQuantize(name, component string) bool {
|
|||||||
|
|
||||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||||
// This is a more detailed check that also considers tensor dimensions.
|
// This is a more detailed check that also considers tensor dimensions.
|
||||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
|
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "mxfp4", "int8", "mxfp8").
|
||||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||||
return GetTensorQuantization(name, shape, quantize) != ""
|
return GetTensorQuantization(name, shape, quantize) != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
|
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp4/MXFP4, mxfp8/MXFP8
|
||||||
func normalizeQuantType(quantize string) string {
|
func normalizeQuantType(quantize string) string {
|
||||||
switch strings.ToUpper(quantize) {
|
switch strings.ToUpper(quantize) {
|
||||||
case "Q4", "INT4", "FP4":
|
case "Q4", "INT4", "FP4":
|
||||||
@@ -282,6 +282,8 @@ func normalizeQuantType(quantize string) string {
|
|||||||
return "int8"
|
return "int8"
|
||||||
case "NVFP4":
|
case "NVFP4":
|
||||||
return "nvfp4"
|
return "nvfp4"
|
||||||
|
case "MXFP4":
|
||||||
|
return "mxfp4"
|
||||||
case "MXFP8":
|
case "MXFP8":
|
||||||
return "mxfp8"
|
return "mxfp8"
|
||||||
default:
|
default:
|
||||||
@@ -335,7 +337,7 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
|||||||
quantNorm := normalizeQuantType(quantize)
|
quantNorm := normalizeQuantType(quantize)
|
||||||
|
|
||||||
// MLX quantization requires last dimension to be divisible by group size
|
// MLX quantization requires last dimension to be divisible by group size
|
||||||
// nvfp4: 16, mxfp8: 32, int4/int8: 64
|
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||||
groupSize := int32(32)
|
groupSize := int32(32)
|
||||||
switch quantNorm {
|
switch quantNorm {
|
||||||
case "nvfp4":
|
case "nvfp4":
|
||||||
@@ -353,8 +355,8 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
|
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||||
if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
|
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||||
return quantNorm
|
return quantNorm
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,23 +393,39 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
|||||||
return quantNorm
|
return quantNorm
|
||||||
}
|
}
|
||||||
|
|
||||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
var expertLayerPrefixRegexp = regexp.MustCompile(`^(?:model\.language_model\.|language_model(?:\.model)?\.|model\.)?layers\.\d+$`)
|
||||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
|
||||||
// Captures: model.layers.{L}.mlp.experts
|
|
||||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
|
||||||
|
|
||||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||||
// For example:
|
// For example:
|
||||||
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
|
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
|
||||||
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
|
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
|
||||||
|
// - "language_model.model.layers.1.mlp.switch_mlp.down_proj.weight" -> "language_model.model.layers.1.mlp.switch_mlp"
|
||||||
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
|
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
|
||||||
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
|
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
|
||||||
func ExpertGroupPrefix(tensorName string) string {
|
func ExpertGroupPrefix(tensorName string) string {
|
||||||
m := expertGroupRegexp.FindStringSubmatch(tensorName)
|
if !strings.HasSuffix(tensorName, ".weight") {
|
||||||
if m == nil {
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return m[1]
|
|
||||||
|
for _, marker := range []string{
|
||||||
|
".mlp.experts.",
|
||||||
|
".mlp.shared_experts.",
|
||||||
|
".mlp.switch_mlp.",
|
||||||
|
} {
|
||||||
|
idx := strings.Index(tensorName, marker)
|
||||||
|
if idx == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
layerPrefix := tensorName[:idx]
|
||||||
|
if !expertLayerPrefixRegexp.MatchString(layerPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return layerPrefix + strings.TrimSuffix(marker, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
|
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
|
||||||
@@ -427,6 +445,8 @@ type sourceQuantization struct {
|
|||||||
Bits int `json:"bits"`
|
Bits int `json:"bits"`
|
||||||
GroupSize int `json:"group_size"`
|
GroupSize int `json:"group_size"`
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
|
QuantMethod string `json:"quant_method"`
|
||||||
|
WeightBlockSize []int32 `json:"weight_block_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type sourceModelConfig struct {
|
type sourceModelConfig struct {
|
||||||
@@ -493,6 +513,98 @@ func (cfg sourceModelConfig) QuantMetadata() map[string]string {
|
|||||||
return metadata
|
return metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sourceQuantizedKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
sourceQuantizedKindNone sourceQuantizedKind = ""
|
||||||
|
sourceQuantizedKindPrequantized sourceQuantizedKind = "prequantized"
|
||||||
|
sourceQuantizedKindHFFP8 sourceQuantizedKind = "hf_fp8"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cfg sourceModelConfig) quantizationConfigs() []sourceQuantization {
|
||||||
|
return []sourceQuantization{
|
||||||
|
cfg.Quantization,
|
||||||
|
cfg.QuantizationConfig,
|
||||||
|
cfg.TextConfig.Quantization,
|
||||||
|
cfg.TextConfig.QuantizationConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg sourceModelConfig) HFFP8WeightBlockSize() (rows, cols int32, ok bool) {
|
||||||
|
for _, q := range cfg.quantizationConfigs() {
|
||||||
|
if !strings.EqualFold(q.QuantMethod, "fp8") || len(q.WeightBlockSize) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return q.WeightBlockSize[0], q.WeightBlockSize[1], true
|
||||||
|
}
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func inspectSourceQuantization(modelDir string, cfg sourceModelConfig) (sourceQuantizedKind, error) {
|
||||||
|
entries, err := os.ReadDir(modelDir)
|
||||||
|
if err != nil {
|
||||||
|
return sourceQuantizedKindNone, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hasScaleInv := false
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
|
||||||
|
if err != nil {
|
||||||
|
return sourceQuantizedKindNone, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range extractor.ListTensors() {
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(name, ".scales"):
|
||||||
|
extractor.Close()
|
||||||
|
return sourceQuantizedKindPrequantized, nil
|
||||||
|
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||||
|
hasScaleInv = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasScaleInv {
|
||||||
|
if _, _, ok := cfg.HFFP8WeightBlockSize(); ok {
|
||||||
|
return sourceQuantizedKindHFFP8, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceQuantizedKindNone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
|
||||||
|
switch sourceKind {
|
||||||
|
case sourceQuantizedKindNone:
|
||||||
|
return requested, nil
|
||||||
|
case sourceQuantizedKindPrequantized:
|
||||||
|
if requested != "" {
|
||||||
|
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
case sourceQuantizedKindHFFP8:
|
||||||
|
if requested != "" {
|
||||||
|
return "", fmt.Errorf("cannot requantize already-quantized fp8 source model with --quantize %q", requested)
|
||||||
|
}
|
||||||
|
rows, cols, ok := cfg.HFFP8WeightBlockSize()
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("fp8 source model missing weight_block_size metadata")
|
||||||
|
}
|
||||||
|
if rows != 128 || cols != 128 {
|
||||||
|
return "", fmt.Errorf("unsupported fp8 source block size %dx%d", rows, cols)
|
||||||
|
}
|
||||||
|
return "mxfp8", nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported source quantization kind %q", sourceKind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type tensorImportTransform interface {
|
type tensorImportTransform interface {
|
||||||
skipTensor(name string) bool
|
skipTensor(name string) bool
|
||||||
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
|
transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error)
|
||||||
@@ -546,6 +658,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read source config.json: %w", err)
|
return fmt.Errorf("failed to read source config.json: %w", err)
|
||||||
}
|
}
|
||||||
|
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to inspect source quantization: %w", err)
|
||||||
|
}
|
||||||
|
effectiveQuantize, err := resolveEffectiveQuantization(sourceConfig, sourceQuantKind, quantize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
sourceQuantMetadata := sourceConfig.QuantMetadata()
|
sourceQuantMetadata := sourceConfig.QuantMetadata()
|
||||||
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
|
importTransform, err := newTensorImportTransform(modelDir, sourceConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -557,7 +677,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
if len(createPackedLayer) > 0 {
|
if len(createPackedLayer) > 0 {
|
||||||
packedCreator = createPackedLayer[0]
|
packedCreator = createPackedLayer[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate expert tensors by group prefix for packing.
|
// Accumulate expert tensors by group prefix for packing.
|
||||||
// Readers reference file-backed SectionReaders, so we keep extractors
|
// Readers reference file-backed SectionReaders, so we keep extractors
|
||||||
// open until each group is flushed to avoid buffering tensor data in memory.
|
// open until each group is flushed to avoid buffering tensor data in memory.
|
||||||
@@ -600,8 +719,8 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
tensorSet[name] = struct{}{}
|
tensorSet[name] = struct{}{}
|
||||||
}
|
}
|
||||||
quantizeMsg := ""
|
quantizeMsg := ""
|
||||||
if quantize != "" {
|
if effectiveQuantize != "" {
|
||||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
quantizeMsg = fmt.Sprintf(", quantizing to %s", effectiveQuantize)
|
||||||
}
|
}
|
||||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||||
|
|
||||||
@@ -612,9 +731,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
if importTransform.skipTensor(tensorName) {
|
if importTransform.skipTensor(tensorName) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if shouldSkipPrequantizedCompanion(tensorName, tensorSet) {
|
if shouldSkipSourceCompanion(tensorName, tensorSet) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
sourceFP8ScaleName, hasSourceFP8Scale := sourceFP8Companion(tensorName, tensorSet)
|
||||||
|
|
||||||
td, err := extractor.GetTensor(tensorName)
|
td, err := extractor.GetTensor(tensorName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -623,7 +743,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if quantize == "" {
|
if effectiveQuantize == "" {
|
||||||
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
|
layer, ok, err := createPrequantizedLayer(extractor, td, tensorName, tensorSet, sourceQuantMetadata, createLayer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
extractor.Close()
|
extractor.Close()
|
||||||
@@ -647,8 +767,33 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||||
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
// GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
|
||||||
quantizeType := ""
|
quantizeType := ""
|
||||||
if quantize != "" {
|
switch {
|
||||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, quantize)
|
case sourceQuantKind == sourceQuantizedKindHFFP8 && hasSourceFP8Scale:
|
||||||
|
quantizeType = "mxfp8"
|
||||||
|
case sourceQuantKind == sourceQuantizedKindHFFP8:
|
||||||
|
quantizeType = ""
|
||||||
|
case effectiveQuantize != "":
|
||||||
|
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||||
|
}
|
||||||
|
reader := outTD.SafetensorsReader()
|
||||||
|
if hasSourceFP8Scale {
|
||||||
|
if len(outputTensors) != 1 {
|
||||||
|
extractor.Close()
|
||||||
|
closeExtractors()
|
||||||
|
return fmt.Errorf("source fp8 tensor %s rewrote into %d tensors; only 1:1 rewrites are supported", tensorName, len(outputTensors))
|
||||||
|
}
|
||||||
|
if quantizeType == "" {
|
||||||
|
extractor.Close()
|
||||||
|
closeExtractors()
|
||||||
|
return fmt.Errorf("source fp8 tensor %s was not scheduled for mxfp8 conversion", tensorName)
|
||||||
|
}
|
||||||
|
scaleTD, err := extractor.GetTensor(sourceFP8ScaleName)
|
||||||
|
if err != nil {
|
||||||
|
extractor.Close()
|
||||||
|
closeExtractors()
|
||||||
|
return fmt.Errorf("failed to get fp8 scale tensor %s: %w", sourceFP8ScaleName, err)
|
||||||
|
}
|
||||||
|
reader = buildSourceFP8Reader(outTD, scaleTD.WithName(outTD.Name+".scale_inv"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this tensor belongs to an expert group for packing
|
// Check if this tensor belongs to an expert group for packing
|
||||||
@@ -670,13 +815,13 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
Dtype: outTD.Dtype,
|
Dtype: outTD.Dtype,
|
||||||
Shape: outTD.Shape,
|
Shape: outTD.Shape,
|
||||||
Quantize: quantizeType,
|
Quantize: quantizeType,
|
||||||
Reader: outTD.SafetensorsReader(),
|
Reader: reader,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// Store as minimal safetensors format (88 bytes header overhead)
|
// Store as minimal safetensors format (88 bytes header overhead)
|
||||||
// This enables native mmap loading via mlx_load_safetensors
|
// This enables native mmap loading via mlx_load_safetensors
|
||||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||||
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
newLayers, err := createTensorLayer(reader, outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
extractor.Close()
|
extractor.Close()
|
||||||
closeExtractors()
|
closeExtractors()
|
||||||
@@ -760,7 +905,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{}) bool {
|
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}) bool {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasSuffix(name, ".scales"):
|
case strings.HasSuffix(name, ".scales"):
|
||||||
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
|
_, ok := tensorSet[strings.TrimSuffix(name, ".scales")+".weight"]
|
||||||
@@ -768,11 +913,28 @@ func shouldSkipPrequantizedCompanion(name string, tensorSet map[string]struct{})
|
|||||||
case strings.HasSuffix(name, ".biases"):
|
case strings.HasSuffix(name, ".biases"):
|
||||||
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
|
_, ok := tensorSet[strings.TrimSuffix(name, ".biases")+".weight"]
|
||||||
return ok
|
return ok
|
||||||
|
case strings.HasSuffix(name, ".weight_scale_inv"):
|
||||||
|
_, ok := tensorSet[strings.TrimSuffix(name, "_scale_inv")]
|
||||||
|
return ok
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sourceFP8Companion(weightName string, tensorSet map[string]struct{}) (scaleName string, ok bool) {
|
||||||
|
if !strings.HasSuffix(weightName, ".weight") {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
scaleName = weightName + "_scale_inv"
|
||||||
|
_, ok = tensorSet[scaleName]
|
||||||
|
return scaleName, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSourceFP8Reader(weightTD, scaleTD *safetensors.TensorData) io.Reader {
|
||||||
|
return safetensors.BuildPackedSafetensorsReader([]*safetensors.TensorData{weightTD, scaleTD})
|
||||||
|
}
|
||||||
|
|
||||||
func createPrequantizedLayer(
|
func createPrequantizedLayer(
|
||||||
extractor *safetensors.TensorExtractor,
|
extractor *safetensors.TensorExtractor,
|
||||||
td *safetensors.TensorData,
|
td *safetensors.TensorData,
|
||||||
|
|||||||
@@ -246,6 +246,30 @@ func readSingleTensorRaw(t *testing.T, data []byte) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readSafetensorsHeaderNames(t *testing.T, data []byte) []string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var headerSize uint64
|
||||||
|
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||||
|
t.Fatalf("failed to read header size: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var header map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||||
|
t.Fatalf("failed to parse header: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, 0, len(header))
|
||||||
|
for name := range header {
|
||||||
|
if name == "__metadata__" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
slices.Sort(names)
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel(t *testing.T) {
|
func TestCreateSafetensorsModel(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
@@ -546,6 +570,215 @@ func TestCreateSafetensorsModel_PacksPrequantizedTensorTriplets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateSafetensorsModel_HFFP8AutoConvertsToMXFP8(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
configJSON := `{
|
||||||
|
"model_type": "test",
|
||||||
|
"architectures": ["TestModel"],
|
||||||
|
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||||
|
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("dense.weight", "BF16", []int32{128, 128}, make([]byte, 128*128*2)),
|
||||||
|
st.NewTensorDataFromBytes("norm.weight", "BF16", []int32{2}, make([]byte, 4)),
|
||||||
|
})
|
||||||
|
|
||||||
|
quantizeByName := make(map[string]string)
|
||||||
|
headerNamesByName := make(map[string][]string)
|
||||||
|
|
||||||
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||||
|
_, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return LayerInfo{}, err
|
||||||
|
}
|
||||||
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
quantizeByName[name] = quantize
|
||||||
|
headerNamesByName[name] = readSafetensorsHeaderNames(t, data)
|
||||||
|
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||||
|
|
||||||
|
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}); err != nil {
|
||||||
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := quantizeByName["linear.weight"]; got != "mxfp8" {
|
||||||
|
t.Fatalf("linear.weight quantization = %q, want %q", got, "mxfp8")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := quantizeByName["norm.weight"]; got != "" {
|
||||||
|
t.Fatalf("norm.weight quantization = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := quantizeByName["dense.weight"]; got != "" {
|
||||||
|
t.Fatalf("dense.weight quantization = %q, want empty", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := quantizeByName["linear.weight_scale_inv"]; ok {
|
||||||
|
t.Fatal("linear.weight_scale_inv should not be imported as a standalone tensor")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := headerNamesByName["linear.weight"]; !slices.Equal(got, []string{"linear.weight", "linear.weight.scale_inv"}) {
|
||||||
|
t.Fatalf("linear.weight blob tensors = %v, want %v", got, []string{"linear.weight", "linear.weight.scale_inv"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := headerNamesByName["norm.weight"]; !slices.Equal(got, []string{"norm.weight"}) {
|
||||||
|
t.Fatalf("norm.weight blob tensors = %v, want %v", got, []string{"norm.weight"})
|
||||||
|
}
|
||||||
|
if got := headerNamesByName["dense.weight"]; !slices.Equal(got, []string{"dense.weight"}) {
|
||||||
|
t.Fatalf("dense.weight blob tensors = %v, want %v", got, []string{"dense.weight"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSafetensorsModel_RejectsRequantizingQuantizedSources(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configJSON string
|
||||||
|
tensors []*st.TensorData
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "prequantized affine",
|
||||||
|
configJSON: `{"model_type": "test", "architectures": ["TestModel"]}`,
|
||||||
|
tensors: []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("linear.weight", "U32", []int32{4, 4}, make([]byte, 16)),
|
||||||
|
st.NewTensorDataFromBytes("linear.scales", "BF16", []int32{4, 1}, make([]byte, 8)),
|
||||||
|
},
|
||||||
|
wantErr: `cannot requantize already-quantized source model with --quantize "int4"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf fp8 source",
|
||||||
|
configJSON: `{
|
||||||
|
"model_type": "test",
|
||||||
|
"architectures": ["TestModel"],
|
||||||
|
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||||
|
}`,
|
||||||
|
tensors: []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("linear.weight", "F8_E4M3", []int32{2, 2}, []byte{1, 2, 3, 4}),
|
||||||
|
st.NewTensorDataFromBytes("linear.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
},
|
||||||
|
wantErr: `cannot requantize already-quantized fp8 source model with --quantize "int4"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config.json: %v", err)
|
||||||
|
}
|
||||||
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), tt.tensors)
|
||||||
|
|
||||||
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||||
|
return LayerInfo{}, nil
|
||||||
|
}
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||||
|
|
||||||
|
err := CreateSafetensorsModel("test-model", dir, "int4", createLayer, createTensorLayer, writeManifest, func(string) {})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Fatalf("error = %q, want substring %q", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSafetensorsModel_HFFP8PacksExperts(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
configJSON := `{
|
||||||
|
"model_type": "test",
|
||||||
|
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||||
|
"quantization_config": {"quant_method": "fp8", "weight_block_size": [128, 128]}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 2 experts so stacking produces a [2, 128, 128] tensor
|
||||||
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.0.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.up_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight", "F8_E4M3", []int32{128, 128}, make([]byte, 128*128)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.1.down_proj.weight_scale_inv", "BF16", []int32{1, 1}, make([]byte, 2)),
|
||||||
|
})
|
||||||
|
|
||||||
|
var packedLayerNames []string
|
||||||
|
var packedLayerTensors [][]PackedTensorInput
|
||||||
|
|
||||||
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||||
|
if _, err := io.ReadAll(r); err != nil {
|
||||||
|
return LayerInfo{}, err
|
||||||
|
}
|
||||||
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||||
|
if _, err := io.ReadAll(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []LayerInfo{{Name: name, Digest: "sha256:tensor_" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||||
|
packedLayerNames = append(packedLayerNames, groupName)
|
||||||
|
packedLayerTensors = append(packedLayerTensors, tensors)
|
||||||
|
return LayerInfo{Name: groupName, Digest: "sha256:packed_" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error { return nil }
|
||||||
|
|
||||||
|
if err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||||
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packedLayerNames) != 1 {
|
||||||
|
t.Fatalf("expected 1 packed layer, got %d: %v", len(packedLayerNames), packedLayerNames)
|
||||||
|
}
|
||||||
|
if packedLayerNames[0] != "language_model.model.layers.0.mlp.experts" {
|
||||||
|
t.Fatalf("unexpected packed layer name: %s", packedLayerNames[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all 6 expert tensors (2 experts × 3 proj types) were accumulated
|
||||||
|
tensors := packedLayerTensors[0]
|
||||||
|
if len(tensors) != 6 {
|
||||||
|
t.Fatalf("expected 6 tensors in packed group, got %d", len(tensors))
|
||||||
|
}
|
||||||
|
|
||||||
|
// All should be marked for mxfp8 quantization
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
if tensor.Quantize != "mxfp8" {
|
||||||
|
t.Fatalf("expected mxfp8 quantize for %s, got %q", tensor.Name, tensor.Quantize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
@@ -693,6 +926,113 @@ func TestCreateSafetensorsModel_Qwen35Transforms(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateSafetensorsModel_Qwen35DirectNonAffineKeepsSensitiveWeightsBF16(t *testing.T) {
|
||||||
|
for _, quantize := range []string{"nvfp4", "mxfp8", "mxfp4"} {
|
||||||
|
t.Run(quantize, func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
configJSON := `{
|
||||||
|
"model_type": "test",
|
||||||
|
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||||
|
"text_config": {"dtype": "bfloat16"}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gateUpValues := make([]float32, 2*128*64)
|
||||||
|
for expert := range 2 {
|
||||||
|
base := expert * 128 * 64
|
||||||
|
for i := range 64 * 64 {
|
||||||
|
gateUpValues[base+i] = 1
|
||||||
|
gateUpValues[base+64*64+i] = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("lm_head.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_a.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.linear_attn.in_proj_b.weight", "BF16", []int32{32, 64}, make([]byte, 32*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.shared_expert_gate.weight", "BF16", []int32{1, 64}, make([]byte, 64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||||
|
})
|
||||||
|
|
||||||
|
type tensorCall struct {
|
||||||
|
quantize string
|
||||||
|
}
|
||||||
|
type packedTensorCall struct {
|
||||||
|
Name string
|
||||||
|
Quantize string
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorCalls := make(map[string]tensorCall)
|
||||||
|
packedCalls := make(map[string][]packedTensorCall)
|
||||||
|
|
||||||
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||||
|
_, _ = io.ReadAll(r)
|
||||||
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantizeType string) ([]LayerInfo, error) {
|
||||||
|
_, _ = io.ReadAll(r)
|
||||||
|
tensorCalls[name] = tensorCall{quantize: quantizeType}
|
||||||
|
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||||
|
group := make([]packedTensorCall, 0, len(tensors))
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
group = append(group, packedTensorCall{
|
||||||
|
Name: tensor.Name,
|
||||||
|
Quantize: tensor.Quantize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
packedCalls[groupName] = group
|
||||||
|
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := CreateSafetensorsModel("test-model", dir, quantize, createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||||
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range []string{
|
||||||
|
"language_model.model.embed_tokens.weight",
|
||||||
|
"language_model.lm_head.weight",
|
||||||
|
"language_model.model.layers.0.linear_attn.in_proj_a.weight",
|
||||||
|
"language_model.model.layers.0.linear_attn.in_proj_b.weight",
|
||||||
|
"language_model.model.layers.0.mlp.gate.weight",
|
||||||
|
"language_model.model.layers.0.mlp.shared_expert_gate.weight",
|
||||||
|
} {
|
||||||
|
if got := tensorCalls[name].quantize; got != "" {
|
||||||
|
t.Fatalf("%s quantize = %q, want empty", name, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := tensorCalls["language_model.model.layers.0.self_attn.q_proj.weight"].quantize; got != quantize {
|
||||||
|
t.Fatalf("q_proj quantize = %q, want %q", got, quantize)
|
||||||
|
}
|
||||||
|
|
||||||
|
group := packedCalls["language_model.model.layers.0.mlp.switch_mlp"]
|
||||||
|
if len(group) != 3 {
|
||||||
|
t.Fatalf("packed switch_mlp tensor count = %d, want 3", len(group))
|
||||||
|
}
|
||||||
|
for _, tensor := range group {
|
||||||
|
if tensor.Quantize != quantize {
|
||||||
|
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, quantize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolveManifestPath(t *testing.T) {
|
func TestResolveManifestPath(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -865,6 +1205,7 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||||
|
{"large 2D weight mxfp4", "q_proj.weight", []int32{4096, 4096}, "mxfp4", true},
|
||||||
|
|
||||||
// Small tensors should not be quantized (< 1024 elements)
|
// Small tensors should not be quantized (< 1024 elements)
|
||||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||||
@@ -891,9 +1232,11 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||||
|
|
||||||
// Group size divisibility tests
|
// Group size divisibility tests
|
||||||
// FP8/FP4 require divisible by 32
|
// FP8/FP4/MXFP4 require divisible by 32
|
||||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||||
|
{"not divisible by 32 mxfp4", "proj.weight", []int32{128, 48}, "mxfp4", false},
|
||||||
|
{"divisible by 32 mxfp4", "proj.weight", []int32{128, 64}, "mxfp4", true},
|
||||||
// NVFP4 requires divisible by 16
|
// NVFP4 requires divisible by 16
|
||||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||||
@@ -919,10 +1262,20 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||||
|
|
||||||
|
// Expert tensors with language_model prefix should also match
|
||||||
|
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||||
|
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||||
|
|
||||||
// Shared expert tensors should return their own group prefix
|
// Shared expert tensors should return their own group prefix
|
||||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||||
|
|
||||||
|
// Rewritten Qwen switch_mlp tensors should also be packed per-layer.
|
||||||
|
{"model.layers.1.mlp.switch_mlp.down_proj.weight", "model.layers.1.mlp.switch_mlp"},
|
||||||
|
{"language_model.layers.2.mlp.switch_mlp.gate_proj.weight", "language_model.layers.2.mlp.switch_mlp"},
|
||||||
|
{"language_model.model.layers.3.mlp.switch_mlp.up_proj.weight", "language_model.model.layers.3.mlp.switch_mlp"},
|
||||||
|
{"model.language_model.layers.4.mlp.switch_mlp.gate_proj.weight", "model.language_model.layers.4.mlp.switch_mlp"},
|
||||||
|
|
||||||
// Non-expert tensors should return empty string
|
// Non-expert tensors should return empty string
|
||||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||||
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
||||||
@@ -978,6 +1331,161 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
|||||||
if combinedDown != "int8" {
|
if combinedDown != "int8" {
|
||||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nvfp4GateUp := GetTensorQuantization(
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||||
|
[]int32{64, 11008, 4096},
|
||||||
|
"nvfp4",
|
||||||
|
)
|
||||||
|
if nvfp4GateUp != "nvfp4" {
|
||||||
|
t.Fatalf("nvfp4 gate_proj quantization = %q, want %q", nvfp4GateUp, "nvfp4")
|
||||||
|
}
|
||||||
|
|
||||||
|
nvfp4Down := GetTensorQuantization(
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||||
|
[]int32{64, 4096, 11008},
|
||||||
|
"nvfp4",
|
||||||
|
)
|
||||||
|
if nvfp4Down != "nvfp4" {
|
||||||
|
t.Fatalf("nvfp4 down_proj quantization = %q, want %q", nvfp4Down, "nvfp4")
|
||||||
|
}
|
||||||
|
|
||||||
|
mxfp4GateUp := GetTensorQuantization(
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||||
|
[]int32{64, 11008, 4096},
|
||||||
|
"mxfp4",
|
||||||
|
)
|
||||||
|
if mxfp4GateUp != "mxfp4" {
|
||||||
|
t.Fatalf("mxfp4 gate_proj quantization = %q, want %q", mxfp4GateUp, "mxfp4")
|
||||||
|
}
|
||||||
|
|
||||||
|
mxfp4Down := GetTensorQuantization(
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||||
|
[]int32{64, 4096, 11008},
|
||||||
|
"mxfp4",
|
||||||
|
)
|
||||||
|
if mxfp4Down != "mxfp4" {
|
||||||
|
t.Fatalf("mxfp4 down_proj quantization = %q, want %q", mxfp4Down, "mxfp4")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
configJSON := `{
|
||||||
|
"model_type": "test",
|
||||||
|
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
|
||||||
|
"text_config": {"dtype": "bfloat16"}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gateUpValues := make([]float32, 2*128*64)
|
||||||
|
for expert := range 2 {
|
||||||
|
base := expert * 128 * 64
|
||||||
|
for i := range 64 * 64 {
|
||||||
|
gateUpValues[base+i] = 1
|
||||||
|
gateUpValues[base+64*64+i] = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.gate.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.gate_up_proj", "BF16", []int32{2, 128, 64}, bfloat16.EncodeFloat32(gateUpValues)),
|
||||||
|
st.NewTensorDataFromBytes("model.language_model.layers.0.mlp.experts.down_proj", "BF16", []int32{2, 64, 64}, bfloat16.EncodeFloat32(make([]float32, 2*64*64))),
|
||||||
|
})
|
||||||
|
|
||||||
|
type tensorCall struct {
|
||||||
|
quantize string
|
||||||
|
}
|
||||||
|
type packedTensorCall struct {
|
||||||
|
Name string
|
||||||
|
Dtype string
|
||||||
|
Shape []int32
|
||||||
|
Quantize string
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorCalls := make(map[string]tensorCall)
|
||||||
|
packedCalls := make(map[string][]packedTensorCall)
|
||||||
|
|
||||||
|
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||||
|
_, _ = io.ReadAll(r)
|
||||||
|
return LayerInfo{Name: name, Digest: "sha256:" + name, MediaType: mediaType}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||||
|
_, _ = io.ReadAll(r)
|
||||||
|
tensorCalls[name] = tensorCall{quantize: quantize}
|
||||||
|
return []LayerInfo{{Name: name, Digest: "sha256:" + name, MediaType: "application/vnd.ollama.image.tensor"}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
createPackedLayer := func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) {
|
||||||
|
group := make([]packedTensorCall, 0, len(tensors))
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
group = append(group, packedTensorCall{
|
||||||
|
Name: tensor.Name,
|
||||||
|
Dtype: tensor.Dtype,
|
||||||
|
Shape: append([]int32(nil), tensor.Shape...),
|
||||||
|
Quantize: tensor.Quantize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
packedCalls[groupName] = group
|
||||||
|
return LayerInfo{Name: groupName, Digest: "sha256:" + groupName, MediaType: "application/vnd.ollama.image.tensor"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := CreateSafetensorsModel("test-model", dir, "nvfp4", createLayer, createTensorLayer, writeManifest, func(string) {}, createPackedLayer); err != nil {
|
||||||
|
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupName := "language_model.model.layers.0.mlp.switch_mlp"
|
||||||
|
group, ok := packedCalls[groupName]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing packed group %q: %v", groupName, packedCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(group) != 3 {
|
||||||
|
t.Fatalf("packed group %q has %d tensors, want 3", groupName, len(group))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotNames := make([]string, 0, len(group))
|
||||||
|
for _, tensor := range group {
|
||||||
|
gotNames = append(gotNames, tensor.Name)
|
||||||
|
if tensor.Quantize != "nvfp4" {
|
||||||
|
t.Fatalf("packed tensor %q quantize = %q, want %q", tensor.Name, tensor.Quantize, "nvfp4")
|
||||||
|
}
|
||||||
|
if tensor.Dtype != "BF16" {
|
||||||
|
t.Fatalf("packed tensor %q dtype = %q, want %q", tensor.Name, tensor.Dtype, "BF16")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(gotNames)
|
||||||
|
|
||||||
|
wantNames := []string{
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.down_proj.weight",
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight",
|
||||||
|
"language_model.model.layers.0.mlp.switch_mlp.up_proj.weight",
|
||||||
|
}
|
||||||
|
if !slices.Equal(gotNames, wantNames) {
|
||||||
|
t.Fatalf("packed tensor names = %v, want %v", gotNames, wantNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range wantNames {
|
||||||
|
if _, ok := tensorCalls[name]; ok {
|
||||||
|
t.Fatalf("packed expert tensor %q unexpectedly handled by createTensorLayer", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := tensorCalls["language_model.model.embed_tokens.weight"].quantize; got != "" {
|
||||||
|
t.Fatalf("embed_tokens quantize = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := tensorCalls["language_model.model.layers.0.mlp.gate.weight"].quantize; got != "" {
|
||||||
|
t.Fatalf("mlp.gate quantize = %q, want empty", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||||
|
|||||||
@@ -87,6 +87,27 @@ func (t qwen35ImportTransform) skipTensor(name string) bool {
|
|||||||
return strings.Contains(name, "mtp.")
|
return strings.Contains(name, "mtp.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(name, "embed_tokens.weight"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, "lm_head.weight"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
|
||||||
|
return true
|
||||||
|
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||||
if strings.HasPrefix(name, "vision_tower.") {
|
if strings.HasPrefix(name, "vision_tower.") {
|
||||||
return ""
|
return ""
|
||||||
@@ -127,6 +148,13 @@ func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quan
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
|
||||||
|
// keep embeddings, LM head, low-rank linear_attn projections, and routing
|
||||||
|
// gates in BF16 rather than forcing them into a non-affine quantized format.
|
||||||
|
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return quantNorm
|
return quantNorm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
# Read MLX version from top-level file (shared with Dockerfile)
|
# Read MLX-C version from top-level file (shared with Dockerfile)
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
# Read MLX core version from top-level file
|
# Read MLX version from top-level file
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
|
||||||
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
||||||
|
|
||||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||||
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
|
|||||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
||||||
|
|
||||||
|
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
|
||||||
|
find_program(GO_EXECUTABLE go REQUIRED)
|
||||||
|
message(STATUS "Regenerating MLX Go wrappers")
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${GO_EXECUTABLE} generate ./x/...
|
||||||
|
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||||
|
COMMAND_ERROR_IS_FATAL ANY
|
||||||
|
)
|
||||||
|
|
||||||
# For local dev builds, override MLX_VERSION with git describe output
|
# For local dev builds, override MLX_VERSION with git describe output
|
||||||
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
|
|||||||
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
|
||||||
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
||||||
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
||||||
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
||||||
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
|
|||||||
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
||||||
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
|
|||||||
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
||||||
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
||||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
||||||
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
|
|||||||
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
|
|||||||
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
||||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||||
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
|
||||||
|
if (mlx_bartlett_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
||||||
if (mlx_bitwise_and_ptr == NULL) {
|
if (mlx_bitwise_and_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
||||||
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
|
||||||
|
if (mlx_blackman_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
||||||
if (mlx_block_masked_mm_ptr == NULL) {
|
if (mlx_block_masked_mm_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
||||||
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
|
||||||
|
if (mlx_hamming_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
|
||||||
|
if (mlx_hanning_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
||||||
if (mlx_identity_ptr == NULL) {
|
if (mlx_identity_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
||||||
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
|
|||||||
return mlx_distributed_group_split_ptr(group, color, key);
|
return mlx_distributed_group_split_ptr(group, color, key);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void) {
|
bool mlx_distributed_is_available(const char* bk) {
|
||||||
return mlx_distributed_is_available_ptr();
|
return mlx_distributed_is_available_ptr(bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict) {
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
|
||||||
return mlx_distributed_init_ptr(strict);
|
return mlx_distributed_init_ptr(strict, bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
||||||
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|||||||
return mlx_atleast_3d_ptr(res, a, s);
|
return mlx_atleast_3d_ptr(res, a, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
||||||
return mlx_bitwise_and_ptr(res, a, b, s);
|
return mlx_bitwise_and_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
|
|||||||
return mlx_bitwise_xor_ptr(res, a, b, s);
|
return mlx_bitwise_xor_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
||||||
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
||||||
}
|
}
|
||||||
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
|
|||||||
return mlx_depends_ptr(res, inputs, dependencies);
|
return mlx_depends_ptr(res, inputs, dependencies);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
|
|||||||
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_ptr(res, n, dtype, s);
|
return mlx_identity_ptr(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
|
|||||||
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
|
||||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
|
||||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||||
|
|||||||
@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
|
||||||
|
|
||||||
// Result is a vector of arrays: [weights, scales, biases?]
|
// Result is a vector of arrays: [weights, scales, biases?]
|
||||||
// mxfp8 mode returns only 2 elements (no biases)
|
// mxfp8 mode returns only 2 elements (no biases)
|
||||||
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
res := C.mlx_array_new()
|
res := C.mlx_array_new()
|
||||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
|
||||||
return newArray(res)
|
return newArray(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -309,10 +309,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -365,6 +367,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
|
|||||||
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_ptr)(void);
|
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
|
||||||
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
||||||
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
||||||
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
|
|||||||
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
||||||
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
|
|||||||
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
||||||
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
||||||
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
|
|||||||
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
|
|||||||
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
|
|||||||
|
|
||||||
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk);
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
|
|
||||||
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
|
|||||||
|
|
||||||
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
|
|
||||||
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
|
|||||||
|
|
||||||
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
|
|||||||
|
|
||||||
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
|
|
||||||
|
|||||||
@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for partial match within a node's edge — truncate path
|
|
||||||
// to the parent boundary. snapshot() will split the node and
|
|
||||||
// create the branch point during prefill when caches are ready.
|
|
||||||
partialMatch := false
|
|
||||||
if len(matchPath) > 1 {
|
|
||||||
lastNode := matchPath[len(matchPath)-1]
|
|
||||||
matchedInEdge := matched - lastNode.startOffset()
|
|
||||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
|
||||||
matchPath = matchPath[:len(matchPath)-1]
|
|
||||||
partialMatch = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Switch to the matched path, paging in/out as needed.
|
// Switch to the matched path, paging in/out as needed.
|
||||||
c.switchToPath(matchPath)
|
c.switchToPath(matchPath, matched)
|
||||||
|
|
||||||
// switchToPath aligns caches to a common offset
|
// switchToPath aligns caches to a common offset
|
||||||
prefix := c.minCacheOffset()
|
prefix := c.minCacheOffset()
|
||||||
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
// Schedule a snapshot at the branch point during prefill so future
|
// Schedule a snapshot at the branch point during prefill so future
|
||||||
// requests diverging here can restore instead of re-evaluating.
|
// requests diverging here can restore instead of re-evaluating.
|
||||||
var snapshotAt int
|
var snapshotAt int
|
||||||
if partialMatch || (prefix == 0 && matched > 0) {
|
if prefix < matched {
|
||||||
snapshotAt = matched
|
snapshotAt = matched
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
|
|
||||||
// switchToPath transitions from the current active path to a new path,
|
// switchToPath transitions from the current active path to a new path,
|
||||||
// paging out diverging segments and paging in the new path.
|
// paging out diverging segments and paging in the new path.
|
||||||
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||||
defer c.enforceEvictionPolicy()
|
defer c.enforceEvictionPolicy()
|
||||||
|
|
||||||
// Find common ancestor index.
|
// Find common ancestor index.
|
||||||
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
// non-leaf nodes here would produce wrong results for non-rewindable
|
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||||
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||||
// the intermediate boundary.
|
// the intermediate boundary.
|
||||||
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
leaf := len(c.activePath) - 1
|
||||||
|
leafDiverges := leaf >= commonLen
|
||||||
|
leafNeedsRewind := matched < c.activePath[leaf].endOffset
|
||||||
|
if leafDiverges || leafNeedsRewind {
|
||||||
node := c.activePath[leaf]
|
node := c.activePath[leaf]
|
||||||
if !node.hasAllSnapshots() {
|
if !node.hasAllSnapshots() {
|
||||||
fromOffset := node.startOffset()
|
fromOffset := node.startOffset()
|
||||||
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewind each cache to the ancestor offset or free it. Freed
|
// Rewind each cache to the target offset or free it. When matched
|
||||||
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
// falls within the ancestor's range (same-path case), we rewind
|
||||||
// from snapshots during page-in.
|
// directly to the match point. Otherwise we rewind to the ancestor
|
||||||
|
// and let page-in bring us forward to matched.
|
||||||
|
rewindTarget := min(ancestorOffset, matched)
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !kv.Restore(nil, ancestorOffset) {
|
if !kv.Restore(nil, rewindTarget) {
|
||||||
kv.Free()
|
kv.Free()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
// Page in — walk the full new path, restoring from snapshots.
|
// Page in — walk the full new path, restoring from snapshots.
|
||||||
// Freed caches naturally pick up the first available snapshot.
|
// Freed caches naturally pick up the first available snapshot.
|
||||||
// Caches already past a node skip it via offset check.
|
// Caches already past a node skip it via offset check.
|
||||||
|
pageIn:
|
||||||
for _, node := range newPath {
|
for _, node := range newPath {
|
||||||
if len(node.snapshots) == 0 {
|
if !node.hasSnapshots() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
nodeTarget := min(node.endOffset, matched)
|
||||||
for j, kv := range c.caches {
|
for j, kv := range c.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if kv.Offset() >= node.endOffset {
|
if kv.Offset() >= nodeTarget {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||||
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
// Restore failed — stop page-in and let alignment
|
||||||
c.freeAll()
|
// bring all caches to a consistent offset.
|
||||||
c.activePath = []*trieNode{c.root}
|
break pageIn
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if node.endOffset > ancestorOffset {
|
if node.endOffset > ancestorOffset {
|
||||||
pageInCount++
|
pageInCount++
|
||||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -536,6 +529,9 @@ func (c *kvCache) dumpTree() {
|
|||||||
if nodeBytes > 0 {
|
if nodeBytes > 0 {
|
||||||
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||||
}
|
}
|
||||||
|
if !n.lastUsed.IsZero() {
|
||||||
|
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
|
||||||
|
}
|
||||||
var flags []string
|
var flags []string
|
||||||
if n.user {
|
if n.user {
|
||||||
flags = append(flags, "user")
|
flags = append(flags, "user")
|
||||||
|
|||||||
28
x/mlxrunner/cache/cache.go
vendored
28
x/mlxrunner/cache/cache.go
vendored
@@ -17,7 +17,8 @@ type Cache interface {
|
|||||||
Snapshot(fromOffset int) Snapshot
|
Snapshot(fromOffset int) Snapshot
|
||||||
|
|
||||||
// Restore brings the cache to target. If snapshot is nil, rewinds
|
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||||
// using the cache's own live state.
|
// using the cache's own live state. Returns false if the target is
|
||||||
|
// unreachable (e.g. target > current offset, or negative).
|
||||||
Restore(snapshot Snapshot, target int) bool
|
Restore(snapshot Snapshot, target int) bool
|
||||||
|
|
||||||
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||||
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
// Rewind using live state — just clamp offset.
|
if target > c.offset {
|
||||||
target = max(0, min(target, c.offset))
|
return false
|
||||||
|
}
|
||||||
c.offset = target
|
c.offset = target
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
snap := snapshot.(*kvSnapshot)
|
snap := snapshot.(*kvSnapshot)
|
||||||
|
|
||||||
// Check that the cache has data up to the snapshot's starting point.
|
if target > snap.toOffset || c.offset < snap.fromOffset {
|
||||||
if c.offset < snap.fromOffset {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
|
if target >= c.offset {
|
||||||
|
return target == c.offset
|
||||||
|
}
|
||||||
// Live rewind is only safe when the buffer hasn't filled yet
|
// Live rewind is only safe when the buffer hasn't filled yet
|
||||||
// (offset <= maxSize). Once the window has shifted, rewinding
|
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||||
// leaves fewer than maxSize trailing tokens to attend to —
|
// leaves fewer than maxSize trailing tokens to attend to —
|
||||||
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
if c.offset > c.maxSize {
|
if c.offset > c.maxSize {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
target = max(0, min(target, c.offset))
|
|
||||||
c.offset = target
|
c.offset = target
|
||||||
c.idx = target
|
c.idx = target
|
||||||
return true
|
return true
|
||||||
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
snap := snapshot.(*rotatingSnapshot)
|
snap := snapshot.(*rotatingSnapshot)
|
||||||
|
|
||||||
|
if target > snap.toOffset {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Reject if clamping would leave an incomplete window.
|
// Reject if clamping would leave an incomplete window.
|
||||||
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||||
return false
|
return false
|
||||||
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
// Clamp to target if needed.
|
// Clamp to target if needed.
|
||||||
if target < c.offset {
|
if target < c.offset {
|
||||||
target = max(0, target)
|
|
||||||
c.offset = target
|
c.offset = target
|
||||||
c.idx = target
|
c.idx = target
|
||||||
}
|
}
|
||||||
|
|||||||
18
x/mlxrunner/cache/recurrent.go
vendored
18
x/mlxrunner/cache/recurrent.go
vendored
@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.Pin(v)
|
mlx.Pin(v)
|
||||||
if old != nil && old != v {
|
|
||||||
mlx.Unpin(old)
|
mlx.Unpin(old)
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
root := v
|
root := v
|
||||||
if ensureContiguous {
|
if ensureContiguous {
|
||||||
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
detached := root.Clone()
|
detached := root.Clone()
|
||||||
|
|
||||||
mlx.Pin(detached)
|
mlx.Pin(detached)
|
||||||
if old != nil && old != detached {
|
|
||||||
mlx.Unpin(old)
|
mlx.Unpin(old)
|
||||||
}
|
|
||||||
|
|
||||||
return detached
|
return detached
|
||||||
}
|
}
|
||||||
@@ -150,10 +140,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
snap := snapshot.(*recurrentSnapshot)
|
snap := snapshot.(*recurrentSnapshot)
|
||||||
|
|
||||||
// Recurrent state encodes all tokens up to snap.offset. Restoring
|
// Recurrent snapshots encode cumulative state up to exactly
|
||||||
// to a target before that would leave stale state from tokens
|
// snap.offset. Target must match — rewinding would leave stale
|
||||||
// [target, snap.offset) baked in. Only allow restoring forward.
|
// state, and advancing isn't possible without feeding tokens.
|
||||||
if target < snap.offset {
|
if target != snap.offset {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
34
x/mlxrunner/cache/recurrent_test.go
vendored
34
x/mlxrunner/cache/recurrent_test.go
vendored
@@ -6,39 +6,35 @@ import (
|
|||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
|
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
|
||||||
// allows restoring forward (target >= snapshot offset), never backward.
|
// only succeeds when target exactly matches the snapshot's offset. Recurrent
|
||||||
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
|
// state is cumulative, so it can't be rewound or fast-forwarded.
|
||||||
|
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||||
skipIfNoMLX(t)
|
skipIfNoMLX(t)
|
||||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||||
_ = c.ConvState(1, mlx.DTypeFloat16)
|
_ = c.ConvState(1, mlx.DTypeFloat16)
|
||||||
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
||||||
c.Advance(10)
|
c.Advance(10)
|
||||||
|
|
||||||
snap := c.Snapshot(0)
|
snap := c.Snapshot(0) // snap.offset == 10
|
||||||
|
|
||||||
c.Advance(5) // now at 15
|
c.Advance(5) // cache now at 15
|
||||||
|
|
||||||
// Restore backward should fail.
|
// target < snap.offset: fails (can't rewind past snapshot)
|
||||||
if c.Restore(snap, 5) {
|
if c.Restore(snap, 5) {
|
||||||
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
|
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore to exact snap offset should succeed.
|
// target > snap.offset: fails (can't advance without feeding tokens)
|
||||||
|
if c.Restore(snap, 15) {
|
||||||
|
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
|
||||||
|
}
|
||||||
|
|
||||||
|
// target == snap.offset: succeeds
|
||||||
if !c.Restore(snap, 10) {
|
if !c.Restore(snap, 10) {
|
||||||
t.Fatal("Restore(snap, 10) should succeed")
|
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||||
}
|
}
|
||||||
if c.Offset() != 10 {
|
if c.Offset() != 10 {
|
||||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
|
|
||||||
snap2 := c.Snapshot(0)
|
|
||||||
if !c.Restore(snap2, 15) {
|
|
||||||
t.Fatal("Restore(snap, 15) should succeed")
|
|
||||||
}
|
|
||||||
// Recurrent state is at snap.offset (10), not target (15).
|
|
||||||
if c.Offset() != 10 {
|
|
||||||
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
if snapshot == nil {
|
|
||||||
// Rewind live state.
|
|
||||||
if target < 0 {
|
if target < 0 {
|
||||||
target = 0
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if snapshot == nil {
|
||||||
if target > len(c.tokens) {
|
if target > len(c.tokens) {
|
||||||
target = len(c.tokens)
|
return false
|
||||||
}
|
}
|
||||||
c.tokens = c.tokens[:target]
|
c.tokens = c.tokens[:target]
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
if len(c.tokens) < s.from {
|
if target > s.to || len(c.tokens) < s.from {
|
||||||
return false // don't have base data up to snapshot start
|
return false
|
||||||
}
|
}
|
||||||
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||||
if target < len(c.tokens) {
|
if target < len(c.tokens) {
|
||||||
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
if target == len(c.tokens) {
|
if target >= len(c.tokens) {
|
||||||
return true
|
return target == len(c.tokens)
|
||||||
}
|
}
|
||||||
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||||
if len(c.tokens) > c.maxSize {
|
if len(c.tokens) > c.maxSize {
|
||||||
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
if target > s.to {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Reject if clamping would leave an incomplete window
|
||||||
|
// (matches RotatingKVCache behavior).
|
||||||
|
if target < s.to && s.to > c.maxSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
c.tokens = slices.Clone(s.tokens)
|
c.tokens = slices.Clone(s.tokens)
|
||||||
if target < len(c.tokens) {
|
if target < len(c.tokens) {
|
||||||
c.tokens = c.tokens[:target]
|
c.tokens = c.tokens[:target]
|
||||||
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
|||||||
return target == len(c.tokens) // can only no-op
|
return target == len(c.tokens) // can only no-op
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
if target < s.to {
|
if target != s.to {
|
||||||
return false // can't go backward
|
return false // cumulative state requires exact match
|
||||||
}
|
}
|
||||||
c.tokens = slices.Clone(s.tokens)
|
c.tokens = slices.Clone(s.tokens)
|
||||||
return true
|
return true
|
||||||
@@ -297,6 +309,7 @@ type testEnv struct {
|
|||||||
kvc *kvCache
|
kvc *kvCache
|
||||||
caches []cache.Cache // typed references for assertions
|
caches []cache.Cache // typed references for assertions
|
||||||
tracker *snapshotTracker
|
tracker *snapshotTracker
|
||||||
|
rewindable bool // true when all caches support arbitrary Restore(nil, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTransformerEnv creates a test environment with a single rewindable cache
|
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||||
@@ -308,20 +321,25 @@ func newTransformerEnv() *testEnv {
|
|||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tracker,
|
tracker: tracker,
|
||||||
|
rewindable: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||||
// one sliding window cache (Mistral-style architecture).
|
// one sliding window cache (Mistral-style architecture). The sliding window
|
||||||
|
// maxSize is set small enough that test sequences fill it, making
|
||||||
|
// Restore(nil, target) fail — the same behavior as production models where
|
||||||
|
// the window fills after a few turns.
|
||||||
func newSlidingWindowEnv() *testEnv {
|
func newSlidingWindowEnv() *testEnv {
|
||||||
tr := &snapshotTracker{}
|
tr := &snapshotTracker{}
|
||||||
rc := &fakeRewindableCache{tracker: tr}
|
rc := &fakeRewindableCache{tracker: tr}
|
||||||
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
|
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
|
||||||
caches := []cache.Cache{rc, sw}
|
caches := []cache.Cache{rc, sw}
|
||||||
return &testEnv{
|
return &testEnv{
|
||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tr,
|
tracker: tr,
|
||||||
|
rewindable: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,6 +354,7 @@ func newRecurrentEnv() *testEnv {
|
|||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tr,
|
tracker: tr,
|
||||||
|
rewindable: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||||
// Partial match in A's edge triggers snapshotOffset.
|
// For rewindable caches, switchToPath rewinds to the match point
|
||||||
|
// so only the non-matching suffix needs evaluation. For non-rewindable
|
||||||
|
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||||
|
if env.rewindable {
|
||||||
|
if resB.snapshotOffset != 0 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
if len(resB.remaining) != 3 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if resB.snapshotOffset != 5 {
|
if resB.snapshotOffset != 5 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||||
}
|
}
|
||||||
// Cache was rewound to 0 (partial match truncates path to root),
|
|
||||||
// so all tokens were re-evaluated.
|
|
||||||
if len(resB.remaining) != 8 {
|
if len(resB.remaining) != 8 {
|
||||||
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||||
|
|
||||||
@@ -635,15 +663,25 @@ func TestExactMatchSeedBehavior(t *testing.T) {
|
|||||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||||
|
|
||||||
// Request B: identical prompt. Holdback means matched=4, partial in
|
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||||
// the 5-token edge, so path truncates to root and all tokens are
|
// the 5-token edge. For rewindable caches, switchToPath rewinds to
|
||||||
// re-evaluated. snapshotOffset should be set at the holdback point.
|
// offset 4, so only the held-back token needs re-evaluation. For
|
||||||
|
// non-rewindable caches, the rewind fails and freeAll fires.
|
||||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||||
|
if env.rewindable {
|
||||||
|
if len(resB.remaining) != 1 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||||
|
}
|
||||||
|
if resB.snapshotOffset != 0 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if len(resB.remaining) != 5 {
|
if len(resB.remaining) != 5 {
|
||||||
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
|
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||||
}
|
}
|
||||||
if resB.snapshotOffset != 4 {
|
if resB.snapshotOffset != 4 {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||||
|
|
||||||
checkTrieInvariants(t, kvc.root)
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
|||||||
@@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
resp, err := c.client.Do(httpReq)
|
resp, err := c.client.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanner.Err()
|
if err := scanner.Err(); err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
|||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||||
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
mlx-c
|
mlx-c
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
|
|||||||
for _, t := range s {
|
for _, t := range s {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
t.pinned--
|
t.pinned--
|
||||||
|
if t.pinned < 0 {
|
||||||
|
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,7 +264,7 @@ func LogArrays() {
|
|||||||
|
|
||||||
for _, t := range arrays {
|
for _, t := range arrays {
|
||||||
nb := t.NumBytes()
|
nb := t.NumBytes()
|
||||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ var (
|
|||||||
gatedDeltaMetalKernelOnce sync.Once
|
gatedDeltaMetalKernelOnce sync.Once
|
||||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||||
gatedDeltaMetalDisabled bool
|
gatedDeltaMetalDisabled bool
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce sync.Once
|
||||||
|
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||||
|
gatedDeltaCUDADisabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
const gatedDeltaMetalKernelSource = `
|
const gatedDeltaMetalKernelSource = `
|
||||||
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const gatedDeltaCUDAKernelSource = `
|
||||||
|
auto tid_x = threadIdx.x;
|
||||||
|
auto tid_y = threadIdx.y;
|
||||||
|
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||||
|
auto grid_z = blockIdx.z;
|
||||||
|
|
||||||
|
int T_val = static_cast<int>(*T);
|
||||||
|
|
||||||
|
auto n = grid_z;
|
||||||
|
auto b_idx = n / Hv;
|
||||||
|
auto hv_idx = n % Hv;
|
||||||
|
auto hk_idx = hv_idx / (Hv / Hk);
|
||||||
|
constexpr int n_per_t = Dk / 32;
|
||||||
|
|
||||||
|
// q, k: [B, T, Hk, Dk]
|
||||||
|
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
|
||||||
|
// v, y: [B, T, Hv, Dv]
|
||||||
|
auto dv_idx = grid_y;
|
||||||
|
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
|
||||||
|
auto dk_idx = tid_x;
|
||||||
|
|
||||||
|
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||||
|
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||||
|
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||||
|
|
||||||
|
float state[n_per_t];
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = static_cast<float>(i_state[s_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// g: [B, T, Hv]
|
||||||
|
auto g_ = g + b_idx * T_val * Hv;
|
||||||
|
auto beta_ = beta + b_idx * T_val * Hv;
|
||||||
|
|
||||||
|
for (int t = 0; t < T_val; ++t) {
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||||
|
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction (full warp, 32 threads in x)
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||||
|
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||||
|
|
||||||
|
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||||
|
|
||||||
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||||
|
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||||
|
if (tid_x == 0) {
|
||||||
|
y[dv_idx] = static_cast<InT>(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
q_ += Hk * Dk;
|
||||||
|
k_ += Hk * Dk;
|
||||||
|
v_ += Hv * Dv;
|
||||||
|
y += Hv * Dv;
|
||||||
|
g_ += Hv;
|
||||||
|
beta_ += Hv;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||||
vec := C.mlx_vector_string_new()
|
vec := C.mlx_vector_string_new()
|
||||||
ok := true
|
ok := true
|
||||||
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
|||||||
return Concatenate(outs, 1), nextState
|
return Concatenate(outs, 1), nextState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initGatedDeltaCUDAKernel() {
|
||||||
|
var cudaAvail C.bool
|
||||||
|
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeInputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeInputs()
|
||||||
|
|
||||||
|
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeOutputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeOutputs()
|
||||||
|
|
||||||
|
cName := C.CString("gated_delta_step")
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||||
|
defer C.free(unsafe.Pointer(cSource))
|
||||||
|
cHeader := C.CString("")
|
||||||
|
defer C.free(unsafe.Pointer(cHeader))
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||||
|
cName,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
cSource,
|
||||||
|
cHeader,
|
||||||
|
C.bool(true),
|
||||||
|
C.int(0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
Hv, Dv := vd[2], vd[3]
|
||||||
|
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype := q.DType()
|
||||||
|
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||||
|
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||||
|
|
||||||
|
cInT := C.CString("InT")
|
||||||
|
defer C.free(unsafe.Pointer(cInT))
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
for _, tpl := range []struct {
|
||||||
|
name string
|
||||||
|
value int
|
||||||
|
}{
|
||||||
|
{name: "Dk", value: Dk},
|
||||||
|
{name: "Dv", value: Dv},
|
||||||
|
{name: "Hk", value: Hk},
|
||||||
|
{name: "Hv", value: Hv},
|
||||||
|
} {
|
||||||
|
cn := C.CString(tpl.name)
|
||||||
|
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||||
|
C.free(unsafe.Pointer(cn))
|
||||||
|
if rc != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||||
|
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
threadY := Dv
|
||||||
|
if threadY > 4 {
|
||||||
|
threadY = 4
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
tScalar := FromValue(T)
|
||||||
|
inputs := []C.mlx_array{
|
||||||
|
q.ctx,
|
||||||
|
k.ctx,
|
||||||
|
v.ctx,
|
||||||
|
g.ctx,
|
||||||
|
beta.ctx,
|
||||||
|
state.ctx,
|
||||||
|
tScalar.ctx,
|
||||||
|
}
|
||||||
|
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
y = New("GATED_DELTA_CUDA_Y")
|
||||||
|
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||||
|
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||||
|
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||||
|
return y, nextState, true
|
||||||
|
}
|
||||||
|
|
||||||
// GatedDelta runs the recurrent update operation.
|
// GatedDelta runs the recurrent update operation.
|
||||||
//
|
//
|
||||||
// It uses the fused Metal kernel when available and otherwise falls back to a
|
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||||
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||||
|
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||||
|
return y, nextState
|
||||||
|
}
|
||||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||||
return y, nextState
|
return y, nextState
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
|
|||||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */) = NULL;
|
||||||
void (*mlx_set_error_handler_)(
|
void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
|
|||||||
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_)(
|
int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_)(
|
int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_)(
|
int (*mlx_inner_)(
|
||||||
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_)(
|
int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_)(
|
int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_atleast_1d);
|
CHECK_LOAD(handle, mlx_atleast_1d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_2d);
|
CHECK_LOAD(handle, mlx_atleast_2d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_3d);
|
CHECK_LOAD(handle, mlx_atleast_3d);
|
||||||
|
CHECK_LOAD(handle, mlx_bartlett);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_and);
|
CHECK_LOAD(handle, mlx_bitwise_and);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_invert);
|
CHECK_LOAD(handle, mlx_bitwise_invert);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_or);
|
CHECK_LOAD(handle, mlx_bitwise_or);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_xor);
|
CHECK_LOAD(handle, mlx_bitwise_xor);
|
||||||
|
CHECK_LOAD(handle, mlx_blackman);
|
||||||
CHECK_LOAD(handle, mlx_block_masked_mm);
|
CHECK_LOAD(handle, mlx_block_masked_mm);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_to);
|
CHECK_LOAD(handle, mlx_broadcast_to);
|
||||||
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_greater);
|
CHECK_LOAD(handle, mlx_greater);
|
||||||
CHECK_LOAD(handle, mlx_greater_equal);
|
CHECK_LOAD(handle, mlx_greater_equal);
|
||||||
CHECK_LOAD(handle, mlx_hadamard_transform);
|
CHECK_LOAD(handle, mlx_hadamard_transform);
|
||||||
|
CHECK_LOAD(handle, mlx_hamming);
|
||||||
|
CHECK_LOAD(handle, mlx_hanning);
|
||||||
CHECK_LOAD(handle, mlx_identity);
|
CHECK_LOAD(handle, mlx_identity);
|
||||||
CHECK_LOAD(handle, mlx_imag);
|
CHECK_LOAD(handle, mlx_imag);
|
||||||
CHECK_LOAD(handle, mlx_inner);
|
CHECK_LOAD(handle, mlx_inner);
|
||||||
|
|||||||
@@ -300,10 +300,12 @@
|
|||||||
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
||||||
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
||||||
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
||||||
|
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
|
||||||
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
||||||
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
||||||
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
||||||
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
||||||
|
#define mlx_blackman mlx_blackman_mlx_gen_orig_
|
||||||
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
||||||
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
||||||
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
||||||
@@ -356,6 +358,8 @@
|
|||||||
#define mlx_greater mlx_greater_mlx_gen_orig_
|
#define mlx_greater mlx_greater_mlx_gen_orig_
|
||||||
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
||||||
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
||||||
|
#define mlx_hamming mlx_hamming_mlx_gen_orig_
|
||||||
|
#define mlx_hanning mlx_hanning_mlx_gen_orig_
|
||||||
#define mlx_identity mlx_identity_mlx_gen_orig_
|
#define mlx_identity mlx_identity_mlx_gen_orig_
|
||||||
#define mlx_imag mlx_imag_mlx_gen_orig_
|
#define mlx_imag mlx_imag_mlx_gen_orig_
|
||||||
#define mlx_inner mlx_inner_mlx_gen_orig_
|
#define mlx_inner mlx_inner_mlx_gen_orig_
|
||||||
@@ -889,10 +893,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -945,6 +951,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
|
|||||||
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_)(void);
|
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
extern void (*mlx_set_error_handler_)(
|
extern void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
|
|||||||
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_)(
|
extern int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_)(
|
extern int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_)(
|
extern int (*mlx_inner_)(
|
||||||
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantize_)(
|
extern int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_)(
|
extern int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
|
|||||||
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
||||||
return mlx_distributed_group_split_(group, color, key);
|
return mlx_distributed_group_split_(group, color, key);
|
||||||
}
|
}
|
||||||
static inline bool mlx_distributed_is_available(void) {
|
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
|
||||||
return mlx_distributed_is_available_();
|
return mlx_distributed_is_available_(bk);
|
||||||
}
|
}
|
||||||
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
|
static inline mlx_distributed_group mlx_distributed_init(
|
||||||
return mlx_distributed_init_(strict);
|
bool strict,
|
||||||
|
const char* bk /* may be null */) {
|
||||||
|
return mlx_distributed_init_(strict, bk);
|
||||||
}
|
}
|
||||||
static inline void mlx_set_error_handler(
|
static inline void mlx_set_error_handler(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
|
|||||||
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||||
return mlx_atleast_3d_(res, a, s);
|
return mlx_atleast_3d_(res, a, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_bitwise_and(
|
static inline int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_bitwise_xor_(res, a, b, s);
|
return mlx_bitwise_xor_(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_block_masked_mm(
|
static inline int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
return mlx_diag_(res, a, k, s);
|
return mlx_diag_(res, a, k, s);
|
||||||
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_hadamard_transform_(res, a, scale, s);
|
return mlx_hadamard_transform_(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_(res, M, s);
|
||||||
|
}
|
||||||
|
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_(res, n, dtype, s);
|
return mlx_identity_(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantize(
|
static inline int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_quantize_(res, w, group_size, bits, mode, s);
|
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantized_matmul(
|
static inline int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Vendored MLX-C Headers
|
# Vendored MLX-C Headers
|
||||||
|
|
||||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||||
The pinned version is in `MLX_VERSION` at the repo root.
|
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||||
|
|
||||||
Headers are automatically refreshed when you run a CMake build:
|
Headers are automatically refreshed when you run a CMake build:
|
||||||
|
|
||||||
|
|||||||
@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
|||||||
/**
|
/**
|
||||||
* Check if distributed is available.
|
* Check if distributed is available.
|
||||||
*/
|
*/
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize distributed.
|
* Initialize distributed.
|
||||||
*/
|
*/
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
|
|
||||||
/**@}*/
|
/**@}*/
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ int mlx_astype(
|
|||||||
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_bitwise_and(
|
int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_block_masked_mm(
|
int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -362,6 +364,7 @@ int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_inner(
|
int mlx_inner(
|
||||||
@@ -790,6 +795,8 @@ int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantize(
|
int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -797,6 +804,7 @@ int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantized_matmul(
|
int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -4,35 +4,91 @@ package mlx
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"iter"
|
"iter"
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Load(path string) iter.Seq2[string, *Array] {
|
// SafetensorsFile represents a loaded safetensors file.
|
||||||
return func(yield func(string, *Array) bool) {
|
type SafetensorsFile struct {
|
||||||
string2array := C.mlx_map_string_to_array_new()
|
arrays C.mlx_map_string_to_array
|
||||||
defer C.mlx_map_string_to_array_free(string2array)
|
metadata C.mlx_map_string_to_string
|
||||||
|
}
|
||||||
|
|
||||||
string2string := C.mlx_map_string_to_string_new()
|
func loadSafetensorsStream() C.mlx_stream {
|
||||||
defer C.mlx_map_string_to_string_free(string2string)
|
if runtime.GOOS == "darwin" {
|
||||||
|
return C.mlx_default_cpu_stream_new()
|
||||||
|
}
|
||||||
|
return C.mlx_default_gpu_stream_new()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
||||||
|
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||||
|
var arrays C.mlx_map_string_to_array
|
||||||
|
var metadata C.mlx_map_string_to_string
|
||||||
|
|
||||||
cPath := C.CString(path)
|
cPath := C.CString(path)
|
||||||
defer C.free(unsafe.Pointer(cPath))
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
|
||||||
// Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu).
|
stream := loadSafetensorsStream()
|
||||||
// macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream.
|
|
||||||
var stream C.mlx_stream
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
stream = C.mlx_default_cpu_stream_new()
|
|
||||||
} else {
|
|
||||||
stream = C.mlx_default_gpu_stream_new()
|
|
||||||
}
|
|
||||||
defer C.mlx_stream_free(stream)
|
defer C.mlx_stream_free(stream)
|
||||||
|
|
||||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, stream)
|
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||||
|
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a tensor by name.
|
||||||
|
func (s *SafetensorsFile) Get(name string) *Array {
|
||||||
|
cName := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
|
||||||
|
value := C.mlx_array_new()
|
||||||
|
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if value.ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := New(name)
|
||||||
|
arr.ctx = value
|
||||||
|
return arr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetadata retrieves a metadata value by key.
|
||||||
|
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||||
|
cKey := C.CString(key)
|
||||||
|
defer C.free(unsafe.Pointer(cKey))
|
||||||
|
|
||||||
|
var cValue *C.char
|
||||||
|
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return C.GoString(cValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free releases the loaded safetensors maps.
|
||||||
|
func (s *SafetensorsFile) Free() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
C.mlx_map_string_to_array_free(s.arrays)
|
||||||
|
C.mlx_map_string_to_string_free(s.metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load(path string) iter.Seq2[string, *Array] {
|
||||||
|
return func(yield func(string, *Array) bool) {
|
||||||
|
sf, err := LoadSafetensorsNative(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sf.Free()
|
||||||
|
|
||||||
|
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
||||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -51,3 +107,43 @@ func Load(path string) iter.Seq2[string, *Array] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
||||||
|
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||||
|
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
||||||
|
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||||
|
cPath := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
|
||||||
|
cArrays := C.mlx_map_string_to_array_new()
|
||||||
|
defer C.mlx_map_string_to_array_free(cArrays)
|
||||||
|
|
||||||
|
for name, arr := range arrays {
|
||||||
|
if arr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cName := C.CString(name)
|
||||||
|
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||||
|
C.free(unsafe.Pointer(cName))
|
||||||
|
}
|
||||||
|
|
||||||
|
cMetadata := C.mlx_map_string_to_string_new()
|
||||||
|
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||||
|
|
||||||
|
for key, value := range metadata {
|
||||||
|
cKey := C.CString(key)
|
||||||
|
cValue := C.CString(value)
|
||||||
|
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||||
|
C.free(unsafe.Pointer(cKey))
|
||||||
|
C.free(unsafe.Pointer(cValue))
|
||||||
|
}
|
||||||
|
|
||||||
|
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
||||||
|
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,8 +7,44 @@ package mlx
|
|||||||
// #cgo LDFLAGS: -lstdc++
|
// #cgo LDFLAGS: -lstdc++
|
||||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
|
// #include <string.h>
|
||||||
|
//
|
||||||
|
// static char _mlx_last_error_msg[1024] = {0};
|
||||||
|
// static int _mlx_last_error_flag = 0;
|
||||||
|
//
|
||||||
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
|
// (void)data;
|
||||||
|
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||||
|
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||||
|
// _mlx_last_error_flag = 1;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_install_capture_handler(void) {
|
||||||
|
// if (mlx_set_error_handler_) {
|
||||||
|
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_clear_last_error(void) {
|
||||||
|
// _mlx_last_error_flag = 0;
|
||||||
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static int mlx_had_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static const char* mlx_get_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||||
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
|
// the error message so we can surface it in Go.
|
||||||
|
C.mlx_install_capture_handler()
|
||||||
|
}
|
||||||
|
|
||||||
// Version returns the MLX core library version string.
|
// Version returns the MLX core library version string.
|
||||||
func Version() string {
|
func Version() string {
|
||||||
str := C.mlx_string_new()
|
str := C.mlx_string_new()
|
||||||
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
var rc C.int
|
||||||
if async {
|
if async {
|
||||||
C.mlx_async_eval(vector)
|
rc = C.mlx_async_eval(vector)
|
||||||
} else {
|
} else {
|
||||||
C.mlx_eval(vector)
|
rc = C.mlx_eval(vector)
|
||||||
|
}
|
||||||
|
if rc != 0 {
|
||||||
|
msg := "mlx eval failed"
|
||||||
|
if C.mlx_had_last_error() != 0 {
|
||||||
|
msg = C.GoString(C.mlx_get_last_error())
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
defer C.mlx_vector_array_free(res)
|
defer C.mlx_vector_array_free(res)
|
||||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||||
|
|
||||||
vecSize := int(C.mlx_vector_array_size(res))
|
vecSize := int(C.mlx_vector_array_size(res))
|
||||||
w0 := New("QUANTIZE_W")
|
w0 := New("QUANTIZE_W")
|
||||||
@@ -32,6 +33,18 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
return w0, w1, nil
|
return w0, w1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FromFP8(x *Array, dtype DType) *Array {
|
||||||
|
out := New("FROM_FP8")
|
||||||
|
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToFP8(x *Array) *Array {
|
||||||
|
out := New("TO_FP8")
|
||||||
|
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||||
cMode := C.CString(mode)
|
cMode := C.CString(mode)
|
||||||
defer C.free(unsafe.Pointer(cMode))
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
@@ -45,7 +58,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := New("DEQUANTIZE")
|
out := New("DEQUANTIZE")
|
||||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,6 +149,40 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Pad(a *Array, paddings []int32) *Array {
|
||||||
|
numAxes := len(paddings) / 2
|
||||||
|
axes := make([]C.int, numAxes)
|
||||||
|
lowPad := make([]C.int, numAxes)
|
||||||
|
highPad := make([]C.int, numAxes)
|
||||||
|
for i := range numAxes {
|
||||||
|
axes[i] = C.int(i)
|
||||||
|
lowPad[i] = C.int(paddings[i*2])
|
||||||
|
highPad[i] = C.int(paddings[i*2+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
padValue := C.mlx_array_new_float(C.float(0))
|
||||||
|
defer C.mlx_array_free(padValue)
|
||||||
|
|
||||||
|
cMode := C.CString("constant")
|
||||||
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
|
|
||||||
|
out := New("PAD")
|
||||||
|
C.mlx_pad(
|
||||||
|
&out.ctx,
|
||||||
|
a.ctx,
|
||||||
|
unsafe.SliceData(axes),
|
||||||
|
C.size_t(len(axes)),
|
||||||
|
unsafe.SliceData(lowPad),
|
||||||
|
C.size_t(len(lowPad)),
|
||||||
|
unsafe.SliceData(highPad),
|
||||||
|
C.size_t(len(highPad)),
|
||||||
|
padValue,
|
||||||
|
cMode,
|
||||||
|
DefaultStream().ctx,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||||
groups := int32(x.Dim(x.NumDims() - 1))
|
groups := int32(x.Dim(x.NumDims() - 1))
|
||||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||||
|
|||||||
@@ -11,8 +11,10 @@ func QuantizationParams(quantization string) (groupSize, bits int, mode string)
|
|||||||
switch strings.ToUpper(quantization) {
|
switch strings.ToUpper(quantization) {
|
||||||
case "NVFP4":
|
case "NVFP4":
|
||||||
return 16, 4, "nvfp4"
|
return 16, 4, "nvfp4"
|
||||||
|
case "MXFP4":
|
||||||
|
return 32, 4, "mxfp4"
|
||||||
case "FP4", "Q4", "INT4":
|
case "FP4", "Q4", "INT4":
|
||||||
return 32, 4, "affine"
|
return 64, 4, "affine"
|
||||||
case "MXFP8":
|
case "MXFP8":
|
||||||
return 32, 8, "mxfp8"
|
return 32, 8, "mxfp8"
|
||||||
case "FP8", "Q8", "INT8":
|
case "FP8", "Q8", "INT8":
|
||||||
|
|||||||
@@ -144,3 +144,44 @@ func TestLayerNormDefaultEps(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
weightVals := make([]float32, 3*32)
|
||||||
|
for i := range weightVals {
|
||||||
|
weightVals[i] = float32((i%11)-5) / 7
|
||||||
|
}
|
||||||
|
inputVals := make([]float32, 2*32)
|
||||||
|
for i := range inputVals {
|
||||||
|
inputVals[i] = float32((i%7)-3) / 5
|
||||||
|
}
|
||||||
|
|
||||||
|
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
|
||||||
|
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
|
||||||
|
mlx.Eval(weight, input)
|
||||||
|
|
||||||
|
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
|
||||||
|
if ql.QBiases != nil {
|
||||||
|
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
|
||||||
|
}
|
||||||
|
|
||||||
|
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
||||||
|
mlx.Eval(dequantizedWeight)
|
||||||
|
|
||||||
|
qOut := ql.Forward(input)
|
||||||
|
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
|
||||||
|
mlx.Eval(qOut, dOut)
|
||||||
|
|
||||||
|
got := qOut.Floats()
|
||||||
|
want := dOut.Floats()
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("output length = %d, want %d", len(got), len(want))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
if !approxEqual(got[i], want[i], 1e-3) {
|
||||||
|
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -420,7 +420,16 @@ func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func supportsGatherQMM(mode string, bits int) bool {
|
func supportsGatherQMM(mode string, bits int) bool {
|
||||||
return mode == "affine" && (bits == 4 || bits == 8)
|
switch mode {
|
||||||
|
case "affine":
|
||||||
|
return bits == 4 || bits == 8
|
||||||
|
case "mxfp8":
|
||||||
|
return bits == 8
|
||||||
|
case "nvfp4", "mxfp4":
|
||||||
|
return bits == 4
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {
|
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {
|
||||||
|
|||||||
@@ -83,6 +83,28 @@ func TestLayerSelectionHelpers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSupportsGatherQMM(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
mode string
|
||||||
|
bits int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{mode: "affine", bits: 4, want: true},
|
||||||
|
{mode: "affine", bits: 8, want: true},
|
||||||
|
{mode: "mxfp8", bits: 8, want: true},
|
||||||
|
{mode: "nvfp4", bits: 4, want: true},
|
||||||
|
{mode: "mxfp4", bits: 4, want: true},
|
||||||
|
{mode: "mxfp8", bits: 4, want: false},
|
||||||
|
{mode: "affine", bits: 3, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
|
||||||
|
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolveTensorPathLayout(t *testing.T) {
|
func TestResolveTensorPathLayout(t *testing.T) {
|
||||||
dummy := mlx.New("dummy")
|
dummy := mlx.New("dummy")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user