package qwen3next import ( "errors" "log/slog" "math" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" ) const chunkSize = 64 // TriType constants for triangular matrix operations const ( TriTypeUpperDiag = 0 TriTypeUpper = 1 TriTypeLowerDiag = 2 TriTypeLower = 3 ) // convKernel wraps the 1D convolution kernel tensor type convKernel struct { Weight ml.Tensor `gguf:"weight"` } // Masks holds pre-computed mask tensors for chunked attention type Masks struct { Causal ml.Tensor // Lower triangular [chunkSize, chunkSize] Identity ml.Tensor // Diagonal [chunkSize, chunkSize] Diag ml.Tensor // causal + identity } // GatedDeltaNet implements linear attention with SSM convolution and recurrent state. // It implements the Operator interface directly. type GatedDeltaNet struct { // Optimized path: pre-split QKV and gate SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated) SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next) SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35) SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35) SSMConv1D *convKernel `gguf:"ssm_conv1d"` SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` SSMOut *nn.Linear `gguf:"ssm_out"` // Layer index for cache access (set during model construction) Layer int } // createMasks builds the constant mask tensors (called once, reused for all chunks) func createMasks(ctx ml.Context) *Masks { ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize) ones = ones.Fill(ctx, 1.0) causalMask := ones.Tri(ctx, TriTypeLower) onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize) onesVec = onesVec.Fill(ctx, 1.0) identity := onesVec.Diag(ctx) diagMask := causalMask.Add(ctx, identity) return &Masks{ Causal: causalMask, Identity: identity, Diag: diagMask, } } func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { layer := gdn.Layer nSeqTokens := hiddenStates.Dim(1) nSeqs := hiddenStates.Dim(2) if cache != nil && cache.IsSupportedForBatch() { seqTokens := cache.seqTokens() seqs := cache.numSeqs() if seqTokens > 0 && seqs > 0 { if nSeqs > 1 { if nSeqTokens != seqTokens || nSeqs != seqs { return nil, ErrUnsupportedBatchLayout } } else { if nSeqTokens != seqTokens*seqs { return nil, ErrUnsupportedBatchLayout } hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs) nSeqTokens = seqTokens nSeqs = seqs } } } headKDim := opts.ssmDState numKHeads := opts.ssmNGroup numVHeads := opts.ssmDtRank headVDim := opts.ssmDInner / numVHeads convKernelSize := opts.convKernelSize qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)") } // Optimized path: pre-split QKV and gate qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs) z := gdn.SSMQKVGate.Forward(ctx, hiddenStates) var beta ml.Tensor var alpha ml.Tensor switch { case gdn.SSMBetaAlpha != nil: // Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head. mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates) baNewDim := 2 * numVHeads / numKHeads mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs) betaSize := numVHeads / numKHeads alphaSize := numVHeads / numKHeads b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1) a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1) // Keep beta layout consistent with qwen35. // [1, numVHeads, nSeqTokens, nSeqs] beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs) alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs) case gdn.SSMBeta != nil && gdn.SSMAlpha != nil: // qwen35 path: beta/alpha are separate projections. beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs) alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs) default: return nil, errors.New("qwen3next: missing linear attention beta/alpha projections") } if gdn.SSMDT == nil { return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor") } if gdn.SSMA == nil { return nil, errors.New("qwen3next: missing linear attention ssm_a tensor") } if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil { return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor") } if gdn.SSMNorm == nil || gdn.SSMOut == nil { return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections") } // Compute gate: softplus(alpha + dt_bias) * -A alphaBiased := alpha.Add(ctx, gdn.SSMDT) alphaSoftplus := alphaBiased.Softplus(ctx) gate := alphaSoftplus.Mul(ctx, gdn.SSMA) gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs) qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3) // Get conv state from cache convStates, err := cache.ConvState(ctx, layer) if err != nil { // Log this - if it happens, short-term context will be lost slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err) convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs) } // Reshape conv states convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs) // Concatenate with input for convolution convInput := convStates.Concat(ctx, qkvMixed, 0) // Save new conv state (last convKernelSize-1 tokens) lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1) cache.UpdateConvState(ctx, layer, lastConvStates) // Apply SSM convolution (kernel must be F32 for Metal) convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight) convOutput = convOutput.SILU(ctx) // Reshape for extraction convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs) // Extract convolved Q, K, V qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1) kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1) vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1) // Reshape to 4D qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs) kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs) vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs) // Get delta state from cache state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads) if err != nil { // Log this - if it happens frequently, context will degrade slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err) state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs) } state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs) // Repeat interleave Q and K if numKHeads != numVHeads if numKHeads != numVHeads { if opts.vHeadReordered { qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs) kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs) } else { repeatFactor := numVHeads / numKHeads qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs) kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs) qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1) kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1) qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs) kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs) } } // Choose computation mode based on sequence length var attnOut ml.Tensor if nSeqTokens == 1 { attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache) } else { if opts.masks == nil { opts.masks = createMasks(ctx) } attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache) } // Apply gated normalization attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) // norm(attnOut, z) = RMSNorm(attnOut) * silu(z) attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps) zSilu := z2D.SILU(ctx) attnOutGated := attnOutNorm.Mul(ctx, zSilu) // Reshape for output projection finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs) out := gdn.SSMOut.Forward(ctx, finalOutput) return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil } // deltaNetAutoregressive implements single-token state update. // NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]). func (gdn *GatedDeltaNet) deltaNetAutoregressive( ctx ml.Context, q, k, v, gate, beta, state ml.Tensor, opts *Options, layer int, cache *HybridCache, ) ml.Tensor { numVHeads := v.Dim(1) headVDim := v.Dim(0) nSeqs := q.Dim(3) // L2 normalize Q and K q = q.L2Norm(ctx, opts.eps) k = k.L2Norm(ctx, opts.eps) // Scale Q scale := 1.0 / math.Sqrt(float64(headVDim)) q = q.Scale(ctx, scale) // Sigmoid beta beta = beta.Sigmoid(ctx) // Reshape state: [headVDim, headVDim, numVHeads, nSeqs] state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs) // Reshape gate and beta for broadcasting gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs) betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs) // Apply exponential to gate gT = gT.Exp(ctx) // state = state * g_t state = state.Mul(ctx, gT) // kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs) kvMem := state.Mul(ctx, kTUnsqueezed) // Sum over dim=-2 (second dimension after permute) kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) kvMem = kvMem.SumRows(ctx) kvMem = kvMem.Permute(ctx, 1, 0, 2, 3) // v_t with singleton dimension vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs) // delta = (v_t - kv_mem) * beta_t vDiff := vT.Sub(ctx, kvMem) delta := vDiff.Mul(ctx, betaT) // state = state + k_t.unsqueeze(-1) * delta kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs) kTDelta := kTUnsqueezedBroad.Mul(ctx, delta) state = state.Add(ctx, kTDelta) // core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2) qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs) stateQ := state.Mul(ctx, qTUnsqueezed) stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) coreAttnOut := stateQ.SumRows(ctx) coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3) // Update delta state in cache cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs) } // deltaNetChunked implements chunked computation for prefill. // NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]). func (gdn *GatedDeltaNet) deltaNetChunked( ctx ml.Context, q, k, v, gate, beta, state ml.Tensor, masks *Masks, opts *Options, layer int, cache *HybridCache, ) ml.Tensor { headKDim := q.Dim(0) numVHeads := v.Dim(1) headVDim := v.Dim(0) nTokens := q.Dim(2) nSeqs := q.Dim(3) // L2 normalize Q and K q = q.L2Norm(ctx, opts.eps) k = k.L2Norm(ctx, opts.eps) // Scale Q scale := 1.0 / math.Sqrt(float64(headVDim)) q = q.Scale(ctx, scale) // Sigmoid beta beta = beta.Sigmoid(ctx) // Permute tensors for chunked computation q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs) k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs) v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs) // gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs] gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs) beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs) state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs) // Compute padding pad := (chunkSize - nTokens%chunkSize) % chunkSize nChunks := (nTokens + pad) / chunkSize // Pad tensors if pad > 0 { q = q.Pad(ctx, 0, pad, 0, 0) k = k.Pad(ctx, 0, pad, 0, 0) v = v.Pad(ctx, 0, pad, 0, 0) gate = gate.Pad(ctx, 0, pad, 0, 0) beta = beta.Pad(ctx, 0, pad, 0, 0) } // Use pre-computed masks (passed in, not recreated) causalMask := masks.Causal identity := masks.Identity diagMask := masks.Diag identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1) // v_beta = v * beta, k_beta = k * beta vBeta := v.Mul(ctx, beta) kBeta := k.Mul(ctx, beta) // Reshape for chunked computation q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs) vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs) // Reshape gate and cumsum over chunk axis. // [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs] gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs) // g_cumsum = cumsum(gate) gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx) // Compute decay mask gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs) gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs) gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs) decayMask := gcsBroadcast.Sub(ctx, gcsI) decayMask = decayMask.Mul(ctx, diagMask) decayMask = decayMask.Exp(ctx) decayMask = decayMask.Mul(ctx, diagMask) // k @ k_beta^T kMulKBeta := k.Mulmat(ctx, kBeta) // k_decay = k @ k_beta^T * decay_mask kDecay := kMulKBeta.Mul(ctx, decayMask) // attn = -k_decay * causal_mask attn := kDecay.Neg(ctx).Mul(ctx, causalMask) // Triangular solve: (I - attn_lower)^-1 @ attn attnLower := attn.Mul(ctx, causalMask) lhs := attnLower.Neg(ctx).Add(ctx, identity4D) linSolve := lhs.SolveTri(ctx, attn, true, true, false) attn = linSolve.Mul(ctx, causalMask) attn = attn.Add(ctx, identity4D) // v = v_beta^T @ attn vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) v = vBetaT.Mulmat(ctx, attn) // Compute g_exp for state update gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) gExp := gCumsumT.Exp(ctx) // kbeta_gexp = k_beta * g_exp kBetaGExp := kBeta.Mul(ctx, gExp) // k_cumdecay = attn @ kbeta_gexp^T kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) kCumdecay := attn.Mulmat(ctx, kBetaGExpT) kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask attnKQ := k.Mulmat(ctx, q) attnKQ = attnKQ.Mul(ctx, decayMask) attnKQ = attnKQ.Mul(ctx, diagMask) // Pre-compute g_last and key_gdiff // g_last = view of last element in g_cumsum along chunk_size dimension // We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs] gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs) gLastExp := gLast.Exp(ctx) // g_diff = -(g_cumsum - g_last) = g_last - g_cumsum gDiff := gCumsum.Neg(ctx).Add(ctx, gLast) gDiffExp := gDiff.Exp(ctx) // Reshapes g_diff_exp to [1, chunkSize, nChunks, ...] gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs) keyGDiff := k.Mul(ctx, gDiffExpReshaped) keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // Process chunks and update state. // Keep a transposed view of v and recurrent state across chunks so the // chunk loop does not need extra transpose+contiguous nodes. vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs) stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs) // Collect chunk outputs and concatenate at the end. // Avoids SET on buffer-less intermediates under partial offload. chunks := make([]ml.Tensor, nChunks) for chunk := range nChunks { qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1) vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1) gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1) kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1) attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed! // v'_t = k_cumdecay @ state_t vTPrime := kCumdecayChunk.Mulmat(ctx, stateT) // v_t_new = v_t - v'_t vTNewChunk := vTChunk.Sub(ctx, vTPrime) // attn_inter = (q * g_exp) @ state qGExp := qChunk.Mul(ctx, gExpChunk) attnInter := stateT.Mulmat(ctx, qGExp) // core_attn_out = attn_inter + attn @ v_new vAttn := vTNewChunk.Mulmat(ctx, attnChunk) coreAttnOutChunk := attnInter.Add(ctx, vAttn) chunks[chunk] = coreAttnOutChunk // Update state for next chunk gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1) // kgdmulvnew = key_gdiff_t @ v_new_t kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk) // stateT = stateT * g_last + kgdmulvnew stateT = stateT.Mul(ctx, gExpLastChunk) stateT = stateT.Add(ctx, kgdMulVNew) } // Use a balanced concat tree so concat work does not balloon on long prompts. for len(chunks) > 1 { merged := make([]ml.Tensor, 0, (len(chunks)+1)/2) for i := 0; i < len(chunks); i += 2 { if i+1 < len(chunks) { merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2)) } else { merged = append(merged, chunks[i]) } } chunks = merged } v = chunks[0] // Final reshape coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs) // Slice to remove padding if pad > 0 { coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1) } // Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs] newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs) // Update delta state in cache cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs) }