mlx: fix gemma4 cache to use logical view (#15617)

This commit is contained in:
Daniel Hiltgen
2026-04-16 11:54:30 -07:00
committed by GitHub
parent 031baef094
commit b9cb535407

View File

@@ -1061,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
}
}
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
var donorKV *sharedKVEntry
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
// If this layer is a donor, store its cached KV for later shared layers.
if layer.IsDonor && c != nil {
state := c.State()
if len(state) >= 2 && state[0] != nil && state[1] != nil {
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
}
if layer.IsDonor && donorKV != nil {
sharedKV[layer.LayerIdx] = *donorKV
}
}
@@ -1190,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
return mlx.Squeeze(sliced, 2)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
@@ -1237,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
h = mlx.Mul(h, l.LayerScalar)
}
return h
return h, donorKV
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
// Determine head dim and scale based on layer type.
headDim := cfg.HeadDim
scale := cfg.SlidingScale
@@ -1274,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
var k, v *mlx.Array
var donorKV *sharedKVEntry
if donorEntry != nil {
// Shared layer: use donor's cached K/V.
@@ -1312,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// Update cache.
if c != nil {
k, v = c.Update(k, v)
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
}
}
@@ -1365,7 +1365,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// strided views differently. Metal handles them natively.
out = mlx.Contiguous(out, false)
}
return a.OProj.Forward(out)
return a.OProj.Forward(out), donorKV
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {