Gemma4 on MLX (#15244)

* gemma4: implement Gemma 4 model for MLX (text-only runtime)

* gemma4: two MoE + SWA prefill perf fixes

Two performance optimizations in the gemma4 forward pass

1. Memoize the sliding-window prefill mask across layers.
2. Softmax only over the selected experts in Router.Forward.

* review comments
This commit is contained in:
Daniel Hiltgen
2026-04-13 16:36:51 -07:00
committed by GitHub
parent bf2a421727
commit 2cba7756c5
8 changed files with 2715 additions and 0 deletions

View File

@@ -560,6 +560,9 @@ func getParserName(modelDir string) string {
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3"
}
@@ -574,6 +577,9 @@ func getParserName(modelDir string) string {
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
@@ -602,6 +608,9 @@ func getRendererName(modelDir string) string {
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
@@ -616,6 +625,9 @@ func getRendererName(modelDir string) string {
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}

View File

@@ -634,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {

264
x/create/gemma4.go Normal file
View File

@@ -0,0 +1,264 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/x/safetensors"
)
type gemma4ImportTransform struct {
numLayers int
numExperts int
}
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
type gemma4Config struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
TextConfig struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
} `json:"text_config"`
}
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
var cfg gemma4Config
if err := json.Unmarshal(data, &cfg); err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
numLayers := cfg.NumHiddenLayers
if numLayers == 0 {
numLayers = cfg.TextConfig.NumHiddenLayers
}
numExperts := cfg.NumExperts
if numExperts == 0 {
numExperts = cfg.TextConfig.NumExperts
}
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
}
func (t gemma4ImportTransform) skipTensor(name string) bool {
return false
}
// layerIndexRe extracts the layer index from tensor names like
// "model.language_model.layers.5.self_attn.v_proj.weight" or
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
// useMoreBits returns true for layers where quantization-sensitive tensors
// should use higher precision: the first and last 1/8 of layers (which handle
// input grounding and final output refinement), plus every 3rd layer in between
// to limit error accumulation through the residual stream.
func useMoreBits(layerIdx, numLayers int) bool {
return layerIdx < numLayers/8 ||
layerIdx >= 7*numLayers/8 ||
(layerIdx-numLayers/8)%3 == 2
}
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
if isEmbedTokensWeight(name) {
switch quantNorm {
case "int4", "int8":
if isAligned(shape, "int8") {
return "int8"
}
case "mxfp4", "nvfp4", "mxfp8":
if isAligned(shape, "mxfp8") {
return "mxfp8"
}
}
if isAligned(shape, quantNorm) {
return quantNorm
}
return ""
}
// Mixed-precision quantization: sensitive tensors get higher precision.
//
// Value projections (v_proj) directly determine attention output quality.
// Down projections (down_proj) are the final MLP output and errors there
// propagate directly to the residual stream. Both benefit from higher
// precision at early layers, late layers, and periodically in between
// (the "useMoreBits" heuristic).
//
// For int4: promote → int8 (same affine family, GatherQMM compatible).
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
// quant metadata and the kernel dispatches per-tensor.
if t.numLayers > 0 {
layerIdx := -1
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
if idx, err := strconv.Atoi(m[1]); err == nil {
layerIdx = idx
}
}
// Determine promotion target for sensitive tensors.
// "int8" = int4 base → int8 (affine family)
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
// "" = no promotion (int8/mxfp8, already 8-bit)
promote := ""
switch quantNorm {
case "int4":
promote = "int8"
case "mxfp4", "nvfp4":
promote = "mxfp8"
}
// Only apply to language model tensors — audio/vision tower tensors
// should pass through to GetTensorQuantization which skips them.
isModelTensor := !strings.Contains(name, "audio_tower") &&
!strings.Contains(name, "vision_tower")
isSensitive := isModelTensor &&
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
if promote != "" && (isSensitive || isSensitiveK) {
shouldPromote := false
// 8-expert models: v_proj and k_proj share very few KV heads,
// so quantization errors are amplified. Always promote.
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
shouldPromote = true
}
// Layer-position heuristic for v_proj and down_proj.
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
shouldPromote = true
}
if shouldPromote && isAligned(shape, promote) {
return promote
}
// Sensitive tensor at a non-promoted layer: use base quant type.
// Return directly to bypass GetTensorQuantization's uniform
// promotion — the layer-position heuristic is authoritative here.
if !isAligned(shape, quantNorm) {
return ""
}
return quantNorm
}
}
return GetTensorQuantization(name, shape, quantize)
}
// isEmbedTokensWeight returns true for the main token embedding weight.
func isEmbedTokensWeight(name string) bool {
return strings.HasSuffix(name, "embed_tokens.weight") &&
!strings.Contains(name, "per_layer")
}
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
// [out, in] tensors so they go through the standard expert packing and
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
if isGemma4StackedMoETensor(td.Name, td.Shape) {
return splitStackedMoETensor(td)
}
return []*safetensors.TensorData{td}, nil
}
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
// Gemma 4 HF weights come in two layouts depending on the model version:
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
//
// The newer layout has gate+up already fused. We keep it fused (no splitting)
// so the tensors flow through the standard expert packing and quantization path.
func isGemma4StackedMoETensor(name string, shape []int32) bool {
if len(shape) != 3 {
return false
}
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
}
return false
}
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
// N individual [out, in] tensors named with the per-expert convention that
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
raw, err := io.ReadAll(td.Reader())
if err != nil {
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
}
numExperts := int(td.Shape[0])
rows := int(td.Shape[1]) // out_features in HF layout
cols := int(td.Shape[2]) // in_features in HF layout
elemSize, err := DTypeSize(td.Dtype)
if err != nil {
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
}
perExpertBytes := rows * cols * elemSize
if len(raw) != numExperts*perExpertBytes {
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
td.Name, len(raw), td.Shape, td.Dtype)
}
// Determine the per-expert name pattern.
// Two source layouts:
// Old: model.language_model.layers.N.moe.gate_proj
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
// New: model.language_model.layers.N.experts.gate_up_proj
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
baseName := td.Name
baseName = strings.TrimSuffix(baseName, ".weight")
lastDot := strings.LastIndex(baseName, ".")
if lastDot < 0 {
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
}
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
// so we get a consistent "layers.N.moe.experts.E" pattern.
var moePrefix string
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
moePrefix = cut + ".moe"
} else {
moePrefix = parentPrefix
}
transposedShape := []int32{td.Shape[1], td.Shape[2]}
results := make([]*safetensors.TensorData, numExperts)
for e := range numExperts {
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
start := e * perExpertBytes
end := start + perExpertBytes
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
}
return results, nil
}

191
x/create/gemma4_test.go Normal file
View File

@@ -0,0 +1,191 @@
package create
import (
"testing"
)
func TestGemma4QuantizationType(t *testing.T) {
// 26B MoE: 30 layers, 128 experts
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
// 8-expert model (hypothetical)
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
tests := []struct {
name string
transform gemma4ImportTransform
tensor string
shape []int32
quantize string
want string
}{
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
// === v_proj: layer-position heuristic for int4/nvfp4 ===
// Layer 0 is in first 1/8 (30/8=3) → promoted
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
// Layer 4 is NOT in useMoreBits → base quant
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
// Layer 29 is in last 1/8 → promoted
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
// int8/mxfp8: no promotion (already 8-bit)
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
// === down_proj (dense MLP): same heuristic as v_proj ===
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
// so GatherQMM sees uniform quant per projection per layer)
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
// === k_proj: promoted only for 8-expert models ===
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
// === Non-quantizable tensors: always bf16 ===
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
// These contain .v_proj and down_proj but should NOT be intercepted by
// the sensitive-tensor promotion logic.
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
// but not intercepted by gemma4's layer-position heuristic.
// Falls through to GetTensorQuantization which applies uniform promotion.
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
// Audio tower down_proj
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestUseMoreBits(t *testing.T) {
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
// In between: every 3rd from offset (i - n/8) % 3 == 2
n := 30
promoted := map[int]bool{}
for i := range n {
if useMoreBits(i, n) {
promoted[i] = true
}
}
// First 1/8 (30/8 = 3): layers 0, 1, 2
for _, i := range []int{0, 1, 2} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (first 1/8)", i)
}
}
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
for _, i := range []int{26, 27, 28, 29} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (last 1/8)", i)
}
}
// Some middle layers should NOT be promoted
for _, i := range []int{3, 4, 6, 7} {
if promoted[i] {
t.Errorf("layer %d should NOT be promoted", i)
}
}
// Layer 5 should be promoted: (5 - 3) % 3 == 2
if !promoted[5] {
t.Errorf("layer 5 should be promoted (periodic)")
}
}
func TestIsGemma4StackedMoETensor(t *testing.T) {
tests := []struct {
label string
tensorName string
shape []int32
want bool
}{
// New-style: .experts.gate_up_proj
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
// Old-style: .moe.gate_proj
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
// Not stacked: 2D
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
// Not expert
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
// Not a projection
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
if got != tt.want {
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
tt.tensorName, tt.shape, got, tt.want)
}
})
}
}

View File

@@ -2,6 +2,7 @@ package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/gemma4"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"

1514
x/models/gemma4/gemma4.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// onesLike creates a tensor of the given shape filled with a small constant.
func onesLike(shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
}
func TestMoEForward(t *testing.T) {
skipIfNoMLX(t)
// Small config matching 26b architecture pattern.
cfg := &TextConfig{
HiddenSize: 16, // tiny for testing
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
NumGlobalKeyValueHeads: 1,
HeadDim: 8,
GlobalHeadDim: 8,
NumExperts: 4,
TopKExperts: 2,
ExpertIntermediateSize: 8,
EnableMoeBlock: true,
AttentionKEqV: false,
RMSNormEps: 1e-6,
SlidingScale: 1.0,
FullScale: 1.0,
}
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
// Test Router.Forward.
router := &Router{
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
Scale: onesLike(int(cfg.HiddenSize)),
}
t.Run("Router", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
})
// Test MoEBlock.Forward.
moe := &MoEBlock{
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
PerExpertScale: onesLike(int(cfg.NumExperts)),
}
t.Run("MoEBlock", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
})
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
t.Run("MoEBlock_sorted", func(t *testing.T) {
bigB, bigL := int32(1), int32(128)
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
scores, inds := router.Forward(bigX, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(bigX, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
}
})
}
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
// which takes the top-k of the raw logits and softmaxes only the selected
// values — produces the same indices and (within tolerance) the same
// normalized scores as the legacy path that softmaxes over every expert
// first, gathers the top-k probabilities, then renormalizes.
func TestRouterForwardMatchesLegacy(t *testing.T) {
skipIfNoMLX(t)
cfg := &TextConfig{
HiddenSize: 8,
NumExperts: 4,
TopKExperts: 2,
RMSNormEps: 1e-6,
RouterScale: 0.5,
}
// Distinct per-expert weight rows so top-k has a well-defined ordering
// (tied scores would let argpartition pick either tied expert and make
// the index comparison below flaky).
projWeight := mlx.FromValues([]float32{
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
}, int(cfg.NumExperts), int(cfg.HiddenSize))
scale := mlx.FromValues([]float32{
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
}, int(cfg.HiddenSize))
r := &Router{
Proj: linearFromWeight(projWeight),
Scale: scale,
}
// Varied x so different positions potentially hit different top-k.
x := mlx.FromValues([]float32{
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
}, 1, 3, int(cfg.HiddenSize))
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
}
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
}
}
// legacyRouterForward implements the pre-optimization router: full softmax
// over every expert, gather the top-k probabilities, then renormalize them
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
dims := x.Dims()
BL := int32(dims[0]) * int32(dims[1])
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
normed = mlx.MulScalar(normed, cfg.RouterScale)
normed = mlx.Mul(normed, r.Scale)
expertScores := r.Proj.Forward(normed)
probs := mlx.SoftmaxAxis(expertScores, -1, true)
neg := mlx.Neg(expertScores)
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
inds = mlx.SliceStartStop(inds,
[]int32{0, 0},
[]int32{BL, cfg.TopKExperts},
)
scores := mlx.TakeAlongAxis(probs, inds, -1)
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
return scores, inds
}
func intSlicesEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func floatSlicesClose(a, b []float32, tol float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
d := a[i] - b[i]
if d < 0 {
d = -d
}
if d > tol {
return false
}
}
return true
}
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
func linearFromWeight(w *mlx.Array) *simpleLinear {
return &simpleLinear{weight: w}
}
type simpleLinear struct {
weight *mlx.Array
}
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
}
func (l *simpleLinear) OutputDim() int32 {
return int32(l.weight.Dims()[0])
}

View File

@@ -0,0 +1,503 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseTextConfigE2B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 1536,
"num_hidden_layers": 35,
"intermediate_size": 6144,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"sliding_window_pattern": 5,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": true,
"num_kv_shared_layers": 20,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
// Basic fields.
if cfg.HiddenSize != 1536 {
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 35 {
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
}
if cfg.GlobalHeadDim != 512 {
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
}
if cfg.FinalLogitSoftcapping != 30.0 {
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
}
if cfg.NumKVSharedLayers != 20 {
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
}
// RoPE settings.
if cfg.SlidingRopeDims != 256 {
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
}
if cfg.FullRopeDims != 512 {
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
}
if cfg.SlidingRopeBase != 10000 {
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeBase != 1000000 {
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
}
// Attention scale.
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
t.Error("attention scales should be non-zero")
}
// KV sharing map.
// First shared layer is 35 - 20 = 15.
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
}
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
// Layer 14 should not be shared.
if _, ok := cfg.KVShareMap[14]; ok {
t.Error("layer 14 should not be in KVShareMap (non-shared)")
}
// Donors.
if !cfg.KVDonors[13] {
t.Error("layer 13 should be a KV donor")
}
if !cfg.KVDonors[14] {
t.Error("layer 14 should be a KV donor")
}
}
func TestParseTextConfig26B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2816,
"num_hidden_layers": 30,
"intermediate_size": 2112,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"num_global_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"enable_moe_block": true,
"num_experts": 128,
"top_k_experts": 8,
"moe_intermediate_size": 704,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2816 {
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 2 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
}
if !cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be true")
}
if cfg.NumExperts != 128 {
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
}
if cfg.TopKExperts != 8 {
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
}
if cfg.ExpertIntermediateSize != 704 {
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
}
func TestParseTextConfig31B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 5376,
"num_hidden_layers": 60,
"intermediate_size": 21504,
"num_attention_heads": 32,
"num_key_value_heads": 16,
"num_global_key_value_heads": 4,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 5376 {
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 60 {
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 4 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
}
if cfg.NumKeyValueHeads != 16 {
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
}
if cfg.NumKVSharedLayers != 0 {
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
if cfg.SlidingWindow != 1024 {
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
}
// KV sharing should be empty (no shared layers).
if len(cfg.KVShareMap) != 0 {
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
}
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(59, &cfg) {
t.Error("layer 59 should be full attention")
}
}
func TestParseTextConfigE4B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2560,
"num_hidden_layers": 42,
"intermediate_size": 10240,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 18,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"enable_moe_block": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2560 {
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 42 {
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
}
if cfg.IntermediateSize != 10240 {
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
}
if cfg.NumKeyValueHeads != 2 {
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
}
if cfg.UseDoubleWideMLP {
t.Error("UseDoubleWideMLP should be false")
}
if cfg.NumKVSharedLayers != 18 {
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
}
if cfg.AttentionKEqV {
t.Error("AttentionKEqV should be false")
}
if cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be false")
}
if cfg.SlidingWindow != 512 {
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
}
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(41, &cfg) {
t.Error("layer 41 should be full attention")
}
// KV sharing: first shared = 42 - 18 = 24.
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
if donor, ok := cfg.KVShareMap[24]; !ok {
t.Error("layer 24 should be in KVShareMap")
} else {
t.Logf("layer 24 donor = %d", donor)
}
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
// Non-shared full layers: 5, 11, 17, 23.
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
}
// Layer 23 should NOT be shared (it's the last non-shared layer).
if _, ok := cfg.KVShareMap[23]; ok {
t.Error("layer 23 should not be in KVShareMap (non-shared)")
}
}
func TestLayerTypeDetection(t *testing.T) {
cfg := &TextConfig{
LayerTypes: []string{
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
},
}
if !isLayerSliding(0, cfg) {
t.Error("layer 0 should be sliding")
}
if !isLayerSliding(3, cfg) {
t.Error("layer 3 should be sliding")
}
if isLayerSliding(4, cfg) {
t.Error("layer 4 should be full attention")
}
}
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: 0},
{IsSliding: false, KVShareDonor: 1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), 2; got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: -1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), len(m.Layers); got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestResolveWeightPrefix(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
tests := []struct {
name string
key string
wantPfx string
}{
{"bare", "embed_tokens.weight", ""},
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
{"with_model", "model.embed_tokens.weight", "model."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dummy := mlx.FromValue(float32(1.0))
mlx.Eval(dummy)
tensors := map[string]*mlx.Array{tt.key: dummy}
got := resolveWeightPrefix(tensors)
if got != tt.wantPfx {
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
}
})
}
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}