mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)

* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA

Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum,
Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip,
ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor
RoPEWithBase to delegate to RoPEWithFreqs.

* review comments

* mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
This commit is contained in:
Daniel Hiltgen
2026-04-13 11:43:24 -07:00
committed by GitHub
parent d3da29cbfc
commit c88fb286ec
3 changed files with 216 additions and 39 deletions

View File

@@ -592,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))

View File

@@ -4,16 +4,57 @@ package mlx
import "C"
import "math"
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
t.Add(
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(t.DType())
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox matches mlx.nn.gelu_approx:
//
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func GELUApprox(x *Array) *Array {
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
half := scalarWithDtype(0.5, x)
defer C.mlx_array_free(half)
coeff := scalarWithDtype(geluCoeff, x)
defer C.mlx_array_free(coeff)
c := scalarWithDtype(0.044715, x)
defer C.mlx_array_free(c)
// x^3 via x*x*x (avoids general Power which is slower)
x3 := New("GELU_X3")
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
tmp := New("GELU_X3b")
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
x3 = tmp
// 0.044715 * x^3
cx3 := New("GELU_CX3")
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
// x + 0.044715 * x^3
inner := New("GELU_INNER")
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := New("GELU_SCALED")
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
// tanh(...)
th := New("GELU_TANH")
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
// 1 + tanh(...)
one := scalarWithDtype(1.0, x)
defer C.mlx_array_free(one)
onePlusTanh := New("GELU_1PT")
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
// 0.5 * x
halfX := New("GELU_HALFX")
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
// 0.5 * x * (1 + tanh(...))
out := New("GELU_APPROX")
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
return out
}
func SILU(t *Array) *Array {

View File

@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
return out
}
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
axes := make([]C.int, numAxes)
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := range numAxes {
axes[i] = C.int(i)
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
// MLX uses NHWC layout.
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
out := New("CONV2D")
C.mlx_conv2d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(strideH), C.int(strideW),
C.int(padH), C.int(padW),
C.int(dilationH), C.int(dilationW),
C.int(groups),
DefaultStream().ctx,
)
return out
}
// Pad pads array a along the given axes with specified low/high pad sizes.
// mode should be "constant", "edge", or "reflect".
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
cAxes := make([]C.int, len(axes))
cLow := make([]C.int, len(lowPad))
cHigh := make([]C.int, len(highPad))
for i := range axes {
cAxes[i] = C.int(axes[i])
cLow[i] = C.int(lowPad[i])
cHigh[i] = C.int(highPad[i])
}
padValue := C.mlx_array_new_float(C.float(0))
defer C.mlx_array_free(padValue)
cMode := C.CString("constant")
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(axes),
C.size_t(len(axes)),
unsafe.SliceData(lowPad),
C.size_t(len(lowPad)),
unsafe.SliceData(highPad),
C.size_t(len(highPad)),
padValue,
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
unsafe.SliceData(cLow), C.size_t(len(cLow)),
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
padValue.ctx,
cMode,
DefaultStream().ctx,
)
return out
}
// PadConstant pads with zeros along the given axes.
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
zero := NewScalarArray(float32(0))
return Pad(a, axes, lowPad, highPad, zero, "constant")
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Maximum returns element-wise maximum of two arrays.
func Maximum(a, b *Array) *Array {
out := New("MAXIMUM")
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise minimum of two arrays.
func Minimum(a, b *Array) *Array {
out := New("MINIMUM")
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
func Softplus(a *Array) *Array {
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
}
// ReLU computes max(0, x).
func ReLU(a *Array) *Array {
return Maximum(a, NewScalarArray(float32(0)))
}
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
// returns first * sigmoid(second).
func GLU(a *Array) *Array {
lastDim := a.NumDims() - 1
halfSize := a.Dim(lastDim) / 2
first := SliceStartStop(a,
make([]int32, lastDim+1), // all zeros for start
appendDims(a, lastDim, int32(halfSize)),
)
second := SliceStartStop(a,
appendDimsStart(a, lastDim, int32(halfSize)),
appendDims(a, lastDim, int32(a.Dim(lastDim))),
)
return first.Multiply(second.Sigmoid())
}
// helper: builds stop array for SliceStartStop where the target axis = val
func appendDims(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
} else {
out[i] = int32(a.Dim(i))
}
}
return out
}
// helper: builds start array for SliceStartStop where the target axis = val
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
}
}
return out
}
// Clamp clamps array values to [min, max].
func Clamp(a *Array, minVal, maxVal float32) *Array {
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -323,20 +410,37 @@ func SiLU(a *Array) *Array {
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
}
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
freqsCtx = freqs.ctx
optBase = C.mlx_optional_float{has_value: C.bool(false)}
} else {
empty := New("")
freqsCtx = empty.ctx
optBase = C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
optBase,
C.float(scale),
C.int(offset),
freqs.ctx,
freqsCtx,
DefaultStream().ctx,
)
return out
@@ -358,6 +462,24 @@ func Log(a *Array) *Array {
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Clip(a, aMin, aMax *Array) *Array {
out := New("CLIP")
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
@@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
return out
}
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
// empty string is "no mask" and silently ignores the array argument.
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
sinks := New("")
cMode := C.CString("array")
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array