gemma4: render differently based on model size

Following up on #15560, this change now has e2b/e4b render differently
from 26b/31b.

For backwards compatibility, we take the existing renderer name `gemma4`
and make it do dynamic resolution based on the model name/size, but the
intended use is for the models to be republished with the renderer
variant specified explicitly: `gemma4-small` or `gemma4-large`.
This commit is contained in:
Devon Rifkin
2026-04-15 14:37:16 -07:00
parent 06ae6367bd
commit e585ecd11f
11 changed files with 399 additions and 16 deletions

View File

@@ -12,7 +12,8 @@ import (
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
// <|tool_call>/<|tool_response> tags for function calling.
type Gemma4Renderer struct {
useImgTags bool
useImgTags bool
emptyBlockOnNothink bool
}
const (
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Generation prompt.
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
sb.WriteString("<|turn>model\n")
if r.emptyBlockOnNothink {
sb.WriteString("<|channel>thought\n<channel|>")
}
}
return sb.String(), nil

View File

@@ -3,9 +3,9 @@ package renderers
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
// Gemma 4 reference template.
//
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
// reference intentionally uses the shared baseline without an empty generation-time
// thought channel until renderer selection is split by size.
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
// reference cases below use the small (e2b/e4b-style) baseline, with large
// (26b/31b-style) checks covered separately in this file.
//
// To regenerate expected values, save the E2B template to
// gemma4_e2b_chat_template.jinja2 and run:
@@ -1474,6 +1474,40 @@ Hi<turn|>
}
}
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
messages := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
rendererName string
expected string
}{
{
name: "legacy_alias",
rendererName: "gemma4",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "small",
rendererName: "gemma4-small",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large",
rendererName: "gemma4-large",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
assert.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
if os.Getenv("VERIFY_JINJA2") == "" {
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
@@ -1616,15 +1650,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
variants := []struct {
name string
renderer *Gemma4Renderer
templateRel string
}{
{
name: "small",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
templateRel: gemma4E2BTemplate,
},
{
name: "large",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
templateRel: gemma431BTemplate,
},
}
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
for _, variant := range variants {
t.Run(variant.name, func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
})
}
})
}
}

View File

@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
return renderer
case "nemotron-3-nano":
return &Nemotron3NanoRenderer{}
case "gemma4":
case "gemma4", "gemma4-small":
return &Gemma4Renderer{useImgTags: RenderImgTags}
case "gemma4-large":
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":

View File

@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
arch := layer.GGML.KV().Architecture()
switch arch {
case "gemma4":
config.Renderer = cmp.Or(config.Renderer, "gemma4")
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil {

78
server/gemma4_test.go Normal file
View File

@@ -0,0 +1,78 @@
package server
import "testing"
func TestResolveGemma4Renderer(t *testing.T) {
tests := []struct {
name string
model *Model
want string
}{
{
name: "nil model falls back to legacy alias",
model: nil,
want: gemma4RendererLegacy,
},
{
name: "explicit small passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererSmall),
},
want: gemma4RendererSmall,
},
{
name: "explicit large passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLarge),
},
want: gemma4RendererLarge,
},
{
name: "legacy e4b tag resolves small",
model: &Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
{
name: "legacy 31b tag resolves large",
model: &Model{
Name: "gemma4:31b-cloud",
ShortName: "gemma4:31b-cloud",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererLarge,
},
{
name: "legacy model type resolves small",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
},
want: gemma4RendererSmall,
},
{
name: "legacy model type resolves large",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: gemma4RendererLarge,
},
{
name: "legacy unknown defaults small",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveGemma4Renderer(tt.model); got != tt.want {
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -156,7 +156,7 @@ func (m *Model) Capabilities() []model.Capability {
// Temporary workaround — suppress vision/audio for gemma4 MLX models
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == "audio"
})

View File

@@ -118,6 +118,39 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "gemma4 small safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererSmall,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "gemma4 large safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLarge,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "legacy gemma4 safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLegacy,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
}
// compare two slices of model.Capability regardless of order

View File

@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
rendererName := resolveRendererName(m)
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
if err != nil {
return "", err
}

View File

@@ -13,6 +13,14 @@ import (
"github.com/ollama/ollama/types/model"
)
func testConfigWithRenderer(renderer string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer}
}
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
}
}
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
msgs := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
model Model
want string
}{
{
name: "small from name",
model: Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large from model type",
model: Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderPrompt(&tt.model, msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -0,0 +1,110 @@
package server
import (
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
gemma4RendererLegacy = "gemma4"
gemma4RendererSmall = "gemma4-small"
gemma4RendererLarge = "gemma4-large"
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
// large template. Default to the small prompt unless the model is clearly in
// the large range.
gemma4LargeMinParameterCount = 16_000_000_000
)
func resolveRendererName(m *Model) string {
if m == nil || m.Config.Renderer == "" {
return ""
}
switch m.Config.Renderer {
case gemma4RendererLegacy:
return resolveGemma4Renderer(m)
default:
return m.Config.Renderer
}
}
func resolveGemma4Renderer(m *Model) string {
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
if m == nil {
return gemma4RendererLegacy
}
return m.Config.Renderer
}
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
return renderer
}
if renderer, ok := gemma4RendererFromName(m.Name); ok {
return renderer
}
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
return gemma4RendererForParameterCount(parameterCount)
}
return gemma4RendererSmall
}
func gemma4RendererForParameterCount(parameterCount uint64) string {
if parameterCount >= gemma4LargeMinParameterCount {
return gemma4RendererLarge
}
return gemma4RendererSmall
}
func gemma4RendererFromName(name string) (string, bool) {
lower := strings.ToLower(name)
switch {
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
return gemma4RendererSmall, true
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
return gemma4RendererLarge, true
default:
return "", false
}
}
func parseHumanParameterCount(s string) (uint64, bool) {
if s == "" {
return 0, false
}
unit := strings.ToUpper(s[len(s)-1:])
var multiplier float64
switch unit {
case "B":
multiplier = float64(format.Billion)
case "M":
multiplier = float64(format.Million)
case "K":
multiplier = float64(format.Thousand)
default:
return 0, false
}
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
if err != nil {
return 0, false
}
return uint64(value * multiplier), true
}
func isGemma4Renderer(renderer string) bool {
switch renderer {
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
return true
default:
return false
}
}

View File

@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
})
}
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "gemma4",
"general.parameter_count": uint64(25_200_000_000),
}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != gemma4RendererLegacy {
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
}
if cfg.Parser != "gemma4" {
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
}
}
func TestDetectModelTypeFromFiles(t *testing.T) {
t.Run("gguf file", func(t *testing.T) {
_, digest := createBinFile(t, nil, nil)