From c88fb286ec87d3ecdbe59f2ef2316a1e6024b43d Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 13 Apr 2026 11:43:24 -0700 Subject: [PATCH] mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913) * mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum, Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip, ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor RoPEWithBase to delegate to RoPEWithFreqs. * review comments * mlx: fix ScaledDotProductAttentionMasked to consult the mask argument --- x/create/client/quantize.go | 2 +- x/mlxrunner/mlx/act.go | 61 +++++++++-- x/mlxrunner/mlx/ops_extra.go | 192 ++++++++++++++++++++++++++++++----- 3 files changed, 216 insertions(+), 39 deletions(-) diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 893252b7e..d425f72e9 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -592,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) { padBottom := blockRows*scaleShape[0] - rows padSide := blockCols*scaleShape[1] - cols if padBottom > 0 || padSide > 0 { - decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)}) + decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide}) } decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols)) diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index ce0e48eda..50801352f 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -4,16 +4,57 @@ package mlx import "C" import "math" -func GELUApprox(t *Array) *Array { - return t.Multiply( - FromValue[float32](0.5), - ).Multiply( - t.Add( - t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)), - ).Multiply( - FromValue(float32(math.Sqrt(2 / math.Pi))), - ).Tanh().Add(FromValue[float32](1.0)), - ).AsType(t.DType()) +var geluCoeff = float32(math.Sqrt(2 / math.Pi)) + +// GELUApprox matches mlx.nn.gelu_approx: +// +// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +func GELUApprox(x *Array) *Array { + // Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs. + half := scalarWithDtype(0.5, x) + defer C.mlx_array_free(half) + coeff := scalarWithDtype(geluCoeff, x) + defer C.mlx_array_free(coeff) + c := scalarWithDtype(0.044715, x) + defer C.mlx_array_free(c) + + // x^3 via x*x*x (avoids general Power which is slower) + x3 := New("GELU_X3") + C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx) + tmp := New("GELU_X3b") + C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx) + x3 = tmp + + // 0.044715 * x^3 + cx3 := New("GELU_CX3") + C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx) + + // x + 0.044715 * x^3 + inner := New("GELU_INNER") + C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx) + + // sqrt(2/pi) * (x + 0.044715 * x^3) + scaled := New("GELU_SCALED") + C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx) + + // tanh(...) + th := New("GELU_TANH") + C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx) + + // 1 + tanh(...) + one := scalarWithDtype(1.0, x) + defer C.mlx_array_free(one) + onePlusTanh := New("GELU_1PT") + C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx) + + // 0.5 * x + halfX := New("GELU_HALFX") + C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx) + + // 0.5 * x * (1 + tanh(...)) + out := New("GELU_APPROX") + C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx) + return out } func SILU(t *Array) *Array { diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 9de0037da..a898a89a4 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array { return out } -func Pad(a *Array, paddings []int32) *Array { - numAxes := len(paddings) / 2 - axes := make([]C.int, numAxes) - lowPad := make([]C.int, numAxes) - highPad := make([]C.int, numAxes) - for i := range numAxes { - axes[i] = C.int(i) - lowPad[i] = C.int(paddings[i*2]) - highPad[i] = C.int(paddings[i*2+1]) +// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in]. +// MLX uses NHWC layout. +func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array { + out := New("CONV2D") + C.mlx_conv2d( + &out.ctx, + x.ctx, + weight.ctx, + C.int(strideH), C.int(strideW), + C.int(padH), C.int(padW), + C.int(dilationH), C.int(dilationW), + C.int(groups), + DefaultStream().ctx, + ) + return out +} + +// Pad pads array a along the given axes with specified low/high pad sizes. +// mode should be "constant", "edge", or "reflect". +func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array { + cAxes := make([]C.int, len(axes)) + cLow := make([]C.int, len(lowPad)) + cHigh := make([]C.int, len(highPad)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + cLow[i] = C.int(lowPad[i]) + cHigh[i] = C.int(highPad[i]) } - - padValue := C.mlx_array_new_float(C.float(0)) - defer C.mlx_array_free(padValue) - - cMode := C.CString("constant") + cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) - out := New("PAD") C.mlx_pad( &out.ctx, a.ctx, - unsafe.SliceData(axes), - C.size_t(len(axes)), - unsafe.SliceData(lowPad), - C.size_t(len(lowPad)), - unsafe.SliceData(highPad), - C.size_t(len(highPad)), - padValue, + unsafe.SliceData(cAxes), C.size_t(len(cAxes)), + unsafe.SliceData(cLow), C.size_t(len(cLow)), + unsafe.SliceData(cHigh), C.size_t(len(cHigh)), + padValue.ctx, cMode, DefaultStream().ctx, ) return out } +// PadConstant pads with zeros along the given axes. +func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array { + zero := NewScalarArray(float32(0)) + return Pad(a, axes, lowPad, highPad, zero, "constant") +} + func DepthwiseConv1d(x, weight *Array, bias *Array) *Array { groups := int32(x.Dim(x.NumDims() - 1)) return Conv1d(x, weight, bias, 1, 0, 1, groups) } +// Maximum returns element-wise maximum of two arrays. +func Maximum(a, b *Array) *Array { + out := New("MAXIMUM") + C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Minimum returns element-wise minimum of two arrays. +func Minimum(a, b *Array) *Array { + out := New("MINIMUM") + C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability. +func Softplus(a *Array) *Array { + return Logaddexp(a, Zeros(a.DType(), a.Dims()...)) +} + +// ReLU computes max(0, x). +func ReLU(a *Array) *Array { + return Maximum(a, NewScalarArray(float32(0))) +} + +// GLU applies Gated Linear Unit: splits x along last dim into two halves, +// returns first * sigmoid(second). +func GLU(a *Array) *Array { + lastDim := a.NumDims() - 1 + halfSize := a.Dim(lastDim) / 2 + first := SliceStartStop(a, + make([]int32, lastDim+1), // all zeros for start + appendDims(a, lastDim, int32(halfSize)), + ) + second := SliceStartStop(a, + appendDimsStart(a, lastDim, int32(halfSize)), + appendDims(a, lastDim, int32(a.Dim(lastDim))), + ) + return first.Multiply(second.Sigmoid()) +} + +// helper: builds stop array for SliceStartStop where the target axis = val +func appendDims(a *Array, targetAxis int, val int32) []int32 { + n := a.NumDims() + out := make([]int32, n) + for i := range n { + if i == targetAxis { + out[i] = val + } else { + out[i] = int32(a.Dim(i)) + } + } + return out +} + +// helper: builds start array for SliceStartStop where the target axis = val +func appendDimsStart(a *Array, targetAxis int, val int32) []int32 { + n := a.NumDims() + out := make([]int32, n) + for i := range n { + if i == targetAxis { + out[i] = val + } + } + return out +} + +// Clamp clamps array values to [min, max]. +func Clamp(a *Array, minVal, maxVal float32) *Array { + return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal)) +} + // Convenience wrappers (function-style for the model code) func Stack(arrays []*Array, axis int) *Array { @@ -323,20 +410,37 @@ func SiLU(a *Array) *Array { } func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { - freqs := New("") + return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil) +} + +// RoPEWithFreqs applies RoPE with optional custom frequencies. +// When freqs is non-nil, it is used instead of computing from base. +// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass +// the actual frequencies (base^(2i/dim)), not the inverse frequencies. +func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array { + var freqsCtx C.mlx_array + var optBase C.mlx_optional_float + if freqs != nil { + freqsCtx = freqs.ctx + optBase = C.mlx_optional_float{has_value: C.bool(false)} + } else { + empty := New("") + freqsCtx = empty.ctx + optBase = C.mlx_optional_float{ + value: C.float(base), + has_value: C.bool(func() bool { return base != 0 }()), + } + } out := New("FAST_ROPE") C.mlx_fast_rope( &out.ctx, x.ctx, C.int(dims), C.bool(traditional), - C.mlx_optional_float{ - value: C.float(base), - has_value: C.bool(func() bool { return base != 0 }()), - }, + optBase, C.float(scale), C.int(offset), - freqs.ctx, + freqsCtx, DefaultStream().ctx, ) return out @@ -358,6 +462,24 @@ func Log(a *Array) *Array { return out } +func Sin(a *Array) *Array { + out := New("SIN") + C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func Cos(a *Array) *Array { + out := New("COS") + C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func Clip(a, aMin, aMax *Array) *Array { + out := New("CLIP") + C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx) + return out +} + func Logaddexp(a, b *Array) *Array { out := New("LOGADDEXP") C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) @@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b return out } +// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit +// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores +// before softmax. Pass mode="array" so MLX actually consults mask_arr; the +// empty string is "no mask" and silently ignores the array argument. +func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array { + sinks := New("") + cMode := C.CString("array") + defer C.free(unsafe.Pointer(cMode)) + + out := New("FAST_SDPA") + C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) + return out +} + func LayerNormFn(x, weight, bias *Array, eps float32) *Array { out := New("FAST_LAYERNORM") var w, b C.mlx_array