mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
Gemma4 on MLX (#15244)
* gemma4: implement Gemma 4 model for MLX (text-only runtime) * gemma4: two MoE + SWA prefill perf fixes Two performance optimizations in the gemma4 forward pass 1. Memoize the sliding-window prefill mask across layers. 2. Softmax only over the selected experts in Router.Forward. * review comments
This commit is contained in:
@@ -560,6 +560,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -574,6 +577,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -602,6 +608,9 @@ func getRendererName(modelDir string) string {
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
@@ -616,6 +625,9 @@ func getRendererName(modelDir string) string {
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
|
||||
@@ -634,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
||||
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
|
||||
264
x/create/gemma4.go
Normal file
264
x/create/gemma4.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
type gemma4ImportTransform struct {
|
||||
numLayers int
|
||||
numExperts int
|
||||
}
|
||||
|
||||
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
|
||||
type gemma4Config struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
TextConfig struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
var cfg gemma4Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
|
||||
numLayers := cfg.NumHiddenLayers
|
||||
if numLayers == 0 {
|
||||
numLayers = cfg.TextConfig.NumHiddenLayers
|
||||
}
|
||||
numExperts := cfg.NumExperts
|
||||
if numExperts == 0 {
|
||||
numExperts = cfg.TextConfig.NumExperts
|
||||
}
|
||||
|
||||
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) skipTensor(name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// layerIndexRe extracts the layer index from tensor names like
|
||||
// "model.language_model.layers.5.self_attn.v_proj.weight" or
|
||||
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
|
||||
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
|
||||
|
||||
// useMoreBits returns true for layers where quantization-sensitive tensors
|
||||
// should use higher precision: the first and last 1/8 of layers (which handle
|
||||
// input grounding and final output refinement), plus every 3rd layer in between
|
||||
// to limit error accumulation through the residual stream.
|
||||
func useMoreBits(layerIdx, numLayers int) bool {
|
||||
return layerIdx < numLayers/8 ||
|
||||
layerIdx >= 7*numLayers/8 ||
|
||||
(layerIdx-numLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
|
||||
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
|
||||
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
|
||||
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
|
||||
if isEmbedTokensWeight(name) {
|
||||
switch quantNorm {
|
||||
case "int4", "int8":
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
case "mxfp4", "nvfp4", "mxfp8":
|
||||
if isAligned(shape, "mxfp8") {
|
||||
return "mxfp8"
|
||||
}
|
||||
}
|
||||
if isAligned(shape, quantNorm) {
|
||||
return quantNorm
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||
//
|
||||
// Value projections (v_proj) directly determine attention output quality.
|
||||
// Down projections (down_proj) are the final MLP output and errors there
|
||||
// propagate directly to the residual stream. Both benefit from higher
|
||||
// precision at early layers, late layers, and periodically in between
|
||||
// (the "useMoreBits" heuristic).
|
||||
//
|
||||
// For int4: promote → int8 (same affine family, GatherQMM compatible).
|
||||
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
|
||||
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
|
||||
// quant metadata and the kernel dispatches per-tensor.
|
||||
if t.numLayers > 0 {
|
||||
layerIdx := -1
|
||||
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
|
||||
if idx, err := strconv.Atoi(m[1]); err == nil {
|
||||
layerIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine promotion target for sensitive tensors.
|
||||
// "int8" = int4 base → int8 (affine family)
|
||||
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
|
||||
// "" = no promotion (int8/mxfp8, already 8-bit)
|
||||
promote := ""
|
||||
switch quantNorm {
|
||||
case "int4":
|
||||
promote = "int8"
|
||||
case "mxfp4", "nvfp4":
|
||||
promote = "mxfp8"
|
||||
}
|
||||
|
||||
// Only apply to language model tensors — audio/vision tower tensors
|
||||
// should pass through to GetTensorQuantization which skips them.
|
||||
isModelTensor := !strings.Contains(name, "audio_tower") &&
|
||||
!strings.Contains(name, "vision_tower")
|
||||
isSensitive := isModelTensor &&
|
||||
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
|
||||
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
|
||||
|
||||
if promote != "" && (isSensitive || isSensitiveK) {
|
||||
shouldPromote := false
|
||||
|
||||
// 8-expert models: v_proj and k_proj share very few KV heads,
|
||||
// so quantization errors are amplified. Always promote.
|
||||
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
// Layer-position heuristic for v_proj and down_proj.
|
||||
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
if shouldPromote && isAligned(shape, promote) {
|
||||
return promote
|
||||
}
|
||||
|
||||
// Sensitive tensor at a non-promoted layer: use base quant type.
|
||||
// Return directly to bypass GetTensorQuantization's uniform
|
||||
// promotion — the layer-position heuristic is authoritative here.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
return quantNorm
|
||||
}
|
||||
}
|
||||
|
||||
return GetTensorQuantization(name, shape, quantize)
|
||||
}
|
||||
|
||||
// isEmbedTokensWeight returns true for the main token embedding weight.
|
||||
func isEmbedTokensWeight(name string) bool {
|
||||
return strings.HasSuffix(name, "embed_tokens.weight") &&
|
||||
!strings.Contains(name, "per_layer")
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
if td == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
|
||||
// [out, in] tensors so they go through the standard expert packing and
|
||||
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
|
||||
if isGemma4StackedMoETensor(td.Name, td.Shape) {
|
||||
return splitStackedMoETensor(td)
|
||||
}
|
||||
|
||||
return []*safetensors.TensorData{td}, nil
|
||||
}
|
||||
|
||||
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
|
||||
// Gemma 4 HF weights come in two layouts depending on the model version:
|
||||
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
|
||||
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
|
||||
//
|
||||
// The newer layout has gate+up already fused. We keep it fused (no splitting)
|
||||
// so the tensors flow through the standard expert packing and quantization path.
|
||||
func isGemma4StackedMoETensor(name string, shape []int32) bool {
|
||||
if len(shape) != 3 {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
|
||||
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
|
||||
// N individual [out, in] tensors named with the per-expert convention that
|
||||
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
|
||||
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
raw, err := io.ReadAll(td.Reader())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
||||
}
|
||||
|
||||
numExperts := int(td.Shape[0])
|
||||
rows := int(td.Shape[1]) // out_features in HF layout
|
||||
cols := int(td.Shape[2]) // in_features in HF layout
|
||||
|
||||
elemSize, err := DTypeSize(td.Dtype)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
|
||||
}
|
||||
|
||||
perExpertBytes := rows * cols * elemSize
|
||||
if len(raw) != numExperts*perExpertBytes {
|
||||
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
|
||||
td.Name, len(raw), td.Shape, td.Dtype)
|
||||
}
|
||||
|
||||
// Determine the per-expert name pattern.
|
||||
// Two source layouts:
|
||||
// Old: model.language_model.layers.N.moe.gate_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
|
||||
// New: model.language_model.layers.N.experts.gate_up_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
|
||||
baseName := td.Name
|
||||
baseName = strings.TrimSuffix(baseName, ".weight")
|
||||
lastDot := strings.LastIndex(baseName, ".")
|
||||
if lastDot < 0 {
|
||||
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
|
||||
}
|
||||
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
|
||||
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
|
||||
|
||||
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
|
||||
// so we get a consistent "layers.N.moe.experts.E" pattern.
|
||||
var moePrefix string
|
||||
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
|
||||
moePrefix = cut + ".moe"
|
||||
} else {
|
||||
moePrefix = parentPrefix
|
||||
}
|
||||
|
||||
transposedShape := []int32{td.Shape[1], td.Shape[2]}
|
||||
|
||||
results := make([]*safetensors.TensorData, numExperts)
|
||||
for e := range numExperts {
|
||||
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
|
||||
start := e * perExpertBytes
|
||||
end := start + perExpertBytes
|
||||
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
191
x/create/gemma4_test.go
Normal file
191
x/create/gemma4_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4QuantizationType(t *testing.T) {
|
||||
// 26B MoE: 30 layers, 128 experts
|
||||
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
|
||||
// 8-expert model (hypothetical)
|
||||
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
|
||||
|
||||
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
transform gemma4ImportTransform
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
|
||||
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
|
||||
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
|
||||
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === v_proj: layer-position heuristic for int4/nvfp4 ===
|
||||
// Layer 0 is in first 1/8 (30/8=3) → promoted
|
||||
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// Layer 4 is NOT in useMoreBits → base quant
|
||||
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
|
||||
// Layer 29 is in last 1/8 → promoted
|
||||
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
|
||||
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
|
||||
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
// int8/mxfp8: no promotion (already 8-bit)
|
||||
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === down_proj (dense MLP): same heuristic as v_proj ===
|
||||
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
|
||||
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
|
||||
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
|
||||
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
|
||||
// so GatherQMM sees uniform quant per projection per layer)
|
||||
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
|
||||
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
|
||||
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
|
||||
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === k_proj: promoted only for 8-expert models ===
|
||||
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
|
||||
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
|
||||
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// === Non-quantizable tensors: always bf16 ===
|
||||
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
|
||||
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
|
||||
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
|
||||
|
||||
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
|
||||
// These contain .v_proj and down_proj but should NOT be intercepted by
|
||||
// the sensitive-tensor promotion logic.
|
||||
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
|
||||
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
|
||||
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
|
||||
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
|
||||
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
|
||||
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
|
||||
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
|
||||
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
|
||||
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
|
||||
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
|
||||
// but not intercepted by gemma4's layer-position heuristic.
|
||||
// Falls through to GetTensorQuantization which applies uniform promotion.
|
||||
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
|
||||
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// Audio tower down_proj
|
||||
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseMoreBits(t *testing.T) {
|
||||
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
|
||||
// In between: every 3rd from offset (i - n/8) % 3 == 2
|
||||
n := 30
|
||||
promoted := map[int]bool{}
|
||||
for i := range n {
|
||||
if useMoreBits(i, n) {
|
||||
promoted[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
// First 1/8 (30/8 = 3): layers 0, 1, 2
|
||||
for _, i := range []int{0, 1, 2} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (first 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
|
||||
for _, i := range []int{26, 27, 28, 29} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (last 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Some middle layers should NOT be promoted
|
||||
for _, i := range []int{3, 4, 6, 7} {
|
||||
if promoted[i] {
|
||||
t.Errorf("layer %d should NOT be promoted", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Layer 5 should be promoted: (5 - 3) % 3 == 2
|
||||
if !promoted[5] {
|
||||
t.Errorf("layer 5 should be promoted (periodic)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemma4StackedMoETensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
label string
|
||||
tensorName string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// New-style: .experts.gate_up_proj
|
||||
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
|
||||
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
|
||||
// Old-style: .moe.gate_proj
|
||||
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
|
||||
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
|
||||
// Not stacked: 2D
|
||||
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
|
||||
// Not expert
|
||||
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
|
||||
// Not a projection
|
||||
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.label, func(t *testing.T) {
|
||||
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
|
||||
tt.tensorName, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
|
||||
1514
x/models/gemma4/gemma4.go
Normal file
1514
x/models/gemma4/gemma4.go
Normal file
File diff suppressed because it is too large
Load Diff
228
x/models/gemma4/gemma4_moe_test.go
Normal file
228
x/models/gemma4/gemma4_moe_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// onesLike creates a tensor of the given shape filled with a small constant.
|
||||
func onesLike(shape ...int) *mlx.Array {
|
||||
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
|
||||
}
|
||||
|
||||
func TestMoEForward(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Small config matching 26b architecture pattern.
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 16, // tiny for testing
|
||||
NumAttentionHeads: 2,
|
||||
NumKeyValueHeads: 1,
|
||||
NumGlobalKeyValueHeads: 1,
|
||||
HeadDim: 8,
|
||||
GlobalHeadDim: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
ExpertIntermediateSize: 8,
|
||||
EnableMoeBlock: true,
|
||||
AttentionKEqV: false,
|
||||
RMSNormEps: 1e-6,
|
||||
SlidingScale: 1.0,
|
||||
FullScale: 1.0,
|
||||
}
|
||||
|
||||
B, L := int32(1), int32(3)
|
||||
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
|
||||
|
||||
// Test Router.Forward.
|
||||
router := &Router{
|
||||
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
|
||||
Scale: onesLike(int(cfg.HiddenSize)),
|
||||
}
|
||||
|
||||
t.Run("Router", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
sDims := scores.Dims()
|
||||
iDims := inds.Dims()
|
||||
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
|
||||
|
||||
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
})
|
||||
|
||||
// Test MoEBlock.Forward.
|
||||
moe := &MoEBlock{
|
||||
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
|
||||
PerExpertScale: onesLike(int(cfg.NumExperts)),
|
||||
}
|
||||
|
||||
t.Run("MoEBlock", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(x, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
|
||||
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
|
||||
t.Run("MoEBlock_sorted", func(t *testing.T) {
|
||||
bigB, bigL := int32(1), int32(128)
|
||||
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
|
||||
|
||||
scores, inds := router.Forward(bigX, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(bigX, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE sorted output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
|
||||
// which takes the top-k of the raw logits and softmaxes only the selected
|
||||
// values — produces the same indices and (within tolerance) the same
|
||||
// normalized scores as the legacy path that softmaxes over every expert
|
||||
// first, gathers the top-k probabilities, then renormalizes.
|
||||
func TestRouterForwardMatchesLegacy(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
RMSNormEps: 1e-6,
|
||||
RouterScale: 0.5,
|
||||
}
|
||||
|
||||
// Distinct per-expert weight rows so top-k has a well-defined ordering
|
||||
// (tied scores would let argpartition pick either tied expert and make
|
||||
// the index comparison below flaky).
|
||||
projWeight := mlx.FromValues([]float32{
|
||||
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
|
||||
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
|
||||
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
|
||||
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
|
||||
}, int(cfg.NumExperts), int(cfg.HiddenSize))
|
||||
|
||||
scale := mlx.FromValues([]float32{
|
||||
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
|
||||
}, int(cfg.HiddenSize))
|
||||
|
||||
r := &Router{
|
||||
Proj: linearFromWeight(projWeight),
|
||||
Scale: scale,
|
||||
}
|
||||
|
||||
// Varied x so different positions potentially hit different top-k.
|
||||
x := mlx.FromValues([]float32{
|
||||
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
|
||||
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
|
||||
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
|
||||
}, 1, 3, int(cfg.HiddenSize))
|
||||
|
||||
gotScores, gotInds := r.Forward(x, cfg)
|
||||
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||
|
||||
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
|
||||
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// legacyRouterForward implements the pre-optimization router: full softmax
|
||||
// over every expert, gather the top-k probabilities, then renormalize them
|
||||
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
|
||||
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
|
||||
dims := x.Dims()
|
||||
BL := int32(dims[0]) * int32(dims[1])
|
||||
|
||||
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
|
||||
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
|
||||
normed = mlx.MulScalar(normed, cfg.RouterScale)
|
||||
normed = mlx.Mul(normed, r.Scale)
|
||||
|
||||
expertScores := r.Proj.Forward(normed)
|
||||
probs := mlx.SoftmaxAxis(expertScores, -1, true)
|
||||
|
||||
neg := mlx.Neg(expertScores)
|
||||
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
|
||||
inds = mlx.SliceStartStop(inds,
|
||||
[]int32{0, 0},
|
||||
[]int32{BL, cfg.TopKExperts},
|
||||
)
|
||||
|
||||
scores := mlx.TakeAlongAxis(probs, inds, -1)
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
return scores, inds
|
||||
}
|
||||
|
||||
func intSlicesEqual(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func floatSlicesClose(a, b []float32, tol float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
d := a[i] - b[i]
|
||||
if d < 0 {
|
||||
d = -d
|
||||
}
|
||||
if d > tol {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
|
||||
func linearFromWeight(w *mlx.Array) *simpleLinear {
|
||||
return &simpleLinear{weight: w}
|
||||
}
|
||||
|
||||
type simpleLinear struct {
|
||||
weight *mlx.Array
|
||||
}
|
||||
|
||||
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
|
||||
}
|
||||
|
||||
func (l *simpleLinear) OutputDim() int32 {
|
||||
return int32(l.weight.Dims()[0])
|
||||
}
|
||||
503
x/models/gemma4/gemma4_test.go
Normal file
503
x/models/gemma4/gemma4_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseTextConfigE2B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 1536,
|
||||
"num_hidden_layers": 35,
|
||||
"intermediate_size": 6144,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"sliding_window_pattern": 5,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": true,
|
||||
"num_kv_shared_layers": 20,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
// Basic fields.
|
||||
if cfg.HiddenSize != 1536 {
|
||||
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 35 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.GlobalHeadDim != 512 {
|
||||
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
|
||||
}
|
||||
if cfg.FinalLogitSoftcapping != 30.0 {
|
||||
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 20 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
|
||||
// RoPE settings.
|
||||
if cfg.SlidingRopeDims != 256 {
|
||||
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
|
||||
}
|
||||
if cfg.FullRopeDims != 512 {
|
||||
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
|
||||
}
|
||||
if cfg.SlidingRopeBase != 10000 {
|
||||
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
|
||||
}
|
||||
if cfg.FullRopeBase != 1000000 {
|
||||
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
|
||||
}
|
||||
|
||||
// Attention scale.
|
||||
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
|
||||
t.Error("attention scales should be non-zero")
|
||||
}
|
||||
|
||||
// KV sharing map.
|
||||
// First shared layer is 35 - 20 = 15.
|
||||
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
|
||||
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 14 should not be shared.
|
||||
if _, ok := cfg.KVShareMap[14]; ok {
|
||||
t.Error("layer 14 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
|
||||
// Donors.
|
||||
if !cfg.KVDonors[13] {
|
||||
t.Error("layer 13 should be a KV donor")
|
||||
}
|
||||
if !cfg.KVDonors[14] {
|
||||
t.Error("layer 14 should be a KV donor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig26B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2816,
|
||||
"num_hidden_layers": 30,
|
||||
"intermediate_size": 2112,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 8,
|
||||
"num_global_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"enable_moe_block": true,
|
||||
"num_experts": 128,
|
||||
"top_k_experts": 8,
|
||||
"moe_intermediate_size": 704,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2816 {
|
||||
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 2 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if !cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be true")
|
||||
}
|
||||
if cfg.NumExperts != 128 {
|
||||
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
|
||||
}
|
||||
if cfg.TopKExperts != 8 {
|
||||
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
|
||||
}
|
||||
if cfg.ExpertIntermediateSize != 704 {
|
||||
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig31B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 5376,
|
||||
"num_hidden_layers": 60,
|
||||
"intermediate_size": 21504,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 16,
|
||||
"num_global_key_value_heads": 4,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 5376 {
|
||||
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 60 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 4 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 16 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 0 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.SlidingWindow != 1024 {
|
||||
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// KV sharing should be empty (no shared layers).
|
||||
if len(cfg.KVShareMap) != 0 {
|
||||
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(59, &cfg) {
|
||||
t.Error("layer 59 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfigE4B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 42,
|
||||
"intermediate_size": 10240,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 18,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"enable_moe_block": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2560 {
|
||||
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 42 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.IntermediateSize != 10240 {
|
||||
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 2 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.UseDoubleWideMLP {
|
||||
t.Error("UseDoubleWideMLP should be false")
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 18 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be false")
|
||||
}
|
||||
if cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be false")
|
||||
}
|
||||
if cfg.SlidingWindow != 512 {
|
||||
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(41, &cfg) {
|
||||
t.Error("layer 41 should be full attention")
|
||||
}
|
||||
|
||||
// KV sharing: first shared = 42 - 18 = 24.
|
||||
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
|
||||
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
|
||||
if donor, ok := cfg.KVShareMap[24]; !ok {
|
||||
t.Error("layer 24 should be in KVShareMap")
|
||||
} else {
|
||||
t.Logf("layer 24 donor = %d", donor)
|
||||
}
|
||||
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
|
||||
// Non-shared full layers: 5, 11, 17, 23.
|
||||
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
|
||||
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 23 should NOT be shared (it's the last non-shared layer).
|
||||
if _, ok := cfg.KVShareMap[23]; ok {
|
||||
t.Error("layer 23 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerTypeDetection(t *testing.T) {
|
||||
cfg := &TextConfig{
|
||||
LayerTypes: []string{
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
},
|
||||
}
|
||||
|
||||
if !isLayerSliding(0, cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if !isLayerSliding(3, cfg) {
|
||||
t.Error("layer 3 should be sliding")
|
||||
}
|
||||
if isLayerSliding(4, cfg) {
|
||||
t.Error("layer 4 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: 0},
|
||||
{IsSliding: false, KVShareDonor: 1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), 2; got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), len(m.Layers); got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWeightPrefix(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantPfx string
|
||||
}{
|
||||
{"bare", "embed_tokens.weight", ""},
|
||||
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
|
||||
{"with_model", "model.embed_tokens.weight", "model."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dummy := mlx.FromValue(float32(1.0))
|
||||
mlx.Eval(dummy)
|
||||
tensors := map[string]*mlx.Array{tt.key: dummy}
|
||||
got := resolveWeightPrefix(tensors)
|
||||
if got != tt.wantPfx {
|
||||
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user