mlx: Improve gemma4 performance with fused operations (#15587)

* mlx: Improve gemma4 performance with fused operations

* review comments
This commit is contained in:
Daniel Hiltgen
2026-04-14 18:04:04 -07:00
committed by GitHub
parent e1e3cec8d0
commit 48ad7085c4
2 changed files with 29 additions and 15 deletions

View File

@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
}, },
Shapeless(), Shapeless(),
) )
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
// geglu, used by Gemma-family MLP and MoE paths.
var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
// mlx_lm's logit_softcap. cap must have the same dtype as x.
var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)

View File

@@ -80,7 +80,6 @@ type TextConfig struct {
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size) PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071... PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size) RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
// KV sharing: maps shared layer index -> donor layer index. // KV sharing: maps shared layer index -> donor layer index.
KVShareMap map[int32]int32 `json:"-"` KVShareMap map[int32]int32 `json:"-"`
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5)) cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
} }
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize))) cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
if cfg.FinalLogitSoftcapping > 0 {
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
}
// Compute KV sharing map. // Compute KV sharing map.
cfg.KVShareMap = make(map[int32]int32) cfg.KVShareMap = make(map[int32]int32)
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
logits := m.LMHead.Forward(x) logits := m.LMHead.Forward(x)
if m.FinalLogitSoftcapping > 0 { if m.FinalLogitSoftcapping > 0 {
logits = mlx.MulScalar(logits, m.SoftcapInv) cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
logits = logits.Tanh() logits = mlx.LogitSoftcap(logits, cap)
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
} }
return logits return logits
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
// PLE injection (after MLP residual). // PLE injection (after MLP residual).
if l.PLE != nil && pleInput != nil { if l.PLE != nil && pleInput != nil {
residual := h residual := h
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h)) gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
gated := mlx.Mul(gate, pleInput)
projected := l.PLE.Projection.Forward(gated) projected := l.PLE.Projection.Forward(gated)
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps) projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
h = mlx.Add(residual, projected) h = mlx.Add(residual, projected)
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
} }
func (m *MLP) Forward(x *mlx.Array) *mlx.Array { func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x)) gate := m.GateProj.Forward(x)
up := m.UpProj.Forward(x) up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up)) return m.DownProj.Forward(mlx.GeGLU(gate, up))
} }
// Forward runs the router to select top-k experts per token. // Forward runs the router to select top-k experts per token.
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp, up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid}, []int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} else { } else {
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases, gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort) nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases, up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort) nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} }
downMode := m.DownQuantMode downMode := m.DownQuantMode
if downMode == "" { if downMode == "" {
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp, up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid}, []int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} else { } else {
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort) gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort) up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} }
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort) down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
} }