mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
499 lines
17 KiB
Go
499 lines
17 KiB
Go
package qwen3next
|
|
|
|
import (
|
|
"errors"
|
|
"log/slog"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
const chunkSize = 64
|
|
|
|
// TriType constants for triangular matrix operations
|
|
const (
|
|
TriTypeUpperDiag = 0
|
|
TriTypeUpper = 1
|
|
TriTypeLowerDiag = 2
|
|
TriTypeLower = 3
|
|
)
|
|
|
|
// convKernel wraps the 1D convolution kernel tensor
|
|
type convKernel struct {
|
|
Weight ml.Tensor `gguf:"weight"`
|
|
}
|
|
|
|
// Masks holds pre-computed mask tensors for chunked attention
|
|
type Masks struct {
|
|
Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
|
|
Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
|
|
Diag ml.Tensor // causal + identity
|
|
}
|
|
|
|
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
|
|
// It implements the Operator interface directly.
|
|
type GatedDeltaNet struct {
|
|
// Optimized path: pre-split QKV and gate
|
|
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
|
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
|
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
|
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
|
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
|
|
|
// Layer index for cache access (set during model construction)
|
|
Layer int
|
|
}
|
|
|
|
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
|
func createMasks(ctx ml.Context) *Masks {
|
|
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
|
ones = ones.Fill(ctx, 1.0)
|
|
causalMask := ones.Tri(ctx, TriTypeLower)
|
|
|
|
onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
|
|
onesVec = onesVec.Fill(ctx, 1.0)
|
|
identity := onesVec.Diag(ctx)
|
|
|
|
diagMask := causalMask.Add(ctx, identity)
|
|
|
|
return &Masks{
|
|
Causal: causalMask,
|
|
Identity: identity,
|
|
Diag: diagMask,
|
|
}
|
|
}
|
|
|
|
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
|
layer := gdn.Layer
|
|
nSeqTokens := hiddenStates.Dim(1)
|
|
nSeqs := hiddenStates.Dim(2)
|
|
if cache != nil && cache.IsSupportedForBatch() {
|
|
seqTokens := cache.seqTokens()
|
|
seqs := cache.numSeqs()
|
|
if seqTokens > 0 && seqs > 0 {
|
|
if nSeqs > 1 {
|
|
if nSeqTokens != seqTokens || nSeqs != seqs {
|
|
return nil, ErrUnsupportedBatchLayout
|
|
}
|
|
} else {
|
|
if nSeqTokens != seqTokens*seqs {
|
|
return nil, ErrUnsupportedBatchLayout
|
|
}
|
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
|
nSeqTokens = seqTokens
|
|
nSeqs = seqs
|
|
}
|
|
}
|
|
}
|
|
|
|
headKDim := opts.ssmDState
|
|
numKHeads := opts.ssmNGroup
|
|
numVHeads := opts.ssmDtRank
|
|
headVDim := opts.ssmDInner / numVHeads
|
|
convKernelSize := opts.convKernelSize
|
|
|
|
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
|
|
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
|
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
|
}
|
|
// Optimized path: pre-split QKV and gate
|
|
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
|
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
|
|
|
var beta ml.Tensor
|
|
var alpha ml.Tensor
|
|
switch {
|
|
case gdn.SSMBetaAlpha != nil:
|
|
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
|
|
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
|
baNewDim := 2 * numVHeads / numKHeads
|
|
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
|
|
|
betaSize := numVHeads / numKHeads
|
|
alphaSize := numVHeads / numKHeads
|
|
|
|
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
|
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
|
|
|
// Keep beta layout consistent with qwen35 and llama.cpp:
|
|
// [1, numVHeads, nSeqTokens, nSeqs]
|
|
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
|
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
|
|
|
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
|
|
// qwen35 path: beta/alpha are separate projections.
|
|
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
|
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
|
|
|
|
default:
|
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
|
}
|
|
|
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
|
alphaSoftplus := alphaBiased.Softplus(ctx)
|
|
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
|
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
|
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
|
|
|
// Get conv state from cache
|
|
convStates, err := cache.ConvState(ctx, layer)
|
|
if err != nil {
|
|
// Log this - if it happens, short-term context will be lost
|
|
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
|
convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
|
|
}
|
|
|
|
// Reshape conv states
|
|
convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
|
|
|
|
// Concatenate with input for convolution
|
|
convInput := convStates.Concat(ctx, qkvMixed, 0)
|
|
|
|
// Save new conv state (last convKernelSize-1 tokens)
|
|
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
|
cache.UpdateConvState(ctx, layer, lastConvStates)
|
|
|
|
// Apply SSM convolution (kernel must be F32 for Metal)
|
|
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
|
convOutput = convOutput.SILU(ctx)
|
|
|
|
// Reshape for extraction
|
|
convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
|
|
|
|
// Extract convolved Q, K, V
|
|
qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
|
|
kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
|
|
vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
|
|
|
|
// Reshape to 4D
|
|
qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
|
kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
|
|
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
|
|
|
// Get delta state from cache
|
|
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
|
if err != nil {
|
|
// Log this - if it happens frequently, context will degrade
|
|
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
|
state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
|
|
}
|
|
state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
|
|
|
|
// Repeat interleave Q and K if numKHeads != numVHeads
|
|
if numKHeads != numVHeads {
|
|
if opts.vHeadReordered {
|
|
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
|
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
|
} else {
|
|
repeatFactor := numVHeads / numKHeads
|
|
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
|
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
|
|
|
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
|
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
|
|
|
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
|
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
|
}
|
|
}
|
|
|
|
// Choose computation mode based on sequence length
|
|
var attnOut ml.Tensor
|
|
if nSeqTokens == 1 {
|
|
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
|
} else {
|
|
if opts.masks == nil {
|
|
opts.masks = createMasks(ctx)
|
|
}
|
|
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
|
}
|
|
|
|
// Apply gated normalization
|
|
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
|
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
|
|
|
// norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
|
|
attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
|
|
zSilu := z2D.SILU(ctx)
|
|
attnOutGated := attnOutNorm.Mul(ctx, zSilu)
|
|
|
|
// Reshape for output projection
|
|
finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
|
|
|
out := gdn.SSMOut.Forward(ctx, finalOutput)
|
|
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
|
}
|
|
|
|
// deltaNetAutoregressive implements single-token state update.
|
|
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
|
func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
|
ctx ml.Context,
|
|
q, k, v, gate, beta, state ml.Tensor,
|
|
opts *Options,
|
|
layer int,
|
|
cache *HybridCache,
|
|
) ml.Tensor {
|
|
numVHeads := v.Dim(1)
|
|
headVDim := v.Dim(0)
|
|
nSeqs := q.Dim(3)
|
|
|
|
// L2 normalize Q and K
|
|
q = q.L2Norm(ctx, opts.eps)
|
|
k = k.L2Norm(ctx, opts.eps)
|
|
|
|
// Scale Q
|
|
scale := 1.0 / math.Sqrt(float64(headVDim))
|
|
q = q.Scale(ctx, scale)
|
|
|
|
// Sigmoid beta
|
|
beta = beta.Sigmoid(ctx)
|
|
|
|
// Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
|
|
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
|
|
|
// Reshape gate and beta for broadcasting
|
|
gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
|
betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
|
|
|
// Apply exponential to gate
|
|
gT = gT.Exp(ctx)
|
|
|
|
// state = state * g_t
|
|
state = state.Mul(ctx, gT)
|
|
|
|
// kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
|
kvMem := state.Mul(ctx, kTUnsqueezed)
|
|
// Sum over dim=-2 (second dimension after permute)
|
|
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
kvMem = kvMem.SumRows(ctx)
|
|
kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
|
|
|
|
// v_t with singleton dimension
|
|
vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
|
|
|
|
// delta = (v_t - kv_mem) * beta_t
|
|
vDiff := vT.Sub(ctx, kvMem)
|
|
delta := vDiff.Mul(ctx, betaT)
|
|
|
|
// state = state + k_t.unsqueeze(-1) * delta
|
|
kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
|
kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
|
|
state = state.Add(ctx, kTDelta)
|
|
|
|
// core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
|
|
qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
|
|
stateQ := state.Mul(ctx, qTUnsqueezed)
|
|
stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
coreAttnOut := stateQ.SumRows(ctx)
|
|
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
|
|
|
// Update delta state in cache
|
|
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
|
|
|
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
|
}
|
|
|
|
// deltaNetChunked implements chunked computation for prefill.
|
|
// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
|
|
func (gdn *GatedDeltaNet) deltaNetChunked(
|
|
ctx ml.Context,
|
|
q, k, v, gate, beta, state ml.Tensor,
|
|
masks *Masks,
|
|
opts *Options,
|
|
layer int,
|
|
cache *HybridCache,
|
|
) ml.Tensor {
|
|
headKDim := q.Dim(0)
|
|
numVHeads := v.Dim(1)
|
|
headVDim := v.Dim(0)
|
|
nTokens := q.Dim(2)
|
|
nSeqs := q.Dim(3)
|
|
|
|
// L2 normalize Q and K
|
|
q = q.L2Norm(ctx, opts.eps)
|
|
k = k.L2Norm(ctx, opts.eps)
|
|
|
|
// Scale Q
|
|
scale := 1.0 / math.Sqrt(float64(headVDim))
|
|
q = q.Scale(ctx, scale)
|
|
|
|
// Sigmoid beta
|
|
beta = beta.Sigmoid(ctx)
|
|
|
|
// Permute tensors for chunked computation
|
|
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
|
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
|
// Match llama.cpp delta-net-base layout:
|
|
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
|
|
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
|
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
|
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
|
|
|
// Compute padding
|
|
pad := (chunkSize - nTokens%chunkSize) % chunkSize
|
|
nChunks := (nTokens + pad) / chunkSize
|
|
|
|
// Pad tensors
|
|
if pad > 0 {
|
|
q = q.Pad(ctx, 0, pad, 0, 0)
|
|
k = k.Pad(ctx, 0, pad, 0, 0)
|
|
v = v.Pad(ctx, 0, pad, 0, 0)
|
|
gate = gate.Pad(ctx, 0, pad, 0, 0)
|
|
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
|
}
|
|
|
|
// Use pre-computed masks (passed in, not recreated)
|
|
causalMask := masks.Causal
|
|
identity := masks.Identity
|
|
diagMask := masks.Diag
|
|
identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
|
|
|
|
// v_beta = v * beta, k_beta = k * beta
|
|
vBeta := v.Mul(ctx, beta)
|
|
kBeta := k.Mul(ctx, beta)
|
|
|
|
// Reshape for chunked computation
|
|
q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
|
k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
|
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
|
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
|
|
|
// Reshape gate and cumsum over chunk axis.
|
|
// [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs]
|
|
gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
|
|
|
// g_cumsum = cumsum(gate)
|
|
gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx)
|
|
|
|
// Compute decay mask
|
|
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
|
gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
|
gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
|
|
decayMask := gcsBroadcast.Sub(ctx, gcsI)
|
|
|
|
decayMask = decayMask.Mul(ctx, diagMask)
|
|
decayMask = decayMask.Exp(ctx)
|
|
decayMask = decayMask.Mul(ctx, diagMask)
|
|
|
|
// k @ k_beta^T
|
|
kMulKBeta := k.Mulmat(ctx, kBeta)
|
|
|
|
// k_decay = k @ k_beta^T * decay_mask
|
|
kDecay := kMulKBeta.Mul(ctx, decayMask)
|
|
|
|
// attn = -k_decay * causal_mask
|
|
attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
|
|
|
|
// Triangular solve: (I - attn_lower)^-1 @ attn
|
|
attnLower := attn.Mul(ctx, causalMask)
|
|
lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
|
|
linSolve := lhs.SolveTri(ctx, attn, true, true, false)
|
|
attn = linSolve.Mul(ctx, causalMask)
|
|
attn = attn.Add(ctx, identity4D)
|
|
|
|
// v = v_beta^T @ attn
|
|
vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
v = vBetaT.Mulmat(ctx, attn)
|
|
|
|
// Compute g_exp for state update
|
|
gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
gExp := gCumsumT.Exp(ctx)
|
|
|
|
// kbeta_gexp = k_beta * g_exp
|
|
kBetaGExp := kBeta.Mul(ctx, gExp)
|
|
|
|
// k_cumdecay = attn @ kbeta_gexp^T
|
|
kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
|
|
kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
// Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
|
|
attnKQ := k.Mulmat(ctx, q)
|
|
attnKQ = attnKQ.Mul(ctx, decayMask)
|
|
attnKQ = attnKQ.Mul(ctx, diagMask)
|
|
|
|
// Pre-compute g_last and key_gdiff
|
|
// g_last = view of last element in g_cumsum along chunk_size dimension
|
|
// We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
|
|
gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
|
|
gLastExp := gLast.Exp(ctx)
|
|
|
|
// g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
|
|
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
|
gDiffExp := gDiff.Exp(ctx)
|
|
|
|
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
|
|
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
|
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
|
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
// Process chunks and update state
|
|
var coreAttnOut ml.Tensor
|
|
newState := state
|
|
|
|
for chunk := range nChunks {
|
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
|
|
|
// state^T - permute is needed but Contiguous creates a copy
|
|
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
|
|
|
// v_prime = k_cumdecay @ state
|
|
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
|
|
|
// v_new = v - v_prime
|
|
vNew := vChunk.Sub(ctx, vPrime)
|
|
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
// attn_inter = (q * g_exp) @ state
|
|
qGExp := qChunk.Mul(ctx, gExpChunk)
|
|
attnInter := stateT.Mulmat(ctx, qGExp)
|
|
|
|
// core_attn_out = attn_inter + attn @ v_new
|
|
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
|
|
|
if coreAttnOut == nil {
|
|
coreAttnOut = coreAttnOutChunk
|
|
} else {
|
|
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
|
}
|
|
|
|
// Update state for next chunk
|
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
|
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
|
|
|
// state = state * g_last + kgdmulvnew
|
|
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
|
newState = newState.Mul(ctx, gExpLastReshaped)
|
|
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
|
}
|
|
|
|
// Final reshape
|
|
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
|
|
|
// Slice to remove padding
|
|
if pad > 0 {
|
|
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
|
}
|
|
|
|
// Update delta state in cache
|
|
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
|
|
|
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
|
}
|