diff --git a/x/create/client/create.go b/x/create/client/create.go index 5ada0c23b..c5962fdd7 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -560,6 +560,9 @@ func getParserName(modelDir string) string { if strings.Contains(archLower, "deepseek") { return "deepseek3" } + if strings.Contains(archLower, "gemma4") { + return "gemma4" + } if strings.Contains(archLower, "qwen3") { return "qwen3" } @@ -574,6 +577,9 @@ func getParserName(modelDir string) string { if strings.Contains(typeLower, "deepseek") { return "deepseek3" } + if strings.Contains(typeLower, "gemma4") { + return "gemma4" + } if strings.Contains(typeLower, "qwen3") { return "qwen3" } @@ -602,6 +608,9 @@ func getRendererName(modelDir string) string { // Check architectures for known renderers for _, arch := range cfg.Architectures { archLower := strings.ToLower(arch) + if strings.Contains(archLower, "gemma4") { + return "gemma4" + } if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { return "glm-4.7" } @@ -616,6 +625,9 @@ func getRendererName(modelDir string) string { // Also check model_type if cfg.ModelType != "" { typeLower := strings.ToLower(cfg.ModelType) + if strings.Contains(typeLower, "gemma4") { + return "gemma4" + } if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { return "glm-4.7" } diff --git a/x/create/create.go b/x/create/create.go index 88747ebed..54beed3ee 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -634,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{ "Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform, "Qwen3NextMoeForCausalLM": newQwen35ImportTransform, "Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform, + "Gemma4ForCausalLM": newGemma4ImportTransform, + "Gemma4ForConditionalGeneration": newGemma4ImportTransform, } func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) { diff --git a/x/create/gemma4.go b/x/create/gemma4.go new file mode 100644 index 000000000..35e920077 --- /dev/null +++ b/x/create/gemma4.go @@ -0,0 +1,264 @@ +package create + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/ollama/ollama/x/safetensors" +) + +type gemma4ImportTransform struct { + numLayers int + numExperts int +} + +// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions. +type gemma4Config struct { + NumHiddenLayers int `json:"num_hidden_layers"` + NumExperts int `json:"num_experts"` + TextConfig struct { + NumHiddenLayers int `json:"num_hidden_layers"` + NumExperts int `json:"num_experts"` + } `json:"text_config"` +} + +func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) { + data, err := os.ReadFile(filepath.Join(modelDir, "config.json")) + if err != nil { + return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic + } + var cfg gemma4Config + if err := json.Unmarshal(data, &cfg); err != nil { + return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic + } + + numLayers := cfg.NumHiddenLayers + if numLayers == 0 { + numLayers = cfg.TextConfig.NumHiddenLayers + } + numExperts := cfg.NumExperts + if numExperts == 0 { + numExperts = cfg.TextConfig.NumExperts + } + + return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil +} + +func (t gemma4ImportTransform) skipTensor(name string) bool { + return false +} + +// layerIndexRe extracts the layer index from tensor names like +// "model.language_model.layers.5.self_attn.v_proj.weight" or +// "model.language_model.layers.5.moe.experts.42.down_proj.weight" +var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`) + +// useMoreBits returns true for layers where quantization-sensitive tensors +// should use higher precision: the first and last 1/8 of layers (which handle +// input grounding and final output refinement), plus every 3rd layer in between +// to limit error accumulation through the residual stream. +func useMoreBits(layerIdx, numLayers int) bool { + return layerIdx < numLayers/8 || + layerIdx >= 7*numLayers/8 || + (layerIdx-numLayers/8)%3 == 2 +} + +func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string { + quantNorm := normalizeQuantType(quantize) + + // Embedding: quantize to 8-bit variant for bandwidth efficiency. + // The embedding serves double duty: lookup (via QuantizedEmbedding) and + // lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality + // (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16. + if isEmbedTokensWeight(name) { + switch quantNorm { + case "int4", "int8": + if isAligned(shape, "int8") { + return "int8" + } + case "mxfp4", "nvfp4", "mxfp8": + if isAligned(shape, "mxfp8") { + return "mxfp8" + } + } + if isAligned(shape, quantNorm) { + return quantNorm + } + return "" + } + + // Mixed-precision quantization: sensitive tensors get higher precision. + // + // Value projections (v_proj) directly determine attention output quality. + // Down projections (down_proj) are the final MLP output and errors there + // propagate directly to the residual stream. Both benefit from higher + // precision at early layers, late layers, and periodically in between + // (the "useMoreBits" heuristic). + // + // For int4: promote → int8 (same affine family, GatherQMM compatible). + // For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed + // nvfp4+mxfp8 modes within the same model — each tensor carries its own + // quant metadata and the kernel dispatches per-tensor. + if t.numLayers > 0 { + layerIdx := -1 + if m := layerIndexRe.FindStringSubmatch(name); m != nil { + if idx, err := strconv.Atoi(m[1]); err == nil { + layerIdx = idx + } + } + + // Determine promotion target for sensitive tensors. + // "int8" = int4 base → int8 (affine family) + // "mxfp8" = mxfp4/nvfp4 base → mxfp8 + // "" = no promotion (int8/mxfp8, already 8-bit) + promote := "" + switch quantNorm { + case "int4": + promote = "int8" + case "mxfp4", "nvfp4": + promote = "mxfp8" + } + + // Only apply to language model tensors — audio/vision tower tensors + // should pass through to GetTensorQuantization which skips them. + isModelTensor := !strings.Contains(name, "audio_tower") && + !strings.Contains(name, "vision_tower") + isSensitive := isModelTensor && + (strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj")) + isSensitiveK := isModelTensor && strings.Contains(name, "k_proj") + + if promote != "" && (isSensitive || isSensitiveK) { + shouldPromote := false + + // 8-expert models: v_proj and k_proj share very few KV heads, + // so quantization errors are amplified. Always promote. + if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) { + shouldPromote = true + } + + // Layer-position heuristic for v_proj and down_proj. + if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) { + shouldPromote = true + } + + if shouldPromote && isAligned(shape, promote) { + return promote + } + + // Sensitive tensor at a non-promoted layer: use base quant type. + // Return directly to bypass GetTensorQuantization's uniform + // promotion — the layer-position heuristic is authoritative here. + if !isAligned(shape, quantNorm) { + return "" + } + return quantNorm + } + } + + return GetTensorQuantization(name, shape, quantize) +} + +// isEmbedTokensWeight returns true for the main token embedding weight. +func isEmbedTokensWeight(name string) bool { + return strings.HasSuffix(name, "embed_tokens.weight") && + !strings.Contains(name, "per_layer") +} + +func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { + if td == nil { + return nil, nil + } + + // Split pre-stacked MoE expert tensors [N, out, in] into per-expert + // [out, in] tensors so they go through the standard expert packing and + // quantization flow (ExpertGroupPrefix matching, per-expert quantize). + if isGemma4StackedMoETensor(td.Name, td.Shape) { + return splitStackedMoETensor(td) + } + + return []*safetensors.TensorData{td}, nil +} + +// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight. +// Gemma 4 HF weights come in two layouts depending on the model version: +// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2] +// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2] +// +// The newer layout has gate+up already fused. We keep it fused (no splitting) +// so the tensors flow through the standard expert packing and quantization path. +func isGemma4StackedMoETensor(name string, shape []int32) bool { + if len(shape) != 3 { + return false + } + if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") { + return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight") + } + return false +} + +// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into +// N individual [out, in] tensors named with the per-expert convention that +// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight +func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { + raw, err := io.ReadAll(td.Reader()) + if err != nil { + return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err) + } + + numExperts := int(td.Shape[0]) + rows := int(td.Shape[1]) // out_features in HF layout + cols := int(td.Shape[2]) // in_features in HF layout + + elemSize, err := DTypeSize(td.Dtype) + if err != nil { + return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err) + } + + perExpertBytes := rows * cols * elemSize + if len(raw) != numExperts*perExpertBytes { + return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s", + td.Name, len(raw), td.Shape, td.Dtype) + } + + // Determine the per-expert name pattern. + // Two source layouts: + // Old: model.language_model.layers.N.moe.gate_proj + // -> model.language_model.layers.N.moe.experts.E.gate_proj.weight + // New: model.language_model.layers.N.experts.gate_up_proj + // -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight + baseName := td.Name + baseName = strings.TrimSuffix(baseName, ".weight") + lastDot := strings.LastIndex(baseName, ".") + if lastDot < 0 { + return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name) + } + parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts" + projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj" + + // Normalize: if parent already ends with ".experts", use the grandparent + ".moe" + // so we get a consistent "layers.N.moe.experts.E" pattern. + var moePrefix string + if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok { + moePrefix = cut + ".moe" + } else { + moePrefix = parentPrefix + } + + transposedShape := []int32{td.Shape[1], td.Shape[2]} + + results := make([]*safetensors.TensorData, numExperts) + for e := range numExperts { + expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName) + start := e * perExpertBytes + end := start + perExpertBytes + results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end]) + } + + return results, nil +} diff --git a/x/create/gemma4_test.go b/x/create/gemma4_test.go new file mode 100644 index 000000000..40b183162 --- /dev/null +++ b/x/create/gemma4_test.go @@ -0,0 +1,191 @@ +package create + +import ( + "testing" +) + +func TestGemma4QuantizationType(t *testing.T) { + // 26B MoE: 30 layers, 128 experts + transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128} + // 8-expert model (hypothetical) + transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8} + + aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4) + + tests := []struct { + name string + transform gemma4ImportTransform + tensor string + shape []int32 + quantize string + want string + }{ + // === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) === + {"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"}, + {"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"}, + {"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"}, + {"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"}, + {"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"}, + + // === v_proj: layer-position heuristic for int4/nvfp4 === + // Layer 0 is in first 1/8 (30/8=3) → promoted + {"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"}, + // Layer 4 is NOT in useMoreBits → base quant + {"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"}, + // Layer 29 is in last 1/8 → promoted + {"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"}, + // nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul) + {"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"}, + {"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"}, + // mxfp4: promoted to mxfp8 at promoted layers (same mxfp family) + {"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"}, + {"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"}, + // int8/mxfp8: no promotion (already 8-bit) + {"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"}, + {"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"}, + + // === down_proj (dense MLP): same heuristic as v_proj === + {"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"}, + {"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"}, + {"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"}, + {"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"}, + {"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"}, + {"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"}, + + // === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers === + {"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"}, + {"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"}, + // nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment, + // so GatherQMM sees uniform quant per projection per layer) + {"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"}, + {"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"}, + // mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible) + {"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"}, + {"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"}, + + // === Expert gate_up_proj: always base quant (not a sensitive tensor) === + {"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"}, + {"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"}, + {"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"}, + + // === k_proj: promoted only for 8-expert models === + {"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"}, + {"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"}, + {"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"}, + {"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"}, + + // === q_proj, o_proj, gate_proj, up_proj: always base quant === + {"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"}, + {"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"}, + {"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"}, + {"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"}, + + // === Non-quantizable tensors: always bf16 === + {"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""}, + {"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""}, + {"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""}, + + // === Audio/vision tower tensors: must pass through unquantized for all quant types === + // These contain .v_proj and down_proj but should NOT be intercepted by + // the sensitive-tensor promotion logic. + {"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""}, + {"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""}, + {"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""}, + {"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""}, + {"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""}, + {"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""}, + {"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""}, + {"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""}, + // Audio tower v_proj — must NOT be promoted despite containing .v_proj + {"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""}, + {"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""}, + // Vision tower v_proj — vision tower IS quantized (unlike audio tower), + // but not intercepted by gemma4's layer-position heuristic. + // Falls through to GetTensorQuantization which applies uniform promotion. + {"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"}, + {"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"}, + // Audio tower down_proj + {"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""}, + {"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize) + if got != tt.want { + t.Errorf("quantizationType(%q, %v, %q) = %q, want %q", + tt.tensor, tt.shape, tt.quantize, got, tt.want) + } + }) + } +} + +func TestUseMoreBits(t *testing.T) { + // 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29 + // In between: every 3rd from offset (i - n/8) % 3 == 2 + n := 30 + promoted := map[int]bool{} + for i := range n { + if useMoreBits(i, n) { + promoted[i] = true + } + } + + // First 1/8 (30/8 = 3): layers 0, 1, 2 + for _, i := range []int{0, 1, 2} { + if !promoted[i] { + t.Errorf("layer %d should be promoted (first 1/8)", i) + } + } + + // Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26) + for _, i := range []int{26, 27, 28, 29} { + if !promoted[i] { + t.Errorf("layer %d should be promoted (last 1/8)", i) + } + } + + // Some middle layers should NOT be promoted + for _, i := range []int{3, 4, 6, 7} { + if promoted[i] { + t.Errorf("layer %d should NOT be promoted", i) + } + } + + // Layer 5 should be promoted: (5 - 3) % 3 == 2 + if !promoted[5] { + t.Errorf("layer 5 should be promoted (periodic)") + } +} + +func TestIsGemma4StackedMoETensor(t *testing.T) { + tests := []struct { + label string + tensorName string + shape []int32 + want bool + }{ + // New-style: .experts.gate_up_proj + {"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true}, + {"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true}, + // Old-style: .moe.gate_proj + {"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true}, + {"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true}, + // Not stacked: 2D + {"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false}, + // Not expert + {"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false}, + // Not a projection + {"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false}, + } + + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + got := isGemma4StackedMoETensor(tt.tensorName, tt.shape) + if got != tt.want { + t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v", + tt.tensorName, tt.shape, got, tt.want) + } + }) + } +} diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index 6b6394d60..ea16e4ea2 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -2,6 +2,7 @@ package mlxrunner import ( _ "github.com/ollama/ollama/x/models/gemma3" + _ "github.com/ollama/ollama/x/models/gemma4" _ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/llama" _ "github.com/ollama/ollama/x/models/qwen3" diff --git a/x/models/gemma4/gemma4.go b/x/models/gemma4/gemma4.go new file mode 100644 index 000000000..90737f813 --- /dev/null +++ b/x/models/gemma4/gemma4.go @@ -0,0 +1,1514 @@ +// Package gemma4 provides the Gemma 4 text model implementation for MLX. +package gemma4 + +import ( + "encoding/json" + "fmt" + "math" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" +) + +func init() { + base.Register("Gemma4ForCausalLM", newModel) + base.Register("Gemma4ForConditionalGeneration", newModel) +} + +// Compile-time interface checks. +var _ base.Model = (*Model)(nil) + +// RopeParams holds per-layer-type RoPE settings. +type RopeParams struct { + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + RopeTheta float32 `json:"rope_theta"` + RopeType string `json:"rope_type"` +} + +// TextConfig holds configuration for the Gemma 4 text model. +type TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + GlobalHeadDim int32 `json:"global_head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + SlidingWindow int32 `json:"sliding_window"` + SlidingWindowPattern int32 `json:"sliding_window_pattern"` + LayerTypes []string `json:"layer_types"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + FinalLogitSoftcapping float32 `json:"final_logit_softcapping"` + UseDoubleWideMLP bool `json:"use_double_wide_mlp"` + NumKVSharedLayers int32 `json:"num_kv_shared_layers"` + HiddenSizePerLayer int32 `json:"hidden_size_per_layer_input"` + VocabSizePerLayer int32 `json:"vocab_size_per_layer_input"` + AttentionKEqV bool `json:"attention_k_eq_v"` + NumGlobalKeyValueHeads int32 `json:"num_global_key_value_heads"` + EnableMoeBlock bool `json:"enable_moe_block"` + NumExperts int32 `json:"num_experts"` + TopKExperts int32 `json:"top_k_experts"` + ExpertIntermediateSize int32 `json:"moe_intermediate_size"` + RopeParameters map[string]*RopeParams `json:"rope_parameters"` + ImageTokenIDValue int32 `json:"image_token_id"` + + // Quantization parameters. + QuantGroupSize int `json:"-"` + QuantBits int `json:"-"` + QuantMode string `json:"-"` + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` + + // Computed fields. + SlidingScale float32 `json:"-"` // 1/sqrt(HeadDim) for sliding layers + FullScale float32 `json:"-"` // 1/sqrt(GlobalHeadDim) for full layers + SlidingRopeDims int `json:"-"` // HeadDim (full rotation for sliding) + FullRopeDims int `json:"-"` // GlobalHeadDim (partial rotation via custom freqs) + SlidingRopeBase float32 `json:"-"` + FullRopeBase float32 `json:"-"` + FullRopeFreqs *mlx.Array `json:"-"` // Precomputed proportional RoPE frequencies + + // Precomputed scale factors (avoid per-forward math.Sqrt/Pow). + EmbedScale float32 `json:"-"` // sqrt(hidden_size) + PLEScale float32 `json:"-"` // sqrt(hidden_size_per_layer_input) + PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size) + PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071... + RouterScale float32 `json:"-"` // 1/sqrt(hidden_size) + SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping + + // KV sharing: maps shared layer index -> donor layer index. + KVShareMap map[int32]int32 `json:"-"` + // Set of donor layer indices that need to store their KV. + KVDonors map[int32]bool `json:"-"` +} + +// sharedKVEntry stores cached KV state from a donor layer for KV sharing. +type sharedKVEntry struct { + K, V *mlx.Array + Offset int // RoPE offset from donor's cache +} + +// Attention implements Gemma 4 attention with Q/K normalization and v-norm. +type Attention struct { + QProj nn.LinearLayer + KProj nn.LinearLayer + VProj nn.LinearLayer + OProj nn.LinearLayer + + QNorm *nn.RMSNorm + KNorm *nn.RMSNorm + + // Norm weight for Q/K RMSNorm. + QNormScaled *mlx.Array + KNormScaled *mlx.Array +} + +// MLP is the feed-forward network with GELU activation. +type MLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +// stackedExpertResult holds the result of collecting and stacking per-expert weights. +type stackedExpertResult struct { + Weight *mlx.Array + Scales *mlx.Array + Biases *mlx.Array + Bits int + GroupSize int + Mode string +} + +// firstNonNil returns the first non-nil tensor found under any of the given keys. +func firstNonNil(tensors map[string]*mlx.Array, keys ...string) *mlx.Array { + for _, k := range keys { + if t := tensors[k]; t != nil { + return t + } + } + return nil +} + +// buildCausalMaskWindow creates a [1, 1, Q, K] additive causal mask with an +// optional sliding-window restriction. When window > 0, positions where +// kv < absQ - window + 1 are also masked (the token can only see the most +// recent `window` keys). When window == 0, only the causal constraint applies. +// +// This is the prefill-time mask for sliding-window attention layers. Without +// the window restriction, a sliding layer would attend to the entire prefix +// during prompt processing and diverge from the reference starting at the +// first position past the window boundary. +func buildCausalMaskWindow(Q, K, window int32) *mlx.Array { + offset := K - Q // cache offset: kv positions before the current query chunk + vals := make([]float32, Q*K) + negInf := float32(math.Inf(-1)) + for q := range Q { + absQ := offset + q + var lo int32 + if window > 0 { + lo = max(absQ-window+1, 0) + } + for kv := range K { + if kv > absQ || kv < lo { + vals[q*K+kv] = negInf + } + } + } + return mlx.FromValues(vals, 1, 1, int(Q), int(K)) +} + +// sliceAxis1 slices a tensor along axis 1: a[:, start:stop, ...]. +func sliceAxis1(a *mlx.Array, start, stop int32) *mlx.Array { + dims := a.Dims() + beg := make([]int32, len(dims)) + end := make([]int32, len(dims)) + for i, d := range dims { + end[i] = int32(d) + } + beg[1] = start + end[1] = stop + return mlx.SliceStartStop(a, beg, end) +} + +// transposeForGatherMM transposes stacked expert weights from [experts, out, in] +// to [experts, in, out] for use with GatherMM (which computes a @ b[group]). +func transposeForGatherMM(w *mlx.Array) *mlx.Array { + if w == nil || !w.Valid() || w.NumDims() != 3 { + return w + } + t := mlx.Transpose(w, 0, 2, 1).Clone() + mlx.Eval(t) + return t +} + +// collectExpertProjection collects per-expert tensors, stacks them, and +// optionally keeps quantized weight/scale/bias for GatherQMM. +// prefix: e.g. "model.language_model.layers.0.moe.experts" +// proj: e.g. "gate_proj" +func collectExpertProjection(tensors map[string]*mlx.Array, cfg *TextConfig, prefix, proj string, numExperts int32) *stackedExpertResult { + weights := make([]*mlx.Array, 0, numExperts) + scales := make([]*mlx.Array, 0, numExperts) + biases := make([]*mlx.Array, 0, numExperts) + bits, groupSize := 0, 0 + mode := cfg.QuantMode + + for e := range numExperts { + // Try "prefix.E.proj.weight" then "prefix.E.proj" + base := fmt.Sprintf("%s.%d.%s", prefix, e, proj) + w := tensors[base+".weight"] + key := base + ".weight" + if w == nil { + w = tensors[base] + key = base + } + if w == nil { + return nil + } + + s := tensors[key+"_scale"] + if s == nil { + weights = append(weights, w) + continue + } + qb := tensors[key+"_qbias"] + gs, b, m := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, + cfg.TensorQuant, key, w, s, + ) + if bits == 0 { + bits = b + groupSize = gs + mode = m + } + // Keep quantized weights for GatherQMM (supports affine, nvfp4, mxfp8). + weights = append(weights, w) + scales = append(scales, s) + if qb != nil { + biases = append(biases, qb) + } + } + + if len(weights) == 0 { + return nil + } + + stacked := mlx.Stack(weights, 0).Clone() + mlx.Eval(stacked) + out := &stackedExpertResult{Weight: stacked, Bits: bits, GroupSize: groupSize, Mode: mode} + if len(scales) == len(weights) { + out.Scales = mlx.Stack(scales, 0).Clone() + mlx.Eval(out.Scales) + } + if len(biases) == len(weights) { + out.Biases = mlx.Stack(biases, 0).Clone() + mlx.Eval(out.Biases) + } + return out +} + +// Router implements Gemma 4's expert routing mechanism. +type Router struct { + Proj nn.LinearLayer // [hidden_size -> num_experts] + Scale *mlx.Array // learnable scale [hidden_size] +} + +// MoEBlock implements the Gemma 4 mixture-of-experts block. +// Uses GatherQMM for quantized weights, GatherMM for dense. +type MoEBlock struct { + // Dense expert weights for GatherMM (used when not quantized). + GateUpWeight *mlx.Array // [num_experts, 2*intermediate, hidden] (fused gate+up) + GateWeight *mlx.Array // [num_experts, hidden_size, expert_intermediate_size] + UpWeight *mlx.Array // [num_experts, hidden_size, expert_intermediate_size] + DownWeight *mlx.Array // [num_experts, expert_intermediate_size, hidden_size] + + // Quantized expert weights for GatherQMM. + GateUpWeightQ, GateUpScales, GateUpBiases *mlx.Array // fused gate+up + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + + PerExpertScale *mlx.Array // [num_experts] + UseQuantized bool + UseFusedGateUp bool // true when gate+up are stored as single tensor + + // Per-projection quant params (may differ due to mixed-precision). + GateUpGroupSize, GateUpBits int + GateGroupSize, UpGroupSize int + DownGroupSize int + GateBits, UpBits, DownBits int + QuantMode string // gate/up mode + DownQuantMode string // down mode (may differ for mixed mxfp4/mxfp8) +} + +// PLELayer holds per-layer PLE weights for a single decoder layer. +type PLELayer struct { + InputGate nn.LinearLayer // [hidden_size -> ple_dim] + Projection nn.LinearLayer // [ple_dim -> hidden_size] + PostNorm *nn.RMSNorm + + // Norm weight for post-norm. + PostNormScaled *mlx.Array +} + +// DecoderLayer is a single transformer block. +type DecoderLayer struct { + InputNorm *nn.RMSNorm + Attention *Attention + PostAttnNorm *nn.RMSNorm + PreFFNorm *nn.RMSNorm + MLP *MLP + PostFFNorm *nn.RMSNorm + + // PLE per-layer components (nil if no PLE). + PLE *PLELayer + + // MoE components (nil if no MoE). + Router *Router + MoE *MoEBlock + + // Additional norms for MoE dual-path (nil if no MoE). + PostFFNorm1 *nn.RMSNorm // post-norm for dense MLP path + PostFFNorm2 *nn.RMSNorm // post-norm for MoE path + PreFFNorm2 *nn.RMSNorm // pre-norm for MoE input + + // Norm weight for RMSNorm. + InputNormScaled *mlx.Array + PostAttnNormScaled *mlx.Array + PreFFNormScaled *mlx.Array + PostFFNormScaled *mlx.Array + + // Norm weight for MoE norms. + PostFFNorm1Scaled *mlx.Array + PostFFNorm2Scaled *mlx.Array + PreFFNorm2Scaled *mlx.Array + + // Layer scalar for full-attention layers (nil for sliding). + LayerScalar *mlx.Array + + // Layer metadata. + IsSliding bool + LayerIdx int32 + KVShareDonor int32 // -1 if not shared, else index of donor layer + IsDonor bool // true if this layer's KV is shared by later layers +} + +// Model is the Gemma 4 model (text + optional vision). +type Model struct { + EmbedTokens nn.EmbeddingLayer + Layers []*DecoderLayer + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + // PLE model-level components (nil if no PLE). + EmbedTokensPerLayer nn.EmbeddingLayer + PerLayerModelProj nn.LinearLayer + PerLayerProjNorm *nn.RMSNorm + + // Precomputed scaled weights. + NormScaled *mlx.Array + PerLayerProjNormWeight *mlx.Array + + tok *tokenizer.Tokenizer + *TextConfig + + weightPrefix string +} + +func parseTextConfig(configData []byte) (TextConfig, error) { + var cfg TextConfig + if err := json.Unmarshal(configData, &cfg); err != nil { + return TextConfig{}, fmt.Errorf("parse config: %w", err) + } + + var wrapped struct { + TextConfig *TextConfig `json:"text_config"` + } + if err := json.Unmarshal(configData, &wrapped); err != nil { + return TextConfig{}, fmt.Errorf("parse nested text config: %w", err) + } + + if wrapped.TextConfig != nil { + cfg = *wrapped.TextConfig + } + + // Apply defaults. + if cfg.HeadDim == 0 { + cfg.HeadDim = 256 + } + if cfg.GlobalHeadDim == 0 { + cfg.GlobalHeadDim = cfg.HeadDim + } + if cfg.NumAttentionHeads == 0 { + cfg.NumAttentionHeads = 8 + } + if cfg.NumKeyValueHeads == 0 { + cfg.NumKeyValueHeads = 1 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 262144 + } + if cfg.SlidingWindowPattern <= 0 && len(cfg.LayerTypes) == 0 { + cfg.SlidingWindowPattern = 5 + } + if cfg.MaxPositionEmbeddings == 0 { + cfg.MaxPositionEmbeddings = 131072 + } + + // Gemma 4 uses scaling=1.0 (no 1/sqrt(head_dim) scaling); the Q/K norms + // handle magnitude control. This differs from Gemma 3 which uses + // query_pre_attn_scalar^(-0.5). + cfg.SlidingScale = 1.0 + cfg.FullScale = 1.0 + + // Compute RoPE settings from rope_parameters. + cfg.SlidingRopeDims = int(cfg.HeadDim) // full rotation for sliding + cfg.SlidingRopeBase = 10000 + cfg.FullRopeDims = int(cfg.HeadDim) // default: full rotation + cfg.FullRopeBase = 1000000 + + if rp := cfg.RopeParameters; rp != nil { + if sp := rp["sliding_attention"]; sp != nil && sp.RopeTheta > 0 { + cfg.SlidingRopeBase = sp.RopeTheta + } + if fp := rp["full_attention"]; fp != nil { + if fp.RopeTheta > 0 { + cfg.FullRopeBase = fp.RopeTheta + } + if fp.PartialRotaryFactor > 0 { + // Proportional RoPE: the reference computes inv_freq with divisor + // global_head_dim, then applies rotate_half which splits at head_dim/2. + // MLX fast_rope splits at dims/2, so we use dims=global_head_dim + // and pass custom frequencies that match the reference formula. + // Non-rotated dims use 1e10 so MLX reciprocals to ~0 (identity). + ghd := int(cfg.GlobalHeadDim) + cfg.FullRopeDims = ghd + halfDim := ghd / 2 + ropeAngles := int(fp.PartialRotaryFactor * float32(ghd) / 2) + freqs := make([]float32, halfDim) + for i := range ropeAngles { + freqs[i] = float32(math.Pow(float64(cfg.FullRopeBase), float64(2*i)/float64(ghd))) + } + for i := ropeAngles; i < halfDim; i++ { + freqs[i] = 1e10 + } + cfg.FullRopeFreqs = mlx.FromValues(freqs, halfDim) + mlx.Eval(cfg.FullRopeFreqs) + } + } + } + + // Precompute constant scale factors used in forward pass. + cfg.EmbedScale = float32(math.Sqrt(float64(cfg.HiddenSize))) + if cfg.HiddenSizePerLayer > 0 { + cfg.PLEScale = float32(math.Sqrt(float64(cfg.HiddenSizePerLayer))) + cfg.PLEProjScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize))) + cfg.PLECombineScale = float32(math.Pow(2.0, -0.5)) + } + cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize))) + if cfg.FinalLogitSoftcapping > 0 { + cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping + } + + // Compute KV sharing map. + cfg.KVShareMap = make(map[int32]int32) + cfg.KVDonors = make(map[int32]bool) + if cfg.NumKVSharedLayers > 0 && len(cfg.LayerTypes) > 0 { + firstShared := cfg.NumHiddenLayers - cfg.NumKVSharedLayers + prevLayers := cfg.LayerTypes[:firstShared] + + for i := firstShared; i < cfg.NumHiddenLayers; i++ { + layerType := cfg.LayerTypes[i] + // Find the last non-shared layer of the same type. + donor := int32(-1) + for j := len(prevLayers) - 1; j >= 0; j-- { + if prevLayers[j] == layerType { + donor = int32(j) + break + } + } + if donor >= 0 { + cfg.KVShareMap[i] = donor + cfg.KVDonors[donor] = true + } + } + } + + return cfg, nil +} + +func (m *Model) EnableCompile() bool { + return true +} + +func resolveWeightPrefix(tensors map[string]*mlx.Array) string { + for _, prefix := range []string{"", "language_model.", "model.language_model."} { + if tensors[prefix+"embed_tokens.weight"] != nil { + return prefix + } + } + // Also try with "model." before the layer path. + for _, prefix := range []string{"", "language_model.", "model.language_model."} { + if tensors[prefix+"model.embed_tokens.weight"] != nil { + return prefix + "model." + } + } + return "" +} + +func isLayerSliding(layerIdx int32, cfg *TextConfig) bool { + if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) { + return cfg.LayerTypes[layerIdx] == "sliding_attention" + } + if cfg.SlidingWindowPattern <= 0 { + return false + } + return (layerIdx+1)%cfg.SlidingWindowPattern != 0 +} + +// precomputeGemmaScaledWeights assigns raw norm weights to the *Scaled fields. +// Gemma 4 uses scale_shift=0.0 for all norms (no +1.0 adjustment), so the +// precomputed weights are just the raw weights from the model. +func precomputeGemmaScaledWeights(m *Model) { + if m.Norm != nil { + m.NormScaled = m.Norm.Weight + } + + if m.PerLayerProjNorm != nil { + m.PerLayerProjNormWeight = m.PerLayerProjNorm.Weight + } + + for _, layer := range m.Layers { + if layer == nil || layer.Attention == nil { + continue + } + + if layer.InputNorm != nil { + layer.InputNormScaled = layer.InputNorm.Weight + } + if layer.PostAttnNorm != nil { + layer.PostAttnNormScaled = layer.PostAttnNorm.Weight + } + if layer.PreFFNorm != nil { + layer.PreFFNormScaled = layer.PreFFNorm.Weight + } + if layer.PostFFNorm != nil { + layer.PostFFNormScaled = layer.PostFFNorm.Weight + } + if layer.Attention.QNorm != nil { + layer.Attention.QNormScaled = layer.Attention.QNorm.Weight + } + if layer.Attention.KNorm != nil { + layer.Attention.KNormScaled = layer.Attention.KNorm.Weight + } + if layer.PLE != nil && layer.PLE.PostNorm != nil { + layer.PLE.PostNormScaled = layer.PLE.PostNorm.Weight + } + if layer.PostFFNorm1 != nil { + layer.PostFFNorm1Scaled = layer.PostFFNorm1.Weight + } + if layer.PostFFNorm2 != nil { + layer.PostFFNorm2Scaled = layer.PostFFNorm2.Weight + } + if layer.PreFFNorm2 != nil { + layer.PreFFNorm2Scaled = layer.PreFFNorm2.Weight + } + } +} + +func newModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + cfg, err := parseTextConfig(configData) + if err != nil { + return nil, err + } + + if qt := root.QuantType(); qt != "" { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs + } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") + } + cfg.TensorQuant = root.AllTensorQuant() + + tokData, err := root.Manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData} + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), + TextConfig: &cfg, + tok: tok, + } + + for i := range m.Layers { + donor, isShared := cfg.KVShareMap[int32(i)] + if !isShared { + donor = -1 + } + m.Layers[i] = &DecoderLayer{ + LayerIdx: int32(i), + IsSliding: isLayerSliding(int32(i), m.TextConfig), + KVShareDonor: donor, + IsDonor: cfg.KVDonors[int32(i)], + } + } + + return m, nil +} + +// LoadWeights receives all tensors loaded from the manifest and assigns them +// to model fields. +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + m.weightPrefix = resolveWeightPrefix(tensors) + prefix := m.weightPrefix + linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + + // Embeddings. + embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if embedTokens == nil { + return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", prefix) + } + m.EmbedTokens = embedTokens + + // Final norm. + normWeight := tensors[prefix+"norm.weight"] + if normWeight == nil { + return fmt.Errorf("missing final norm weight: %snorm.weight", prefix) + } + m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps) + + // LM head. + if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { + m.LMHead = lmHead + } else if lmHead := linears.Make("lm_head"); lmHead != nil { + m.LMHead = lmHead + } else { + // Gemma 4 ties output projection to embeddings. + m.LMHead = m.EmbedTokens.AsLinear() + } + + // PLE model-level weights. + if m.HiddenSizePerLayer > 0 { + pleEmbed := model.MakeEmbeddingLayer(tensors, prefix+"embed_tokens_per_layer", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if pleEmbed == nil { + return fmt.Errorf("missing PLE embedding weight") + } + m.EmbedTokensPerLayer = pleEmbed + + m.PerLayerModelProj = linears.Make(prefix + "per_layer_model_projection") + if m.PerLayerModelProj == nil { + return fmt.Errorf("missing per_layer_model_projection weight") + } + + pleProjNormWeight := tensors[prefix+"per_layer_projection_norm.weight"] + if pleProjNormWeight == nil { + return fmt.Errorf("missing per_layer_projection_norm weight") + } + m.PerLayerProjNorm = nn.NewRMSNorm(pleProjNormWeight, m.RMSNormEps) + } + + // Decoder layers. + for i := range m.NumHiddenLayers { + layerPrefix := fmt.Sprintf("%slayers.%d", prefix, i) + isSliding := isLayerSliding(i, m.TextConfig) + + donor, isShared := m.KVShareMap[i] + if !isShared { + donor = -1 + } + + layer := &DecoderLayer{ + LayerIdx: i, + IsSliding: isSliding, + KVShareDonor: donor, + IsDonor: m.KVDonors[i], + Attention: &Attention{}, + MLP: &MLP{}, + } + + // Norms. + if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { + layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { + layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil { + layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil { + layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + // Attention projections. + layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") + layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") + layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") + layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") + + if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { + layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil { + layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + // MLP. + layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj") + layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj") + layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj") + + // Layer scalar (all layers in new weights, was full-attention only in earlier releases). + if w := tensors[layerPrefix+".layer_scalar"]; w != nil { + layer.LayerScalar = w + } + + // MoE components. + if m.EnableMoeBlock { + // Router. + routerProj := linears.Make(layerPrefix + ".router.proj") + // Raw safetensors uses ".router.scale"; runner.go remaps to "_scale". + routerScale := tensors[layerPrefix+".router.scale"] + if routerScale == nil { + routerScale = tensors[layerPrefix+".router_scale"] + } + if routerProj == nil || routerScale == nil { + return fmt.Errorf("layer %d: missing router weights", i) + } + layer.Router = &Router{ + Proj: routerProj, + Scale: routerScale, + } + + // MoE expert weights. Try pre-stacked (BF16 from HF) first, + // then per-expert (from quantized create path). + perExpertScale := tensors[layerPrefix+".router.per_expert_scale"] + if perExpertScale == nil { + perExpertScale = tensors[layerPrefix+".moe.per_expert_scale"] + } + if perExpertScale == nil { + return fmt.Errorf("layer %d: missing MoE per_expert_scale", i) + } + + moe := &MoEBlock{PerExpertScale: perExpertScale} + + // Check for pre-stacked tensors (unquantized HF format). + // Try .experts. first (new weight drop), fall back to .moe. (old format). + gateUpW := tensors[layerPrefix+".experts.gate_up_proj"] + if gateUpW == nil { + gateUpW = tensors[layerPrefix+".moe.gate_up_proj"] + } + gateW := tensors[layerPrefix+".experts.gate_proj"] + if gateW == nil { + gateW = tensors[layerPrefix+".moe.gate_proj"] + } + if gateUpW != nil { + // Fused gate+up: split along dim 1, transpose for GatherMM. + dims := gateUpW.Dims() + half := int32(dims[1] / 2) + gateSlice := sliceAxis1(gateUpW, 0, half) + upSlice := sliceAxis1(gateUpW, half, int32(dims[1])) + moe.GateWeight = transposeForGatherMM(gateSlice) + moe.UpWeight = transposeForGatherMM(upSlice) + downW := tensors[layerPrefix+".experts.down_proj"] + if downW == nil { + downW = tensors[layerPrefix+".moe.down_proj"] + } + if downW == nil { + return fmt.Errorf("layer %d: missing MoE down_proj with fused gate_up_proj", i) + } + moe.DownWeight = transposeForGatherMM(downW) + } else if gateW != nil { + // Separate gate_proj and up_proj (older format). Transpose for GatherMM. + moe.GateWeight = transposeForGatherMM(gateW) + upW := tensors[layerPrefix+".experts.up_proj"] + if upW == nil { + upW = tensors[layerPrefix+".moe.up_proj"] + } + downW := tensors[layerPrefix+".experts.down_proj"] + if downW == nil { + downW = tensors[layerPrefix+".moe.down_proj"] + } + moe.UpWeight = transposeForGatherMM(upW) + moe.DownWeight = transposeForGatherMM(downW) + if moe.UpWeight == nil || moe.DownWeight == nil { + return fmt.Errorf("layer %d: incomplete pre-stacked MoE weights", i) + } + } else if switchGateUp := firstNonNil(tensors, + layerPrefix+".moe.switch_mlp.gate_up_proj.weight", + layerPrefix+".moe.switch_mlp.gate_up_proj"); switchGateUp != nil { + // Stacked switch_mlp format (from create pipeline with expert packing). + switchDown := firstNonNil(tensors, + layerPrefix+".moe.switch_mlp.down_proj.weight", + layerPrefix+".moe.switch_mlp.down_proj") + if switchDown == nil { + return fmt.Errorf("layer %d: missing switch_mlp down_proj", i) + } + + // Check for quantized weights (scales present). + // The scale key depends on whether the tensor has .weight suffix. + gateUpKey := layerPrefix + ".moe.switch_mlp.gate_up_proj.weight" + if tensors[gateUpKey] == nil { + gateUpKey = layerPrefix + ".moe.switch_mlp.gate_up_proj" + } + downKey := layerPrefix + ".moe.switch_mlp.down_proj.weight" + if tensors[downKey] == nil { + downKey = layerPrefix + ".moe.switch_mlp.down_proj" + } + gateUpScales := firstNonNil(tensors, gateUpKey+"_scale", gateUpKey+".scale") + downScales := firstNonNil(tensors, downKey+"_scale", downKey+".scale") + + if gateUpScales != nil && downScales != nil { + // Quantized: keep fused gate_up as single tensor for GatherQMM. + // One fused call instead of two separate gate+up calls. + gateUpBiases := firstNonNil(tensors, gateUpKey+"_qbias", gateUpKey+".bias") + downBiases := firstNonNil(tensors, downKey+"_qbias", downKey+".bias") + + moe.GateUpWeightQ = switchGateUp + moe.GateUpScales = gateUpScales + moe.GateUpBiases = gateUpBiases + moe.DownWeightQ = switchDown + moe.DownScales = downScales + if downBiases != nil { + moe.DownBiases = downBiases + } + + groupSize, bits, mode := model.ResolveLinearQuantParams( + m.QuantGroupSize, m.QuantBits, m.QuantMode, + m.TensorQuant, gateUpKey, switchGateUp, gateUpScales, + ) + moe.UseQuantized = true + moe.UseFusedGateUp = true + moe.GateUpGroupSize = groupSize + moe.GateUpBits = bits + moe.QuantMode = mode + + dGroupSize, dBits, dMode := model.ResolveLinearQuantParams( + m.QuantGroupSize, m.QuantBits, m.QuantMode, + m.TensorQuant, downKey, switchDown, downScales, + ) + moe.DownGroupSize = dGroupSize + moe.DownBits = dBits + moe.DownQuantMode = dMode + } else { + // Unquantized switch_mlp: keep fused and transpose for GatherMM. + moe.GateUpWeight = transposeForGatherMM(switchGateUp) + moe.UseFusedGateUp = true + moe.DownWeight = transposeForGatherMM(switchDown) + } + } else { + // Per-expert tensors (from create path). + // Try separate gate_proj/up_proj first, then fused gate_up_proj. + gateStacked := collectExpertProjection(tensors, m.TextConfig, + layerPrefix+".moe.experts", "gate_proj", m.NumExperts) + upStacked := collectExpertProjection(tensors, m.TextConfig, + layerPrefix+".moe.experts", "up_proj", m.NumExperts) + downStacked := collectExpertProjection(tensors, m.TextConfig, + layerPrefix+".moe.experts", "down_proj", m.NumExperts) + + if gateStacked == nil && upStacked == nil { + // Try fused gate_up_proj format — split along axis 1 (out-dim). + // For quantized weights, also split scales and biases. + gateUpStacked := collectExpertProjection(tensors, m.TextConfig, + layerPrefix+".moe.experts", "gate_up_proj", m.NumExperts) + if gateUpStacked != nil { + dims := gateUpStacked.Weight.Dims() + if len(dims) >= 2 { + mid := int32(dims[1] / 2) + gateStacked = &stackedExpertResult{ + Weight: sliceAxis1(gateUpStacked.Weight, 0, mid), + Bits: gateUpStacked.Bits, + GroupSize: gateUpStacked.GroupSize, + Mode: gateUpStacked.Mode, + } + upStacked = &stackedExpertResult{ + Weight: sliceAxis1(gateUpStacked.Weight, mid, int32(dims[1])), + Bits: gateUpStacked.Bits, + GroupSize: gateUpStacked.GroupSize, + Mode: gateUpStacked.Mode, + } + if gateUpStacked.Scales != nil { + sDims := gateUpStacked.Scales.Dims() + sMid := int32(sDims[1] / 2) + gateStacked.Scales = sliceAxis1(gateUpStacked.Scales, 0, sMid) + upStacked.Scales = sliceAxis1(gateUpStacked.Scales, sMid, int32(sDims[1])) + } + if gateUpStacked.Biases != nil { + bDims := gateUpStacked.Biases.Dims() + bMid := int32(bDims[1] / 2) + gateStacked.Biases = sliceAxis1(gateUpStacked.Biases, 0, bMid) + upStacked.Biases = sliceAxis1(gateUpStacked.Biases, bMid, int32(bDims[1])) + } + } + } + } + + if gateStacked == nil || upStacked == nil || downStacked == nil { + return fmt.Errorf("layer %d: missing MoE expert weights", i) + } + // Use GatherQMM if all projections have quantized weights. + if gateStacked.Scales != nil && upStacked.Scales != nil && downStacked.Scales != nil { + moe.GateWeightQ = gateStacked.Weight + moe.GateScales = gateStacked.Scales + moe.GateBiases = gateStacked.Biases + moe.UpWeightQ = upStacked.Weight + moe.UpScales = upStacked.Scales + moe.UpBiases = upStacked.Biases + moe.DownWeightQ = downStacked.Weight + moe.DownScales = downStacked.Scales + moe.DownBiases = downStacked.Biases + moe.UseQuantized = true + moe.GateGroupSize = gateStacked.GroupSize + moe.GateBits = gateStacked.Bits + moe.UpGroupSize = upStacked.GroupSize + moe.UpBits = upStacked.Bits + moe.DownGroupSize = downStacked.GroupSize + moe.DownBits = downStacked.Bits + moe.QuantMode = gateStacked.Mode + moe.DownQuantMode = downStacked.Mode + } else { + // Unquantized: transpose for GatherMM (expects [experts, in, out]). + moe.GateWeight = transposeForGatherMM(gateStacked.Weight) + moe.UpWeight = transposeForGatherMM(upStacked.Weight) + moe.DownWeight = transposeForGatherMM(downStacked.Weight) + } + } + layer.MoE = moe + + // Additional norms for MoE dual-path. + if w := tensors[layerPrefix+".post_feedforward_layernorm_1.weight"]; w != nil { + layer.PostFFNorm1 = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_feedforward_layernorm_2.weight"]; w != nil { + layer.PostFFNorm2 = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".pre_feedforward_layernorm_2.weight"]; w != nil { + layer.PreFFNorm2 = nn.NewRMSNorm(w, m.RMSNormEps) + } + + if layer.PostFFNorm1 == nil || layer.PostFFNorm2 == nil || layer.PreFFNorm2 == nil { + return fmt.Errorf("layer %d: missing MoE norm weights", i) + } + } + + // PLE per-layer weights. + if m.HiddenSizePerLayer > 0 { + layer.PLE = &PLELayer{} + layer.PLE.InputGate = linears.Make(layerPrefix + ".per_layer_input_gate") + layer.PLE.Projection = linears.Make(layerPrefix + ".per_layer_projection") + if w := tensors[layerPrefix+".post_per_layer_input_norm.weight"]; w != nil { + layer.PLE.PostNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + if layer.PLE.InputGate == nil || layer.PLE.Projection == nil || layer.PLE.PostNorm == nil { + return fmt.Errorf("layer %d: missing PLE weights", i) + } + } + + // Validation. + if layer.InputNorm == nil { + return fmt.Errorf("layer %d: missing input_layernorm", i) + } + if layer.PostAttnNorm == nil { + return fmt.Errorf("layer %d: missing post_attention_layernorm", i) + } + if layer.PreFFNorm == nil { + return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i) + } + if layer.PostFFNorm == nil { + return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i) + } + if layer.Attention.QProj == nil || layer.Attention.OProj == nil { + return fmt.Errorf("layer %d: missing attention q/o projections", i) + } + if layer.Attention.KProj == nil { + return fmt.Errorf("layer %d: missing attention k projection", i) + } + // VProj is nil for K=V full-attention layers (value_states = key_states). + useAltAttn := m.AttentionKEqV && !isSliding + if layer.Attention.VProj == nil && !useAltAttn { + return fmt.Errorf("layer %d: missing attention v projection", i) + } + if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil { + return fmt.Errorf("layer %d: missing attention q/k norms", i) + } + if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil { + return fmt.Errorf("layer %d: missing mlp projections", i) + } + + m.Layers[i] = layer + } + + precomputeGemmaScaledWeights(m) + if m.NormScaled == nil { + return fmt.Errorf("missing precomputed final norm weight") + } + + return nil +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + h := m.EmbedTokens.Forward(tokens) + h = mlx.MulScalar(h, m.EmbedScale) + + // Compute PLE inputs if configured. + var perLayerInputs *mlx.Array + if m.HiddenSizePerLayer > 0 && m.EmbedTokensPerLayer != nil { + perLayerInputs = m.computePLEInputs(tokens, h) + } + + var sharedKV map[int32]sharedKVEntry + if len(m.KVShareMap) > 0 { + sharedKV = make(map[int32]sharedKVEntry) + } + + // Per-forward-pass sliding-window mask cache. The first sliding layer + // populates it; every subsequent sliding layer reuses the cached mask. + // Before this, Attention.Forward rebuilt a ~L×kLen mask from scratch on + // the CPU for every sliding layer (~25 rebuilds on a 26B MoE prefill), + // which was the bulk of a ~28% prefill regression vs the pre-SWA-fix + // baseline. + var smc *slidingMaskCache + if L > 1 && m.SlidingWindow > 0 { + smc = &slidingMaskCache{} + } + + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil && i < len(caches) { + c = caches[i] + } + + // Extract per-layer PLE input for this layer. + var pleInput *mlx.Array + if perLayerInputs != nil { + pleInput = sliceLayerDim(perLayerInputs, int32(i), B, L, m.HiddenSizePerLayer) + } + + // Get shared KV for this layer if it's a shared layer. + var donorEntry *sharedKVEntry + if layer.KVShareDonor >= 0 { + if entry, ok := sharedKV[layer.KVShareDonor]; ok { + donorEntry = &entry + } + } + + h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc) + + // If this layer is a donor, store its cached KV for later shared layers. + if layer.IsDonor && c != nil { + state := c.State() + if len(state) >= 2 && state[0] != nil && state[1] != nil { + sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()} + } + } + } + + return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps) +} + +// slidingMaskCache is a per-forward-pass cache for the sliding-window +// additive mask used by Attention.Forward. All sliding-attention layers in +// a single forward pass see the same (L, kLen, window, dtype) tuple, so +// we can build the mask once on the first sliding layer's call and let +// every subsequent layer reuse it. The cache is instantiated fresh at the +// top of Model.Forward and passed through DecoderLayer → Attention; it is +// nil on paths where no caching is wanted (e.g. L == 1 decode). +type slidingMaskCache struct { + mask *mlx.Array + L int32 + kLen int32 + window int32 +} + +// get returns a cached mask matching (L, kLen, window, dtype), or builds +// and caches a new one. Safe to call with a nil receiver (falls through to +// a direct build without caching). +func (c *slidingMaskCache) get(L, kLen, window int32, dtype mlx.DType) *mlx.Array { + if c == nil { + return buildCausalMaskWindow(L, kLen, window).AsType(dtype) + } + if c.mask != nil && c.L == L && c.kLen == kLen && c.window == window { + return c.mask + } + c.mask = buildCausalMaskWindow(L, kLen, window).AsType(dtype) + c.L = L + c.kLen = kLen + c.window = window + return c.mask +} + +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + logits := m.LMHead.Forward(x) + + if m.FinalLogitSoftcapping > 0 { + logits = mlx.MulScalar(logits, m.SoftcapInv) + logits = logits.Tanh() + logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping) + } + + return logits +} + +func (m *Model) NumLayers() int { + return len(m.Layers) +} + +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +// NewCaches creates cache objects for layers that own KV state. +func (m *Model) NewCaches() []cache.Cache { + cacheLayers := len(m.Layers) + for i, layer := range m.Layers { + if layer.KVShareDonor >= 0 { + cacheLayers = i + break + } + } + + caches := make([]cache.Cache, cacheLayers) + for i, layer := range m.Layers[:cacheLayers] { + if m.SlidingWindow > 0 && layer.IsSliding { + caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} + +// computePLEInputs computes per-layer embeddings and projections. +// Returns shape [B, L, NumHiddenLayers, HiddenSizePerLayer]. +func (m *Model) computePLEInputs(tokens, h *mlx.Array) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + pleScale := m.PLEScale + projScale := m.PLEProjScale + + // Token-based per-layer embeddings: [B, L, NumLayers*PLEDim] + pleEmb := m.EmbedTokensPerLayer.Forward(tokens) + pleEmb = mlx.MulScalar(pleEmb, pleScale) + // Reshape to [B, L, NumLayers, PLEDim] + pleEmb = mlx.Reshape(pleEmb, B, L, m.NumHiddenLayers, m.HiddenSizePerLayer) + + // Hidden-state projection: [B, L, NumLayers*PLEDim] + pleProj := m.PerLayerModelProj.Forward(h) + pleProj = mlx.MulScalar(pleProj, projScale) + // Reshape to [B, L, NumLayers, PLEDim] + pleProj = mlx.Reshape(pleProj, B, L, m.NumHiddenLayers, m.HiddenSizePerLayer) + + // Apply per-layer projection norm (scale_shift=0.0, uses raw weight). + pleProj = mlx.RMSNormFn(pleProj, m.PerLayerProjNormWeight, m.RMSNormEps) + + // Combine: (proj + emb) * 2^(-0.5) + combined := mlx.Add(pleProj, pleEmb) + combined = mlx.MulScalar(combined, m.PLECombineScale) + + return combined +} + +// sliceLayerDim extracts a single layer's PLE input from the combined tensor. +// Input shape: [B, L, NumLayers, PLEDim], output shape: [B, L, PLEDim]. +func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array { + sliced := mlx.SliceStartStop(combined, + []int32{0, 0, layerIdx, 0}, + []int32{B, L, layerIdx + 1, pleDim}, + ) + return mlx.Squeeze(sliced, 2) +} + +func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array { + normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps) + attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache) + attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := mlx.Add(x, attnOut) + + if l.Router != nil && l.MoE != nil { + // Dual-path: dense MLP + MoE, both normed separately, then combined. + residual := h + + // Path 1: Dense MLP. + normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps) + mlpOut := l.MLP.Forward(normed) + mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNorm1Scaled, cfg.RMSNormEps) + + // Path 2: MoE. + scores, inds := l.Router.Forward(h, cfg) + normed2 := mlx.RMSNormFn(h, l.PreFFNorm2Scaled, cfg.RMSNormEps) + moeOut := l.MoE.Forward(normed2, scores, inds, cfg) + moeOut = mlx.RMSNormFn(moeOut, l.PostFFNorm2Scaled, cfg.RMSNormEps) + + // Combine and apply outer post-norm. + combined := mlx.Add(mlpOut, moeOut) + combined = mlx.RMSNormFn(combined, l.PostFFNormScaled, cfg.RMSNormEps) + h = mlx.Add(residual, combined) + } else { + // Standard single MLP path. + normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps) + mlpOut := l.MLP.Forward(normed) + mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + h = mlx.Add(h, mlpOut) + } + + // PLE injection (after MLP residual). + if l.PLE != nil && pleInput != nil { + residual := h + gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h)) + gated := mlx.Mul(gate, pleInput) + projected := l.PLE.Projection.Forward(gated) + projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps) + h = mlx.Add(residual, projected) + } + + // Layer scalar for full-attention layers. + if l.LayerScalar != nil { + h = mlx.Mul(h, l.LayerScalar) + } + + return h +} + +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array { + // Determine head dim and scale based on layer type. + headDim := cfg.HeadDim + scale := cfg.SlidingScale + ropeDims := cfg.SlidingRopeDims + ropeBase := cfg.SlidingRopeBase + if !isSliding { + headDim = cfg.GlobalHeadDim + scale = cfg.FullScale + ropeDims = cfg.FullRopeDims + ropeBase = cfg.FullRopeBase + } + + q := a.QProj.Forward(x) + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim) + q = mlx.Transpose(q, 0, 2, 1, 3) + + // Apply Q norm. + q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps) + + // RoPE offset: use cache offset for non-shared layers, donor offset for shared. + offset := 0 + if donorEntry != nil { + offset = donorEntry.Offset - int(L) + } else if c != nil { + offset = c.Offset() + } + var ropeFreqs *mlx.Array + if !isSliding { + ropeFreqs = cfg.FullRopeFreqs + } + q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs) + + var k, v *mlx.Array + + if donorEntry != nil { + // Shared layer: use donor's cached K/V. + k = donorEntry.K + v = donorEntry.V + } else { + // Determine KV head count: K=V full-attention layers use NumGlobalKeyValueHeads. + kvHeads := cfg.NumKeyValueHeads + if a.VProj == nil && !isSliding && cfg.NumGlobalKeyValueHeads > 0 { + kvHeads = cfg.NumGlobalKeyValueHeads + } + + // Non-shared layer: compute K/V. + k = a.KProj.Forward(x) + k = mlx.Reshape(k, B, L, kvHeads, headDim) + k = mlx.Transpose(k, 0, 2, 1, 3) + + if a.VProj != nil { + v = a.VProj.Forward(x) + v = mlx.Reshape(v, B, L, kvHeads, headDim) + v = mlx.Transpose(v, 0, 2, 1, 3) + } else { + // K=V: value_states = key_states (raw, before k_norm/rope). + v = k + } + + // Apply K norm. + k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps) + + // Apply RoPE to K. + k = mlx.RoPEWithFreqs(k, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs) + + // Apply V norm (no learnable weight, pure RMS normalization). + v = mlx.RMSNormFn(v, nil, cfg.RMSNormEps) + + // Update cache. + if c != nil { + k, v = c.Update(k, v) + } + } + + // Sliding-window layers must restrict attention to the last `window` keys + // during prefill. The rotating KV cache handles decode, but for L > 1 the + // cache holds all prefix keys so we need an explicit mask. Full-attention + // layers pass window=0 (plain causal). + var window int32 + if isSliding && L > 1 && cfg.SlidingWindow > 0 { + window = cfg.SlidingWindow + } + + var out *mlx.Array + switch { + case headDim > 128 && L > 1 && !mlx.MetalIsAvailable(): + // Manual attention for CUDA prefill with head_dim > 128. + // cuDNN SDPA requires head_dim <= 128, and the MLX CUDA SDPA vector + // kernel only handles L < 4 (generation). For prefill, we fall back + // to explicit matmul+softmax+matmul on CUDA. + kvHeads := int32(k.Dim(1)) + nRepeats := cfg.NumAttentionHeads / kvHeads + kLen := int32(k.Dim(2)) + + q = mlx.MulScalar(q, scale) + q = mlx.Reshape(q, B, kvHeads, nRepeats, L, headDim) + k = mlx.Reshape(k, B, kvHeads, 1, kLen, headDim) + v = mlx.Reshape(v, B, kvHeads, 1, kLen, headDim) + + kT := mlx.Transpose(k, 0, 1, 2, 4, 3) + scores := mlx.Matmul(q, kT) + mask := buildCausalMaskWindow(L, kLen, window) + scores = mlx.Add(scores, mask) + scores = mlx.SoftmaxAxis(scores, -1, true) + out = mlx.Matmul(scores, v) + out = mlx.Reshape(out, B, cfg.NumAttentionHeads, L, headDim) + case window > 0: + // Sliding-window prefill: fast SDPA "causal" mode doesn't take a + // window, so supply an explicit additive mask. All sliding layers + // in one forward pass see the same (L, kLen, window, dtype) tuple, + // so we memoize the mask on the first call and reuse it for every + // subsequent sliding layer via slidingMaskCache. + kLen := int32(k.Dim(2)) + mask := slidingMaskCache.get(L, kLen, window, q.DType()) + out = mlx.ScaledDotProductAttentionMasked(q, k, v, scale, mask) + default: + out = mlx.ScaledDotProductAttentionCausal(q, k, v, scale, L > 1) + } + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*headDim) + if !mlx.MetalIsAvailable() { + // Force contiguous layout before OProj on CUDA where matmul handles + // strided views differently. Metal handles them natively. + out = mlx.Contiguous(out, false) + } + return a.OProj.Forward(out) +} + +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.GELUApprox(m.GateProj.Forward(x)) + up := m.UpProj.Forward(x) + return m.DownProj.Forward(mlx.Mul(gate, up)) +} + +// Forward runs the router to select top-k experts per token. +// Returns (scores [B*L, topK], indices [B*L, topK]). +func (r *Router) Forward(x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) { + dims := x.Dims() + BL := int32(dims[0]) * int32(dims[1]) + + // Flatten to [B*L, hidden]. + xFlat := mlx.Reshape(x, BL, cfg.HiddenSize) + + // Norm (no weight) -> scale by 1/sqrt(hidden_size) -> multiply by learnable scale. + normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps) + normed = mlx.MulScalar(normed, cfg.RouterScale) + normed = mlx.Mul(normed, r.Scale) + + // Project to expert scores: [B*L, num_experts]. + expertScores := r.Proj.Forward(normed) + + // Top-k selection via argpartition on negated scores. + neg := mlx.Neg(expertScores) + inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1) + inds = mlx.SliceStartStop(inds, + []int32{0, 0}, + []int32{BL, cfg.TopKExperts}, + ) + + // Softmax only over selected logits. This is equivalent to full softmax + + // gather + renormalize, but avoids normalizing over every expert. + scores := mlx.TakeAlongAxis(expertScores, inds, -1) + scores = mlx.SoftmaxAxis(scores, -1, true) // [B*L, topK] + + return scores, inds +} + +// Forward runs the MoE block using GatherQMM (quantized) or GatherMM (dense). +// scores: [B*L, topK], inds: [B*L, topK], x: [B, L, hidden]. +func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfig) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + topK := cfg.TopKExperts + + // Flatten and prepare for expert dispatch. + xFlat := mlx.Reshape(x, B*L, 1, 1, cfg.HiddenSize) + idxFlat := mlx.Reshape(inds, B*L, topK) + + // Sort indices for efficiency when batch is large enough. + // The sorted_indices flag tells GatherQMM the indices are pre-sorted, + // enabling coalesced memory access. Testing confirmed the sort is + // beneficial for prefill (2x faster with sort at 2048 tokens). + doSort := B*L >= 64 + var invOrder *mlx.Array + n := B * L * topK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + } + + // Expert computation: gate+up followed by GELU and down. + // When gate+up are fused, we do 2 GatherQMM calls instead of 3. + var hidden, down *mlx.Array + if m.UseQuantized { + if m.UseFusedGateUp { + // Fused gate+up: single GatherQMM produces [B*L*topK, 1, 1, 2*intermediate] + gateUp := mlx.GatherQMM(xFlat, m.GateUpWeightQ, m.GateUpScales, m.GateUpBiases, + nil, idxFlat, true, m.GateUpGroupSize, m.GateUpBits, m.QuantMode, doSort) + // Split along last dim into gate and up + guDims := gateUp.Dims() + mid := int32(guDims[len(guDims)-1] / 2) + gate := mlx.SliceStartStop(gateUp, + []int32{0, 0, 0, 0}, + []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), mid}) + up := mlx.SliceStartStop(gateUp, + []int32{0, 0, 0, mid}, + []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) + hidden = mlx.Mul(mlx.GELUApprox(gate), up) + } else { + gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases, + nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort) + up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases, + nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort) + hidden = mlx.Mul(mlx.GELUApprox(gate), up) + } + downMode := m.DownQuantMode + if downMode == "" { + downMode = m.QuantMode + } + down = mlx.GatherQMM(hidden, m.DownWeightQ, m.DownScales, m.DownBiases, + nil, idxFlat, true, m.DownGroupSize, m.DownBits, downMode, doSort) + } else { + if m.UseFusedGateUp && m.GateUpWeight != nil { + gateUp := mlx.GatherMM(xFlat, m.GateUpWeight, nil, idxFlat, doSort) + guDims := gateUp.Dims() + mid := int32(guDims[len(guDims)-1] / 2) + gate := mlx.SliceStartStop(gateUp, + []int32{0, 0, 0, 0}, + []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), mid}) + up := mlx.SliceStartStop(gateUp, + []int32{0, 0, 0, mid}, + []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) + hidden = mlx.Mul(mlx.GELUApprox(gate), up) + } else { + gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort) + up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort) + hidden = mlx.Mul(mlx.GELUApprox(gate), up) + } + down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort) + } + + // Unsort if needed. + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + + // Reshape to [B*L, topK, hidden_size]. + down = mlx.Reshape(down, B*L, topK, cfg.HiddenSize) + + // Gather per-expert scales at selected indices: flatten inds, take, reshape back. + indsFlat := mlx.Reshape(inds, B*L*topK) + expertScales := mlx.Take(m.PerExpertScale, indsFlat, 0) // [B*L*topK] + expertScales = mlx.Reshape(expertScales, B*L, topK) // [B*L, topK] + down = mlx.Mul(down, mlx.ExpandDims(expertScales, -1)) + + // Weight by dispatch scores and sum across experts (axis 1 = topK dim). + y := mlx.Sum(mlx.Mul(down, mlx.ExpandDims(scores, -1)), 1, false) // [B*L, hidden_size] + + return mlx.Reshape(y, B, L, cfg.HiddenSize) +} diff --git a/x/models/gemma4/gemma4_moe_test.go b/x/models/gemma4/gemma4_moe_test.go new file mode 100644 index 000000000..ab390ae59 --- /dev/null +++ b/x/models/gemma4/gemma4_moe_test.go @@ -0,0 +1,228 @@ +package gemma4 + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// onesLike creates a tensor of the given shape filled with a small constant. +func onesLike(shape ...int) *mlx.Array { + return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01) +} + +func TestMoEForward(t *testing.T) { + skipIfNoMLX(t) + + // Small config matching 26b architecture pattern. + cfg := &TextConfig{ + HiddenSize: 16, // tiny for testing + NumAttentionHeads: 2, + NumKeyValueHeads: 1, + NumGlobalKeyValueHeads: 1, + HeadDim: 8, + GlobalHeadDim: 8, + NumExperts: 4, + TopKExperts: 2, + ExpertIntermediateSize: 8, + EnableMoeBlock: true, + AttentionKEqV: false, + RMSNormEps: 1e-6, + SlidingScale: 1.0, + FullScale: 1.0, + } + + B, L := int32(1), int32(3) + x := onesLike(int(B), int(L), int(cfg.HiddenSize)) + + // Test Router.Forward. + router := &Router{ + Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))), + Scale: onesLike(int(cfg.HiddenSize)), + } + + t.Run("Router", func(t *testing.T) { + scores, inds := router.Forward(x, cfg) + mlx.Eval(scores, inds) + + sDims := scores.Dims() + iDims := inds.Dims() + t.Logf("scores shape: %v, inds shape: %v", sDims, iDims) + + if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) { + t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts) + } + if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) { + t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts) + } + }) + + // Test MoEBlock.Forward. + moe := &MoEBlock{ + GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)), + UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)), + DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)), + PerExpertScale: onesLike(int(cfg.NumExperts)), + } + + t.Run("MoEBlock", func(t *testing.T) { + scores, inds := router.Forward(x, cfg) + mlx.Eval(scores, inds) + + out := moe.Forward(x, scores, inds, cfg) + mlx.Eval(out) + + outDims := out.Dims() + t.Logf("MoE output shape: %v", outDims) + + if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) { + t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize) + } + }) + + // Test with larger batch to exercise the sorted GatherMM path (B*L >= 64). + t.Run("MoEBlock_sorted", func(t *testing.T) { + bigB, bigL := int32(1), int32(128) + bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize)) + + scores, inds := router.Forward(bigX, cfg) + mlx.Eval(scores, inds) + + out := moe.Forward(bigX, scores, inds, cfg) + mlx.Eval(out) + + outDims := out.Dims() + t.Logf("MoE sorted output shape: %v", outDims) + + if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) { + t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize) + } + }) +} + +// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward — +// which takes the top-k of the raw logits and softmaxes only the selected +// values — produces the same indices and (within tolerance) the same +// normalized scores as the legacy path that softmaxes over every expert +// first, gathers the top-k probabilities, then renormalizes. +func TestRouterForwardMatchesLegacy(t *testing.T) { + skipIfNoMLX(t) + + cfg := &TextConfig{ + HiddenSize: 8, + NumExperts: 4, + TopKExperts: 2, + RMSNormEps: 1e-6, + RouterScale: 0.5, + } + + // Distinct per-expert weight rows so top-k has a well-defined ordering + // (tied scores would let argpartition pick either tied expert and make + // the index comparison below flaky). + projWeight := mlx.FromValues([]float32{ + 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0 + 0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1 + -0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2 + 0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3 + }, int(cfg.NumExperts), int(cfg.HiddenSize)) + + scale := mlx.FromValues([]float32{ + 1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05, + }, int(cfg.HiddenSize)) + + r := &Router{ + Proj: linearFromWeight(projWeight), + Scale: scale, + } + + // Varied x so different positions potentially hit different top-k. + x := mlx.FromValues([]float32{ + 0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05, + -0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2, + 0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1, + }, 1, 3, int(cfg.HiddenSize)) + + gotScores, gotInds := r.Forward(x, cfg) + wantScores, wantInds := legacyRouterForward(r, x, cfg) + mlx.Eval(gotScores, gotInds, wantScores, wantInds) + + if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) { + t.Fatalf("indices mismatch:\n got %v\n want %v", got, want) + } + if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) { + t.Fatalf("scores mismatch:\n got %v\n want %v", got, want) + } +} + +// legacyRouterForward implements the pre-optimization router: full softmax +// over every expert, gather the top-k probabilities, then renormalize them +// to sum to 1. Algebraically identical to the fused form in Router.Forward. +func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) { + dims := x.Dims() + BL := int32(dims[0]) * int32(dims[1]) + + xFlat := mlx.Reshape(x, BL, cfg.HiddenSize) + normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps) + normed = mlx.MulScalar(normed, cfg.RouterScale) + normed = mlx.Mul(normed, r.Scale) + + expertScores := r.Proj.Forward(normed) + probs := mlx.SoftmaxAxis(expertScores, -1, true) + + neg := mlx.Neg(expertScores) + inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1) + inds = mlx.SliceStartStop(inds, + []int32{0, 0}, + []int32{BL, cfg.TopKExperts}, + ) + + scores := mlx.TakeAlongAxis(probs, inds, -1) + sumScores := mlx.Sum(scores, -1, true) + scores = mlx.Div(scores, sumScores) + return scores, inds +} + +func intSlicesEqual(a, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func floatSlicesClose(a, b []float32, tol float32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + d := a[i] - b[i] + if d < 0 { + d = -d + } + if d > tol { + return false + } + } + return true +} + +// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias). +func linearFromWeight(w *mlx.Array) *simpleLinear { + return &simpleLinear{weight: w} +} + +type simpleLinear struct { + weight *mlx.Array +} + +func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array { + return x.Matmul(mlx.Transpose(l.weight, 1, 0)) +} + +func (l *simpleLinear) OutputDim() int32 { + return int32(l.weight.Dims()[0]) +} diff --git a/x/models/gemma4/gemma4_test.go b/x/models/gemma4/gemma4_test.go new file mode 100644 index 000000000..4f674ca66 --- /dev/null +++ b/x/models/gemma4/gemma4_test.go @@ -0,0 +1,503 @@ +package gemma4 + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +func TestParseTextConfigE2B(t *testing.T) { + skipIfNoMLX(t) + data := []byte(`{ + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": { + "hidden_size": 1536, + "num_hidden_layers": 35, + "intermediate_size": 6144, + "num_attention_heads": 8, + "num_key_value_heads": 1, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 131072, + "sliding_window": 512, + "sliding_window_pattern": 5, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": true, + "num_kv_shared_layers": 20, + "hidden_size_per_layer_input": 256, + "vocab_size_per_layer_input": 262144, + "attention_k_eq_v": false, + "tie_word_embeddings": true, + "layer_types": [ + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention" + ], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } + } + }`) + + cfg, err := parseTextConfig(data) + if err != nil { + t.Fatalf("parseTextConfig failed: %v", err) + } + + // Basic fields. + if cfg.HiddenSize != 1536 { + t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize) + } + if cfg.NumHiddenLayers != 35 { + t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers) + } + if cfg.GlobalHeadDim != 512 { + t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim) + } + if cfg.FinalLogitSoftcapping != 30.0 { + t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping) + } + if cfg.NumKVSharedLayers != 20 { + t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers) + } + if cfg.HiddenSizePerLayer != 256 { + t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer) + } + + // RoPE settings. + if cfg.SlidingRopeDims != 256 { + t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims) + } + if cfg.FullRopeDims != 512 { + t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims) + } + if cfg.SlidingRopeBase != 10000 { + t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase) + } + if cfg.FullRopeBase != 1000000 { + t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase) + } + + // Attention scale. + if cfg.SlidingScale == 0 || cfg.FullScale == 0 { + t.Error("attention scales should be non-zero") + } + + // KV sharing map. + // First shared layer is 35 - 20 = 15. + if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 { + t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok) + } + if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 { + t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok) + } + if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 { + t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok) + } + // Layer 14 should not be shared. + if _, ok := cfg.KVShareMap[14]; ok { + t.Error("layer 14 should not be in KVShareMap (non-shared)") + } + + // Donors. + if !cfg.KVDonors[13] { + t.Error("layer 13 should be a KV donor") + } + if !cfg.KVDonors[14] { + t.Error("layer 14 should be a KV donor") + } +} + +func TestParseTextConfig26B(t *testing.T) { + skipIfNoMLX(t) + data := []byte(`{ + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": { + "hidden_size": 2816, + "num_hidden_layers": 30, + "intermediate_size": 2112, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "num_global_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 131072, + "sliding_window": 1024, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": false, + "num_kv_shared_layers": 0, + "hidden_size_per_layer_input": null, + "attention_k_eq_v": true, + "enable_moe_block": true, + "num_experts": 128, + "top_k_experts": 8, + "moe_intermediate_size": 704, + "tie_word_embeddings": true, + "layer_types": [ + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention" + ], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } + } + }`) + + cfg, err := parseTextConfig(data) + if err != nil { + t.Fatalf("parseTextConfig failed: %v", err) + } + + if cfg.HiddenSize != 2816 { + t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize) + } + if !cfg.AttentionKEqV { + t.Error("AttentionKEqV should be true") + } + if cfg.NumGlobalKeyValueHeads != 2 { + t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads) + } + if !cfg.EnableMoeBlock { + t.Error("EnableMoeBlock should be true") + } + if cfg.NumExperts != 128 { + t.Errorf("NumExperts = %d, want 128", cfg.NumExperts) + } + if cfg.TopKExperts != 8 { + t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts) + } + if cfg.ExpertIntermediateSize != 704 { + t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize) + } + if cfg.HiddenSizePerLayer != 0 { + t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer) + } +} + +func TestParseTextConfig31B(t *testing.T) { + skipIfNoMLX(t) + data := []byte(`{ + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": { + "hidden_size": 5376, + "num_hidden_layers": 60, + "intermediate_size": 21504, + "num_attention_heads": 32, + "num_key_value_heads": 16, + "num_global_key_value_heads": 4, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 131072, + "sliding_window": 1024, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": false, + "num_kv_shared_layers": 0, + "hidden_size_per_layer_input": null, + "attention_k_eq_v": true, + "tie_word_embeddings": true, + "layer_types": [ + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention" + ], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } + } + }`) + + cfg, err := parseTextConfig(data) + if err != nil { + t.Fatalf("parseTextConfig failed: %v", err) + } + + if cfg.HiddenSize != 5376 { + t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize) + } + if cfg.NumHiddenLayers != 60 { + t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers) + } + if !cfg.AttentionKEqV { + t.Error("AttentionKEqV should be true") + } + if cfg.NumGlobalKeyValueHeads != 4 { + t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads) + } + if cfg.NumKeyValueHeads != 16 { + t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads) + } + if cfg.NumKVSharedLayers != 0 { + t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers) + } + if cfg.HiddenSizePerLayer != 0 { + t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer) + } + if cfg.SlidingWindow != 1024 { + t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow) + } + + // KV sharing should be empty (no shared layers). + if len(cfg.KVShareMap) != 0 { + t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap)) + } + + // Layer types: pattern is 5 sliding + 1 full, repeating 10 times. + if !isLayerSliding(0, &cfg) { + t.Error("layer 0 should be sliding") + } + if isLayerSliding(5, &cfg) { + t.Error("layer 5 should be full attention") + } + if !isLayerSliding(6, &cfg) { + t.Error("layer 6 should be sliding") + } + if isLayerSliding(59, &cfg) { + t.Error("layer 59 should be full attention") + } +} + +func TestParseTextConfigE4B(t *testing.T) { + skipIfNoMLX(t) + data := []byte(`{ + "architectures": ["Gemma4ForConditionalGeneration"], + "text_config": { + "hidden_size": 2560, + "num_hidden_layers": 42, + "intermediate_size": 10240, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 131072, + "sliding_window": 512, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": false, + "num_kv_shared_layers": 18, + "hidden_size_per_layer_input": 256, + "vocab_size_per_layer_input": 262144, + "attention_k_eq_v": false, + "enable_moe_block": false, + "tie_word_embeddings": true, + "layer_types": [ + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention", + "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention" + ], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } + } + }`) + + cfg, err := parseTextConfig(data) + if err != nil { + t.Fatalf("parseTextConfig failed: %v", err) + } + + if cfg.HiddenSize != 2560 { + t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize) + } + if cfg.NumHiddenLayers != 42 { + t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers) + } + if cfg.IntermediateSize != 10240 { + t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize) + } + if cfg.NumKeyValueHeads != 2 { + t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads) + } + if cfg.UseDoubleWideMLP { + t.Error("UseDoubleWideMLP should be false") + } + if cfg.NumKVSharedLayers != 18 { + t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers) + } + if cfg.HiddenSizePerLayer != 256 { + t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer) + } + if cfg.AttentionKEqV { + t.Error("AttentionKEqV should be false") + } + if cfg.EnableMoeBlock { + t.Error("EnableMoeBlock should be false") + } + if cfg.SlidingWindow != 512 { + t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow) + } + + // Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers. + if !isLayerSliding(0, &cfg) { + t.Error("layer 0 should be sliding") + } + if isLayerSliding(5, &cfg) { + t.Error("layer 5 should be full attention") + } + if !isLayerSliding(6, &cfg) { + t.Error("layer 6 should be sliding") + } + if isLayerSliding(41, &cfg) { + t.Error("layer 41 should be full attention") + } + + // KV sharing: first shared = 42 - 18 = 24. + // Layer 24 is sliding, its donor should be the last non-shared sliding layer. + // Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full). + if donor, ok := cfg.KVShareMap[24]; !ok { + t.Error("layer 24 should be in KVShareMap") + } else { + t.Logf("layer 24 donor = %d", donor) + } + // Layer 29 is full_attention (5th full), donor should be the last non-shared full layer. + // Non-shared full layers: 5, 11, 17, 23. + if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 { + t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok) + } + // Layer 23 should NOT be shared (it's the last non-shared layer). + if _, ok := cfg.KVShareMap[23]; ok { + t.Error("layer 23 should not be in KVShareMap (non-shared)") + } +} + +func TestLayerTypeDetection(t *testing.T) { + cfg := &TextConfig{ + LayerTypes: []string{ + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + }, + } + + if !isLayerSliding(0, cfg) { + t.Error("layer 0 should be sliding") + } + if !isLayerSliding(3, cfg) { + t.Error("layer 3 should be sliding") + } + if isLayerSliding(4, cfg) { + t.Error("layer 4 should be full attention") + } +} + +func TestNewCachesOmitsSharedKVLayers(t *testing.T) { + m := &Model{ + Layers: []*DecoderLayer{ + {IsSliding: true, KVShareDonor: -1}, + {IsSliding: false, KVShareDonor: -1}, + {IsSliding: true, KVShareDonor: 0}, + {IsSliding: false, KVShareDonor: 1}, + }, + TextConfig: &TextConfig{SlidingWindow: 512}, + } + + caches := m.NewCaches() + if got, want := len(caches), 2; got != want { + t.Fatalf("len(NewCaches()) = %d, want %d", got, want) + } +} + +func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) { + m := &Model{ + Layers: []*DecoderLayer{ + {IsSliding: true, KVShareDonor: -1}, + {IsSliding: false, KVShareDonor: -1}, + {IsSliding: true, KVShareDonor: -1}, + }, + TextConfig: &TextConfig{SlidingWindow: 512}, + } + + caches := m.NewCaches() + if got, want := len(caches), len(m.Layers); got != want { + t.Fatalf("len(NewCaches()) = %d, want %d", got, want) + } +} + +func TestResolveWeightPrefix(t *testing.T) { + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } + + tests := []struct { + name string + key string + wantPfx string + }{ + {"bare", "embed_tokens.weight", ""}, + {"language_model", "model.language_model.embed_tokens.weight", "model.language_model."}, + {"with_model", "model.embed_tokens.weight", "model."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dummy := mlx.FromValue(float32(1.0)) + mlx.Eval(dummy) + tensors := map[string]*mlx.Array{tt.key: dummy} + got := resolveWeightPrefix(tensors) + if got != tt.wantPfx { + t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx) + } + }) + } +} + +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +}