From e585ecd11f5e90f8b02d224122430049c0d56d8e Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Wed, 15 Apr 2026 14:37:16 -0700 Subject: [PATCH] gemma4: render differently based on model size Following up on #15560, this change now has e2b/e4b render differently from 26b/31b. For backwards compatibility, we take the existing renderer name `gemma4` and make it do dynamic resolution based on the model name/size, but the intended use is for the models to be republished with the renderer variant specified explicitly: `gemma4-small` or `gemma4-large`. --- model/renderers/gemma4.go | 6 +- model/renderers/gemma4_reference_test.go | 76 +++++++++++++--- model/renderers/renderer.go | 4 +- server/create.go | 2 +- server/gemma4_test.go | 78 ++++++++++++++++ server/images.go | 2 +- server/images_test.go | 33 +++++++ server/prompt.go | 3 +- server/prompt_test.go | 48 ++++++++++ server/renderer_resolution.go | 110 +++++++++++++++++++++++ server/routes_create_test.go | 53 +++++++++++ 11 files changed, 399 insertions(+), 16 deletions(-) create mode 100644 server/gemma4_test.go create mode 100644 server/renderer_resolution.go diff --git a/model/renderers/gemma4.go b/model/renderers/gemma4.go index 82f5fe5b1..59133527f 100644 --- a/model/renderers/gemma4.go +++ b/model/renderers/gemma4.go @@ -12,7 +12,8 @@ import ( // <|turn>/ markers, <|"|> string delimiters, and <|tool>/ // <|tool_call>/<|tool_response> tags for function calling. type Gemma4Renderer struct { - useImgTags bool + useImgTags bool + emptyBlockOnNothink bool } const ( @@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV // Generation prompt. if prevMessageType != "tool_response" && prevMessageType != "tool_call" { sb.WriteString("<|turn>model\n") + if r.emptyBlockOnNothink { + sb.WriteString("<|channel>thought\n") + } } return sb.String(), nil diff --git a/model/renderers/gemma4_reference_test.go b/model/renderers/gemma4_reference_test.go index 5c6458ffd..d65a061e4 100644 --- a/model/renderers/gemma4_reference_test.go +++ b/model/renderers/gemma4_reference_test.go @@ -3,9 +3,9 @@ package renderers // TestGemma4RendererMatchesReference verifies our renderer matches the checked-in // Gemma 4 reference template. // -// Current upstream Gemma 4 chat templates differ by model size, so the checked-in -// reference intentionally uses the shared baseline without an empty generation-time -// thought channel until renderer selection is split by size. +// Current upstream Gemma 4 chat templates differ by model size. The checked-in +// reference cases below use the small (e2b/e4b-style) baseline, with large +// (26b/31b-style) checks covered separately in this file. // // To regenerate expected values, save the E2B template to // gemma4_e2b_chat_template.jinja2 and run: @@ -1474,6 +1474,40 @@ Hi } } +func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) { + messages := []api.Message{{Role: "user", Content: "Hello"}} + + tests := []struct { + name string + rendererName string + expected string + }{ + { + name: "legacy_alias", + rendererName: "gemma4", + expected: "<|turn>user\nHello\n<|turn>model\n", + }, + { + name: "small", + rendererName: "gemma4-small", + expected: "<|turn>user\nHello\n<|turn>model\n", + }, + { + name: "large", + rendererName: "gemma4-large", + expected: "<|turn>user\nHello\n<|turn>model\n<|channel>thought\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil) + assert.NoError(t, err) + assert.Equal(t, tt.expected, got) + }) + } +} + func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) { if os.Getenv("VERIFY_JINJA2") == "" { t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks") @@ -1616,15 +1650,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - renderer := &Gemma4Renderer{useImgTags: RenderImgTags} - got, err := renderer.Render(tt.messages, tt.tools, tt.think) - assert.NoError(t, err) + variants := []struct { + name string + renderer *Gemma4Renderer + templateRel string + }{ + { + name: "small", + renderer: &Gemma4Renderer{useImgTags: RenderImgTags}, + templateRel: gemma4E2BTemplate, + }, + { + name: "large", + renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}, + templateRel: gemma431BTemplate, + }, + } - jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think) - assert.Equal(t, jinja2Output, got, - "renderer output doesn't match Jinja2 template output") + for _, variant := range variants { + t.Run(variant.name, func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think) + assert.NoError(t, err) + + jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think) + assert.Equal(t, jinja2Output, got, + "renderer output doesn't match Jinja2 template output") + }) + } }) } } diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index f63eb36fd..84cc78f8d 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -81,8 +81,10 @@ func rendererForName(name string) Renderer { return renderer case "nemotron-3-nano": return &Nemotron3NanoRenderer{} - case "gemma4": + case "gemma4", "gemma4-small": return &Gemma4Renderer{useImgTags: RenderImgTags} + case "gemma4-large": + return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true} case "functiongemma": return &FunctionGemmaRenderer{} case "glm-4.7": diff --git a/server/create.go b/server/create.go index 01fbe5738..9ddb2bf8b 100644 --- a/server/create.go +++ b/server/create.go @@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, arch := layer.GGML.KV().Architecture() switch arch { case "gemma4": - config.Renderer = cmp.Or(config.Renderer, "gemma4") + config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy) config.Parser = cmp.Or(config.Parser, "gemma4") if _, ok := r.Parameters["stop"]; !ok { if r.Parameters == nil { diff --git a/server/gemma4_test.go b/server/gemma4_test.go new file mode 100644 index 000000000..ddcc5b6fe --- /dev/null +++ b/server/gemma4_test.go @@ -0,0 +1,78 @@ +package server + +import "testing" + +func TestResolveGemma4Renderer(t *testing.T) { + tests := []struct { + name string + model *Model + want string + }{ + { + name: "nil model falls back to legacy alias", + model: nil, + want: gemma4RendererLegacy, + }, + { + name: "explicit small passes through", + model: &Model{ + Config: testConfigWithRenderer(gemma4RendererSmall), + }, + want: gemma4RendererSmall, + }, + { + name: "explicit large passes through", + model: &Model{ + Config: testConfigWithRenderer(gemma4RendererLarge), + }, + want: gemma4RendererLarge, + }, + { + name: "legacy e4b tag resolves small", + model: &Model{ + Name: "gemma4:e4b", + ShortName: "gemma4:e4b", + Config: testConfigWithRenderer(gemma4RendererLegacy), + }, + want: gemma4RendererSmall, + }, + { + name: "legacy 31b tag resolves large", + model: &Model{ + Name: "gemma4:31b-cloud", + ShortName: "gemma4:31b-cloud", + Config: testConfigWithRenderer(gemma4RendererLegacy), + }, + want: gemma4RendererLarge, + }, + { + name: "legacy model type resolves small", + model: &Model{ + Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"), + }, + want: gemma4RendererSmall, + }, + { + name: "legacy model type resolves large", + model: &Model{ + Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"), + }, + want: gemma4RendererLarge, + }, + { + name: "legacy unknown defaults small", + model: &Model{ + Config: testConfigWithRenderer(gemma4RendererLegacy), + }, + want: gemma4RendererSmall, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveGemma4Renderer(tt.model); got != tt.want { + t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/server/images.go b/server/images.go index a7fce62dc..110a7dc6e 100644 --- a/server/images.go +++ b/server/images.go @@ -156,7 +156,7 @@ func (m *Model) Capabilities() []model.Capability { // Temporary workaround — suppress vision/audio for gemma4 MLX models // until multimodal runtime pipeline lands. Remove when imageproc.go is wired up. - if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" { + if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) { capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool { return c == model.CapabilityVision || c == "audio" }) diff --git a/server/images_test.go b/server/images_test.go index 88f1d07c5..d7f99afc8 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -118,6 +118,39 @@ func TestModelCapabilities(t *testing.T) { }, expectedCaps: []model.Capability{model.CapabilityEmbedding}, }, + { + name: "gemma4 small safetensors suppresses vision and audio", + model: Model{ + Config: model.ConfigV2{ + ModelFormat: "safetensors", + Renderer: gemma4RendererSmall, + Capabilities: []string{"vision", "audio"}, + }, + Template: chatTemplate, + }, + }, + { + name: "gemma4 large safetensors suppresses vision and audio", + model: Model{ + Config: model.ConfigV2{ + ModelFormat: "safetensors", + Renderer: gemma4RendererLarge, + Capabilities: []string{"vision", "audio"}, + }, + Template: chatTemplate, + }, + }, + { + name: "legacy gemma4 safetensors suppresses vision and audio", + model: Model{ + Config: model.ConfigV2{ + ModelFormat: "safetensors", + Renderer: gemma4RendererLegacy, + Capabilities: []string{"vision", "audio"}, + }, + Template: chatTemplate, + }, + }, } // compare two slices of model.Capability regardless of order diff --git a/server/prompt.go b/server/prompt.go index 0737fa215..8fa164557 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { if m.Config.Renderer != "" { - rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) + rendererName := resolveRendererName(m) + rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think) if err != nil { return "", err } diff --git a/server/prompt_test.go b/server/prompt_test.go index 8bbadb22d..e4cc27a5a 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -13,6 +13,14 @@ import ( "github.com/ollama/ollama/types/model" ) +func testConfigWithRenderer(renderer string) model.ConfigV2 { + return model.ConfigV2{Renderer: renderer} +} + +func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 { + return model.ConfigV2{Renderer: renderer, ModelType: modelType} +} + func TestChatPrompt(t *testing.T) { type expect struct { prompt string @@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) { t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt) } } + +func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) { + msgs := []api.Message{{Role: "user", Content: "Hello"}} + + tests := []struct { + name string + model Model + want string + }{ + { + name: "small from name", + model: Model{ + Name: "gemma4:e4b", + ShortName: "gemma4:e4b", + Config: testConfigWithRenderer(gemma4RendererLegacy), + }, + want: "<|turn>user\nHello\n<|turn>model\n", + }, + { + name: "large from model type", + model: Model{ + Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"), + }, + want: "<|turn>user\nHello\n<|turn>model\n<|channel>thought\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := renderPrompt(&tt.model, msgs, nil, nil) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/server/renderer_resolution.go b/server/renderer_resolution.go new file mode 100644 index 000000000..870f700e4 --- /dev/null +++ b/server/renderer_resolution.go @@ -0,0 +1,110 @@ +package server + +import ( + "strconv" + "strings" + + "github.com/ollama/ollama/format" +) + +const ( + gemma4RendererLegacy = "gemma4" + gemma4RendererSmall = "gemma4-small" + gemma4RendererLarge = "gemma4-large" + + // Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the + // large template. Default to the small prompt unless the model is clearly in + // the large range. + gemma4LargeMinParameterCount = 16_000_000_000 +) + +func resolveRendererName(m *Model) string { + if m == nil || m.Config.Renderer == "" { + return "" + } + + switch m.Config.Renderer { + case gemma4RendererLegacy: + return resolveGemma4Renderer(m) + default: + return m.Config.Renderer + } +} + +func resolveGemma4Renderer(m *Model) string { + if m == nil || m.Config.Renderer != gemma4RendererLegacy { + if m == nil { + return gemma4RendererLegacy + } + return m.Config.Renderer + } + + if renderer, ok := gemma4RendererFromName(m.ShortName); ok { + return renderer + } + + if renderer, ok := gemma4RendererFromName(m.Name); ok { + return renderer + } + + if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok { + return gemma4RendererForParameterCount(parameterCount) + } + + return gemma4RendererSmall +} + +func gemma4RendererForParameterCount(parameterCount uint64) string { + if parameterCount >= gemma4LargeMinParameterCount { + return gemma4RendererLarge + } + + return gemma4RendererSmall +} + +func gemma4RendererFromName(name string) (string, bool) { + lower := strings.ToLower(name) + switch { + case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"): + return gemma4RendererSmall, true + case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"): + return gemma4RendererLarge, true + default: + return "", false + } +} + +func parseHumanParameterCount(s string) (uint64, bool) { + if s == "" { + return 0, false + } + + unit := strings.ToUpper(s[len(s)-1:]) + var multiplier float64 + switch unit { + case "B": + multiplier = float64(format.Billion) + case "M": + multiplier = float64(format.Million) + case "K": + multiplier = float64(format.Thousand) + default: + return 0, false + } + + value, err := strconv.ParseFloat(s[:len(s)-1], 64) + if err != nil { + return 0, false + } + + return uint64(value * multiplier), true +} + +func isGemma4Renderer(renderer string) bool { + switch renderer { + case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge: + return true + default: + return false + } +} diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 75bdac73b..3a4dfb6dc 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) { }) } +func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) { + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "gemma4", + "general.parameter_count": uint64(25_200_000_000), + }, nil) + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test", + Files: map[string]string{"test.gguf": digest}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + mf, err := manifest.ParseNamedManifest(model.ParseName("test")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + if mf.Config.Digest == "" { + t.Fatalf("unexpected empty config digest for manifest") + } + + configPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("config blob path: %v", err) + } + + cfgFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("open config blob: %v", err) + } + defer cfgFile.Close() + + var cfg model.ConfigV2 + if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil { + t.Fatalf("decode config: %v", err) + } + + if cfg.Renderer != gemma4RendererLegacy { + t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer) + } + if cfg.Parser != "gemma4" { + t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser) + } +} + func TestDetectModelTypeFromFiles(t *testing.T) { t.Run("gguf file", func(t *testing.T) { _, digest := createBinFile(t, nil, nil)