models: fuse MLP activation functions via mlx_compile

Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU,
matching upstream mlx/mlx_lm's activations pattern. Routes llama,
qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through
mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA
kernel rather than a chain of per-op launches.
This commit is contained in:
Jesse Gross
2026-04-13 12:20:39 -07:00
parent d3e67e305c
commit e1e3cec8d0
6 changed files with 44 additions and 71 deletions

View File

@@ -1,62 +1,44 @@
package mlx
// #include "generated.h"
import "C"
import "math"
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox matches mlx.nn.gelu_approx:
//
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func GELUApprox(x *Array) *Array {
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
half := scalarWithDtype(0.5, x)
defer C.mlx_array_free(half)
coeff := scalarWithDtype(geluCoeff, x)
defer C.mlx_array_free(coeff)
c := scalarWithDtype(0.044715, x)
defer C.mlx_array_free(c)
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// as a fused kernel.
var GELUApprox = Compile1(
"GELUApprox",
func(x *Array) *Array {
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
dt := x.DType()
half := FromValue[float32](0.5).AsType(dt)
coeff := FromValue(geluCoeff).AsType(dt)
c := FromValue[float32](0.044715).AsType(dt)
one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower)
x3 := New("GELU_X3")
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
tmp := New("GELU_X3b")
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
x3 = tmp
// x^3 via x*x*x (avoids general Power which is slower).
x3 := x.Multiply(x).Multiply(x)
inner := x.Add(c.Multiply(x3))
tanh := coeff.Multiply(inner).Tanh()
return half.Multiply(x).Multiply(one.Add(tanh))
},
Shapeless(),
)
// 0.044715 * x^3
cx3 := New("GELU_CX3")
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
// SiLU returns a * sigmoid(a) as a fused kernel.
var SiLU = Compile1(
"SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// x + 0.044715 * x^3
inner := New("GELU_INNER")
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := New("GELU_SCALED")
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
// tanh(...)
th := New("GELU_TANH")
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
// 1 + tanh(...)
one := scalarWithDtype(1.0, x)
defer C.mlx_array_free(one)
onePlusTanh := New("GELU_1PT")
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
// 0.5 * x
halfX := New("GELU_HALFX")
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
// 0.5 * x * (1 + tanh(...))
out := New("GELU_APPROX")
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
return out
}
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
}
// SwiGLU returns silu(gate) * up as a fused kernel.
var SwiGLU = Compile2(
"SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)

View File

@@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
func SiLU(a *Array) *Array {
sig := a.Sigmoid()
return a.Multiply(sig)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
}

View File

@@ -148,9 +148,7 @@ type DenseMLP struct {
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}
// MoEGate implements the expert gating mechanism
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
@@ -273,9 +271,7 @@ type SharedExperts struct {
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
}
// MoE implements the full Mixture of Experts layer

View File

@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

View File

@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

View File

@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
}
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
}