mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: Improve gemma4 performance with fused operations (#15587)
* mlx: Improve gemma4 performance with fused operations * review comments
This commit is contained in:
@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||
// geglu, used by Gemma-family MLP and MoE paths.
|
||||
var GeGLU = Compile2(
|
||||
"GeGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return GELUApprox(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||
var LogitSoftcap = Compile2(
|
||||
"LogitSoftcap",
|
||||
func(x, cap *Array) *Array {
|
||||
return x.Divide(cap).Tanh().Multiply(cap)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
@@ -80,7 +80,6 @@ type TextConfig struct {
|
||||
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
||||
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
|
||||
|
||||
// KV sharing: maps shared layer index -> donor layer index.
|
||||
KVShareMap map[int32]int32 `json:"-"`
|
||||
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
|
||||
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
||||
}
|
||||
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
||||
if cfg.FinalLogitSoftcapping > 0 {
|
||||
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
|
||||
}
|
||||
|
||||
// Compute KV sharing map.
|
||||
cfg.KVShareMap = make(map[int32]int32)
|
||||
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
logits := m.LMHead.Forward(x)
|
||||
|
||||
if m.FinalLogitSoftcapping > 0 {
|
||||
logits = mlx.MulScalar(logits, m.SoftcapInv)
|
||||
logits = logits.Tanh()
|
||||
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
|
||||
cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
|
||||
logits = mlx.LogitSoftcap(logits, cap)
|
||||
}
|
||||
|
||||
return logits
|
||||
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
||||
// PLE injection (after MLP residual).
|
||||
if l.PLE != nil && pleInput != nil {
|
||||
residual := h
|
||||
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
|
||||
gated := mlx.Mul(gate, pleInput)
|
||||
gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
|
||||
projected := l.PLE.Projection.Forward(gated)
|
||||
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
||||
h = mlx.Add(residual, projected)
|
||||
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
||||
gate := m.GateProj.Forward(x)
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
return m.DownProj.Forward(mlx.GeGLU(gate, up))
|
||||
}
|
||||
|
||||
// Forward runs the router to select top-k experts per token.
|
||||
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
||||
up := mlx.SliceStartStop(gateUp,
|
||||
[]int32{0, 0, 0, mid},
|
||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
||||
hidden = mlx.GeGLU(gate, up)
|
||||
} else {
|
||||
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
||||
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
||||
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
||||
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
||||
hidden = mlx.GeGLU(gate, up)
|
||||
}
|
||||
downMode := m.DownQuantMode
|
||||
if downMode == "" {
|
||||
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
||||
up := mlx.SliceStartStop(gateUp,
|
||||
[]int32{0, 0, 0, mid},
|
||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
||||
hidden = mlx.GeGLU(gate, up)
|
||||
} else {
|
||||
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
||||
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
||||
hidden = mlx.GeGLU(gate, up)
|
||||
}
|
||||
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user