mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 04:54:08 +02:00
Compare commits
1 Commits
pdevine/sa
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c330ea33ed |
@@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
|||||||
if nSeqs > 0 {
|
if nSeqs > 0 {
|
||||||
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
||||||
if batchSize != seqTokens || nSeqs != seqs {
|
if batchSize != seqTokens || nSeqs != seqs {
|
||||||
return nil, ErrUnsupportedBatchLayout
|
// Fallback: treat as flat batch if layout doesn't match.
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs)
|
||||||
|
batchSize = batchSize * nSeqs
|
||||||
|
} else {
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||||
|
batchSize = seqTokens * seqs
|
||||||
}
|
}
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
|
||||||
batchSize = seqTokens * seqs
|
|
||||||
} else if batchSize != seqTokens*seqs {
|
} else if batchSize != seqTokens*seqs {
|
||||||
return nil, ErrUnsupportedBatchLayout
|
// Layout mismatch; proceed with flat batch.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ type HybridCache struct {
|
|||||||
curSlots []int
|
curSlots []int
|
||||||
curSlotsInput ml.Tensor
|
curSlotsInput ml.Tensor
|
||||||
curSeqTokens int
|
curSeqTokens int
|
||||||
|
// token indices per sequence in batch order
|
||||||
|
curSeqTokenIdxs [][]int32
|
||||||
|
|
||||||
// track if EnsureWritable has been called for this forward pass
|
// track if EnsureWritable has been called for this forward pass
|
||||||
writableEnsured bool
|
writableEnsured bool
|
||||||
@@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(c.curSeqs) == 0 {
|
if len(c.curSeqs) == 0 {
|
||||||
|
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0]
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cap(c.curSeqTokenIdxs) < len(c.curSeqs) {
|
||||||
|
c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs))
|
||||||
|
} else {
|
||||||
|
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)]
|
||||||
|
}
|
||||||
|
for i := range c.curSeqTokenIdxs {
|
||||||
|
c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
seqIndex := make(map[int]int, len(c.curSeqs))
|
||||||
|
for i, s := range c.curSeqs {
|
||||||
|
seqIndex[s] = i
|
||||||
|
}
|
||||||
|
for i, s := range batch.Sequences {
|
||||||
|
c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i))
|
||||||
|
}
|
||||||
|
|
||||||
nTokens := len(batch.Sequences)
|
nTokens := len(batch.Sequences)
|
||||||
nSeqs := len(c.curSeqs)
|
nSeqs := len(c.curSeqs)
|
||||||
want := nTokens / nSeqs
|
want := nTokens / nSeqs
|
||||||
|
uniform := true
|
||||||
for _, s := range c.curSeqs {
|
for _, s := range c.curSeqs {
|
||||||
if seqCounts[s] != want {
|
if seqCounts[s] != want {
|
||||||
return kvcache.ErrNotSupported
|
uniform = false
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.curSeqTokens = want
|
if uniform {
|
||||||
|
c.curSeqTokens = want
|
||||||
|
} else {
|
||||||
|
// Mixed batch: recurrent layers will process sequences independently.
|
||||||
|
c.curSeqTokens = 0
|
||||||
|
}
|
||||||
|
|
||||||
// When reserving memory for estimation, use fake slot assignments
|
// When reserving memory for estimation, use fake slot assignments
|
||||||
if reserve {
|
if reserve {
|
||||||
@@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te
|
|||||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1].
|
||||||
|
func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) {
|
||||||
|
c.ensureWritableOnce(ctx)
|
||||||
|
if c.writableError != nil {
|
||||||
|
return nil, c.writableError
|
||||||
|
}
|
||||||
|
buf := c.convBuffer(ctx, layer)
|
||||||
|
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||||
|
cur := buf.Rows(ctx, slotIdx)
|
||||||
|
return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateConvStateForSlot writes a new conv state for a single slot.
|
||||||
|
func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
|
||||||
|
buf := c.convBuffer(ctx, layer)
|
||||||
|
src := newState.Reshape(ctx, c.convDim*c.convChannels, 1)
|
||||||
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||||
|
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||||
|
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
|
||||||
|
c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1].
|
||||||
|
func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||||
|
c.ensureWritableOnce(ctx)
|
||||||
|
if c.writableError != nil {
|
||||||
|
return nil, c.writableError
|
||||||
|
}
|
||||||
|
buf := c.deltaBuffer(ctx, layer)
|
||||||
|
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||||
|
cur := buf.Rows(ctx, slotIdx)
|
||||||
|
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateDeltaStateForSlot writes a new delta state for a single slot.
|
||||||
|
func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
|
||||||
|
buf := c.deltaBuffer(ctx, layer)
|
||||||
|
src := newState.Reshape(ctx, c.deltaStateSize, 1)
|
||||||
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||||
|
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||||
|
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
|
||||||
|
c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
|
||||||
|
if c.checkpointCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.reserveCheckpoints {
|
||||||
|
c.reserveCheckpointConv(layer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pos := c.curCheckpointPos[seqIndex]
|
||||||
|
if pos < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slot := c.curSlots[seqIndex]
|
||||||
|
idx := c.checkpointIndexForSlot(slot, pos)
|
||||||
|
if idx < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry := &c.checkpoints[slot].entries[idx]
|
||||||
|
dst := c.ensureCheckpointConv(layer, entry)
|
||||||
|
ctx.Forward(src.Copy(ctx, dst))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
|
||||||
|
if c.checkpointCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.reserveCheckpoints {
|
||||||
|
c.reserveCheckpointDelta(layer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pos := c.curCheckpointPos[seqIndex]
|
||||||
|
if pos < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slot := c.curSlots[seqIndex]
|
||||||
|
idx := c.checkpointIndexForSlot(slot, pos)
|
||||||
|
if idx < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry := &c.checkpoints[slot].entries[idx]
|
||||||
|
dst := c.ensureCheckpointDelta(layer, entry)
|
||||||
|
ctx.Forward(src.Copy(ctx, dst))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing.
|
||||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ type GatedDeltaNet struct {
|
|||||||
Layer int
|
Layer int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stateAccessors struct {
|
||||||
|
convState func() (ml.Tensor, error)
|
||||||
|
updateConv func(ml.Tensor)
|
||||||
|
deltaState func() (ml.Tensor, error)
|
||||||
|
updateDelta func(ml.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
||||||
func createMasks(ctx ml.Context) *Masks {
|
func createMasks(ctx ml.Context) *Masks {
|
||||||
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
||||||
@@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||||
layer := gdn.Layer
|
|
||||||
nSeqTokens := hiddenStates.Dim(1)
|
nSeqTokens := hiddenStates.Dim(1)
|
||||||
nSeqs := hiddenStates.Dim(2)
|
nSeqs := hiddenStates.Dim(2)
|
||||||
if cache != nil && cache.IsSupportedForBatch() {
|
if cache != nil && cache.IsSupportedForBatch() {
|
||||||
@@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
if seqTokens > 0 && seqs > 0 {
|
if seqTokens > 0 && seqs > 0 {
|
||||||
if nSeqs > 1 {
|
if nSeqs > 1 {
|
||||||
if nSeqTokens != seqTokens || nSeqs != seqs {
|
if nSeqTokens != seqTokens || nSeqs != seqs {
|
||||||
return nil, ErrUnsupportedBatchLayout
|
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if nSeqTokens != seqTokens*seqs {
|
if nSeqTokens != seqTokens*seqs {
|
||||||
return nil, ErrUnsupportedBatchLayout
|
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||||
}
|
}
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
||||||
nSeqTokens = seqTokens
|
|
||||||
nSeqs = seqs
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
numVHeads := opts.ssmDtRank
|
||||||
|
headVDim := opts.ssmDInner / numVHeads
|
||||||
|
layer := gdn.Layer
|
||||||
|
access := stateAccessors{
|
||||||
|
convState: func() (ml.Tensor, error) {
|
||||||
|
return cache.ConvState(ctx, layer)
|
||||||
|
},
|
||||||
|
updateConv: func(newState ml.Tensor) {
|
||||||
|
cache.UpdateConvState(ctx, layer, newState)
|
||||||
|
},
|
||||||
|
deltaState: func() (ml.Tensor, error) {
|
||||||
|
return cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||||
|
},
|
||||||
|
updateDelta: func(newState ml.Tensor) {
|
||||||
|
cache.UpdateDeltaState(ctx, layer, newState)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cache == nil {
|
||||||
|
return nil, ErrUnsupportedBatchLayout
|
||||||
|
}
|
||||||
|
|
||||||
|
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||||
|
if hiddenStates.Dim(2) > 0 {
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cache.curSeqs) == 0 {
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any shared slots are detached once for this forward pass.
|
||||||
|
cache.ensureWritableOnce(ctx)
|
||||||
|
|
||||||
|
layer := gdn.Layer
|
||||||
|
numVHeads := opts.ssmDtRank
|
||||||
|
headVDim := opts.ssmDInner / numVHeads
|
||||||
|
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Precompute projections for the full batch and slice per sequence.
|
||||||
|
mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||||
|
qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates)
|
||||||
|
zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
out := hiddenStates
|
||||||
|
for seqIndex := range cache.curSeqs {
|
||||||
|
idxs := cache.curSeqTokenIdxs[seqIndex]
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idxTensor := ctx.Input().FromInts(idxs, len(idxs))
|
||||||
|
|
||||||
|
mixedBA := mixedBAFull.Rows(ctx, idxTensor)
|
||||||
|
qkvMixed := qkvMixedFull.Rows(ctx, idxTensor)
|
||||||
|
z := zFull.Rows(ctx, idxTensor)
|
||||||
|
|
||||||
|
slot := cache.curSlots[seqIndex]
|
||||||
|
access := stateAccessors{
|
||||||
|
convState: func() (ml.Tensor, error) {
|
||||||
|
return cache.convStateForSlot(ctx, layer, slot)
|
||||||
|
},
|
||||||
|
updateConv: func(newState ml.Tensor) {
|
||||||
|
cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState)
|
||||||
|
},
|
||||||
|
deltaState: func() (ml.Tensor, error) {
|
||||||
|
return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads)
|
||||||
|
},
|
||||||
|
updateDelta: func(newState ml.Tensor) {
|
||||||
|
cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = out.SetRows(ctx, seqOut, idxTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) {
|
||||||
|
nSeqTokens := hiddenStates.Dim(1)
|
||||||
|
nSeqs := hiddenStates.Dim(2)
|
||||||
|
|
||||||
|
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||||
|
}
|
||||||
|
// Optimized path: pre-split QKV and gate
|
||||||
|
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates)
|
||||||
|
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gdn *GatedDeltaNet) forwardProjected(
|
||||||
|
ctx ml.Context,
|
||||||
|
nSeqTokens, nSeqs int,
|
||||||
|
mixedBA, qkvMixed, z ml.Tensor,
|
||||||
|
opts *Options,
|
||||||
|
access stateAccessors,
|
||||||
|
) (ml.Tensor, error) {
|
||||||
|
layer := gdn.Layer
|
||||||
|
|
||||||
headKDim := opts.ssmDState
|
headKDim := opts.ssmDState
|
||||||
numKHeads := opts.ssmNGroup
|
numKHeads := opts.ssmNGroup
|
||||||
numVHeads := opts.ssmDtRank
|
numVHeads := opts.ssmDtRank
|
||||||
headVDim := opts.ssmDInner / numVHeads
|
headVDim := opts.ssmDInner / numVHeads
|
||||||
convKernelSize := opts.convKernelSize
|
convKernelSize := opts.convKernelSize
|
||||||
|
|
||||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
|
||||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||||
|
|
||||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
|
||||||
}
|
|
||||||
// Optimized path: pre-split QKV and gate
|
|
||||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
|
||||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
|
||||||
|
|
||||||
baNewDim := 2 * numVHeads / numKHeads
|
baNewDim := 2 * numVHeads / numKHeads
|
||||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||||
@@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||||
|
|
||||||
// Get conv state from cache
|
// Get conv state from cache
|
||||||
convStates, err := cache.ConvState(ctx, layer)
|
convStates, err := access.convState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log this - if it happens, short-term context will be lost
|
// Log this - if it happens, short-term context will be lost
|
||||||
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||||
@@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
|
|
||||||
// Save new conv state (last convKernelSize-1 tokens)
|
// Save new conv state (last convKernelSize-1 tokens)
|
||||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
||||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
access.updateConv(lastConvStates)
|
||||||
|
|
||||||
// Apply SSM convolution (kernel must be F32 for Metal)
|
// Apply SSM convolution (kernel must be F32 for Metal)
|
||||||
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
||||||
@@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
|
||||||
// Get delta state from cache
|
// Get delta state from cache
|
||||||
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
state, err := access.deltaState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log this - if it happens frequently, context will degrade
|
// Log this - if it happens frequently, context will degrade
|
||||||
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
||||||
@@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Choose computation mode based on sequence length
|
// Choose computation mode based on sequence length
|
||||||
var attnOut ml.Tensor
|
var (
|
||||||
|
attnOut ml.Tensor
|
||||||
|
newState ml.Tensor
|
||||||
|
)
|
||||||
if nSeqTokens == 1 {
|
if nSeqTokens == 1 {
|
||||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts)
|
||||||
} else {
|
} else {
|
||||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
access.updateDelta(newState)
|
||||||
|
|
||||||
// Apply gated normalization
|
// Apply gated normalization
|
||||||
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||||
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||||
@@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
|||||||
ctx ml.Context,
|
ctx ml.Context,
|
||||||
q, k, v, gate, beta, state ml.Tensor,
|
q, k, v, gate, beta, state ml.Tensor,
|
||||||
opts *Options,
|
opts *Options,
|
||||||
layer int,
|
) (ml.Tensor, ml.Tensor) {
|
||||||
cache *HybridCache,
|
|
||||||
) ml.Tensor {
|
|
||||||
numVHeads := v.Dim(1)
|
numVHeads := v.Dim(1)
|
||||||
headVDim := v.Dim(0)
|
headVDim := v.Dim(0)
|
||||||
nSeqs := q.Dim(3)
|
nSeqs := q.Dim(3)
|
||||||
@@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
|||||||
coreAttnOut := stateQ.SumRows(ctx)
|
coreAttnOut := stateQ.SumRows(ctx)
|
||||||
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
||||||
|
|
||||||
// Update delta state in cache
|
newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
|
||||||
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState
|
||||||
|
|
||||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deltaNetChunked implements chunked computation for prefill.
|
// deltaNetChunked implements chunked computation for prefill.
|
||||||
@@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
q, k, v, gate, beta, state ml.Tensor,
|
q, k, v, gate, beta, state ml.Tensor,
|
||||||
masks *Masks,
|
masks *Masks,
|
||||||
opts *Options,
|
opts *Options,
|
||||||
layer int,
|
) (ml.Tensor, ml.Tensor) {
|
||||||
cache *HybridCache,
|
|
||||||
) ml.Tensor {
|
|
||||||
headKDim := q.Dim(0)
|
headKDim := q.Dim(0)
|
||||||
numVHeads := v.Dim(1)
|
numVHeads := v.Dim(1)
|
||||||
headVDim := v.Dim(0)
|
headVDim := v.Dim(0)
|
||||||
@@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update delta state in cache
|
newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
|
||||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat
|
||||||
|
|
||||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user