Files
ollama/x/mlxrunner/mlx/act.go
Jesse Gross 27d7bd37a7 models: fuse MLP activation functions via mlx_compile
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU,
matching upstream mlx/mlx_lm's activations pattern. Routes llama,
qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through
mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA
kernel rather than a chain of per-op launches.
2026-04-14 09:40:21 -07:00

45 lines
1.1 KiB
Go

package mlx
import "math"
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// as a fused kernel. Matches mlx.nn.gelu_approx.
var GELUApprox = Compile1(
"GELUApprox",
func(x *Array) *Array {
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
dt := x.DType()
half := FromValue[float32](0.5).AsType(dt)
coeff := FromValue(geluCoeff).AsType(dt)
c := FromValue[float32](0.044715).AsType(dt)
one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower).
x3 := x.Multiply(x).Multiply(x)
inner := x.Add(c.Multiply(x3))
tanh := coeff.Multiply(inner).Tanh()
return half.Multiply(x).Multiply(one.Add(tanh))
},
Shapeless(),
)
// SiLU returns a * sigmoid(a) as a fused kernel. Matches mlx.nn.silu.
var SiLU = Compile1(
"SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// SwiGLU returns silu(gate) * up as a fused kernel. Matches mlx_lm's swiglu.
var SwiGLU = Compile2(
"SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)