mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: fix gemma4 cache to use logical view (#15617)
This commit is contained in:
@@ -1061,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
}
|
||||
}
|
||||
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
var donorKV *sharedKVEntry
|
||||
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
|
||||
// If this layer is a donor, store its cached KV for later shared layers.
|
||||
if layer.IsDonor && c != nil {
|
||||
state := c.State()
|
||||
if len(state) >= 2 && state[0] != nil && state[1] != nil {
|
||||
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
|
||||
}
|
||||
if layer.IsDonor && donorKV != nil {
|
||||
sharedKV[layer.LayerIdx] = *donorKV
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1190,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
|
||||
return mlx.Squeeze(sliced, 2)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
@@ -1237,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
||||
h = mlx.Mul(h, l.LayerScalar)
|
||||
}
|
||||
|
||||
return h
|
||||
return h, donorKV
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
// Determine head dim and scale based on layer type.
|
||||
headDim := cfg.HeadDim
|
||||
scale := cfg.SlidingScale
|
||||
@@ -1274,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||
|
||||
var k, v *mlx.Array
|
||||
var donorKV *sharedKVEntry
|
||||
|
||||
if donorEntry != nil {
|
||||
// Shared layer: use donor's cached K/V.
|
||||
@@ -1312,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
// Update cache.
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1365,7 +1365,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
// strided views differently. Metal handles them natively.
|
||||
out = mlx.Contiguous(out, false)
|
||||
}
|
||||
return a.OProj.Forward(out)
|
||||
return a.OProj.Forward(out), donorKV
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
Reference in New Issue
Block a user