diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 659b797c4..1563de60a 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -42,3 +42,23 @@ var SwiGLU = Compile2( }, Shapeless(), ) + +// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's +// geglu, used by Gemma-family MLP and MoE paths. +var GeGLU = Compile2( + "GeGLU", + func(gate, up *Array) *Array { + return GELUApprox(gate).Multiply(up) + }, + Shapeless(), +) + +// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches +// mlx_lm's logit_softcap. cap must have the same dtype as x. +var LogitSoftcap = Compile2( + "LogitSoftcap", + func(x, cap *Array) *Array { + return x.Divide(cap).Tanh().Multiply(cap) + }, + Shapeless(), +) diff --git a/x/models/gemma4/gemma4.go b/x/models/gemma4/gemma4.go index 90737f813..1800100bf 100644 --- a/x/models/gemma4/gemma4.go +++ b/x/models/gemma4/gemma4.go @@ -80,7 +80,6 @@ type TextConfig struct { PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size) PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071... RouterScale float32 `json:"-"` // 1/sqrt(hidden_size) - SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping // KV sharing: maps shared layer index -> donor layer index. KVShareMap map[int32]int32 `json:"-"` @@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) { cfg.PLECombineScale = float32(math.Pow(2.0, -0.5)) } cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize))) - if cfg.FinalLogitSoftcapping > 0 { - cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping - } // Compute KV sharing map. cfg.KVShareMap = make(map[int32]int32) @@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array { logits := m.LMHead.Forward(x) if m.FinalLogitSoftcapping > 0 { - logits = mlx.MulScalar(logits, m.SoftcapInv) - logits = logits.Tanh() - logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping) + cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType()) + logits = mlx.LogitSoftcap(logits, cap) } return logits @@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex // PLE injection (after MLP residual). if l.PLE != nil && pleInput != nil { residual := h - gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h)) - gated := mlx.Mul(gate, pleInput) + gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput) projected := l.PLE.Projection.Forward(gated) projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps) h = mlx.Add(residual, projected) @@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b } func (m *MLP) Forward(x *mlx.Array) *mlx.Array { - gate := mlx.GELUApprox(m.GateProj.Forward(x)) + gate := m.GateProj.Forward(x) up := m.UpProj.Forward(x) - return m.DownProj.Forward(mlx.Mul(gate, up)) + return m.DownProj.Forward(mlx.GeGLU(gate, up)) } // Forward runs the router to select top-k experts per token. @@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi up := mlx.SliceStartStop(gateUp, []int32{0, 0, 0, mid}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) - hidden = mlx.Mul(mlx.GELUApprox(gate), up) + hidden = mlx.GeGLU(gate, up) } else { gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases, nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort) up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases, nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort) - hidden = mlx.Mul(mlx.GELUApprox(gate), up) + hidden = mlx.GeGLU(gate, up) } downMode := m.DownQuantMode if downMode == "" { @@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi up := mlx.SliceStartStop(gateUp, []int32{0, 0, 0, mid}, []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) - hidden = mlx.Mul(mlx.GELUApprox(gate), up) + hidden = mlx.GeGLU(gate, up) } else { gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort) up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort) - hidden = mlx.Mul(mlx.GELUApprox(gate), up) + hidden = mlx.GeGLU(gate, up) } down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort) }