mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 21:54:08 +02:00
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU, matching upstream mlx/mlx_lm's activations pattern. Routes llama, qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA kernel rather than a chain of per-op launches.
45 lines
1.1 KiB
Go
45 lines
1.1 KiB
Go
package mlx
|
|
|
|
import "math"
|
|
|
|
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
|
|
|
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
// as a fused kernel. Matches mlx.nn.gelu_approx.
|
|
var GELUApprox = Compile1(
|
|
"GELUApprox",
|
|
func(x *Array) *Array {
|
|
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
|
dt := x.DType()
|
|
half := FromValue[float32](0.5).AsType(dt)
|
|
coeff := FromValue(geluCoeff).AsType(dt)
|
|
c := FromValue[float32](0.044715).AsType(dt)
|
|
one := FromValue[float32](1.0).AsType(dt)
|
|
|
|
// x^3 via x*x*x (avoids general Power which is slower).
|
|
x3 := x.Multiply(x).Multiply(x)
|
|
inner := x.Add(c.Multiply(x3))
|
|
tanh := coeff.Multiply(inner).Tanh()
|
|
return half.Multiply(x).Multiply(one.Add(tanh))
|
|
},
|
|
Shapeless(),
|
|
)
|
|
|
|
// SiLU returns a * sigmoid(a) as a fused kernel. Matches mlx.nn.silu.
|
|
var SiLU = Compile1(
|
|
"SiLU",
|
|
func(a *Array) *Array {
|
|
return a.Multiply(a.Sigmoid())
|
|
},
|
|
Shapeless(),
|
|
)
|
|
|
|
// SwiGLU returns silu(gate) * up as a fused kernel. Matches mlx_lm's swiglu.
|
|
var SwiGLU = Compile2(
|
|
"SwiGLU",
|
|
func(gate, up *Array) *Array {
|
|
return SiLU(gate).Multiply(up)
|
|
},
|
|
Shapeless(),
|
|
)
|