mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: Improve gemma4 performance with fused operations (#15587)
* mlx: Improve gemma4 performance with fused operations * review comments
This commit is contained in:
@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
|
|||||||
},
|
},
|
||||||
Shapeless(),
|
Shapeless(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||||
|
// geglu, used by Gemma-family MLP and MoE paths.
|
||||||
|
var GeGLU = Compile2(
|
||||||
|
"GeGLU",
|
||||||
|
func(gate, up *Array) *Array {
|
||||||
|
return GELUApprox(gate).Multiply(up)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||||
|
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||||
|
var LogitSoftcap = Compile2(
|
||||||
|
"LogitSoftcap",
|
||||||
|
func(x, cap *Array) *Array {
|
||||||
|
return x.Divide(cap).Tanh().Multiply(cap)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ type TextConfig struct {
|
|||||||
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
||||||
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
|
|
||||||
|
|
||||||
// KV sharing: maps shared layer index -> donor layer index.
|
// KV sharing: maps shared layer index -> donor layer index.
|
||||||
KVShareMap map[int32]int32 `json:"-"`
|
KVShareMap map[int32]int32 `json:"-"`
|
||||||
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
|
|||||||
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
||||||
}
|
}
|
||||||
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
||||||
if cfg.FinalLogitSoftcapping > 0 {
|
|
||||||
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute KV sharing map.
|
// Compute KV sharing map.
|
||||||
cfg.KVShareMap = make(map[int32]int32)
|
cfg.KVShareMap = make(map[int32]int32)
|
||||||
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
logits := m.LMHead.Forward(x)
|
logits := m.LMHead.Forward(x)
|
||||||
|
|
||||||
if m.FinalLogitSoftcapping > 0 {
|
if m.FinalLogitSoftcapping > 0 {
|
||||||
logits = mlx.MulScalar(logits, m.SoftcapInv)
|
cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
|
||||||
logits = logits.Tanh()
|
logits = mlx.LogitSoftcap(logits, cap)
|
||||||
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||||||
// PLE injection (after MLP residual).
|
// PLE injection (after MLP residual).
|
||||||
if l.PLE != nil && pleInput != nil {
|
if l.PLE != nil && pleInput != nil {
|
||||||
residual := h
|
residual := h
|
||||||
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
|
gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
|
||||||
gated := mlx.Mul(gate, pleInput)
|
|
||||||
projected := l.PLE.Projection.Forward(gated)
|
projected := l.PLE.Projection.Forward(gated)
|
||||||
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
||||||
h = mlx.Add(residual, projected)
|
h = mlx.Add(residual, projected)
|
||||||
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
gate := m.GateProj.Forward(x)
|
||||||
up := m.UpProj.Forward(x)
|
up := m.UpProj.Forward(x)
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
return m.DownProj.Forward(mlx.GeGLU(gate, up))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward runs the router to select top-k experts per token.
|
// Forward runs the router to select top-k experts per token.
|
||||||
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
||||||
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
||||||
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
||||||
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
downMode := m.DownQuantMode
|
downMode := m.DownQuantMode
|
||||||
if downMode == "" {
|
if downMode == "" {
|
||||||
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
||||||
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user