model: add qwen3-next compatibility for legacy ssm_in projections (#15133)

This commit is contained in:
Jeffrey Morgan
2026-03-29 11:50:47 -07:00
committed by GitHub
parent 8e54823fd3
commit b7bda92d52
3 changed files with 47 additions and 7 deletions

View File

@@ -34,9 +34,9 @@ type Masks struct {
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
// It implements the Operator interface directly.
type GatedDeltaNet struct {
// Optimized path: pre-split QKV and gate
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
SSMIn *nn.Linear `gguf:"ssm_in"`
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
@@ -100,12 +100,27 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
// Support both current split projections and older qwen3-next imports that use ssm_in.
var qkvMixed, z ml.Tensor
switch {
case gdn.SSMQKV != nil && gdn.SSMQKVGate != nil:
qkvMixed = gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z = gdn.SSMQKVGate.Forward(ctx, hiddenStates)
case gdn.SSMIn != nil:
vPerHead := headVDim * numVHeads / numKHeads
qkvzDim := 2*headKDim + 2*vPerHead
combined := gdn.SSMIn.Forward(ctx, hiddenStates).Reshape(ctx, qkvzDim, numKHeads, nSeqTokens, nSeqs)
qPart := combined.Slice(ctx, 0, 0, headKDim, 1).Contiguous(ctx, headKDim*numKHeads, nSeqTokens, nSeqs)
kPart := combined.Slice(ctx, 0, headKDim, 2*headKDim, 1).Contiguous(ctx, headKDim*numKHeads, nSeqTokens, nSeqs)
vPart := combined.Slice(ctx, 0, 2*headKDim, 2*headKDim+vPerHead, 1).Contiguous(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
zPart := combined.Slice(ctx, 0, 2*headKDim+vPerHead, qkvzDim, 1).Contiguous(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
qkvMixed = qPart.Concat(ctx, kPart, 0).Concat(ctx, vPart, 0)
z = zPart
default:
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate or ssm_in projections")
}
// Optimized path: pre-split QKV and gate
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
var beta ml.Tensor
var alpha ml.Tensor

View File

@@ -454,7 +454,7 @@ func (m *Model) Validate() error {
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
if gdn.SSMIn == nil && (gdn.SSMQKV == nil || gdn.SSMQKVGate == nil) {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {

View File

@@ -31,6 +31,31 @@ func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
}
}
func TestValidateRecurrentSSMInAccepted(t *testing.T) {
// When SSMIn is set, Validate must not reject the layer for missing
// attn_qkv/attn_gate. It should fail later on missing ssm_dt.
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMIn: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if strings.Contains(err.Error(), "missing attn_qkv/attn_gate") {
t.Fatalf("Validate() should not fail on attn_qkv/attn_gate when SSMIn is set, got: %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},