Files
ollama/x/mlxrunner/mlx/ops_extra.go
Jesse Gross 5daf59cc66 mlxrunner: Fix memory leaks with pin/sweep lifecycle management
The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
2026-02-23 09:50:07 -08:00

443 lines
10 KiB
Go

//go:build mlx
package mlx
// #include "generated.h"
import "C"
import (
"reflect"
"unsafe"
)
// Quantization operations
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W")
C.mlx_vector_array_get(&w0.ctx, res, 0)
w1 := New("QUANTIZE_S")
C.mlx_vector_array_get(&w1.ctx, res, 1)
if vecSize >= 3 {
w2 := New("QUANTIZE_B")
C.mlx_vector_array_get(&w2.ctx, res, 2)
return w0, w1, w2
}
return w0, w1, nil
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
return out
}
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
}
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
// Missing tensor ops
func Tile(a *Array, reps []int32) *Array {
cReps := make([]C.int, len(reps))
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
func Tri(n, m int32, k int) *Array {
out := New("TRI")
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
return out
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
vectorData := make([]C.mlx_array, len(arrays))
for i := range arrays {
vectorData[i] = arrays[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func Neg(a *Array) *Array {
return a.Negative()
}
func Sum(a *Array, axis int, keepDims bool) *Array {
return a.SumAxis(axis, keepDims)
}
func Argsort(a *Array, axis int) *Array {
return a.ArgsortAxis(axis)
}
func Take(a *Array, indices *Array, axis int) *Array {
return a.TakeAxis(indices, axis)
}
func RSqrt(a *Array) *Array {
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func Argpartition(a *Array, kth int, axis int) *Array {
return a.ArgpartitionAxis(kth, axis)
}
func TakeAlongAxis(a, indices *Array, axis int) *Array {
return a.TakeAlongAxis(indices, axis)
}
// Function-style wrappers matching imagegen API
func Add(a, b *Array) *Array {
return a.Add(b)
}
func Sub(a, b *Array) *Array {
return a.Subtract(b)
}
func Mul(a, b *Array) *Array {
return a.Multiply(b)
}
func Div(a, b *Array) *Array {
return a.Divide(b)
}
func Matmul(a, b *Array) *Array {
return a.Matmul(b)
}
func Reshape(a *Array, shape ...int32) *Array {
axes := make([]int, len(shape))
for i, s := range shape {
axes[i] = int(s)
}
return a.Reshape(axes...)
}
func Transpose(a *Array, axes ...int) *Array {
return a.Transpose(axes...)
}
func ExpandDims(a *Array, axis int) *Array {
return a.ExpandDims(axis)
}
func Squeeze(a *Array, axis int) *Array {
return a.Squeeze(axis)
}
func Flatten(a *Array) *Array {
return a.Flatten(0, -1)
}
func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 {
return nil
}
return arrays[0].Concatenate(axis, arrays[1:]...)
}
func SliceStartStop(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
if lhsIndices == nil {
lhsIndices = New("")
}
if rhsIndices == nil {
rhsIndices = New("")
}
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
func SiLU(a *Array) *Array {
sig := a.Sigmoid()
return a.Multiply(sig)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
C.float(scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}
func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("")
sinks := New("")
mode := ""
if causalMask {
mode = "causal"
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
return c.Addmm(a, b, alpha, beta)
}
// Scalar helpers
// scalarWithDtype creates a scalar array matching the dtype of a.
// Matching dtype is important for graph fusion and avoiding implicit casts.
func scalarWithDtype(s float32, a *Array) C.mlx_array {
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.DType()
if dtype == DTypeFloat32 {
return f32
}
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_array_free(f32)
return casted
}
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := FromValue(int(s))
return a.FloorDivide(scalar)
}
// Array constructors
func NewArrayInt32(data []int32, shape []int32) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
out := New("NEW_ARRAY_INT32")
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
return out
}
func NewScalarArray(value float32) *Array {
out := New("SCALAR")
out.ctx = C.mlx_array_new_float32(C.float(value))
return out
}
func ZerosF32(shape []int32) *Array {
return Zeros(DTypeFloat32, func() []int {
ints := make([]int, len(shape))
for i, s := range shape {
ints[i] = int(s)
}
return ints
}()...)
}
// Utility
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
switch v.Kind() {
case reflect.Struct:
// Check if this struct IS an Array (not a pointer to one)
if arr, ok := v.Addr().Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
case reflect.Map:
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
case reflect.Interface:
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
}
}
func EnableCompile() {
C.mlx_enable_compile()
}
func DisableCompile() {
C.mlx_disable_compile()
}