mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 12:54:12 +02:00
104 lines
3.4 KiB
Go
104 lines
3.4 KiB
Go
package qwen3next
|
|
|
|
import (
|
|
"errors"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
|
// with the attention layer requirements.
|
|
var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
|
|
|
|
// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
|
|
// Key differences from standard attention:
|
|
// - Q projection outputs 2x size (Q + gate interleaved)
|
|
// - Both Q and K have RMSNorm
|
|
// - Output is gated: attn * sigmoid(gate)
|
|
type FullAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
|
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
|
// Use Dim() instead of Shape() for consistent behavior during graph construction
|
|
hiddenDim := hiddenStates.Dim(0)
|
|
batchSize := hiddenStates.Dim(1)
|
|
nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
|
|
|
|
if cache != nil && cache.IsSupportedForBatch() {
|
|
seqTokens := cache.seqTokens()
|
|
seqs := cache.numSeqs()
|
|
if seqTokens > 0 && seqs > 0 {
|
|
if nSeqs > 0 {
|
|
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
|
if batchSize != seqTokens || nSeqs != seqs {
|
|
return nil, ErrUnsupportedBatchLayout
|
|
}
|
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
|
batchSize = seqTokens * seqs
|
|
} else if batchSize != seqTokens*seqs {
|
|
return nil, ErrUnsupportedBatchLayout
|
|
}
|
|
}
|
|
}
|
|
headDim := opts.headDim()
|
|
numHeads := opts.numHeads
|
|
|
|
// Q projection outputs query + gate interleaved
|
|
qFull := sa.Query.Forward(ctx, hiddenStates)
|
|
|
|
// Reshape to [headDim * 2, numHeads, batchSize]
|
|
qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
|
|
|
|
// Split Q and gate along dimension 0
|
|
// Q: first headDim elements, gate: second headDim elements
|
|
query := qFull.Slice(ctx, 0, 0, headDim, 1)
|
|
gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
|
|
|
|
// Make query contiguous for further operations
|
|
query = query.Contiguous(ctx, headDim, numHeads, batchSize)
|
|
|
|
// K and V projections
|
|
key := sa.Key.Forward(ctx, hiddenStates)
|
|
value := sa.Value.Forward(ctx, hiddenStates)
|
|
|
|
// Derive numKVHeads from tensor dimensions (per-layer value)
|
|
numKVHeads := key.Dim(0) / headDim
|
|
|
|
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
|
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
|
|
|
// Apply QK normalization
|
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
|
|
|
// Apply RoPE
|
|
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
|
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
|
|
|
// Standard attention computation
|
|
scale := opts.attentionScale
|
|
if scale == 0 {
|
|
scale = 1.0 / math.Sqrt(float64(headDim))
|
|
}
|
|
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
|
|
|
// Flatten heads
|
|
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
|
|
|
// Apply sigmoid gate
|
|
// gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
|
|
gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
|
|
gateSigmoid := gate.Sigmoid(ctx)
|
|
attention = attention.Mul(ctx, gateSigmoid)
|
|
|
|
return sa.Output.Forward(ctx, attention), nil
|
|
}
|