mlx: Improve gemma4 performance with fused operations (#15587)

* mlx: Improve gemma4 performance with fused operations

* review comments
This commit is contained in:
Daniel Hiltgen
2026-04-14 18:04:04 -07:00
committed by GitHub
parent e1e3cec8d0
commit 48ad7085c4
2 changed files with 29 additions and 15 deletions

View File

@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
},
Shapeless(),
)
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
// geglu, used by Gemma-family MLP and MoE paths.
var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
// mlx_lm's logit_softcap. cap must have the same dtype as x.
var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)

View File

@@ -80,7 +80,6 @@ type TextConfig struct {
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
// KV sharing: maps shared layer index -> donor layer index.
KVShareMap map[int32]int32 `json:"-"`
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
}
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
if cfg.FinalLogitSoftcapping > 0 {
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
}
// Compute KV sharing map.
cfg.KVShareMap = make(map[int32]int32)
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
logits := m.LMHead.Forward(x)
if m.FinalLogitSoftcapping > 0 {
logits = mlx.MulScalar(logits, m.SoftcapInv)
logits = logits.Tanh()
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
logits = mlx.LogitSoftcap(logits, cap)
}
return logits
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
// PLE injection (after MLP residual).
if l.PLE != nil && pleInput != nil {
residual := h
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
gated := mlx.Mul(gate, pleInput)
gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
projected := l.PLE.Projection.Forward(gated)
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
h = mlx.Add(residual, projected)
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x))
gate := m.GateProj.Forward(x)
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
return m.DownProj.Forward(mlx.GeGLU(gate, up))
}
// Forward runs the router to select top-k experts per token.
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
hidden = mlx.GeGLU(gate, up)
} else {
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
hidden = mlx.GeGLU(gate, up)
}
downMode := m.DownQuantMode
if downMode == "" {
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
hidden = mlx.GeGLU(gate, up)
} else {
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
hidden = mlx.GeGLU(gate, up)
}
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
}