mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
models: fuse MLP activation functions via mlx_compile
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU, matching upstream mlx/mlx_lm's activations pattern. Routes llama, qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA kernel rather than a chain of per-op launches.
This commit is contained in:
@@ -1,62 +1,44 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
// GELUApprox matches mlx.nn.gelu_approx:
|
||||
//
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func GELUApprox(x *Array) *Array {
|
||||
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
||||
half := scalarWithDtype(0.5, x)
|
||||
defer C.mlx_array_free(half)
|
||||
coeff := scalarWithDtype(geluCoeff, x)
|
||||
defer C.mlx_array_free(coeff)
|
||||
c := scalarWithDtype(0.044715, x)
|
||||
defer C.mlx_array_free(c)
|
||||
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
// as a fused kernel. Matches mlx.nn.gelu_approx.
|
||||
var GELUApprox = Compile1(
|
||||
"GELUApprox",
|
||||
func(x *Array) *Array {
|
||||
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||
dt := x.DType()
|
||||
half := FromValue[float32](0.5).AsType(dt)
|
||||
coeff := FromValue(geluCoeff).AsType(dt)
|
||||
c := FromValue[float32](0.044715).AsType(dt)
|
||||
one := FromValue[float32](1.0).AsType(dt)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower)
|
||||
x3 := New("GELU_X3")
|
||||
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
||||
tmp := New("GELU_X3b")
|
||||
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
||||
x3 = tmp
|
||||
// x^3 via x*x*x (avoids general Power which is slower).
|
||||
x3 := x.Multiply(x).Multiply(x)
|
||||
inner := x.Add(c.Multiply(x3))
|
||||
tanh := coeff.Multiply(inner).Tanh()
|
||||
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// 0.044715 * x^3
|
||||
cx3 := New("GELU_CX3")
|
||||
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
||||
// SiLU returns a * sigmoid(a) as a fused kernel. Matches mlx.nn.silu.
|
||||
var SiLU = Compile1(
|
||||
"SiLU",
|
||||
func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// x + 0.044715 * x^3
|
||||
inner := New("GELU_INNER")
|
||||
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
||||
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := New("GELU_SCALED")
|
||||
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
||||
|
||||
// tanh(...)
|
||||
th := New("GELU_TANH")
|
||||
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
||||
|
||||
// 1 + tanh(...)
|
||||
one := scalarWithDtype(1.0, x)
|
||||
defer C.mlx_array_free(one)
|
||||
onePlusTanh := New("GELU_1PT")
|
||||
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x
|
||||
halfX := New("GELU_HALFX")
|
||||
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
out := New("GELU_APPROX")
|
||||
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
||||
}
|
||||
// SwiGLU returns silu(gate) * up as a fused kernel. Matches mlx_lm's swiglu.
|
||||
var SwiGLU = Compile2(
|
||||
"SwiGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return SiLU(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
@@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
|
||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
func SiLU(a *Array) *Array {
|
||||
sig := a.Sigmoid()
|
||||
return a.Multiply(sig)
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||
}
|
||||
|
||||
@@ -148,9 +148,7 @@ type DenseMLP struct {
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
@@ -273,9 +271,7 @@ type SharedExperts struct {
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
|
||||
@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
}
|
||||
|
||||
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
hidden = mlx.SwiGLU(gate, up)
|
||||
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user