mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
model: add qwen3-next compatibility for legacy ssm_in projections (#15133)
This commit is contained in:
@@ -34,9 +34,9 @@ type Masks struct {
|
||||
// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
|
||||
// It implements the Operator interface directly.
|
||||
type GatedDeltaNet struct {
|
||||
// Optimized path: pre-split QKV and gate
|
||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||
SSMIn *nn.Linear `gguf:"ssm_in"`
|
||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
@@ -100,12 +100,27 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
// Support both current split projections and older qwen3-next imports that use ssm_in.
|
||||
var qkvMixed, z ml.Tensor
|
||||
switch {
|
||||
case gdn.SSMQKV != nil && gdn.SSMQKVGate != nil:
|
||||
qkvMixed = gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z = gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
case gdn.SSMIn != nil:
|
||||
vPerHead := headVDim * numVHeads / numKHeads
|
||||
qkvzDim := 2*headKDim + 2*vPerHead
|
||||
|
||||
combined := gdn.SSMIn.Forward(ctx, hiddenStates).Reshape(ctx, qkvzDim, numKHeads, nSeqTokens, nSeqs)
|
||||
qPart := combined.Slice(ctx, 0, 0, headKDim, 1).Contiguous(ctx, headKDim*numKHeads, nSeqTokens, nSeqs)
|
||||
kPart := combined.Slice(ctx, 0, headKDim, 2*headKDim, 1).Contiguous(ctx, headKDim*numKHeads, nSeqTokens, nSeqs)
|
||||
vPart := combined.Slice(ctx, 0, 2*headKDim, 2*headKDim+vPerHead, 1).Contiguous(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
||||
zPart := combined.Slice(ctx, 0, 2*headKDim+vPerHead, qkvzDim, 1).Contiguous(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
qkvMixed = qPart.Concat(ctx, kPart, 0).Concat(ctx, vPart, 0)
|
||||
z = zPart
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate or ssm_in projections")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
var beta ml.Tensor
|
||||
var alpha ml.Tensor
|
||||
|
||||
@@ -454,7 +454,7 @@ func (m *Model) Validate() error {
|
||||
if !ok || gdn == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||
}
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
if gdn.SSMIn == nil && (gdn.SSMQKV == nil || gdn.SSMQKVGate == nil) {
|
||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||
}
|
||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||
|
||||
@@ -31,6 +31,31 @@ func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRecurrentSSMInAccepted(t *testing.T) {
|
||||
// When SSMIn is set, Validate must not reject the layer for missing
|
||||
// attn_qkv/attn_gate. It should fail later on missing ssm_dt.
|
||||
m := &Model{
|
||||
Layers: []Layer{{
|
||||
Operator: &GatedDeltaNet{
|
||||
SSMIn: &nn.Linear{},
|
||||
SSMBeta: &nn.Linear{},
|
||||
SSMAlpha: &nn.Linear{},
|
||||
},
|
||||
}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{true},
|
||||
},
|
||||
}
|
||||
|
||||
err := m.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error, got nil")
|
||||
}
|
||||
if strings.Contains(err.Error(), "missing attn_qkv/attn_gate") {
|
||||
t.Fatalf("Validate() should not fail on attn_qkv/attn_gate when SSMIn is set, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||
|
||||
Reference in New Issue
Block a user