mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)
* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum, Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip, ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor RoPEWithBase to delegate to RoPEWithFreqs. * review comments * mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
This commit is contained in:
@@ -592,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
|
||||
@@ -4,16 +4,57 @@ package mlx
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
t.Add(
|
||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(t.DType())
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
// GELUApprox matches mlx.nn.gelu_approx:
|
||||
//
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func GELUApprox(x *Array) *Array {
|
||||
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
||||
half := scalarWithDtype(0.5, x)
|
||||
defer C.mlx_array_free(half)
|
||||
coeff := scalarWithDtype(geluCoeff, x)
|
||||
defer C.mlx_array_free(coeff)
|
||||
c := scalarWithDtype(0.044715, x)
|
||||
defer C.mlx_array_free(c)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower)
|
||||
x3 := New("GELU_X3")
|
||||
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
||||
tmp := New("GELU_X3b")
|
||||
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
||||
x3 = tmp
|
||||
|
||||
// 0.044715 * x^3
|
||||
cx3 := New("GELU_CX3")
|
||||
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
||||
|
||||
// x + 0.044715 * x^3
|
||||
inner := New("GELU_INNER")
|
||||
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
||||
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := New("GELU_SCALED")
|
||||
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
||||
|
||||
// tanh(...)
|
||||
th := New("GELU_TANH")
|
||||
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
||||
|
||||
// 1 + tanh(...)
|
||||
one := scalarWithDtype(1.0, x)
|
||||
defer C.mlx_array_free(one)
|
||||
onePlusTanh := New("GELU_1PT")
|
||||
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x
|
||||
halfX := New("GELU_HALFX")
|
||||
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
out := New("GELU_APPROX")
|
||||
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
|
||||
@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Pad(a *Array, paddings []int32) *Array {
|
||||
numAxes := len(paddings) / 2
|
||||
axes := make([]C.int, numAxes)
|
||||
lowPad := make([]C.int, numAxes)
|
||||
highPad := make([]C.int, numAxes)
|
||||
for i := range numAxes {
|
||||
axes[i] = C.int(i)
|
||||
lowPad[i] = C.int(paddings[i*2])
|
||||
highPad[i] = C.int(paddings[i*2+1])
|
||||
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||
// MLX uses NHWC layout.
|
||||
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||
out := New("CONV2D")
|
||||
C.mlx_conv2d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(strideH), C.int(strideW),
|
||||
C.int(padH), C.int(padW),
|
||||
C.int(dilationH), C.int(dilationW),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||
// mode should be "constant", "edge", or "reflect".
|
||||
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
cLow := make([]C.int, len(lowPad))
|
||||
cHigh := make([]C.int, len(highPad))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
cLow[i] = C.int(lowPad[i])
|
||||
cHigh[i] = C.int(highPad[i])
|
||||
}
|
||||
|
||||
padValue := C.mlx_array_new_float(C.float(0))
|
||||
defer C.mlx_array_free(padValue)
|
||||
|
||||
cMode := C.CString("constant")
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(axes),
|
||||
C.size_t(len(axes)),
|
||||
unsafe.SliceData(lowPad),
|
||||
C.size_t(len(lowPad)),
|
||||
unsafe.SliceData(highPad),
|
||||
C.size_t(len(highPad)),
|
||||
padValue,
|
||||
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||
padValue.ctx,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// PadConstant pads with zeros along the given axes.
|
||||
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||
zero := NewScalarArray(float32(0))
|
||||
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Maximum returns element-wise maximum of two arrays.
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAXIMUM")
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise minimum of two arrays.
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MINIMUM")
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||
func Softplus(a *Array) *Array {
|
||||
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||
}
|
||||
|
||||
// ReLU computes max(0, x).
|
||||
func ReLU(a *Array) *Array {
|
||||
return Maximum(a, NewScalarArray(float32(0)))
|
||||
}
|
||||
|
||||
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||
// returns first * sigmoid(second).
|
||||
func GLU(a *Array) *Array {
|
||||
lastDim := a.NumDims() - 1
|
||||
halfSize := a.Dim(lastDim) / 2
|
||||
first := SliceStartStop(a,
|
||||
make([]int32, lastDim+1), // all zeros for start
|
||||
appendDims(a, lastDim, int32(halfSize)),
|
||||
)
|
||||
second := SliceStartStop(a,
|
||||
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||
)
|
||||
return first.Multiply(second.Sigmoid())
|
||||
}
|
||||
|
||||
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
} else {
|
||||
out[i] = int32(a.Dim(i))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// helper: builds start array for SliceStartStop where the target axis = val
|
||||
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Clamp clamps array values to [min, max].
|
||||
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -323,20 +410,37 @@ func SiLU(a *Array) *Array {
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||
}
|
||||
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
freqsCtx = freqs.ctx
|
||||
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||
} else {
|
||||
empty := New("")
|
||||
freqsCtx = empty.ctx
|
||||
optBase = C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
optBase,
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
@@ -358,6 +462,24 @@ func Log(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Clip(a, aMin, aMax *Array) *Array {
|
||||
out := New("CLIP")
|
||||
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
@@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
|
||||
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
|
||||
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
|
||||
// empty string is "no mask" and silently ignores the array argument.
|
||||
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString("array")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
|
||||
Reference in New Issue
Block a user