From e1e3cec8d019b468aae7f5b9c5b3258821cf7cfa Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 13 Apr 2026 12:20:39 -0700 Subject: [PATCH] models: fuse MLP activation functions via mlx_compile Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU, matching upstream mlx/mlx_lm's activations pattern. Routes llama, qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA kernel rather than a chain of per-op launches. --- x/mlxrunner/mlx/act.go | 88 ++++++++++--------------- x/mlxrunner/mlx/ops_extra.go | 5 -- x/models/glm4_moe_lite/glm4_moe_lite.go | 12 ++-- x/models/llama/llama.go | 2 +- x/models/qwen3/qwen3.go | 2 +- x/models/qwen3_5/qwen3_5.go | 6 +- 6 files changed, 44 insertions(+), 71 deletions(-) diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 50801352f..659b797c4 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -1,62 +1,44 @@ package mlx -// #include "generated.h" -import "C" import "math" var geluCoeff = float32(math.Sqrt(2 / math.Pi)) -// GELUApprox matches mlx.nn.gelu_approx: -// -// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -func GELUApprox(x *Array) *Array { - // Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs. - half := scalarWithDtype(0.5, x) - defer C.mlx_array_free(half) - coeff := scalarWithDtype(geluCoeff, x) - defer C.mlx_array_free(coeff) - c := scalarWithDtype(0.044715, x) - defer C.mlx_array_free(c) +// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +// as a fused kernel. +var GELUApprox = Compile1( + "GELUApprox", + func(x *Array) *Array { + // Dtype-matched scalars avoid implicit upcasts on bf16 inputs. + dt := x.DType() + half := FromValue[float32](0.5).AsType(dt) + coeff := FromValue(geluCoeff).AsType(dt) + c := FromValue[float32](0.044715).AsType(dt) + one := FromValue[float32](1.0).AsType(dt) - // x^3 via x*x*x (avoids general Power which is slower) - x3 := New("GELU_X3") - C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx) - tmp := New("GELU_X3b") - C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx) - x3 = tmp + // x^3 via x*x*x (avoids general Power which is slower). + x3 := x.Multiply(x).Multiply(x) + inner := x.Add(c.Multiply(x3)) + tanh := coeff.Multiply(inner).Tanh() + return half.Multiply(x).Multiply(one.Add(tanh)) + }, + Shapeless(), +) - // 0.044715 * x^3 - cx3 := New("GELU_CX3") - C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx) +// SiLU returns a * sigmoid(a) as a fused kernel. +var SiLU = Compile1( + "SiLU", + func(a *Array) *Array { + return a.Multiply(a.Sigmoid()) + }, + Shapeless(), +) - // x + 0.044715 * x^3 - inner := New("GELU_INNER") - C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx) - - // sqrt(2/pi) * (x + 0.044715 * x^3) - scaled := New("GELU_SCALED") - C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx) - - // tanh(...) - th := New("GELU_TANH") - C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx) - - // 1 + tanh(...) - one := scalarWithDtype(1.0, x) - defer C.mlx_array_free(one) - onePlusTanh := New("GELU_1PT") - C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx) - - // 0.5 * x - halfX := New("GELU_HALFX") - C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx) - - // 0.5 * x * (1 + tanh(...)) - out := New("GELU_APPROX") - C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx) - return out -} - -func SILU(t *Array) *Array { - return t.Multiply(t.Sigmoid()).AsType(t.DType()) -} +// SwiGLU returns silu(gate) * up as a fused kernel. +var SwiGLU = Compile2( + "SwiGLU", + func(gate, up *Array) *Array { + return SiLU(gate).Multiply(up) + }, + Shapeless(), +) diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index a898a89a4..409f71263 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices) } -func SiLU(a *Array) *Array { - sig := a.Sigmoid() - return a.Multiply(sig) -} - func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil) } diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index a0a37a3f0..8b40c9348 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -148,9 +148,7 @@ type DenseMLP struct { // Forward applies the SwiGLU MLP func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(m.GateProj.Forward(x)) - up := m.UpProj.Forward(x) - return m.DownProj.Forward(mlx.Mul(gate, up)) + return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) } // MoEGate implements the expert gating mechanism @@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx. up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) - hidden = mlx.Mul(mlx.SiLU(gate), up) + hidden = mlx.SwiGLU(gate, up) down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) @@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx. gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) - hidden = mlx.Mul(mlx.SiLU(gate), up) + hidden = mlx.SwiGLU(gate, up) down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) } @@ -273,9 +271,7 @@ type SharedExperts struct { // Forward applies the shared expert MLP func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(s.GateProj.Forward(x)) - up := s.UpProj.Forward(x) - return s.DownProj.Forward(mlx.Mul(gate, up)) + return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x))) } // MoE implements the full Mixture of Experts layer diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index ca99d9148..4f4da05a7 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config } func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) + return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) } diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 6773f0eb5..a1b31af0d 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config } func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) + return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) } diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index 5dbb59dce..f29563f88 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co } func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array { - return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) + return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) } func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { @@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx. nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) - hidden = mlx.Mul(mlx.SiLU(gate), up) + hidden = mlx.SwiGLU(gate, up) down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) } else { gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort) up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort) - hidden = mlx.Mul(mlx.SiLU(gate), up) + hidden = mlx.SwiGLU(gate, up) down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort) }