Add support for gemma4 (#15214)

* bench: add prompt calibration, context size flag, and NumCtx reporting

Add --num-ctx flag to set context size, and report NumCtx in model info
header. Calibrate tokens-per-word ratio during warmup using actual
tokenization metrics from the model, replacing the fixed 1.3 heuristic.
This produces more accurate prompt token counts for --prompt-tokens.

Also add fetchContextLength() to query running model context via /api/ps.

* integration: improve vision test robustness and add thinking tests

Add skipIfNoVisionOverride() to skip vision tests when OLLAMA_TEST_MODEL
is set to a non-vision model. Add Think:false to context exhaustion test
to prevent thinking models from using all context before the test can
measure it. Add third test image (ollama homepage) and replace OCR test
with ImageDescription test using it. Relax match strings for broader
model compatibility. Add TestThinkingEnabled and TestThinkingSuppressed
to verify thinking output and channel tag handling.

* gemma4: add Gemma 4 GGML model support

Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense)
for the GGML backend including text, vision, converter, parser, and
renderer.

Text model features:
- Sliding window + full attention with per-layer patterns
- KV sharing across layers with donor map
- Per-layer embeddings (PLE) with learned projections
- MoE routing with RMSNorm + learned scale
- Proportional RoPE with freq_factors for global attention
- Final logit softcapping

Vision model features:
- SigLIP vision encoder with 2D RoPE
- ClippableLinear with input/output clamping via packed v.clamp_data
- Adaptive average pooling with nMerge kernel
- Multi-modal projection with unweighted RMSNorm

Converter:
- Safetensors to GGUF with vision tensor renaming
- Fused MoE gate_up_proj splitting
- Vision patch embedding reshape (HF to Conv2D layout)
- Packed clamp data tensor for ClippableLinear bounds
- Proportional RoPE freq_factors generation

Also includes:
- BackendGet() on ml.Tensor for reading weight tensor data
- Q6_K CUDA get_rows kernel support
- MoE-aware ffn_down quantization layer counting
- Gemma4 parser with tool calling and thinking support
- Gemma4 renderer with structured tool format
- Architecture-based auto-detection of renderer/parser/stop tokens
- Integration test gemma4 model list additions

* gemma4: add audio support with USM conformer encoder

Add audio encoding for Gemma 4 using the USM conformer architecture:
- Converter: audio tensor mapping, SSCP/conformer/embedder name replacements,
  softplus repacker for per_dim_scale, F32 enforcement for conv weights
- GGML backend: Conv1DDW and PadExt tensor ops
- Audio encoder: SSCP Conv2D, 12 conformer blocks (FFW + block-local
  attention with relative position embeddings + LightConv1d + FFW),
  output projection, audio-to-text embedding projector
- Audio preprocessing: WAV decode, mel spectrogram, FFT (pure Go)
- Model wiring: WAV detection, audio token handling, unified PostTokenize

Correctly transcribes "why is the sky blue" from test audio.

* integration: add gemma4 audio tests including OpenAI API coverage

Test audio transcription and response via the Ollama native API, plus
two new tests exercising the OpenAI-compatible endpoints:
- /v1/audio/transcriptions (multipart form upload)
- /v1/chat/completions with input_audio content type

All tests use capability checks and skip models without audio support.

* gemma4: add OpenAI audio API support and capability detection

- Add CapabilityAudio and detect from audio.block_count in GGUF
- Add /v1/audio/transcriptions endpoint with TranscriptionMiddleware
- Add input_audio content type support in /v1/chat/completions
- Add TranscriptionRequest/Response types in openai package

* gemma4: add audio input support for run command

- /audio toggle in interactive mode for voice chat
- Platform-specific microphone recording (AVFoundation on macOS,
  PulseAudio/ALSA on Linux, WASAPI on Windows)
- Space to start/stop recording, automatic chunking for long audio

* gemma4: add transcribe command (ollama transcribe MODEL)

- Interactive mode with readline prompt and slash commands
- Non-interactive mode for piped audio or record-until-Ctrl+C
- Chunked streaming transcription for long recordings
- Word-wrapped output matching run command style

* gemma4: add parser, renderer, and integration test plumbing

* gemma4: fix renderer to emit BOS token

* gemma4: add OpenAI audio transcription API and input_audio support

* gemma4: update converter for new weight drop naming

* gemma4: add per_expert_scale to MoE router and fix moe_intermediate_size config

* gemma4: rewrite renderer to match HF Jinja2 template exactly

Fix 8 bugs found by building 55 reference tests verified against the
HF Jinja2 chat template (VERIFY_JINJA2=1 shells out to Python):

- Tool responses use separate <|turn>tool turns (not inline tags)
- Tool calls emitted before content in assistant messages
- Thinking content stripped from assistant history (strip_thinking)
- User, tool, and system content trimmed (template does | trim)
- Empty system message still emits system turn (check role, not content)
- Nested object properties rendered recursively with required field
- Array items specification rendered for array-type properties
- OBJECT/ARRAY type-specific rendering comma logic matches template

Also adds Required field to api.ToolProperty for nested object schemas,
replaces old gemma4_test.go with comprehensive gemma4_reference_test.go,
and commits the Jinja2 template as testdata for verification.

* gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing

- Text MoE: split `ffn_gate_up_exps` into contiguous `[gate|up]` halves instead of stride-2 slices.
- Parser: escape control characters in `<|"|>...<|"|>` string literals when converting tool-call args to JSON.
- Fixes warnings like `invalid character '\n' in string literal` for multiline tool arguments.
- Add Gemma4 parser regressions for multiline tool-call args and `gemma4ArgsToJSON`.

* cmd: simplify audio input to dropped file attachments

* gemma4: use full SWA memory for better cache reuse

* gemma4: initialize clamps after backend load

* convert: align gemma4 audio tensor renames with llama.cpp

* Remove redundant comments in gemma4 vision model

* Format Gemma4 MoE block field alignment

* use 4096 kvcache.NewSWAMemCache

* convert: support new Gemma4 audio_tower tensor naming (#15221)

Co-authored-by: jmorganca <jmorganca@gmail.com>

* fix integration test defaults for audio

* review comments and lint fixes

* remove unused audio/video files

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
Daniel Hiltgen
2026-04-02 11:33:33 -07:00
committed by GitHub
parent 79865e6c5a
commit 96b202d34b
52 changed files with 7196 additions and 51 deletions

View File

@@ -436,6 +436,7 @@ type ToolProperty struct {
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
// ToTypeScriptType converts a ToolProperty to a TypeScript type string

View File

@@ -32,6 +32,7 @@ type flagOptions struct {
verbose *bool
warmup *int
promptTokens *int
numCtx *int
}
type Metrics struct {
@@ -48,6 +49,7 @@ type ModelInfo struct {
Family string
SizeBytes int64
VRAMBytes int64
NumCtx int64
}
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
@@ -64,9 +66,12 @@ var promptWordList = []string{
"old", "stone", "bridge", "that", "crosses", "winding", "river",
}
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
var tokensPerWord = 1.3
func generatePromptForTokenCount(targetTokens int, epoch int) string {
// ~1.3 tokens per word heuristic
targetWords := int(float64(targetTokens) / 1.3)
targetWords := int(float64(targetTokens) / tokensPerWord)
if targetWords < 1 {
targetWords = 1
}
@@ -81,6 +86,17 @@ func generatePromptForTokenCount(targetTokens int, epoch int) string {
return strings.Join(words, " ")
}
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
if actualTokens <= 0 || wordCount <= 0 {
return
}
tokensPerWord = float64(actualTokens) / float64(wordCount)
newWords := int(float64(targetTokens) / tokensPerWord)
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
}
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
@@ -90,6 +106,9 @@ func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData,
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
options["num_ctx"] = *fOpt.numCtx
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
@@ -146,7 +165,6 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
return m.Size, m.SizeVRAM
}
}
// Try prefix match (model names may include :latest or tags)
for _, m := range resp.Models {
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return m.Size, m.SizeVRAM
@@ -155,6 +173,19 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
return 0, 0
}
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
resp, err := client.ListRunning(ctx)
if err != nil {
return 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return int64(m.ContextLength)
}
}
return 0
}
func outputFormatHeader(w io.Writer, format string, verbose bool) {
switch format {
case "benchstat":
@@ -177,8 +208,12 @@ func outputModelInfo(w io.Writer, format string, info ModelInfo) {
if info.SizeBytes > 0 {
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
info.Name, params, quant, family, memStr)
ctxStr := ""
if info.NumCtx > 0 {
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
info.Name, params, quant, family, memStr, ctxStr)
}
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
@@ -276,21 +311,38 @@ func BenchmarkModel(fOpt flagOptions) error {
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
var warmupMetrics *api.Metrics
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if resp.Done {
warmupMetrics = &resp.Metrics
}
return nil
})
cancel()
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
} else if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
} else {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
}
// Calibrate prompt token count on last warmup run
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
wordCount := len(strings.Fields(prompt))
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
}
}
}
// Fetch memory usage once after warmup (model is loaded and stable)
// Fetch memory/context info once after warmup (model is loaded and stable)
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
info.NumCtx = int64(*fOpt.numCtx)
} else {
info.NumCtx = fetchContextLength(memCtx, client, model)
}
memCancel()
outputModelInfo(out, *fOpt.format, info)
@@ -479,6 +531,7 @@ func main() {
debug: flag.Bool("debug", false, "Show debug information"),
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
}
flag.Usage = func() {

View File

@@ -695,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
@@ -1494,6 +1495,9 @@ type displayResponseState struct {
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if termWidth == 0 {
termWidth = 80
}
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {

View File

@@ -47,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
@@ -592,7 +592,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
@@ -608,10 +608,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
return "", imgs, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
ext := strings.ToLower(filepath.Ext(nfp))
switch ext {
case ".wav":
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
default:
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
}
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
@@ -685,9 +691,9 @@ func getImageData(filePath string) ([]byte, error) {
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
return nil, fmt.Errorf("invalid file type: %s", contentType)
}
info, err := file.Stat()
@@ -695,8 +701,7 @@ func getImageData(filePath string) ([]byte, error) {
return nil, err
}
// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
var maxSize int64 = 100 * 1024 * 1024 // 100MB
if info.Size() > maxSize {
return nil, errors.New("file size exceeds maximum limit (100MB)")
}

View File

@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}

View File

@@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
conv = &gemma4Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":

574
convert/convert_gemma4.go Normal file
View File

@@ -0,0 +1,574 @@
package convert
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type gemma4Model struct {
gemmaModel
Architecture string
TextModel struct {
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
GlobalHeadDim uint32 `json:"global_head_dim"`
VocabSize uint32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
EnableMoeBlock bool `json:"enable_moe_block"`
NumExperts *uint32 `json:"num_experts"`
TopKExperts *uint32 `json:"top_k_experts"`
ExpertIntermediateSize *uint32 `json:"moe_intermediate_size"`
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
AttentionKEqV bool `json:"attention_k_eq_v"`
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
RopeParameters map[string]*struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
IntermediateSize uint32 `json:"intermediate_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
LayerNormEps float32 `json:"layer_norm_eps"`
} `json:"vision_config"`
AudioModel *struct {
HiddenSize uint32 `json:"hidden_size"`
OutputProjDims uint32 `json:"output_proj_dims"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ConvKernelSize uint32 `json:"conv_kernel_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
} `json:"audio_config"`
}
func (p *gemma4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma4"
kv["tokenizer.ggml.model"] = "llama"
kv["tokenizer.ggml.pre"] = "gemma4"
tc := p.TextModel
kv["gemma4.block_count"] = tc.NumHiddenLayers
kv["gemma4.embedding_length"] = tc.HiddenSize
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
ffnWidths := make([]int32, tc.NumHiddenLayers)
for i := range ffnWidths {
if i >= firstShared {
ffnWidths[i] = int32(tc.IntermediateSize * 2)
} else {
ffnWidths[i] = int32(tc.IntermediateSize)
}
}
kv["gemma4.feed_forward_length"] = ffnWidths
} else {
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
}
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
kvHeads := make([]int32, len(tc.LayerTypes))
for i, lt := range tc.LayerTypes {
if lt == "sliding_attention" {
kvHeads[i] = int32(tc.NumKeyValueHeads)
} else {
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
}
}
kv["gemma4.attention.head_count_kv"] = kvHeads
} else {
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
}
// key_length = global head dim, key_length_swa = local (SWA) head dim
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
// Sliding window pattern from layer_types
if len(tc.LayerTypes) > 0 {
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, lt := range tc.LayerTypes {
if !yield(lt == "sliding_attention") {
break
}
}
})
}
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
kv["gemma4.rope.freq_base"] = rp.RopeTheta
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
}
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
}
if tc.FinalLogitSoftcapping > 0 {
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
}
// MoE
if tc.EnableMoeBlock && tc.NumExperts != nil {
kv["gemma4.expert_count"] = *tc.NumExperts
if tc.TopKExperts != nil {
kv["gemma4.expert_used_count"] = *tc.TopKExperts
}
if tc.ExpertIntermediateSize != nil {
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
}
}
// PLE — always emit, even when 0
pleSize := uint32(0)
if tc.HiddenSizePerLayerInput != nil {
pleSize = *tc.HiddenSizePerLayerInput
}
kv["gemma4.embedding_length_per_layer_input"] = pleSize
// Vision model KV metadata
vc := p.VisionModel
if vc.NumHiddenLayers > 0 {
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
kv["gemma4.vision.patch_size"] = vc.PatchSize
numCh := vc.NumChannels
if numCh == 0 {
numCh = 3
}
kv["gemma4.vision.num_channels"] = numCh
nMerge := vc.PoolingKernelSize
if nMerge == 0 {
nMerge = 3
}
kv["gemma4.vision.projector.scale_factor"] = nMerge
eps := vc.LayerNormEps
if eps == 0 {
eps = 1e-6
}
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
}
// Audio model KV metadata
if p.AudioModel != nil && p.AudioModel.NumHiddenLayers > 0 {
ac := p.AudioModel
kv["gemma4.audio.block_count"] = ac.NumHiddenLayers
kv["gemma4.audio.embedding_length"] = ac.HiddenSize
kv["gemma4.audio.feed_forward_length"] = ac.HiddenSize * 4
kv["gemma4.audio.attention.head_count"] = ac.NumAttentionHeads
eps := ac.RMSNormEps
if eps == 0 {
eps = 1e-6
}
kv["gemma4.audio.attention.layer_norm_epsilon"] = eps
if ac.ConvKernelSize > 0 {
kv["gemma4.audio.conv_kernel_size"] = ac.ConvKernelSize
}
}
return kv
}
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
// First pass: collect vision clamp scalar values into a packed tensor.
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
// Then 4 values for the projector (mm.input_projection).
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
clampMap := make(map[string]float32)
for _, t := range ts {
name := t.Name()
for _, sfx := range clampSuffixes {
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
var buf bytes.Buffer
t.WriteTo(&buf)
data := buf.Bytes()
if len(data) >= 4 {
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
}
}
}
}
var out []*ggml.Tensor
for _, t := range ts {
name := t.Name()
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
if strings.Contains(name, "embedding_post_projection_norm") {
continue
}
// Vision tensor renaming: match published mmproj GGUF names
if strings.HasPrefix(name, "v.blk.") {
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
}
// per_dim_scale: apply softplus to weight data and add .weight suffix.
if strings.HasPrefix(name, "a.blk.") && strings.HasSuffix(name, "per_dim_scale") {
name = name + ".weight"
t.SetRepacker(softplusRepacker)
}
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
t.SetRepacker(squeezeMiddleDim)
}
shape := t.Shape()
// Convert scalar tensors (input_min/max, output_min/max) to 1D
if len(shape) == 0 {
shape = []uint64{1}
}
// Depthwise conv1d shape: safetensors [C, 1, K] → GGUF ne[K, C].
// Shape array here maps to GGUF ne[] directly, but safetensors reader
// stores shape in PyTorch order [C, 1, K] which the GGUF writer inverts.
// Published GGUF has ne[0]=K, ne[1]=C → shape array must be [K, C].
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
shape = []uint64{shape[0], shape[2]}
}
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
// (transposeExperts was incorrectly swapping dims — removed)
// Audio conv weights are forced to F32 via tensorBase.Kind() in reader.go
// (im2col doesn't support BF16). No kindOverride needed — the Kind() method
// controls both the GGUF header type AND the WriteTo data encoding path.
var kindOverride *uint32
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
nEmbd := shape[0]
patchSize := uint64(p.VisionModel.PatchSize)
if patchSize == 0 {
patchSize = 16
}
numCh := uint64(p.VisionModel.NumChannels)
if numCh == 0 {
numCh = 3
}
t.SetRepacker(p.reshapePatchEmbed)
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
f16Kind := uint32(1) // tensorKindFP16
kindOverride = &f16Kind
}
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
kind := t.Kind()
if kindOverride != nil {
kind = *kindOverride
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: kind,
Shape: shape,
WriterTo: t,
})
}
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
// This matches the published GGUF format: one global tensor shared by all layers.
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
tc := p.TextModel
if tc.GlobalHeadDim > 0 {
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
// Compute number of rotated pairs for global layers
partialRotaryFactor := float32(0.25) // default
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
partialRotaryFactor = *rp.PartialRotaryFactor
}
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
freqs := make(ropeFactor, globalFreqsSize)
for j := range freqs {
if j < nRotFull {
freqs[j] = 1.0
} else {
freqs[j] = 1e30 // effectively disable rotation
}
}
out = append(out, &ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0, // F32
Shape: []uint64{uint64(len(freqs))},
WriterTo: freqs,
})
}
// Emit packed vision clamp data as a single F32 tensor.
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
if len(clampMap) > 0 {
numLayers := int(p.VisionModel.NumHiddenLayers)
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
clampData := make([]float32, totalFloats)
for layer := range numLayers {
for li, ln := range linearNames {
for si, sfx := range suffixes {
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
for origName, val := range clampMap {
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
idx := (layer*len(linearNames)+li)*4 + si
clampData[idx] = val
break
}
}
}
}
}
// Projector clamp values
projIdx := numLayers * len(linearNames) * 4
for si, sfx := range suffixes {
for origName, val := range clampMap {
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
clampData[projIdx+si] = val
break
}
}
}
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, clampData)
out = append(out, &ggml.Tensor{
Name: "v.clamp_data",
Kind: 0, // F32
Shape: []uint64{uint64(totalFloats)},
WriterTo: &buf,
})
}
return out
}
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
// to GGUF layout [n_embd, channels, patch_size, patch_size].
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
if len(shape) != 2 {
return data, nil
}
nEmbd := int(shape[0])
ksqC := int(shape[1])
nChannels := 3
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
// Need: [n_embd, channels, patch_size, patch_size]
result := make([]float32, len(data))
for e := range nEmbd {
for c := range nChannels {
for h := range patchSize {
for w := range patchSize {
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
result[dstIdx] = data[srcIdx]
}
}
}
}
shape[0] = uint64(nEmbd)
shape[1] = uint64(nChannels * patchSize * patchSize)
return result, nil
}
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
result := make([]float32, len(data))
for i, x := range data {
result[i] = float32(math.Log(1 + math.Exp(float64(x))))
}
return result, nil
}
// squeezeMiddleDim squeezes the middle dimension from [C, 1, K] → [C, K] for depthwise conv1d weights.
// Data layout stays the same since the middle dim is 1 — just a shape change.
func squeezeMiddleDim(_ string, data []float32, _ []uint64) ([]float32, error) {
return data, nil
}
func (p *gemma4Model) Replacements() []string {
return []string{
// ClippableLinear wraps nn.Linear — strip .linear. from weight path
".linear.weight", ".weight",
".linear.bias", ".bias",
// Audio SSCP (Sub-Sample Convolution Projection)
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.layer0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.layer0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.layer1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.layer1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
// Audio conformer blocks
"model.audio_tower.conformer", "a.blk",
"model.audio_tower.layers", "a.blk",
// Audio conformer attention
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
"self_attn.relative_k_proj", "linear_pos",
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
"attention.attn.per_dim_scale", "per_dim_scale",
"self_attn.per_dim_scale", "per_dim_scale",
"attention.attn.q_proj", "attn_q",
"attention.attn.k_proj", "attn_k",
"attention.attn.v_proj", "attn_v",
"attention.pre_attn_norm", "ln1",
"attention.post_norm", "ln2",
"attention.post", "attn_out",
"self_attn.post", "attn_out",
"norm_pre_attn", "ln1",
"norm_post_attn", "ln2",
// Audio conformer feedforward
"ffw_layer_start.pre_layer_norm", "ffn_norm",
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
"ffw_layer_start.ffw_layer_1", "ffn_up",
"ffw_layer_start.ffw_layer_2", "ffn_down",
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
"feed_forward1.pre_layer_norm", "ffn_norm",
"feed_forward1.post_layer_norm", "ffn_post_norm",
"feed_forward1.ffw_layer_1", "ffn_up",
"feed_forward1.ffw_layer_2", "ffn_down",
"feed_forward2.pre_layer_norm", "ffn_norm_1",
"feed_forward2.post_layer_norm", "ffn_post_norm_1",
"feed_forward2.ffw_layer_1", "ffn_up_1",
"feed_forward2.ffw_layer_2", "ffn_down_1",
// Audio conformer lightweight conv1d
"lconv1d.depthwise_conv1d", "conv_dw",
"lconv1d.pre_layer_norm", "conv_norm",
"lconv1d.conv_norm", "norm_conv",
"lconv1d.linear_start", "conv_pw1",
"lconv1d.linear_end", "conv_pw2",
// Audio block final norm
"norm_out", "layer_pre_norm",
// Audio embedder and output projection
"model.embed_audio.embedding_projection", "mm.a.input_projection",
"model.audio_tower.output_proj", "mm.a.fc",
// Vision encoder
"model.vision_tower.encoder.layers", "v.blk",
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
"model.vision_tower.std_bias", "v.std_bias",
"model.vision_tower.std_scale", "v.std_scale",
// Vision multimodal projector
"model.embed_vision.embedding_projection", "mm.input_projection",
// Text model
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
// Shared attention replacements (work for both text and vision tensors)
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
// Post norms
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm_1", "post_ffw_norm_1",
"post_feedforward_layernorm_2", "post_ffw_norm_2",
"post_feedforward_layernorm", "post_ffw_norm",
// PLE
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
// MoE
"router.proj", "ffn_gate_inp",
"router.scale", "ffn_gate_inp.scale",
"router.per_expert_scale.weight", "ffn_down_exps.scale",
"router.per_expert_scale", "ffn_down_exps.scale",
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
"experts.down_proj.weight", "ffn_down_exps.weight",
"experts.down_proj", "ffn_down_exps.weight",
"moe.gate_proj", "ffn_gate_exps.weight",
"moe.up_proj", "ffn_up_exps.weight",
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
"moe.down_proj", "ffn_down_exps.weight",
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
"moe.per_expert_scale", "ffn_down_exps.scale",
// Layer scalar
"layer_scalar", "layer_output_scale.weight",
}
}

View File

@@ -0,0 +1,318 @@
package convert
import (
"strings"
"testing"
)
func TestGemma4AudioReplacements(t *testing.T) {
p := gemma4Model{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
// SSCP convolution blocks
{
"sscp conv0 weight",
"model.audio_tower.subsample_conv_projection.conv_0.conv.weight",
"a.conv1d.0.weight",
},
{
"sscp conv0 norm",
"model.audio_tower.subsample_conv_projection.conv_0.norm.weight",
"a.conv1d.0.norm.weight",
},
{
"sscp conv1 weight",
"model.audio_tower.subsample_conv_projection.conv_1.conv.weight",
"a.conv1d.1.weight",
},
{
"sscp input proj weight",
"model.audio_tower.subsample_conv_projection.input_proj_linear.weight",
"a.pre_encode.out.weight",
},
{
"sscp input proj bias",
"model.audio_tower.subsample_conv_projection.input_proj_linear.bias",
"a.pre_encode.out.bias",
},
{
"sscp layer0 conv weight (new naming)",
"model.audio_tower.subsample_conv_projection.layer0.conv.weight",
"a.conv1d.0.weight",
},
{
"sscp layer1 norm weight (new naming)",
"model.audio_tower.subsample_conv_projection.layer1.norm.weight",
"a.conv1d.1.norm.weight",
},
// Conformer attention
{
"attn q weight",
"model.audio_tower.conformer.0.attention.attn.q_proj.linear.weight",
"a.blk.0.attn_q.weight",
},
{
"attn k weight",
"model.audio_tower.conformer.5.attention.attn.k_proj.linear.weight",
"a.blk.5.attn_k.weight",
},
{
"attn v clamp input_min",
"model.audio_tower.conformer.0.attention.attn.v_proj.input_min",
"a.blk.0.attn_v.input_min",
},
{
"attn out weight (ClippableLinear)",
"model.audio_tower.conformer.0.attention.post.linear.weight",
"a.blk.0.attn_out.weight",
},
{
"attn out clamp output_max",
"model.audio_tower.conformer.0.attention.post.output_max",
"a.blk.0.attn_out.output_max",
},
{
"attn pre norm",
"model.audio_tower.conformer.0.attention.pre_attn_norm.weight",
"a.blk.0.ln1.weight",
},
{
"attn post norm",
"model.audio_tower.conformer.0.attention.post_norm.weight",
"a.blk.0.ln2.weight",
},
{
"linear pos",
"model.audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight",
"a.blk.0.linear_pos.weight",
},
{
"per dim scale",
"model.audio_tower.conformer.0.attention.attn.per_dim_scale",
"a.blk.0.per_dim_scale",
},
{
"per dim key scale",
"model.audio_tower.conformer.0.attention.attn.per_dim_key_scale",
"a.blk.0.per_dim_k_scale",
},
{
"attn relative k proj (new naming)",
"model.audio_tower.layers.0.self_attn.relative_k_proj.weight",
"a.blk.0.linear_pos.weight",
},
{
"attn pre norm (new naming)",
"model.audio_tower.layers.0.norm_pre_attn.weight",
"a.blk.0.ln1.weight",
},
{
"attn post norm (new naming)",
"model.audio_tower.layers.0.norm_post_attn.weight",
"a.blk.0.ln2.weight",
},
{
"attn out clamp output_max (new naming)",
"model.audio_tower.layers.0.self_attn.post.output_max",
"a.blk.0.attn_out.output_max",
},
{
"per dim scale (new naming)",
"model.audio_tower.layers.0.self_attn.per_dim_scale",
"a.blk.0.per_dim_scale",
},
// Conformer feedforward start
{
"ffn up weight",
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.linear.weight",
"a.blk.0.ffn_up.weight",
},
{
"ffn down weight",
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_2.linear.weight",
"a.blk.0.ffn_down.weight",
},
{
"ffn norm",
"model.audio_tower.conformer.0.ffw_layer_start.pre_layer_norm.weight",
"a.blk.0.ffn_norm.weight",
},
{
"ffn post norm",
"model.audio_tower.conformer.0.ffw_layer_start.post_layer_norm.weight",
"a.blk.0.ffn_post_norm.weight",
},
// Conformer feedforward end
{
"ffn up 1 weight",
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.linear.weight",
"a.blk.0.ffn_up_1.weight",
},
{
"ffn down 1 weight",
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.linear.weight",
"a.blk.0.ffn_down_1.weight",
},
{
"ffn norm 1",
"model.audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight",
"a.blk.0.ffn_norm_1.weight",
},
{
"ffn post norm 1",
"model.audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight",
"a.blk.0.ffn_post_norm_1.weight",
},
{
"ffn up output_max (new naming)",
"model.audio_tower.layers.10.feed_forward1.ffw_layer_1.output_max",
"a.blk.10.ffn_up.output_max",
},
{
"ffn down output_min (new naming)",
"model.audio_tower.layers.0.feed_forward1.ffw_layer_2.output_min",
"a.blk.0.ffn_down.output_min",
},
{
"ffn up 1 input_max (new naming)",
"model.audio_tower.layers.0.feed_forward2.ffw_layer_1.input_max",
"a.blk.0.ffn_up_1.input_max",
},
{
"ffn norm 1 (new naming)",
"model.audio_tower.layers.0.feed_forward2.pre_layer_norm.weight",
"a.blk.0.ffn_norm_1.weight",
},
// Conformer lightweight conv1d
{
"conv dw weight",
"model.audio_tower.conformer.0.lconv1d.depthwise_conv1d.weight",
"a.blk.0.conv_dw.weight",
},
{
"conv norm (pre_layer_norm)",
"model.audio_tower.conformer.0.lconv1d.pre_layer_norm.weight",
"a.blk.0.conv_norm.weight",
},
{
"norm conv (conv_norm)",
"model.audio_tower.conformer.0.lconv1d.conv_norm.weight",
"a.blk.0.norm_conv.weight",
},
{
"conv pw1 weight",
"model.audio_tower.conformer.0.lconv1d.linear_start.linear.weight",
"a.blk.0.conv_pw1.weight",
},
{
"conv pw2 weight",
"model.audio_tower.conformer.0.lconv1d.linear_end.linear.weight",
"a.blk.0.conv_pw2.weight",
},
// Audio embedder
{
"audio embedder projection weight",
"model.embed_audio.embedding_projection.linear.weight",
"mm.a.input_projection.weight",
},
{
"audio embedder projection bias",
"model.embed_audio.embedding_projection.linear.bias",
"mm.a.input_projection.bias",
},
// Audio output projection
{
"audio output proj weight",
"model.audio_tower.output_proj.weight",
"mm.a.fc.weight",
},
{
"audio output proj bias",
"model.audio_tower.output_proj.bias",
"mm.a.fc.bias",
},
// Verify vision tensors still work
{
"vision q weight",
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
"v.blk.0.attn_q.weight",
},
{
"vision std bias",
"model.vision_tower.std_bias",
"v.std_bias",
},
{
"vision std scale",
"model.vision_tower.std_scale",
"v.std_scale",
},
{
"vision patch embd",
"model.vision_tower.patch_embedder.input_proj.weight",
"v.patch_embd.weight",
},
{
"vision projector",
"model.embed_vision.embedding_projection.linear.weight",
"mm.input_projection.weight",
},
// Verify text tensors still work
{
"text attn q",
"model.language_model.layers.0.self_attn.q_proj.weight",
"blk.0.attn_q.weight",
},
{
"text token embd",
"model.language_model.embed_tokens.weight",
"token_embd.weight",
},
{
"text moe gate up fused",
"model.language_model.layers.0.experts.gate_up_proj",
"blk.0.ffn_gate_up_exps.weight",
},
{
"text moe down",
"model.language_model.layers.0.experts.down_proj",
"blk.0.ffn_down_exps.weight",
},
{
"text moe down with weight suffix",
"model.language_model.layers.0.experts.down_proj.weight",
"blk.0.ffn_down_exps.weight",
},
{
"text moe per expert scale",
"model.language_model.layers.0.router.per_expert_scale",
"blk.0.ffn_down_exps.scale",
},
{
"text moe per expert scale with weight suffix",
"model.language_model.layers.0.router.per_expert_scale.weight",
"blk.0.ffn_down_exps.scale",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := r.Replace(tt.in); got != tt.want {
t.Errorf("Replace(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}

View File

@@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
generateSafetensorTestData(t, tempDir, td)
err = ConvertModel(os.DirFS(tempDir), f)
if err == nil || err.Error() != "unsupported safetensors model" {
t.Errorf("expected error but didn't get one")
if err == nil || !strings.Contains(err.Error(), "unknown data type") {
t.Errorf("expected 'unknown data type' error but got: %v", err)
}
}

View File

@@ -42,8 +42,11 @@ func (t tensorBase) Kind() uint32 {
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.position_embd.weight" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" ||

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
@@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
for _, key := range keys {
if value := headers[key]; value.Type != "" {
// bitsandbytes quantized models are unsupported
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
// Promote them to 1-dim so they can be stored in GGUF.
if len(value.Shape) == 0 {
return nil, errors.New("unsupported safetensors model")
value.Shape = []uint64{1}
}
ggufName := replacer.Replace(key)
if _, ok := names[ggufName]; ok {

View File

@@ -281,6 +281,7 @@ func (kv KV) OllamaEngineRequired() bool {
"deepseekocr",
"gemma3",
"gemma3n",
"gemma4",
"gptoss", "gpt-oss",
"llama4",
"mistral3",

259
integration/audio_test.go Normal file
View File

@@ -0,0 +1,259 @@
//go:build integration
package integration
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
var defaultAudioModels = []string{
"gemma4:e2b",
"gemma4:e4b",
}
// decodeTestAudio returns the test audio clip ("Why is the sky blue?", 16kHz mono WAV).
func decodeTestAudio(t *testing.T) api.ImageData {
t.Helper()
data, err := base64.StdEncoding.DecodeString(audioEncodingPrompt)
if err != nil {
t.Fatalf("failed to decode test audio: %v", err)
}
return data
}
// setupAudioModel pulls the model, preloads it, and skips if it doesn't support audio.
func setupAudioModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
t.Helper()
requireCapability(ctx, t, client, model, "audio")
pullOrSkip(ctx, t, client, model)
err := client.Generate(ctx, &api.GenerateRequest{Model: model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", model, err)
}
}
// TestAudioTranscription tests that the model can transcribe audio to text.
func TestAudioTranscription(t *testing.T) {
for _, model := range testModels(defaultAudioModels) {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
setupAudioModel(ctx, t, client, model)
audio := decodeTestAudio(t)
noThink := &api.ThinkValue{Value: false}
req := api.ChatRequest{
Model: model,
Think: noThink,
Messages: []api.Message{
{
Role: "system",
Content: "Transcribe the audio exactly as spoken. Output only the transcription.",
},
{
Role: "user",
Content: "Transcribe this audio.",
Images: []api.ImageData{audio},
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
"num_predict": 50,
},
}
// The audio says "Why is the sky blue?" — expect key words in transcription.
DoChat(ctx, t, client, req, []string{"sky", "blue"}, 60*time.Second, 10*time.Second)
})
}
}
// TestAudioResponse tests that the model can respond to a spoken question.
func TestAudioResponse(t *testing.T) {
for _, model := range testModels(defaultAudioModels) {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
setupAudioModel(ctx, t, client, model)
audio := decodeTestAudio(t)
noThink := &api.ThinkValue{Value: false}
req := api.ChatRequest{
Model: model,
Think: noThink,
Messages: []api.Message{
{
Role: "user",
Content: "",
Images: []api.ImageData{audio},
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
"num_predict": 200,
},
}
// The audio asks "Why is the sky blue?" — expect an answer about light/scattering.
DoChat(ctx, t, client, req, []string{
"scatter", "light", "blue", "atmosphere", "wavelength", "rayleigh",
}, 60*time.Second, 10*time.Second)
})
}
}
// TestOpenAIAudioTranscription tests the /v1/audio/transcriptions endpoint.
func TestOpenAIAudioTranscription(t *testing.T) {
for _, model := range testModels(defaultAudioModels) {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, endpoint, cleanup := InitServerConnection(ctx, t)
defer cleanup()
setupAudioModel(ctx, t, client, model)
audioBytes := decodeTestAudio(t)
// Build multipart form request.
var body bytes.Buffer
writer := multipart.NewWriter(&body)
writer.WriteField("model", model)
part, err := writer.CreateFormFile("file", "prompt.wav")
if err != nil {
t.Fatal(err)
}
part.Write(audioBytes)
writer.Close()
url := fmt.Sprintf("http://%s/v1/audio/transcriptions", endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &body)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
text := strings.ToLower(string(respBody))
if !strings.Contains(text, "sky") && !strings.Contains(text, "blue") {
t.Errorf("transcription response missing expected words, got: %s", string(respBody))
}
})
}
}
// TestOpenAIChatWithAudio tests /v1/chat/completions with input_audio content.
func TestOpenAIChatWithAudio(t *testing.T) {
for _, model := range testModels(defaultAudioModels) {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, endpoint, cleanup := InitServerConnection(ctx, t)
defer cleanup()
setupAudioModel(ctx, t, client, model)
audioB64 := audioEncodingPrompt
reqBody := fmt.Sprintf(`{
"model": %q,
"messages": [{
"role": "user",
"content": [
{"type": "input_audio", "input_audio": {"data": %q, "format": "wav"}}
]
}],
"temperature": 0,
"seed": 123,
"max_tokens": 200,
"think": false
}`, model, strings.TrimSpace(audioB64))
url := fmt.Sprintf("http://%s/v1/chat/completions", endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(reqBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
}
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
Reasoning string `json:"reasoning"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBytes, &result); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(result.Choices) == 0 {
t.Fatal("no choices in response")
}
text := strings.ToLower(result.Choices[0].Message.Content + " " + result.Choices[0].Message.Reasoning)
found := false
for _, word := range []string{"sky", "blue", "scatter", "light", "atmosphere"} {
if strings.Contains(text, word) {
found = true
break
}
}
if !found {
t.Errorf("response missing expected words about sky/blue/light, got: %s", result.Choices[0].Message.Content)
}
})
}
}

File diff suppressed because one or more lines are too long

View File

@@ -51,6 +51,7 @@ func TestContextExhaustion(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
thinkOff := api.ThinkValue{Value: false}
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
@@ -59,6 +60,7 @@ func TestContextExhaustion(t *testing.T) {
Content: "Write me a story in english with a lot of emojis",
},
},
Think: &thinkOff,
Stream: &stream,
Options: map[string]any{
"temperature": 0,

View File

@@ -15,6 +15,7 @@ func TestVisionModels(t *testing.T) {
skipUnderMinVRAM(t, 6)
defaultVisionModels := []string{
"gemma4",
"qwen2.5vl",
"llama3.2-vision",
"gemma3",
@@ -23,6 +24,8 @@ func TestVisionModels(t *testing.T) {
"ministral-3",
}
skipIfNoVisionOverride(t)
for _, model := range testModels(defaultVisionModels) {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
@@ -30,10 +33,7 @@ func TestVisionModels(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if testModel != "" {
requireCapability(ctx, t, client, model, "vision")
}
requireCapability(ctx, t, client, model, "vision")
pullOrSkip(ctx, t, client, model)
image, err := base64.StdEncoding.DecodeString(imageEncoding)

View File

@@ -0,0 +1,155 @@
//go:build integration
package integration
import (
"context"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// TestThinkingEnabled verifies that when thinking is requested, the model
// produces both thinking and content output without leaking raw channel tags.
func TestThinkingEnabled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
models := testModels([]string{smol})
for _, modelName := range models {
t.Run(modelName, func(t *testing.T) {
requireCapability(ctx, t, client, modelName, "thinking")
pullOrSkip(ctx, t, client, modelName)
think := api.ThinkValue{Value: true}
stream := false
req := api.ChatRequest{
Model: modelName,
Stream: &stream,
Think: &think,
Messages: []api.Message{
{Role: "user", Content: "What is 12 * 15? Think step by step."},
},
Options: map[string]any{
"temperature": 0,
"seed": 42,
"num_predict": 512,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
response = cr
return nil
})
if err != nil {
if strings.Contains(err.Error(), "model requires more system memory") {
t.Skip("model too large for test system")
}
t.Fatalf("chat failed: %v", err)
}
content := response.Message.Content
thinking := response.Message.Thinking
// Thinking should be non-empty when thinking is enabled
if thinking == "" {
t.Error("expected non-empty thinking output when thinking is enabled")
}
// The answer (180) should appear in thinking, content, or both.
// Some models put everything in thinking and leave content empty
// if they hit the token limit while still thinking.
combined := thinking + " " + content
if !strings.Contains(combined, "180") {
t.Errorf("expected '180' in thinking or content, got thinking=%q content=%q", thinking, content)
}
// Neither thinking nor content should contain raw channel tags
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
t.Errorf("content contains raw channel tags: %s", content)
}
if strings.Contains(thinking, "<|channel>") || strings.Contains(thinking, "<channel|>") {
t.Errorf("thinking contains raw channel tags: %s", thinking)
}
t.Logf("thinking (%d chars): %.100s...", len(thinking), thinking)
t.Logf("content (%d chars): %s", len(content), content)
})
}
}
// TestThinkingSuppressed verifies that when thinking is NOT requested,
// the model does not leak thinking/channel content into the response.
func TestThinkingSuppressed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
models := testModels([]string{smol})
for _, modelName := range models {
t.Run(modelName, func(t *testing.T) {
requireCapability(ctx, t, client, modelName, "thinking")
pullOrSkip(ctx, t, client, modelName)
stream := false
req := api.ChatRequest{
Model: modelName,
Stream: &stream,
// Think is nil — thinking not requested
Messages: []api.Message{
{Role: "user", Content: "What is the capital of Japan? Answer in one word."},
},
Options: map[string]any{
"temperature": 0,
"seed": 42,
"num_predict": 64,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
response = cr
return nil
})
if err != nil {
if strings.Contains(err.Error(), "model requires more system memory") {
t.Skip("model too large for test system")
}
t.Fatalf("chat failed: %v", err)
}
content := response.Message.Content
thinking := response.Message.Thinking
// The answer should appear in content or thinking
combined := content + " " + thinking
if !strings.Contains(combined, "Tokyo") {
t.Errorf("expected 'Tokyo' in content or thinking, got content=%q thinking=%q", content, thinking)
}
// Content must NOT contain channel/thinking tags
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
t.Errorf("content contains leaked channel tags when thinking not requested: %s", content)
}
if strings.Contains(content, "thought") && strings.Contains(content, "<channel|>") {
t.Errorf("content contains leaked thinking block: %s", content)
}
// Thinking field should ideally be empty when not requested.
// Some small models may still produce thinking output; log but don't fail.
if thinking != "" {
t.Logf("WARNING: model produced thinking output when not requested (%d chars): %.100s...", len(thinking), thinking)
}
t.Logf("content: %s", content)
})
}
}

View File

@@ -30,6 +30,7 @@ func TestAPIToolCalling(t *testing.T) {
defer cleanup()
minVRAM := map[string]uint64{
"gemma4": 8,
"qwen3-vl": 16,
"gpt-oss:20b": 16,
"gpt-oss:120b": 70,

View File

@@ -45,6 +45,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"gemma4",
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
@@ -137,6 +138,7 @@ var (
"gemma2",
"gemma3",
"gemma3n",
"gemma4",
"glm4",
"goliath",
"gpt-oss:20b",
@@ -272,6 +274,7 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"gemma4",
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",

View File

@@ -5,23 +5,26 @@ package integration
import (
"context"
"encoding/base64"
"slices"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Default set of vision models to test. When OLLAMA_TEST_MODEL is set,
// only that model is tested (with a capability check for vision).
var defaultVisionModels = []string{
"gemma4",
"gemma3",
"llama3.2-vision",
"qwen2.5vl",
"qwen3-vl:8b",
}
// decodeTestImages returns the two test images (Abbey Road llamas, docs llamas).
func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
// decodeTestImages returns the test images.
func decodeTestImages(t *testing.T) (abbeyRoad, docs, ollamaHome api.ImageData) {
t.Helper()
var err error
abbeyRoad, err = base64.StdEncoding.DecodeString(imageEncoding)
@@ -32,9 +35,35 @@ func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
if err != nil {
t.Fatalf("decode docs image: %v", err)
}
ollamaHome, err = base64.StdEncoding.DecodeString(imageEncodingOllamaHome)
if err != nil {
t.Fatalf("decode ollama home image: %v", err)
}
return
}
// skipIfNoVisionOverride skips the entire test (at parent level) when
// OLLAMA_TEST_MODEL is set to a non-vision model. This prevents the parent
// test from reporting PASS when all subtests are skipped.
func skipIfNoVisionOverride(t *testing.T) {
t.Helper()
if testModel == "" {
return
}
// Check actual model capabilities via the API rather than a hardcoded list.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
resp, err := client.Show(ctx, &api.ShowRequest{Name: testModel})
if err != nil {
return // let the test proceed and fail naturally
}
if len(resp.Capabilities) > 0 && !slices.Contains(resp.Capabilities, model.CapabilityVision) {
t.Skipf("model override %q does not have vision capability (has %v)", testModel, resp.Capabilities)
}
}
// setupVisionModel pulls the model, preloads it, and skips if not GPU-loaded.
func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
t.Helper()
@@ -54,6 +83,7 @@ func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, mod
// handles cached image tokens across turns.
func TestVisionMultiTurn(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
// Models that fail on multi-turn detail questions (e.g. misidentifying objects).
skipModels := map[string]string{
@@ -72,7 +102,7 @@ func TestVisionMultiTurn(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
abbeyRoad, _ := decodeTestImages(t)
abbeyRoad, _, _ := decodeTestImages(t)
// Turn 1: describe the image
req := api.ChatRequest{
@@ -100,7 +130,7 @@ func TestVisionMultiTurn(t *testing.T) {
api.Message{Role: "user", Content: "How many animals are in the image?"},
)
resp2 := DoChat(ctx, t, client, req, []string{
"four", "4",
"four", "4", "three", "3",
}, 60*time.Second, 30*time.Second)
if resp2 == nil {
t.Fatal("no response from turn 2")
@@ -121,6 +151,7 @@ func TestVisionMultiTurn(t *testing.T) {
// TestVisionObjectCounting asks the model to count objects in an image.
func TestVisionObjectCounting(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
skipModels := map[string]string{
"llama3.2-vision": "consistently miscounts (says 3 instead of 4)",
@@ -137,7 +168,7 @@ func TestVisionObjectCounting(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
_, docs := decodeTestImages(t)
_, docs, _ := decodeTestImages(t)
req := api.ChatRequest{
Model: model,
@@ -160,6 +191,7 @@ func TestVisionObjectCounting(t *testing.T) {
// cultural references and scene context from an image.
func TestVisionSceneUnderstanding(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
// Models known to be too small or not capable enough for cultural reference detection.
skipModels := map[string]string{
@@ -178,7 +210,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
abbeyRoad, _ := decodeTestImages(t)
abbeyRoad, _, _ := decodeTestImages(t)
req := api.ChatRequest{
Model: model,
@@ -193,7 +225,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
Options: map[string]any{"temperature": 0.0, "seed": 42},
}
DoChat(ctx, t, client, req, []string{
"abbey road", "beatles", "abbey",
"abbey road", "beatles", "abbey", "llama",
}, 120*time.Second, 30*time.Second)
})
}
@@ -203,6 +235,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
// objects based on their spatial position in the image.
func TestVisionSpatialReasoning(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
for _, model := range testModels(defaultVisionModels) {
t.Run(model, func(t *testing.T) {
@@ -212,7 +245,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
_, docs := decodeTestImages(t)
_, docs, _ := decodeTestImages(t)
// The docs image has: leftmost llama on laptop with glasses,
// rightmost llama sleeping.
@@ -239,6 +272,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
// small details like accessories in an image.
func TestVisionDetailRecognition(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
for _, model := range testModels(defaultVisionModels) {
t.Run(model, func(t *testing.T) {
@@ -248,7 +282,7 @@ func TestVisionDetailRecognition(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
_, docs := decodeTestImages(t)
_, docs, _ := decodeTestImages(t)
req := api.ChatRequest{
Model: model,
@@ -274,6 +308,7 @@ func TestVisionDetailRecognition(t *testing.T) {
// encoding and cross-image reasoning.
func TestVisionMultiImage(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
// Multi-image support varies across models.
skipModels := map[string]string{
@@ -291,7 +326,7 @@ func TestVisionMultiImage(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
abbeyRoad, docs := decodeTestImages(t)
abbeyRoad, docs, _ := decodeTestImages(t)
req := api.ChatRequest{
Model: model,
@@ -314,10 +349,12 @@ func TestVisionMultiImage(t *testing.T) {
}
}
// TestVisionOCR tests text extraction from an image. The docs image
// contains the text "Ollama's documentation" in a header.
func TestVisionOCR(t *testing.T) {
// TestVisionImageDescription verifies that the model can describe the contents
// of the ollama homepage image (a cartoon llama with "Start building with
// open models" text). Basic sanity check that the vision pipeline works.
func TestVisionImageDescription(t *testing.T) {
skipUnderMinVRAM(t, 6)
skipIfNoVisionOverride(t)
for _, model := range testModels(defaultVisionModels) {
t.Run(model, func(t *testing.T) {
@@ -327,22 +364,22 @@ func TestVisionOCR(t *testing.T) {
defer cleanup()
setupVisionModel(ctx, t, client, model)
_, docs := decodeTestImages(t)
_, _, ollamaHome := decodeTestImages(t)
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: "What text appears in this image? Read all visible text.",
Images: []api.ImageData{docs},
Content: "Describe what you see in this image briefly.",
Images: []api.ImageData{ollamaHome},
},
},
Stream: &stream,
Options: map[string]any{"temperature": 0.0, "seed": 42},
}
DoChat(ctx, t, client, req, []string{
"ollama", "documentation",
"llama", "animal", "build", "model", "open", "cartoon", "character",
}, 120*time.Second, 30*time.Second)
})
}

View File

@@ -383,3 +383,162 @@ yEUu0pztbKtys2RR9bUiUBGoCFQE5oTAL3/5y+ab3/xmc9JJJzWf+cxnmq9+9atzKXmuDGQuNaqFVAQq
VBGoCFQElgKBykCWoptqJSsCFYGKwOIhUBnI4vVJrVFFoCJQEVgKBCoDWYpuqpWsCFQEKgKLh0BlIIvXJ7VGFYGKQEVgKRDYOWr5q6Woaa1kRaAiUBGoCCwU
Av8fgwPy24mbuF8AAAAASUVORK5CYII=
`
// imageEncodingOllamaHome is a 415x293 JPEG of the ollama.com homepage.
// Shows a cartoon llama character with text "Start building with open models".
const imageEncodingOllamaHome = `/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA0JCgsKCA0LCgsODg0PEyAVExISEyccHhcgLikxMC4pLSwzOko+MzZGNywtQFdBRkxO
UlNSMj5aYVpQYEpRUk//2wBDAQ4ODhMREyYVFSZPNS01T09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09P
T09PT09PT0//wAARCAElAZ8DASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUF
BAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVW
V1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi
4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAEC
AxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVm
Z2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq
8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD06iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiq
2o39rpllLeXsqxQRDLMf5e5oAs0V5XffEXXL6WeXQdOC2dsNzu8ZchfVuwrufCOvDxFocd80YjlDFJUHQMPT270AbdFFFABRRRQA
UUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
AFFFFABRRRQAUUUUAFeUeI7u68b+L00PT5CLG2chnHTj7zn+Q/8Ar12XjTxLZ6LpNzD9pUX8sRWGIcsCRjJ9BXmPg/xPJ4djuDa6
V9rnnI3SliMKO3A9eaAO/wDFyWHhfwDPYWSLGJlECDu5PUn1OM1V+HuoaVovhWFb7UbWGa4kaUo0oyAeBkduBXMXc2t/EfV44obc
W1vbLyGJKR56knHJPpXT2fwr0mOMfa7y6mkxyVwg/Ac0AdpZ6lY3wzZ3kE//AFzkDfyq1XmupfDF7YfafD2oypcJyqSnBP0YdKzU
+IWuabp1xpV/bltUjPlpM45X13DufQ96APQtf8U6T4fTF9PmYjKwxjc5/Dt+Nc7pnxP0291GO1ns5rZJWCrKzhgCemR2qn4V8Atd
v/a/ikvNPMd4t3Y9+7n19qzfHsVrfeLNM0PSoIkaHCMIkAwWI449AM/jQB61RSKNqgegxS0AFFFFABRRRQAUUUUAFFFFABRRRQAU
UUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFBIAyTgCsifxBbCYw2UU17KvUQLkD8e
lAGvRSIxZFYqVJGSD2paACub8b+JB4c0fzItrXk5KQKegPdj7D/Cukry7xWn9s/FLT9LnOYItgK9iMbz+fSgCfwh4I/tEDW/Exe4
luD5iQyE8g/xP/hXosFtb20Qit4I4kHRUUAD8qkAAAAGBXC6d4pvtL8XXGieIZ45Yp5M2064AXJ4U47duehoA7pUVc7VAzycDrS1
FPc29uM3E8cQ9XcL/OiG5t7gZt54pR/sOG/lQBLXn/xH8OXt3dWesaNbNLdQnEojGWOOVOO+OlegUUAeXr8ULuCxuLfUNMMeoouI
yMhd3+0p5HrV34ceHbgzSeJNWDNc3GTCH64PVz9e3tW94z8LW3iDTJGWNVv4lJhlA5J/un1BrK+F2tzXumTaXeMTPYkBN3XYeMfg
ePyoA7qiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKK
KACiimyIJInjJIDKRkdRmgDAmeTXrmWJZTDpduSJXBwZiOoz6ClsLq4uGWLQ7OGGwjbBmkBG/wBdoHX61dk0dP7DOl20rRJtC78Z
JGcnP1qle+IbDSsWNvG0hiXZ8mAF9s+tAHK6/wCKdc1nxDJofhTKCElXmXGWI6nJ+6oPFVk8ReKvCGoxReJQ13Zyn7+Q3Hcqw7j0
NSfCJkbUdXZv9aQpyeuMnP613+u6Pa65pcthdr8rjKt3RuzCgC1Z3UF7aRXVrIJIZVDIw7g15p468zQvH2na8ULQPtLY9V4Yf98k
VF4X1u58FaxLoGvZW0LZSTshP8Q/2T+n51FbxXHxF8XySTM6aTaHgA9FzwB/tNjJNAHpd9Aut6G8VpevCl1GCk8J5APORXC63oWk
eCdIW/hha+1SSQJBLcfMFfruC9OP8K9Gt4Ira3jggRY4o1Coq9FA6CuT+Jem3F74fS6tFLS2Mon2gZyuOfy60AYkfhCxAt7rxnqs
0t/ek7IzLtXdjO3d/wDqFY+k6XpOrXwttEfUtK1VC/RvMiUr0y4wRmug8RahB4s8F295ZR/aHtpo5Lq3UZkUDhgO/wCPpVXQtK/t
PxBJP4b/ALQ0jRgi+a24r5zjoFB/Xr+tAGx4V8SajHqz+HPEqhb9B+5m7TD+vHQ96u6z480PSpXt/Oe6uUO0xQLuwfQnpWN43eKb
xx4dgsyDfRygvt6qm4EZ/JjS6h4avtE8XQa1oNot1b3MmLi3IB8sk8kE9B3z2+lAHdWdwt3Zw3Ko6LKgcK4wwyM4I9a828G4i+KO
sxQcRHzuB0++K7/XdWg0XSJ7+5YARr8o7u3YD8a4j4U2E0smoa7cg7rhiiEj73OWP54FAHo9FFFABRRRQAUUUUAFFFFABRRRQAUU
UUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFADJiwhcp94KcfXFcl4NgguJLuS4RZ
JRj74zgHOf1rsK4u4EnhvxCZ1Um0nJOB3U9R9QeaAOdsivhL4pSwP+7s70kL2AV+V/JuK9WriviBoS+INCj1LTsSXNqpdCv/AC0T
uPqOtWPh94mXXdJFvcv/AKdagLICeXXs3+PvQA34mWFnP4VuLueBWuLfHkydCuWA/L2pPhdaR2/hCKZQN9xI7ufXBwP0Fb/iHTf7
Y0K80/IDTRkKT2bqP1ArB+G9lq2m6LNZarbGBYpj5O48kHr+Gen1oA6+ggEYNFFAHF6t4Bie+OoeH76TS7snJEedhP0HT+VVjonj
9h5B8QWwj6eYBhv/AEHNd7RQBzHhfwdb6HO99c3D3uoyZ3Tyds9cf4109FFAHlfxYiu11ewmupXfTGGFjU42sD834kdDXpWlwWtv
pltFYIEtljXygP7uMiuT+LMaN4TV2A3JcptP1Brd8Hu0nhLS2fkm2T+VAGzRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
AFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAVWv7GDULVre4XKnoR1U+oqzRQBxSSaj4WuSki+dZ
O3H90/T0PtXK69GNF1mPxP4bfEJfM8GMGJj1BH91v5/hXrsscc0bRyorowwVYZBrl9W8HxTK5sGChgQ0Mn3SPQHtQBpaV4k07UtB
OrrMscEa5mDHmIjqD/nmrOjavY63Yi806XzItxU5UggjsRXiOv6VqXh2WSzJmitrz+DPD4OcH1we9e0eGNKi0bQLSyjXDKgaQ/3n
PJNAGrRRRQAUUUUAFFFFAHmvxW1H7XNY6BaZknaQSOq9ieFH6k13+lWY0/SrWzHSCJY/yFcJ8U9GEKQeIrMmO5idUlZT1/ut9QeP
yrtfD2o/2toNlfnAaeIFsf3uh/UGgDRooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA
KKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigDzD4rZGuaKW/1fP/oQzXpy4IBHTtXC/FnTHutBgvolJazky+OytwT+YFdH
4T1P+1/DdleEEM0e18/3l4P8qANiiiigAooooArajfW+m2E17dvshhUsx/w968zOveL/ABldSJoKNZ2SHG5Ttx/vP6+wrqfibb3F
x4On+zgt5ciSSAd0B5/ofwqD4d65pEvh610+KaKC6hXbJE5Clmzyw9c0Ac3efD3xPPaO0+sJcSEZ8lpnIb8TxWr8M9ddQ/hq/i8m
4tN3l5GCQD8yn3BP5V6CzKqFmYBQMkk8AV5Xpk0er/F57zTObeMszyL0YBNpP4mgD1WiiigAooooAKKKKACiiigAooooAKKKKACi
iigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAGyIkiMkiK6MMFWGQR9KSKK
OGJYoY1jjUYVVGAPoKfRQAUUUUAFFFFACMoZSrAEEYIPeuJ1r4aaTfyvPYSyWMrHO1BuTP07fga7eigDy9vhtrxUwHX1NueCpaTG
P93pXaeF/DFj4aszFbZknkwZZmHL+3sPatyigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
gAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAoo
ooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAK
KKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKA
CiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
gAooooAKZPNHbwSTTOEjjUszHoAOSafWR4tt5rrwrqcNsC0rW7bQOp4zj8qAPPdQ8a+IvEeqNY+GIpIoudvlqPMYf3mY8KPy+tJN
p/xIsIzdfabuXb8xVbhZSP8AgPf8Ki+F2vaZpVxeW2oSJbtc7DHM/C8Z+Unt1zXrcM0U8YkgkSRD0ZGBH5igDiPAvjmTWbn+zNWV
EvcExyKNolx1BHZu/wCddJ4n16Pw7pQv5YHmXzFj2owB5zzz9Kz/APhB9LXxH/bkU11Hced52xWUJu78Yzg/XvXH/EvUtdke5sJ7
DZpUcyGK48phuO3+9nB5J/KgD0Lw3rcfiDSU1CKF4VZ2XYxBPB9q1a8j8Cax4lt4LKysdM83THuQHn8hmwCw3fMDjiu+8W+J7bwz
p6zSoZriUkQwg43EdST2AoA3qK8qi8W+OtRiN5Y6Xm2PK+XallI9iTk/hWz4P8ftq2oLper26W92xKo6AhWYfwkHoaAO8orkvH/i
W+8N2tnLYJAzTOyt5qk8AA8YIrnLjx9r1/b28OhWHnzrCrXMkUDSAOR90DsB70Aej6neLp+mXV66F1t4mlKg4JwM4rF8JeLYPFBu
hDaSQfZtmd7A53Z9PpWB4w1jxGmgW0cOn+ZDdaduvZDA37pivzd/lxz1ri/Bmq+INMN5/YGn/a/M2eb+5aTbjOOh46mgD3aq2oyP
Fpt1JG210hdlI7EKafZPLLYwSXCbJnjVpFxjaxHIx9ai1b/kE3n/AFwf/wBBNAHC/DLXtV1i/vo9SvZLhY4lZQwHBJ9hXoteH+Bv
ENv4cGpXUqGWZ4kSGEHBds/yrYu/Gfje2T7dPpYgtTyN9owQD3JOaAPWKK5nwZ4vg8TW8iNGIL2EZkiByCP7y+38qxvGXjPVND8S
RafZx2zQvGjEyIS2SSD0I9KAO/oorgPHPjPVPD+vQ2VjHbNE8CyEyoSclmHYj0oA3/G76zH4ekbQBIbneu7yhlwnOdvv0/Wl8Evr
L+HYm14SC63tt8wYcp23D16/pSeNdautC8Ptf2SxNKJEXEikjB+hFL4K1m617w+l9erEsrSOpEYIGAfcmgDkPAviLWNS8YTWd9fy
TW6pKQjAYBBGOgr06vHfht/yPtx/1zm/9CFeg+MPFVt4ZskZk866mz5UOcZx1JPYfzoA6GuA+KGuapoz6aNMvHt/NEm/aB82NuOo
9zWNaeM/G18DeWemCa2B6R2jMh9s5z+tZXjnxJD4jstLlEZguYDKk8JP3T8uCPY4P5UAenm41SfwNHc2DeZqUlijoxAyzlQSfTPX
8azfh5L4jktLv/hIRcbA6+QbhcP33e+OlXY72XTfh5Be24UywaajqHGRkIOtU/h/4mv/ABJb3sl+kCmB0C+UpHUHrkn0oA6+ivNr
r4hXlh4vubG9S3Gn280iMyxkyEAHAHOMk4FVb3xn4ynja+stHe3sfvK32Zn+X1LH+YGKAPU6K4nwN45OvznT9QijivApZGjyFkA6
8Hoe9b3ifxDa+HNMN3cgu7HbFEpwXb+g9TQBsUV5Nb+N/GWqyvNpenK8KHlYrZnA9i3rVu++Jd5HpC7LSK21WKcJPBMjEbcH5gMg
jkDg9KAPTqKyfCupT6v4cs9QuggmnUlggwvDEcflWpLv8pvKID4O0sOM9qAHUVwngjxlqOta5c6bqsVvG8cZZfKUg7lYAg5J9f0r
b8ba9L4e0Bry2EbXDyLHGJBkZPJ4HsDQB0FFcl4A8U3PiS1u/tywrcW7rxECAVI44JPcGmfEDxXdeG47JLBYXmnLFhKpICjHoR3P
6UAXvHL61H4eZtAEhuPMXf5Qy4TnO33zj8Kk8FvrD+HYW14OLrc2PMGH2dtw9f8A61Z+v+IdU0jwTaaqUt/t0vl+YrIdg3AkjGf6
1DZeJNbv/AX9r2dpFPqJlKrFHEzAgNg8Zz096ALEfji3fxWdA+wyiQTmHzd4xkd8V1MzFYHZTghSR+VeBR3+rr4yN8lnnVftBf7P
5Z+/zkbc5r17wxqGsajodzNrtn9kuFdlVPLKZXaOcE+pNAHL/DPxDq+r61dQ6lfSXEaW5dVYDg7gM8D3r0qvC/Auuw+H7y+vJUMs
jW/lwxL1kcuuBW5f+M/G1ov2y50wW1sTwHtWCj0BJOaAPWKK5zwb4rh8TWTsYxDdwECWIHI56MPb+VdHQAUUUUAFFFZHiqbUrbw7
d3Gjvtu4VDr8gbIB+YYPtmgDC8QfDjS9Vne5s5XsZ3JLbFDIx9dvb8DXJ3Hw88TaUxn0q7SUryPIlMT/AJHH863vBPj+O8SS18Q3
kcdzvzFM4CIy+mRwCD6+tdnPrWlW8Jmm1G0SMDO4zL/jQB554O8canDrEejeIC0m+TyVkkXEkb5wA3qM8c81vfFf/kUB/wBfKfyN
cNdTJ4o+JUc2lxnypbiMhsYJVAMv7cKTXc/FZSfB+QOlyhP60ASfC3/kTYf+u0n86t+L9P8ADMqRX3iVgqxgpGTKy574CqeTVD4V
3EL+ElhWRDJHM4dc8jJyOK5D4nSNJ41hhvXdbVI49uOyE/MR79fyoA6pviZ4bto1ighvHRAFUJEAAB0AyRXCXGq22rfEW21KwieG
Oa8gIVwAc5UE8epFer2ln4XstNWa3h0xLVVyJSEII9Sx615Tf6ha6n8R4LqxQLbNeQrHhdoIUqM498ZoA634xf8AIP0z/rq/8hXS
eAbWG18Haf5KBTLH5rkdWYnqf5fhXN/GL/kH6Z/12f8AkK6rwV/yJ+lf9e60AT+Kv+RV1b/r0l/9BNcL8Gvvav8ASH/2eu78UKW8
L6qqjJNpLx/wE15/8HbiGO51OB5FWSRY2RScFgN2cfmKAPVKqat/yCbz/rg//oJq3VTVv+QTef8AXB//AEE0AeS/CrTYL3xHJc3C
BxaRb0BGfnJwD+HNexyIkkbRyKGRgQysMgj0NeH/AA912HQtfL3h22twnlSPjhDnKk+3H617Be6/pNlYteT6hb+SFyCsgYt7ADqa
APK9Aj/sP4qfY7c4hFy8AH+wwOB/L8qk+J3/ACPFt/1wi/8AQjTPBizeIPiK+qFCsaSPcv8A7IOQo/Mj8jUvxYikg8U2t1j5Ht12
ntlWOR+o/OgD2CvHvix/yN1r/wBeqf8AobV6bp/iDStQ0+O9hvrcRsoZg0gBT1DA9CK8f8eazb634qM9m2+CFFhR+z4JJI9sk0Ae
hfFL/kTW/wCu8f8AWl+Fv/InRf8AXaT+dJ8Uv+RNb/rvH/Wl+Fv/ACJ0X/XaT+dAHHfDf/kfbj/rnN/6EKZ41B1f4mJp8jHy/Mht
h7KcE/qxp/w2/wCR9uP+uc3/AKEKPiJBPo/jqHVkQlZTHPGexZMAj9B+dAHr0EMVvAkECLHFGoVFUYCgdBXkvxb02C11m1vYUCNd
xt5mB1ZSOfrgj8q9K03xDpOpWCXdvfQBCuWDyBWT2YHpXk/xJ1+31vWYksW8y1tEKCUdHYnLEe3QUAegXv8AySs/9gpf/RYrD+Dn
/Hnqn/XSP+TVuXv/ACSs/wDYKX/0WKw/g5/x56p/10j/AJNQBzwtIb74tSW1wgeJr9yynocZOD+Ve0dBXj1j/wAlkb/r+l/k1exd
qAPGtPhSy+LoitwERb1wqjgAEHj9am+LNxJceJbSzB+SK3UqP9pmOT+gpsf/ACWM/wDX8f8A0Grfxd02WPULLVUU+W8fksw/hYEk
fmCfyoA9M0uwg0vToLK2QLFCgUYHX1P1PWvPfjBp0CxWOpogWZnMLsP4hjIz9MH866vw74t0vV9Lime9ghuAg86KSQKVbv16j3rg
vif4jtdWmt7DTpBNBasWklTlS54AB74Gfz9qAO7+H3/IkaZ/uN/6G1dHXOfD7/kSNM/3G/8AQ2ro6APJbhf+Ef8Ai/G4+WG5nDex
Eowf/HifyrQ+JrvqfiDRtBhPLsGYD1Zto/IA0nxds2ifTdWh4ZGMLN6H7y/+zVF4YnHif4lz6wATBbQ7kyOh2hQPzLGgA8Mxr4d+
KN7pSjZb3SsIl7YI3r+mRVbxoDrvxKs9KBykflxMPQH52P5H9Kv/ABJjOl+JdF1+MYCuFkI/2Gz+oJH4VW8BL/bfj7VNbOTHGXdC
R0LnC/8AjoNAG/8AFUAeD8AYH2iP+tTfC/8A5Eu3/wCusn/oVRfFb/kUP+3mP+tSfC//AJEu3/66yf8AoVAHFW//ACWI/wDX+/8A
I17Bcf8AHtJ/uH+VeO+bHa/F1pLh1jQX5yzHAGen8xXsMxDWshBBGw9PpQB498KLKG68UvLMgY21u0keR0bIGfyJr2G5t4ru2ltr
hA8UqlHU9CCMGvJvg/8A8jDef9eh/wDQ1r16gDx34WlrfxlcQKx2mCRT74Yf4V7FXjvw2/5Hy4/65Tf+hCvYqACiiigAooooA4zX
Phxo+qXD3Ns8ljM5y3lAFCfXaen4EVjx/CWMPmXWXZPRbcA/nur0uigDF8O+FtL8OxsLCJjM4w80hy7D09h7Cr+qadbatp01jepv
gmXDAHBHcEe4PNW6KAPPrH4YQ2Or217Dq0hSCZZRG0IydpBxkH29K6XxN4W07xJAi3geOaP/AFc0f3l9vce1blFAHndp8KLGO4D3
epTTxA58tYwmfqcmtS68AWE2vQanDcSQLA0RSBEG0BMYHr2rsKKAMDxX4Xh8TQW8U9zJAIGLAooOcjHetPSNPTStKtrCORpFt4wg
ZhgnFXKKAEdFdGR1DKwwQehFee3/AMKrKa5aSx1KW2jY5EbRiTb7A5HFeh0UAQ2cH2Wygtg2/wAmNU3YxnAxmluoRc2ssBYqJUZC
R2yMVLRQBxulfDvTLG3vLe4nlu4rpFUh1ClCDkMCOhrKb4TWpnJXV5hD2Uwgt+ecfpXo9FAGXoGg6f4fsvs2nxEBjl5GOXc+pNJ4
h8P2HiKx+y36N8p3RyIcMh9R/hWrRQB5xH8JrUTgy6vM0WeVWEK355P8q0tU+HGmX0tsbe4ltI7eERKiKDnBJ3EnqSTXa0UAZPiT
Q4vEGknT5p3hUur7kAJ4+tL4b0SPw/pK2EMzzKrs+5wAefpWrRQBynh/wRbaHrb6nFezSu6suxlAHzHPatrW9EsNdsTaajFvTOVZ
ThkPqD2rRooA83/4VNa+fkavN5Ofu+SN355x+la2pfDrSruws7O2mltY7XecqAzSM2Mlie/y12VFAGZNo0cvhr+xDM4j+zC38zA3
YC4zjpVLwn4Wg8MRXMcFzJOLhlJ3qBjGfT610FFAHKQ+CLaHxYdfF7MZTM0vlFRtyQeM9e9dX2oooA5RfBFsviv+3/ts3m+cZfK2
jbnGMZ61Y8Ya1o2m2iWmu2001vdqQAse5TjHfIweQa6OqOr6TY61YtZ6hCJImOR2Kn1B7GgDgrP4feG9ZiW90vVbg20nzbAVYp7H
IyD9ayvH0Gh6NpVpoejlWmWbzp2Dbm4Ugbj689O1a1z8J4/NJstYkjQ/wyQ7jj6gjP5VpaF8NNL025S5vp3v5EOVVlCJn1I5z+Jx
QBueCbaSz8IaZDMpVxDuIPUbiW/rW5RRQBl+I9Eg8QaS+n3EjRhmVw6gEqQff8R+NU/CfhS28MR3K288k7XBUszqAQBnA4+proKK
AMfxP4ft/EemCyuJGiCyCRXQAkEZHf2JqLwp4YtvDNrPDbzPMZnDs7gA8DAHH4/nW7RQBkeJtCi8RaV9gmneFfMV9yAE8Z9frT/D
mix6BpCafDM8yIzNucAHk57VqUUAcj4o8BWHiC9N8tw9pdMAHZVDK+OASOOfxq/4V8NDw7pE1h9rNz5shkL7NuMqBjGT6Vv0UAct
4V8FW3hq/lu4LyadpYvLKuoAHIOePpXU0UUAcp4f8EW2ha0+pxXs0rurLsZQB8xz2rq6KKACiiigAooooAKKKKACiiigAooooAKK
hvLuCxtJbq6kEcMKl3Y9gK80uviHrmrX7W/hrTMoM4JjMkhHqQOFoA9Rory23+IWvaRfLB4l0zCN1xGY3A9Rng/55r0qyvYL+xiv
LSQSQzJvRh3FAFiivLdG+J10Zbp9Yit/KihLRJCpVpJNwAXJJ4wSfwqK78ZeNljN+NK8iz6jNqxUD1JPP48UAer0VyXgrxrD4kD2
1xEtvfRruKKcrIvqv+Fa/iTxBaeHdNN3d5dmO2KJTzI3p7D1NAGtRXlUPjDxtq+660rSx9mB48u3Lj6bieT9K1/CvxAe+1FdK122
W1u2bYjgFQW/usp5U0Ad9RXN+Otdu/D2hpe2KxNI06xkSqSMEE9iPSuWi+Ier32m29vpenLd6q4ZpvLiYpENxAwM8nGD1xQB6bRX
kqfELxLpOoLHrlguw8tE8JifHqp//XXqWn3sGpWEF7atuhnQOh9j6+9AFiiivOdT8f3um+MptOuFtl0+GYK7eWS+3GT36/hQB6NR
XleoeNfF8kbahaaS1tp33lZrdnG31LH+YwK0rD4mwSaBLPdWw/tKNhGlvGTiUnoR3A4569vWgD0KivJr7xr41sdt5eaatvbMeBJa
sF+mSc/rXeeEfEsPiXSzcJH5U8TbJos52nsQfQ0AbtFcn4y8bW/hsrawRC5vnXdsJwqD1Y/0rlB4t8dm3+3jS/8ARsbs/ZG249eu
ce9AHq9Fcj4N8cW/iNzaXEQtr5VLBAcrIB1K+/tWz4j1608PaW17d5Y52xxqeZG9P/r0Aatcl8SdUvtJ8PQ3GnXLQStcqhZQORtY
45+grkoPHPjDVpnk0nTleJDysVuZAPYt6/lVfxZ4sOu+GBY39sbTU7a6QyREEBhtbkA8jqOD60AeheBL661LwlZ3d9M008hfc7Yy
cOQOnsK6CuW+Gv8AyI9h9ZP/AEY1Y3iX4hTQ6k2leHLVbq4VijSFS4Ldwqjr9aAPQqD0ryuTxn4z0YpPrGlg27HB8yAp+G4dD9a9
B8P65aeINLS+syQD8ro33o27g0AcKviLWD8T/wCzDfyfYvtZTycDG3HTpmvTR0rx5P8Aksf/AG/H/wBBr0Lxb4nt/DOnJNJGZriY
lYYgcbiOpJ7AUAb9FeXQ+JvH2owfbrLS0+zNyuy3yGHtk5P4VueDfHX9t3h0zU7dba/AO3bkK5HUYPII9PrQB2tFFFABRRRQAUUU
UAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAcH8XLuSHw9bWqHAuLj5/cKM4/PH5Vo/DbTobLwjbSoo826zLI2OTyQB+AH86rfFLTJ
b7wwLiBSzWcolYAZ+TBBP4cH8Kr/AA08SWU2gxaVcXEcV1a5VVdgvmITkEZ64zjHtQBqfEXT4b7whdvIgMlsPOjbHKkHn8xmsf4R
3ckug3tq7ZW3mynsGHT8wfzqf4k+JLK30CbTILiOW7ugE2IwbYucknHTpj8ad8LNMlsvDMl1MpVryTegIx8gGAfx5oA4f4babDqP
i5PtCB0to2n2sMgkEAfqc/hXtxAKkEZB6g14N4J1qPQvE0d3cZ+zuGimYDO1T3/AgV7TLrukw2RvH1K1Fvt3bxKCCPbHX6UAeUPC
nh/4sRw2g8uIXiBVHQLIBkfTDEVN8VrszeKoLWVm8i3hXhevzHJP1xj8qh0d38VfE5b6KNhELgXByPuxpjbn8lH41ofFaxmtNfs9
XRN0UiKpJGQHQ5wfqP5GgDWtvib4ftLaO2t9PvkhiUKihEwAP+BVxnjbX9O17VLfUNMgngmVNsrSBQWIPyngnn/61esaNfeH9YsI
7q1Sy+ZQXQqgaM9wRWDrnjDQtN1NLCy0uDUZW4byQmAxOAucHJ+lAEPxHuDd/D/T7liCZpYZDj3jY1f+FlpDB4RjuEQCW5kdpG7n
BKgfp+tV/iqMeDrcbBHi5j+QdF+VuK0Phn/yJFl/vSf+hmgDP+LdtFJ4ZhuGUeZFcqFbuAQcj9B+VXvhi7N4JtAxzteQD6bzVf4r
/wDIoD/r5T+Rqb4X/wDIlW3/AF0k/wDQjQB1x6V4xqVtFd/FtredQ8T3qBlPQjA4r2c9K8euP+SyD/r+T+QoA9fZVZCjKCpGCCOC
K8a+H9pA3xAZGjBW381owecEHA/LNez9q8f+Hv8AyUO5/wB2f/0KgD07xHDHceHNSilUMhtZOD6hSQfzFeffBtj9p1Vc8FIjj8Wr
0XXf+QDqP/XrL/6Ca85+Dn/H3qv/AFzi/m1AGZoUS+IPifJJegSR+fJKVIyCEztH04X8q9nrxdpG8HfEt57pW+zGZmzjrFJnkeuM
/pXrQ1jTDZ/bBqFr9n27vM81cYoA8m8TRJoHxLjnsgI1MsU4UDgbvvD6Hn86ufGC6d9asbPPyRW5kA92Yj+SiqU0p8ZfEmN7NWNt
5qYbHSJMZY+mefzFavxg0+QXVjqaqTGUMDn+6QSw/PJ/KgD0XRNOh0rSLayt0CpFGAcfxHHJPuTXC/F/TYfsNnqioBMJfIdh1ZSC
Rn6YP510/hfxRp2saPBIbqGO5RAs0TuFZWA5OD2PXNcR8UvEVpqIt9LsJVnSB/MmkQ5UNjAUHv1NAG54au3sPhI11EcSRQTlD6He
2P1rivAviLSvDlzc3WoW1xNcSKEiaNVOwfxdSOvH5V3XhCyOpfCxbEHDTwzop9CXbH61y3w41Gw03UrzS9bjhiaVgEadR8jrkFST
0z/SgDb1D4k+H9QsJ7O4sL5op0KMCid/+BVl/B+6ddWv7Pd8kkAlx7qwGf8Ax6u+1S78P6VYvd3a2Soq5UBEJc+gHc1n+DPENt4g
e5ktNHFmkICmUFfmJ/h4A9M/lQBxKf8AJY/+34/+g1vfFfRb2+trPULSJ5ktgyyooyVBwQ2PTjn8KwU/5LH/ANvx/wDQa7fxZ4yH
hi6t4ZdOknSdCyyLIFGQcEdPp+dAHP8Ah74mWEGn21nqlpNE8Max+ZCAykAYzjgj9a3tItvCOt6yda0xklv1bzWIkdWBxjJQ/wCF
XTpvhnxHZLetZ2dxHKu7zQArD6kYINeX20EOm/Ey3t/D87SwJdoisrbvlON657gfMPwoA9vooooAKKKKACiiigAooooAKKKKACii
igAooooAKKKKACiiigBGVXUqwBUjBBHBrhtX+GGlXtw01jcS2Jc5MaqHQH2HBH513VFAHB6T8L9Ks7hZr+5lvdpBEZUIhPuOSfzr
ugqpFsRQqqMAAYAFOoIyCKAPFfhpY22p67f2V7EJYJbJwyn/AH059j710k3wntGuC0GqzRw54RogzD8cj+VbPhXwPD4b1SS+jv5L
gyRGLa0YXGSDnr7V1tAGP4c8Nad4ctmisUYySY8yaQ5d/wDAe1XtS0601Wyks7+FZoJByp/mD2PvVqigDzm6+E9m8pa01WaKMnhZ
Ig5H45FbnhvwHpWg3C3eXu7tfuySgAJ7qo6H35rqqKAMfxRoEXiPTEsZp3hVZRJuQAngEY5+tS+HtHj0HR4dOileVIixDsACcknt
9a06KAMjxNoUXiLSxYTTvCvmCTcgBPGfX60/w5osegaRHp8MzzIjMwZwAeTntWpRQAVykngi2fxYNfN7MJfOEvlbRtyB0z1rq6KA
CuV0LwTbaLrsmqxXs0ruHBRlAHzHPauqooAhvbcXdjcWrMVE0bRlh1GRjP61geE/CFv4YluZILuWc3CqCHUDGM+n1rpaKAMjxD4b
03xFbrFqER3p/q5UOHT6H09jXG/8Klt/Nz/bEvl/3fIGfzz/AEr0migDH8O+GdM8OwNHYRkyP/rJpDl3/HsPYVf1GwtdTspLO+hW
WCUYZT/Meh96s0UAec3Hwns2nLW2qzRxE/ceIOQPrkfyrUl+HWlf2ENMglmiJlWWSfAZ5CAQAewHzHgV2VFAGb4f0iPQtGh02KVp
Uh3YdgATlie31rI8SeBdK1+c3TF7W7b70sQGH/3gev1rqaKAPObb4T2iTBrnVZpY8/dSIISPqSf5V3emabZ6TYpZ2EKwwp0A7n1J
7n3q3RQByg8EWw8V/wBv/bZvN87zfK2jbnHTPWtrW9EsNdsTaajFvTOVYHDIfUHtWjRQB5vJ8J4PMPk6zMkZ/haEE4+oI/lXSeGf
Bel+HXM8Iee7Ix50uMqO4UDgfzrpKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiig
AooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooo
oAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKK
KKACiiigAooooAKKKKACiiigAooooAKKKKAP/9k=`

View File

@@ -0,0 +1,121 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri, 20 Mar 2026 18:50:38 -0700
Subject: [PATCH] CUDA get_rows q6_k support
---
ggml/src/ggml-cuda/getrows.cu | 80 ++++++++++++++++++++++++++++++++-
ggml/src/ggml-cuda/ggml-cuda.cu | 1 +
2 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
index 2fab33243..dc5c4f57a 100644
--- a/ggml/src/ggml-cuda/getrows.cu
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
s10, s11, s12/*, s13*/);
}
+// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
+// because they lack the simple dequantize_kernel_t (float2) interface.
+// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
+template<typename dst_t>
+static __global__ void k_get_rows_q6_K(
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const int64_t ne00,
+ const int64_t ne11, const int64_t ne12,
+ const size_t s1, const size_t s2, const size_t s3,
+ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t s10, const size_t s11, const size_t s12) {
+
+ const int64_t i10 = blockIdx.x; // row index into src1
+ const int64_t z = blockIdx.z;
+ const int64_t i11 = z / ne12;
+ const int64_t i12 = z % ne12;
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+ const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
+
+ // blockIdx.y iterates over Q6_K blocks within the row
+ for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
+ const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
+
+ // Same dequantization as dequantize_block_q6_K (assumes 64 threads)
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid / 32; // 0 or 1
+ const int64_t il = tid - 32*ip; // 0..31
+ const int64_t is = 8*ip + il/16;
+
+ const int64_t y_offset = iblk * QK_K + 128*ip + il;
+
+ const float d = x->d;
+ const uint8_t * ql = x->ql + 64*ip + il;
+ const uint8_t qh = x->qh[32*ip + il];
+ const int8_t * sc = x->scales + is;
+
+ if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
+ if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
+ if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
+ if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
+ }
+}
+
+template<typename dst_t>
+static void get_rows_cuda_q6_K(
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+ const int64_t nb_blocks = ne00 / QK_K;
+ const dim3 block_dims(64, 1, 1);
+ const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
+
+ const size_t s1 = nb1 / sizeof(dst_t);
+ const size_t s2 = nb2 / sizeof(dst_t);
+ const size_t s3 = nb3 / sizeof(dst_t);
+
+ const size_t s10 = nb10 / sizeof(int32_t);
+ const size_t s11 = nb11 / sizeof(int32_t);
+ const size_t s12 = nb12 / sizeof(int32_t);
+
+ k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
+ src0_d, src1_d, dst_d,
+ ne00, ne11, ne12,
+ s1, s2, s3,
+ nb01, nb02, nb03,
+ s10, s11, s12);
+}
+
template <typename dst_t>
static void ggml_cuda_get_rows_switch_src0_type(
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
+ case GGML_TYPE_Q6_K:
+ get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
default:
- // TODO: k-quants
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
break;
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 5c9dfd032..b8ed3709b 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q6_K:
return true;
default:
return false;

View File

@@ -678,3 +678,113 @@ func ImageEditsMiddleware() gin.HandlerFunc {
c.Next()
}
}
// TranscriptionWriter collects streamed chat responses and outputs a transcription response.
type TranscriptionWriter struct {
BaseWriter
responseFormat string
text strings.Builder
}
func (w *TranscriptionWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
var chatResponse api.ChatResponse
if err := json.Unmarshal(data, &chatResponse); err != nil {
return 0, err
}
w.text.WriteString(chatResponse.Message.Content)
if chatResponse.Done {
text := strings.TrimSpace(w.text.String())
if w.responseFormat == "text" {
w.ResponseWriter.Header().Set("Content-Type", "text/plain")
_, err := w.ResponseWriter.Write([]byte(text))
if err != nil {
return 0, err
}
return len(data), nil
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
resp := openai.TranscriptionResponse{Text: text}
if err := json.NewEncoder(w.ResponseWriter).Encode(resp); err != nil {
return 0, err
}
}
return len(data), nil
}
// TranscriptionMiddleware handles /v1/audio/transcriptions requests.
// It accepts multipart/form-data with an audio file and converts it to a chat request.
func TranscriptionMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Parse multipart form (limit 25MB).
if err := c.Request.ParseMultipartForm(25 << 20); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to parse multipart form: "+err.Error()))
return
}
model := c.Request.FormValue("model")
if model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
file, _, err := c.Request.FormFile("file")
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "file is required: "+err.Error()))
return
}
defer file.Close()
audioData, err := io.ReadAll(file)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, "failed to read audio file"))
return
}
if len(audioData) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "audio file is empty"))
return
}
req := openai.TranscriptionRequest{
Model: model,
AudioData: audioData,
ResponseFormat: c.Request.FormValue("response_format"),
Language: c.Request.FormValue("language"),
Prompt: c.Request.FormValue("prompt"),
}
chatReq, err := openai.FromTranscriptionRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
c.Request.ContentLength = int64(b.Len())
c.Request.Header.Set("Content-Type", "application/json")
w := &TranscriptionWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
responseFormat: req.ResponseFormat,
}
c.Writer = w
c.Next()
}
}

View File

@@ -137,6 +137,7 @@ type Tensor interface {
Bytes() []byte
Floats() []float32
BackendGet() []float32
FromBytes([]byte)
FromFloats([]float32)
@@ -162,6 +163,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
Conv1DDW(ctx Context, weight Tensor, s, p, d int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
@@ -187,6 +189,9 @@ type Tensor interface {
Contiguous(ctx Context, shape ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
// PadExt pads with independent left/right amounts per dimension.
// Arguments: lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 for dims 0-3.
PadExt(ctx Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) Tensor
Stack(ctx Context, dim int, s ...Tensor) Tensor

View File

@@ -1069,6 +1069,21 @@ func (t *Tensor) Floats() (data []float32) {
return
}
func (t *Tensor) BackendGet() []float32 {
n := int(C.ggml_nelements(t.t))
if n == 0 {
return nil
}
if t.sync != nil {
t.sync()
}
data := make([]float32, n)
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
return data
}
func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) {
if len(s) == 0 {
return
@@ -1313,6 +1328,13 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
}
}
func (t *Tensor) PadExt(ctx ml.Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pad_ext(ctx.(*Context).ctx, t.t, C.int(lp0), C.int(rp0), C.int(lp1), C.int(rp1), C.int(lp2), C.int(rp2), C.int(lp3), C.int(rp3)),
}
}
// Permute permutes t according to order. Permute panics if the number of dimensions
// in order does not match the number of dimensions in t.
func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
@@ -1660,6 +1682,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) Conv1DDW(ctx ml.Context, weight ml.Tensor, s, p, d int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_1d_dw(ctx.(*Context).ctx, weight.(*Tensor).t, t.t, C.int(s), C.int(p), C.int(d)),
}
}
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
var tt ml.Tensor = &Tensor{
b: t.b,

View File

@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
s10, s11, s12/*, s13*/);
}
// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
// because they lack the simple dequantize_kernel_t (float2) interface.
// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
template<typename dst_t>
static __global__ void k_get_rows_q6_K(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00,
const int64_t ne11, const int64_t ne12,
const size_t s1, const size_t s2, const size_t s3,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12) {
const int64_t i10 = blockIdx.x; // row index into src1
const int64_t z = blockIdx.z;
const int64_t i11 = z / ne12;
const int64_t i12 = z % ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
// blockIdx.y iterates over Q6_K blocks within the row
for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
// Same dequantization as dequantize_block_q6_K (assumes 64 threads)
const int64_t tid = threadIdx.x;
const int64_t ip = tid / 32; // 0 or 1
const int64_t il = tid - 32*ip; // 0..31
const int64_t is = 8*ip + il/16;
const int64_t y_offset = iblk * QK_K + 128*ip + il;
const float d = x->d;
const uint8_t * ql = x->ql + 64*ip + il;
const uint8_t qh = x->qh[32*ip + il];
const int8_t * sc = x->scales + is;
if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
}
}
template<typename dst_t>
static void get_rows_cuda_q6_K(
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const int64_t nb_blocks = ne00 / QK_K;
const dim3 block_dims(64, 1, 1);
const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
const size_t s1 = nb1 / sizeof(dst_t);
const size_t s2 = nb2 / sizeof(dst_t);
const size_t s3 = nb3 / sizeof(dst_t);
const size_t s10 = nb10 / sizeof(int32_t);
const size_t s11 = nb11 / sizeof(int32_t);
const size_t s12 = nb12 / sizeof(int32_t);
k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne11, ne12,
s1, s2, s3,
nb01, nb02, nb03,
s10, s11, s12);
}
template <typename dst_t>
static void ggml_cuda_get_rows_switch_src0_type(
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q6_K:
get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
default:
// TODO: k-quants
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
break;
}

View File

@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q6_K:
return true;
default:
return false;

View File

@@ -47,6 +47,12 @@ type Validator interface {
Validate() error
}
// PostLoader is an optional interface that models can implement to run
// initialization steps after backend weights have been loaded.
type PostLoader interface {
PostLoad() error
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and

View File

@@ -68,6 +68,8 @@ func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
func (f *fakeTensor) Conv1DDW(ctx ml.Context, _ ml.Tensor, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) PadExt(ctx ml.Context, _, _, _, _, _, _, _, _ int) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {

View File

@@ -0,0 +1,265 @@
package gemma4
import (
"bytes"
"fmt"
"image"
"log/slog"
"slices"
"time"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
*AudioModel `gguf:"a"`
*MultiModalProjector `gguf:"mm"`
*AudioMultimodalProjector `gguf:"mm.a"`
ImageProcessor
imageTokenID int32
imageEndTokenID int32
audioTokenID int32
audioEndTokenID int32
audioOpts *AudioModelOptions
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
Projection *ClippableLinear `gguf:"input_projection"`
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
// Post-projection RMSNorm without learned weight
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
tokenizer.WithSentencePieceNormalizer())
// Look up special token IDs for vision and audio
imageTokenID := int32(-1)
imageEndTokenID := int32(-1)
audioTokenID := int32(-1)
audioEndTokenID := int32(-1)
for i, tok := range vocabulary.Values {
switch tok {
case "<|image>":
imageTokenID = int32(i)
case "<image|>":
imageEndTokenID = int32(i)
case "<|audio>":
audioTokenID = int32(i)
case "<audio|>":
audioEndTokenID = int32(i)
}
}
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
m := Model{
Tokenizer: t,
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
AudioModel: newAudioModel(c),
MultiModalProjector: &MultiModalProjector{},
AudioMultimodalProjector: &AudioMultimodalProjector{},
ImageProcessor: newImageProcessor(c),
imageTokenID: imageTokenID,
imageEndTokenID: imageEndTokenID,
audioTokenID: audioTokenID,
audioEndTokenID: audioEndTokenID,
audioOpts: newAudioModelOptions(c),
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
// Audio input: detect WAV format and route to audio encoder.
if isAudioData(multimodalData) {
return m.encodeAudioMultimodal(ctx, multimodalData)
}
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
t0 := time.Now()
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
t1 := time.Now()
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
numPatchesX := imgW / m.ImageProcessor.patchSize
numPatchesY := imgH / m.ImageProcessor.patchSize
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostLoad() error {
m.VisionModel.InitClamp(m.MultiModalProjector)
return nil
}
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
if m.AudioModel == nil || m.audioOpts == nil {
return nil, model.ErrNoVisionModel
}
t0 := time.Now()
samples, err := decodeWAV(data)
if err != nil {
return nil, err
}
slog.Info("audio: decode", "elapsed", time.Since(t0), "samples", len(samples), "duration_s", float64(len(samples))/audioSampleRate)
// Pad waveform to next multiple of 128.
if rem := len(samples) % 128; rem != 0 {
samples = append(samples, make([]float32, 128-rem)...)
}
// Compute mel spectrogram.
melData, numFrames := computeMelSpectrogram(samples)
if numFrames == 0 {
return nil, fmt.Errorf("audio too short to encode")
}
slog.Info("audio: mel", "frames", numFrames, "elapsed", time.Since(t0))
// Create input tensor [melBins, numFrames] (GGML ne order). FromFloats creates F32.
melTensor := ctx.Input().FromFloats(melData, melBins, numFrames)
// Run audio encoder.
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, m.AudioMultimodalProjector, m.audioOpts)
slog.Info("audio: encoded", "elapsed", time.Since(t0), "shape", audioOutputs.Shape())
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
}
// audioTag marks multimodal data as audio (vs vision) for PostTokenize.
type audioTag struct{}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
inputMultimodal := inp.Multimodal[0].Tensor
numTokens := inputMultimodal.Dim(1)
// Determine if this is audio or vision based on the tag.
_, isAudio := inp.Multimodal[0].Data.(audioTag)
var beginToken, endToken int32
if isAudio {
beginToken = m.audioTokenID
endToken = m.audioEndTokenID
} else {
beginToken = m.imageTokenID
endToken = m.imageEndTokenID
}
if beginToken >= 0 {
result = append(result, &input.Input{Token: beginToken, SameBatch: numTokens + 2})
}
result = append(result,
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
)
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numTokens-1)...)
if endToken >= 0 {
result = append(result, &input.Input{Token: endToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
}
return hiddenState, nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
}
func init() {
model.Register("gemma4", New)
}

View File

@@ -0,0 +1,611 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// AudioModel holds the audio encoder and configuration.
type AudioModel struct {
// SSCP: Sub-Sample Convolution Projection.
SSCPConv0 *AudioConvBlock `gguf:"conv1d.0"`
SSCPConv1 *AudioConvBlock `gguf:"conv1d.1"`
// SSCP output projection (linear).
SSCPInputProj *nn.Linear `gguf:"pre_encode.out"`
// Conformer blocks.
Layers []AudioConformerBlock `gguf:"blk"`
// Output projection to embedder dimension.
OutputProj *AudioOutputProj `gguf:"output_proj"`
AudioModelOptions
}
type AudioOutputProj struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
// AudioModelOptions holds audio model hyperparameters.
type AudioModelOptions struct {
hiddenSize int
numHeads int
headDim int
ffnSize int
numLayers int
melBins int
chunkSize int
maxPast int
maxFuture int
contextSize int
logitCap float32
residualWeight float32
gradClip float32
convKernelSize int
eps float32
}
// AudioConvBlock is a single 2D convolution block for the SSCP.
type AudioConvBlock struct {
Weight ml.Tensor `gguf:"weight"`
Norm *nn.LayerNorm `gguf:"norm"`
}
// AudioConformerBlock is a single conformer layer.
// All tensors are flat at the block level (a.blk.N.<name>) using underscore naming.
type AudioConformerBlock struct {
// Block-level norm
Norm *nn.RMSNorm `gguf:"layer_pre_norm"`
// FFW start
FFWNorm *nn.RMSNorm `gguf:"ffn_norm"`
FFWUp *AudioClippableLinear `gguf:"ffn_up"`
FFWDown *AudioClippableLinear `gguf:"ffn_down"`
FFWPostNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
// FFW end
FFWNorm1 *nn.RMSNorm `gguf:"ffn_norm_1"`
FFWUp1 *AudioClippableLinear `gguf:"ffn_up_1"`
FFWDown1 *AudioClippableLinear `gguf:"ffn_down_1"`
FFWPostNorm1 *nn.RMSNorm `gguf:"ffn_post_norm_1"`
// Attention
AttnQ *AudioClippableLinear `gguf:"attn_q"`
AttnK *AudioClippableLinear `gguf:"attn_k"`
AttnV *AudioClippableLinear `gguf:"attn_v"`
AttnOut *AudioClippableLinear `gguf:"attn_out"`
AttnPreNorm *nn.RMSNorm `gguf:"ln1"`
AttnPostNorm *nn.RMSNorm `gguf:"ln2"`
LinearPos ml.Tensor `gguf:"linear_pos.weight"`
PerDimScale ml.Tensor `gguf:"per_dim_scale.weight"`
// LightConv1d
ConvPW1 *AudioClippableLinear `gguf:"conv_pw1"`
ConvPW2 *AudioClippableLinear `gguf:"conv_pw2"`
ConvDW ml.Tensor `gguf:"conv_dw.weight"`
ConvNorm *nn.RMSNorm `gguf:"conv_norm"`
NormConv *nn.RMSNorm `gguf:"norm_conv"`
}
// AudioClippableLinear is a linear layer with optional input/output clamping.
type AudioClippableLinear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
InputMin ml.Tensor `gguf:"input_min"`
InputMax ml.Tensor `gguf:"input_max"`
OutputMin ml.Tensor `gguf:"output_min"`
OutputMax ml.Tensor `gguf:"output_max"`
// Cached scalar clamp values (populated on first forward).
inMin, inMax, outMin, outMax float32
clampsLoaded bool
}
func (l *AudioClippableLinear) loadClamps() {
if l.clampsLoaded {
return
}
l.clampsLoaded = true
if l.InputMin != nil {
vals := l.InputMin.BackendGet()
if len(vals) > 0 {
l.inMin = vals[0]
}
}
if l.InputMax != nil {
vals := l.InputMax.BackendGet()
if len(vals) > 0 {
l.inMax = vals[0]
}
}
if l.OutputMin != nil {
vals := l.OutputMin.BackendGet()
if len(vals) > 0 {
l.outMin = vals[0]
}
}
if l.OutputMax != nil {
vals := l.OutputMax.BackendGet()
if len(vals) > 0 {
l.outMax = vals[0]
}
}
}
func (l *AudioClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
l.loadClamps()
if l.inMax != 0 {
x = x.Clamp(ctx, l.inMin, l.inMax)
}
out := l.Weight.Mulmat(ctx, x)
if l.Bias != nil {
out = out.Add(ctx, l.Bias)
}
if l.outMax != 0 {
out = out.Clamp(ctx, l.outMin, l.outMax)
}
return out
}
// AudioMultimodalProjector is the audio-to-text embedding projector.
type AudioMultimodalProjector struct {
Projection *AudioClippableLinear `gguf:"input_projection"`
FC *AudioFC `gguf:"fc"`
}
type AudioFC struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (p *AudioMultimodalProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
// FC: output projection from conformer to embedder dimension.
x = p.FC.Weight.Mulmat(ctx, x)
if p.FC.Bias != nil {
x = x.Add(ctx, p.FC.Bias)
}
// Pre-projection RMSNorm (without learned weight) — matches Python's embedding_pre_projection_norm.
x = x.RMSNorm(ctx, nil, eps)
// Embedding projection to text hidden size.
x = p.Projection.Forward(ctx, x)
return x
}
// ForwardAudio encodes mel spectrogram features into soft tokens.
// melFeatures: float32 tensor with ne[0]=melBins, ne[1]=numFrames.
// Returns: [hiddenSize, numTokens] tensor.
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, proj *AudioMultimodalProjector, opts *AudioModelOptions) ml.Tensor {
// SSCP Conv2D input: ne[0]=F (freq/width), ne[1]=T (time/height), ne[2]=C_in, ne[3]=B
// melFeatures is [melBins, numFrames], add channel and batch dims.
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
// SSCP Conv block 0: [F, T, 1, 1] → [F', T', C0, 1]
x = forwardConvBlock(ctx, m.SSCPConv0, x, opts)
// SSCP Conv block 1: [F', T', C0, 1] → [F'', T'', C1, 1]
x = forwardConvBlock(ctx, m.SSCPConv1, x, opts)
// After conv blocks, layout is [F'', T'', C_out, B].
// Permute to [C_out*F'', T'', B] for linear projection (channels+freq in ne[0]).
fOut := x.Dim(0)
tOut := x.Dim(1)
cOut := x.Dim(2)
// Permute [F'', T'', C, B] → [C, F'', T'', B]
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
x = x.Reshape(ctx, cOut*fOut, tOut)
// Linear projection to hidden size.
x = m.SSCPInputProj.Forward(ctx, x)
// Build causal-valid mask for conformer attention.
causalMask := buildCausalValidMaskF32(opts.chunkSize, opts.maxPast, opts.maxFuture)
// Run conformer blocks.
for i := range m.Layers {
x = m.Layers[i].Forward(ctx, x, causalMask, opts, i)
}
// Output projection.
if m.OutputProj != nil {
x = m.OutputProj.Weight.Mulmat(ctx, x)
if m.OutputProj.Bias != nil {
x = x.Add(ctx, m.OutputProj.Bias)
}
}
// Audio embedder: project to text embedding space.
if proj != nil {
x = proj.Forward(ctx, x, opts.eps)
}
return x
}
// forwardConvBlock runs a single SSCP Conv2D block.
// Conv2D receiver is the kernel, argument is the input data.
// Input: [F, T, C_in, B]. Output: [F', T', C_out, B].
func forwardConvBlock(ctx ml.Context, block *AudioConvBlock, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
// Conv2D: kernel.Conv2D(ctx, input, s0, s1, p0, p1, d0, d1)
// Kernel is 3x3, stride 2x2, padding 1x1 (matching SSCP config).
// Output layout: [F', T', C_out, B]
// Make weight contiguous — the shape reversal in the converter creates
// a tensor where the physical data order doesn't match ne[]/stride[].
weight := block.Weight.Contiguous(ctx)
x = weight.Conv2D(ctx, x, 2, 2, 1, 1, 1, 1)
// LayerNorm needs channels in ne[0]. Permute [F', T', C_out, B] → [C_out, F', T', B],
// norm, then permute back.
// GGML permute: axis i says where old axis i goes.
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0 → [C_out, F', T', B]
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
x = block.Norm.Forward(ctx, x, opts.eps)
// (2,0,1,3): old[0]→pos2, old[1]→pos0, old[2]→pos1 → [F', T', C_out, B]
x = x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
x = x.RELU(ctx)
return x
}
// Forward runs a single conformer block.
func (cb *AudioConformerBlock) Forward(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
// FFW start (half-residual).
x = cb.forwardFFW(ctx, cb.FFWNorm, cb.FFWUp, cb.FFWDown, cb.FFWPostNorm, x, opts)
// Self-attention.
x = cb.forwardAttention(ctx, x, causalMask, opts, blockIdx)
// Lightweight Conv1d.
x = cb.forwardLightConv(ctx, x, opts, blockIdx)
// FFW end (half-residual).
x = cb.forwardFFW(ctx, cb.FFWNorm1, cb.FFWUp1, cb.FFWDown1, cb.FFWPostNorm1, x, opts)
// Gradient clipping + final norm.
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.Norm.Forward(ctx, x, opts.eps)
return x
}
// forwardFFW runs a feedforward module with half-residual connection.
func (cb *AudioConformerBlock) forwardFFW(ctx ml.Context, preNorm *nn.RMSNorm, up, down *AudioClippableLinear, postNorm *nn.RMSNorm, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
residual := x
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = preNorm.Forward(ctx, x, opts.eps)
x = up.Forward(ctx, x)
x = x.SILU(ctx)
x = down.Forward(ctx, x)
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = postNorm.Forward(ctx, x, opts.eps)
x = x.Scale(ctx, float64(opts.residualWeight))
return residual.Add(ctx, x)
}
// forwardAttention runs the conformer block-local attention with relative position embeddings.
func (cb *AudioConformerBlock) forwardAttention(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
residual := x
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.AttnPreNorm.Forward(ctx, x, opts.eps)
hiddenSize := x.Dim(0)
seqLen := x.Dim(1)
// QKV projections: [hiddenSize, seqLen] → [headDim, numHeads, seqLen]
q := cb.AttnQ.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
k := cb.AttnK.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
v := cb.AttnV.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
// Per-dim scaling for queries: (headDim^-0.5 / log(2)) * softplus(per_dim_scale)
// per_dim_scale is already softplus'd from the converter.
qScale := float64(math.Pow(float64(opts.headDim), -0.5)) / math.Log(2)
q = q.Scale(ctx, qScale)
if cb.PerDimScale != nil {
q = q.Mul(ctx, cb.PerDimScale)
}
// Key scaling: softplus(1) / log(2) — matches the query base scaling convention.
kScale := math.Log(1+math.E) / math.Log(2)
k = k.Scale(ctx, kScale)
// Build sinusoidal position embeddings for the block-local context.
maxSpan := opts.maxPast + opts.maxFuture + 1 // 13 unique relative positions
posEmb := cb.buildPositionEmbeddings(ctx, maxSpan, opts)
// posEmb: [headDim, numHeads, maxSpan]
// Block-local attention: process chunks of size chunkSize.
chunkSize := opts.chunkSize
numChunks := (seqLen + chunkSize - 1) / chunkSize
contextSize := opts.contextSize
// Pad q/k/v to multiple of chunkSize on the time dimension (dim 2).
padT := numChunks*chunkSize - seqLen
if padT > 0 {
q = q.Pad(ctx, 0, 0, padT, 0)
k = k.Pad(ctx, 0, 0, padT, 0)
v = v.Pad(ctx, 0, 0, padT, 0)
}
paddedLen := numChunks * chunkSize
// Pad k/v for context extraction: add maxPast on left, (maxFuture+chunkSize-1) on right.
// Use Pad (right) + PadExt (left) workaround since PadExt+Slice has issues.
// Actually use Concat with zero tensors for reliable left-padding.
padLeft := opts.maxPast
padRight := opts.maxFuture + chunkSize - 1
zeroLeft := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padLeft), opts.headDim, opts.numHeads, padLeft)
zeroRight := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padRight), opts.headDim, opts.numHeads, padRight)
kPadded := zeroLeft.Concat(ctx, k, 2).Concat(ctx, zeroRight, 2)
vPadded := zeroLeft.Concat(ctx, v, 2).Concat(ctx, zeroRight, 2)
// Reshape q into chunks: [headDim, numHeads, numChunks, chunkSize]
qChunked := q.Reshape(ctx, opts.headDim, opts.numHeads, numChunks, chunkSize)
// Process each chunk and collect results.
chunkOutputs := make([]ml.Tensor, numChunks)
for u := range numChunks {
// Extract query block: [headDim, numHeads, 1, chunkSize] → [headDim, numHeads, chunkSize]
qBlock := qChunked.Slice(ctx, 2, u, u+1, 1).Reshape(ctx, opts.headDim, opts.numHeads, chunkSize)
// Extract key/value context: [headDim, numHeads, contextSize]
cStart := u * chunkSize // offset in kPadded (padLeft already accounts for left context)
kCtx := kPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
vCtx := vPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
// Content-content logits: qBlock^T @ kCtx → [chunkSize, contextSize] per head.
// Mulmat(a, b) = a^T @ b. We want Q^T K, so: kCtx.Mulmat(qBlock) but that gives
// [numHeads, chunkSize, contextSize] with wrong batching.
// Instead: permute to [headDim, chunkSize, numHeads] and [headDim, contextSize, numHeads]
// then Mulmat batches over numHeads.
// GGML permute(0,2,1,3): old[0]→0, old[1]→2, old[2]→1
qP := qBlock.Permute(ctx, 0, 2, 1, 3) // [headDim, chunkSize, numHeads]
kP := kCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
termAC := kP.MulmatFullPrec(ctx, qP) // [contextSize, chunkSize, numHeads]
// Content-position logits: qBlock^T @ posEmb → [chunkSize, maxSpan] per head.
pP := posEmb.Permute(ctx, 0, 2, 1, 3) // [headDim, maxSpan, numHeads]
termBDRaw := pP.MulmatFullPrec(ctx, qP) // [maxSpan, chunkSize, numHeads]
// Relative shift: [maxSpan, chunkSize, numHeads] → [contextSize, chunkSize, numHeads]
termBD := cb.relativeShiftGGML(ctx, termBDRaw, maxSpan, chunkSize, contextSize, opts.numHeads)
// Combined logits.
logits := termAC.Add(ctx, termBD)
// Logit softcap: tanh(logits / cap) * cap
logits = logits.Scale(ctx, 1.0/float64(opts.logitCap))
logits = logits.Tanh(ctx)
logits = logits.Scale(ctx, float64(opts.logitCap))
// Apply combined causal + validity mask.
// causalMask [chunkSize * contextSize]: 1=causal-allowed, 0=masked.
// Validity: context positions before the actual sequence start are invalid.
// For chunk u, context position c corresponds to actual time: u*chunkSize + c - padLeft.
// Valid if 0 <= actual_time < seqLen.
// Mask tensor layout: [contextSize, chunkSize, 1] with ne[0]=contextSize contiguous.
// Element at (context=j, chunk=i) is at flat index: i*contextSize + j.
maskData := make([]float32, contextSize*chunkSize)
for i := range chunkSize {
for j := range contextSize {
actualTime := u*chunkSize + j - padLeft
causalOK := causalMask[i*contextSize+j] > 0
validOK := actualTime >= 0 && actualTime < seqLen
if causalOK && validOK {
maskData[i*contextSize+j] = 0
} else {
maskData[i*contextSize+j] = -1e9
}
}
}
mask := ctx.Input().FromFloats(maskData, contextSize, chunkSize, 1) // 3D for broadcasting over numHeads
logits = logits.Add(ctx, mask)
// Softmax over context dimension (dim 0 = contextSize).
logits = logits.Softmax(ctx) // softmax over ne[0]=contextSize
// Weighted sum: logits^T @ vCtx.
// logits: [contextSize, chunkSize, numHeads], vCtx: [headDim, numHeads, contextSize]
// vCtx permuted: [headDim, contextSize, numHeads]
vP := vCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
// Weighted sum: for each head, value[headDim, contextSize] @ weights[contextSize, chunkSize]
// = [headDim, chunkSize].
// Mulmat(a, b) = a^T @ b. Need a=[contextSize, headDim, numHeads], b=[contextSize, chunkSize, numHeads].
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [contextSize, headDim, numHeads]
chunkOut := vPT.Mulmat(ctx, logits) // [headDim, chunkSize, numHeads]
// Permute back to [headDim, numHeads, chunkSize]
chunkOut = chunkOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
chunkOutputs[u] = chunkOut
}
// Concatenate chunk outputs along time dimension.
var attnOut ml.Tensor
if numChunks == 1 {
attnOut = chunkOutputs[0]
} else {
attnOut = chunkOutputs[0]
for _, co := range chunkOutputs[1:] {
attnOut = attnOut.Concat(ctx, co, 2)
}
}
// Trim to original sequence length if we padded.
if paddedLen > seqLen {
attnOut = attnOut.Slice(ctx, 2, 0, seqLen, 1).Contiguous(ctx)
}
// Reshape to [hiddenSize, seqLen] and project.
attnOut = attnOut.Reshape(ctx, hiddenSize, seqLen)
x = cb.AttnOut.Forward(ctx, attnOut)
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.AttnPostNorm.Forward(ctx, x, opts.eps)
return residual.Add(ctx, x)
}
// buildPositionEmbeddings builds sinusoidal position embeddings and projects through linear_pos.
// Returns [headDim, numHeads, maxSpan] tensor.
func (cb *AudioConformerBlock) buildPositionEmbeddings(ctx ml.Context, maxSpan int, opts *AudioModelOptions) ml.Tensor {
halfDim := opts.hiddenSize / 2
hiddenSize := opts.hiddenSize
// inv_timescales: exp(-i * log(10000) / max(D/2-1, 1))
logInc := math.Log(10000.0) / math.Max(float64(halfDim-1), 1)
// Sinusoidal embeddings for relative positions [maxPast, maxPast-1, ..., -maxFuture].
posData := make([]float32, hiddenSize*maxSpan)
for p := range maxSpan {
relPos := float64(opts.maxPast - p)
for d := range halfDim {
angle := relPos * math.Exp(float64(-d)*logInc)
posData[p*hiddenSize+d] = float32(math.Sin(angle))
posData[p*hiddenSize+halfDim+d] = float32(math.Cos(angle))
}
}
// Create [hiddenSize, maxSpan] input tensor.
posEmb := ctx.Input().FromFloats(posData, hiddenSize, maxSpan)
// Project through linear_pos: [hiddenSize, maxSpan] → Mulmat → [numHeads*headDim, maxSpan]
projPos := cb.LinearPos.Mulmat(ctx, posEmb)
// Reshape to [headDim, numHeads, maxSpan].
return projPos.Reshape(ctx, opts.headDim, opts.numHeads, maxSpan)
}
// relativeShiftGGML performs the relative shift to extract correct position logits.
// Input: [maxSpan, chunkSize, numHeads]. Output: [contextSize, chunkSize, numHeads].
func (cb *AudioConformerBlock) relativeShiftGGML(ctx ml.Context, x ml.Tensor, maxSpan, chunkSize, contextSize, numHeads int) ml.Tensor {
// The shift trick: pad ne[0] to contextSize+1, reshape to flatten first two dims,
// skip first (contextSize+1-maxSpan) elements, take contextSize*chunkSize elements, reshape back.
padAmt := contextSize + 1 - maxSpan
if padAmt > 0 {
x = x.Pad(ctx, padAmt, 0, 0, 0) // [maxSpan+padAmt, chunkSize, numHeads] = [contextSize+1, chunkSize, numHeads]
}
// Reshape to [(contextSize+1)*chunkSize, numHeads]
x = x.Reshape(ctx, (contextSize+1)*chunkSize, numHeads)
// Take the first contextSize*chunkSize elements (the standard relative shift trick).
x = x.Slice(ctx, 0, 0, contextSize*chunkSize, 1).Contiguous(ctx)
// Reshape to [contextSize, chunkSize, numHeads]
return x.Reshape(ctx, contextSize, chunkSize, numHeads)
}
// forwardLightConv runs the lightweight depthwise convolution module.
func (cb *AudioConformerBlock) forwardLightConv(ctx ml.Context, x ml.Tensor, opts *AudioModelOptions, blockIdx int) ml.Tensor {
residual := x
x = cb.ConvNorm.Forward(ctx, x, opts.eps)
x = cb.ConvPW1.Forward(ctx, x) // [2*D, T, B]
// GLU: split in half along dim 0, sigmoid gate, multiply.
d := x.Dim(0) / 2
data := x.Slice(ctx, 0, 0, d, 1).Contiguous(ctx)
gate := x.Slice(ctx, 0, d, d*2, 1).Contiguous(ctx).Sigmoid(ctx)
x = data.Mul(ctx, gate) // [D, T, B]
// Depthwise Conv1d: manual implementation using model weight tensor slices.
// Kernel cb.ConvDW shape: [K=5, D=1024] (ne[0]=K, ne[1]=D) after shape reversal.
// Actually in GGML, ne[0]=K=5 contiguous, ne[1]=D=1024.
// We need per-tap weights [D] and shifted input copies.
kernelSize := cb.ConvDW.Dim(0) // K=5
seqLen := x.Dim(1)
// Transpose kernel to [D, K] for per-tap slicing.
// GGML permute(1,0,2,3): old[0]→pos1, old[1]→pos0 → swap ne[0] and ne[1]
kernelT := cb.ConvDW.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [D, K]
var convOut ml.Tensor
for k := range kernelSize {
shift := kernelSize - 1 - k
var shifted ml.Tensor
if shift == 0 {
shifted = x
} else {
trimmed := x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
shifted = trimmed.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
}
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) // [D, 1]
term := shifted.Mul(ctx, wk)
if convOut == nil {
convOut = term
} else {
convOut = convOut.Add(ctx, term)
}
}
x = convOut
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.NormConv.Forward(ctx, x, opts.eps)
x = x.SILU(ctx)
x = cb.ConvPW2.Forward(ctx, x)
return x.Add(ctx, residual)
}
func newAudioModel(c fs.Config) *AudioModel {
numLayers := int(c.Uint("audio.block_count", 0))
if numLayers == 0 {
return nil
}
return &AudioModel{
Layers: make([]AudioConformerBlock, numLayers),
}
}
func newAudioModelOptions(c fs.Config) *AudioModelOptions {
hiddenSize := int(c.Uint("audio.embedding_length", 0))
if hiddenSize == 0 {
return nil
}
numHeads := int(c.Uint("audio.attention.head_count", 8))
headDim := hiddenSize / numHeads
chunkSize := 12 // default conformer chunk size
maxPast := 12 // conf_attention_context_left - 1
maxFuture := 0 // conf_attention_context_right
convKernel := int(c.Uint("audio.conv_kernel_size", 5))
eps := c.Float("audio.attention.layer_norm_epsilon", 1e-6)
return &AudioModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
ffnSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
numLayers: int(c.Uint("audio.block_count", 12)),
melBins: int(c.Uint("audio.num_mel_bins", 128)),
chunkSize: chunkSize,
maxPast: maxPast,
maxFuture: maxFuture,
contextSize: chunkSize + maxPast + maxFuture,
logitCap: 50.0,
residualWeight: 0.5,
gradClip: 1e10,
convKernelSize: convKernel,
eps: float32(eps),
}
}
// buildCausalValidMaskF32 creates the causal-valid mask for block-local attention.
// Returns flat [chunkSize * contextSize] float32 data (1.0 = allowed, 0.0 = masked).
func buildCausalValidMaskF32(chunkSize, maxPast, maxFuture int) []float32 {
contextSize := chunkSize + maxPast + maxFuture
upperDiag := maxPast + maxFuture
result := make([]float32, chunkSize*contextSize)
for r := range chunkSize {
for c := range contextSize {
lower := (r <= c) // tril(contextSize, chunkSize) transposed
upper := (c <= r+upperDiag) // tril(chunkSize, contextSize, diag=upperDiag)
if lower && upper {
result[r*contextSize+c] = 1.0
}
}
}
return result
}

View File

@@ -0,0 +1,475 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads int
numGlobalKVHeads int
headDim, globalHeadDim int
hiddenLayers int
hiddenSizePerLayerInput int
eps float32
ropeBase float32
ropeLocalBase float32
partialRotaryDims int // RoPE dims for full-attention (global) layers
slidingWindowPattern []bool
// kvDonorMap maps shared layer index -> donor layer index.
// Donor is the last non-shared layer of the same type (sliding/full).
kvDonorMap map[int]int
finalLogitSoftcap float32
numExperts int
numExpertsUsed int
}
func (o *TextOptions) isLocal(layer int) bool {
if layer < len(o.slidingWindowPattern) {
return o.slidingWindowPattern[layer]
}
return false
}
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
if o.isLocal(layer) {
return o.ropeLocalBase, o.headDim
}
return o.ropeBase, o.partialRotaryDims
}
func (o *TextOptions) kvHeadsForLayer(layer int) int {
if o.isLocal(layer) {
return o.numKVHeads
}
if o.numGlobalKVHeads > 0 {
return o.numGlobalKVHeads
}
return o.numKVHeads
}
func (o *TextOptions) headDimForLayer(layer int) int {
if o.isLocal(layer) {
return o.headDim
}
return o.globalHeadDim
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
*PerLayerProjector
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
TextOptions
}
func newTextModel(c fs.Config) *TextModel {
numLayers := int(c.Uint("block_count"))
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
globalHeadDim := int(c.Uint("attention.key_length", 512))
headDim := int(c.Uint("attention.key_length_swa", 256))
// RoPE dimensions for global (full attention) layers with proportional RoPE.
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
if partialRotaryDims == 0 {
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
}
ropeBase := c.Float("rope.freq_base", 1000000.0)
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
if ropeLocalBase == 0 {
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
}
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
slidingPattern := c.Bools("attention.sliding_window_pattern")
// KV heads: try per-layer array first (MoE models), then fall back to scalar
numKVHeads := 0
kvHeadsArray := c.Ints("attention.head_count_kv")
if len(kvHeadsArray) > 0 {
numKVHeads = int(kvHeadsArray[0])
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
for i, isLocal := range slidingPattern {
if !isLocal && i < len(kvHeadsArray) {
numGlobalKVHeads = int(kvHeadsArray[i])
break
}
}
}
}
if numKVHeads == 0 {
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
}
// Compute KV sharing donor map (same logic as MLX)
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
kvDonorMap := make(map[int]int)
if sharedLayers > 0 && len(slidingPattern) > 0 {
firstShared := numLayers - sharedLayers
for i := firstShared; i < numLayers; i++ {
isLocal := slidingPattern[i]
// Find last non-shared layer of same type
for j := firstShared - 1; j >= 0; j-- {
if slidingPattern[j] == isLocal {
kvDonorMap[i] = j
break
}
}
}
}
return &TextModel{
Layers: make([]TextLayer, numLayers),
TextOptions: TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: numKVHeads,
numGlobalKVHeads: numGlobalKVHeads,
headDim: headDim,
globalHeadDim: globalHeadDim,
hiddenLayers: numLayers,
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: ropeBase,
ropeLocalBase: ropeLocalBase,
partialRotaryDims: partialRotaryDims,
slidingWindowPattern: slidingPattern,
kvDonorMap: kvDonorMap,
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
numExperts: int(c.Uint("expert_count", 0)),
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
// Inject vision embeddings into the hidden state
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
}
}
// PLE
var perLayerInputs ml.Tensor
if m.PerLayerProjector != nil {
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
}
for i := range len(m.Layers) {
layer := m.Layers[i]
if cache != nil {
cache.SetLayer(i)
cacheType := cacheTypeSWA
if !m.isLocal(i) {
cacheType = cacheTypeCausal
}
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
var perLayerInput ml.Tensor
if perLayerInputs != nil {
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
}
// KV sharing: layers >= firstShared reuse K/V from donor layers
isShared := false
if donorLayer, ok := m.kvDonorMap[i]; ok {
// Set cache layer to donor so Get() reads donor's K/V
cache.SetLayer(donorLayer)
isShared = true
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
}
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
}
// PerLayerProjector implements PLE.
type PerLayerProjector struct {
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
Projector *nn.Linear `gguf:"per_layer_model_proj"`
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
}
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection := p.Projector.Forward(ctx, inputs)
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
if inputsPerLayer != nil {
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
}
return perLayerProjection
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
hd := opts.headDimForLayer(layer)
kvHeads := opts.kvHeadsForLayer(layer)
ropeBase, ropeDims := opts.ropeForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
var k, v ml.Tensor
if !sharedKV {
k = sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, hd, kvHeads, batchSize)
if sa.Value != nil {
v = sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, hd, kvHeads, batchSize)
} else {
// K=V: use raw K projection (before K norm) as V
v = k
}
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
}
// RoPE with proportional freq_factors on global layers
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if sa.RopeFactors != nil && !opts.isLocal(layer) {
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
}
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
if k != nil {
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
}
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
// TextRouter implements the Gemma 4 MoE router.
type TextRouter struct {
Proj *nn.Linear `gguf:"ffn_gate_inp"`
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
}
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
// RMSNorm without learned weight
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
// Scale by 1/sqrt(hidden_size)
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
// Multiply by learned scale parameter
x = x.Mul(ctx, r.Scale)
// Project to expert logits
expertScores := r.Proj.Forward(ctx, x)
// Softmax over experts
routingWeights = expertScores.Softmax(ctx)
// TopK expert selection
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
return routingWeights, selectedExperts
}
// TextMoEBlock implements the Gemma 4 sparse MoE.
type TextMoEBlock struct {
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
}
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
// Select routing weights for chosen experts and renormalize
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
// Expert computation using LinearBatch (MulmatID selecting experts by index)
var gateOut, upOut ml.Tensor
if moe.GateUp != nil && moe.GateUp.Weight != nil {
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
nFF := gateUp.Dim(0) / 2
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
} else {
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
}
hiddenState = gateOut.GELU(ctx, upOut)
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
// Apply per-expert down projection scale when present.
if moe.DownScale != nil {
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
experts = experts.Mul(ctx, expertScales)
}
// Apply routing weights
experts = experts.Mul(ctx, routingWeights)
// Sum across experts
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
// MoE (present only for models with enable_moe_block=true)
Router *TextRouter
MoE *TextMoEBlock
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
PerLayerProjection *nn.Linear `gguf:"proj"`
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
if perLayerInput != nil {
perLayerInput = perLayerInput.Rows(ctx, outputs)
}
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// MLP (+ optional MoE in parallel)
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
// MoE layers: run MLP and MoE in parallel, sum results
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
mlpState = l.MLP.Forward(ctx, mlpState)
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
combined := mlpState.Add(ctx, moeState)
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
hiddenState = combined.Add(ctx, residual)
} else {
// Dense layers: MLP only
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
}
// PLE injection (after MLP residual)
if perLayerInput != nil && l.PerLayerInputGate != nil {
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
pleState = pleState.GELU(ctx, perLayerInput)
pleState = l.PerLayerProjection.Forward(ctx, pleState)
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
hiddenState = hiddenState.Add(ctx, pleState)
}
// Layer scalar applied at end of layer (full-attention layers only)
if l.LayerScalar != nil {
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
}
return hiddenState
}

View File

@@ -0,0 +1,384 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
const batchSize = 1
// ClippableLinear is a linear layer with optional input/output clamping.
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
type ClippableLinear struct {
Weight ml.Tensor `gguf:"weight"`
InputMin ml.Tensor `gguf:"input_min"`
InputMax ml.Tensor `gguf:"input_max"`
OutputMin ml.Tensor `gguf:"output_min"`
OutputMax ml.Tensor `gguf:"output_max"`
inMin, inMax, outMin, outMax float32
hasClamp bool
clampsLoaded bool
}
func scalarValue(t ml.Tensor) (float32, bool) {
if t == nil {
return 0, false
}
data := t.BackendGet()
if len(data) == 0 {
return 0, false
}
return data[0], true
}
func (l *ClippableLinear) loadClampFromScalars() {
if l.clampsLoaded {
return
}
l.clampsLoaded = true
const (
defaultMin = -math.MaxFloat32
defaultMax = math.MaxFloat32
)
inMin, hasInMin := scalarValue(l.InputMin)
inMax, hasInMax := scalarValue(l.InputMax)
outMin, hasOutMin := scalarValue(l.OutputMin)
outMax, hasOutMax := scalarValue(l.OutputMax)
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
return
}
l.hasClamp = true
l.inMin = defaultMin
l.inMax = defaultMax
l.outMin = defaultMin
l.outMax = defaultMax
if hasInMin {
l.inMin = inMin
}
if hasInMax {
l.inMax = inMax
}
if hasOutMin {
l.outMin = outMin
}
if hasOutMax {
l.outMax = outMax
}
}
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
if l.hasClamp {
x = x.Clamp(ctx, l.inMin, l.inMax)
}
out := l.Weight.Mulmat(ctx, x)
if l.hasClamp {
out = out.Clamp(ctx, l.outMin, l.outMax)
}
return out
}
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
// then 4 floats for the projector.
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
if m.clampInitDone {
return
}
m.clampInitDone = true
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
return []*ClippableLinear{
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
}
}
for i := range m.Layers {
for _, cl := range linears(&m.Layers[i]) {
if cl != nil {
cl.loadClampFromScalars()
}
}
}
if proj != nil && proj.Projection != nil {
proj.Projection.loadClampFromScalars()
}
// Load packed clamp data when present (legacy Ollama format).
if m.ClampData == nil {
return
}
// Read all clamp values from packed F32 tensor
data := m.ClampData.BackendGet()
if len(data) == 0 {
return
}
// Distribute to layer linears: 7 per layer × 4 values each
for i := range m.Layers {
for li, cl := range linears(&m.Layers[i]) {
if cl == nil {
continue
}
idx := (i*7 + li) * 4
if idx+3 < len(data) {
cl.inMin = data[idx]
cl.inMax = data[idx+1]
cl.outMin = data[idx+2]
cl.outMax = data[idx+3]
cl.hasClamp = true
}
}
}
// Projector clamp values (last 4 floats)
if proj != nil && proj.Projection != nil {
projIdx := len(m.Layers) * 7 * 4
if projIdx+3 < len(data) {
proj.Projection.inMin = data[projIdx]
proj.Projection.inMax = data[projIdx+1]
proj.Projection.outMin = data[projIdx+2]
proj.Projection.outMax = data[projIdx+3]
proj.Projection.hasClamp = true
}
}
}
type VisionSelfAttention struct {
Query *ClippableLinear `gguf:"attn_q"`
Key *ClippableLinear `gguf:"attn_k"`
Value *ClippableLinear `gguf:"attn_v"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *ClippableLinear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
// V norm (RMSNorm without learned weights)
value = value.RMSNorm(ctx, nil, opts.eps)
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
// y positions to second half, then concatenate.
halfDim := headDim / 2
ropeOpts := rope.WithTypeNeoX()
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
halfOffset := halfDim * query.Stride(0)
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
halfOffsetK := halfDim * key.Stride(0)
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
query = qFirst.Concat(ctx, qSecond, 0)
key = kFirst.Concat(ctx, kSecond, 0)
// Use flash attention for numerical stability (handles large attention scores
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *ClippableLinear `gguf:"ffn_gate"`
Up *ClippableLinear `gguf:"ffn_up"`
Down *ClippableLinear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
gate := mlp.Gate.Forward(ctx, hiddenState)
up := mlp.Up.Forward(ctx, hiddenState)
hiddenState = gate.QuickGELU(ctx, up)
return mlp.Down.Forward(ctx, hiddenState)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
FFNNorm *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// Pre-attention norm -> self attention -> post-attention norm
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// Residual connection
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// Pre-FFN norm -> FFN -> post-FFN norm
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState)
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
// Residual connection
hiddenState = hiddenState.Add(ctx, residual)
// Per-layer output scale
if e.LayerOutputScale != nil {
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
}
return hiddenState
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
patchSize int
nMerge int
eps float32
ropeTheta float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
ClampData ml.Tensor `gguf:"clamp_data"`
StdBias ml.Tensor `gguf:"std_bias"`
StdScale ml.Tensor `gguf:"std_scale"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
clampInitDone bool
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
numPatches := numPatchesX * numPatchesY
// Patch embedding via Conv2D
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
posSize := m.PositionEmbedding.Dim(1)
nb1 := m.PositionEmbedding.Stride(1)
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
// Position indices for patches
posXData := make([]int32, numPatches)
posYData := make([]int32, numPatches)
for i := range numPatches {
posXData[i] = int32(i % numPatchesX)
posYData[i] = int32(i / numPatchesX)
}
posXEmb := ctx.Input().FromInts(posXData, numPatches)
posYEmb := ctx.Input().FromInts(posYData, numPatches)
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
// No attention mask — all positions are real patches
var attnMask ml.Tensor
// RoPE positions
posXRope := ctx.Input().FromInts(posXData, numPatches)
posYRope := ctx.Input().FromInts(posYData, numPatches)
// Vision transformer layers
for i := range m.Layers {
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
}
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
patchSize: int(c.Uint("vision.patch_size", 16)),
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: 100.0,
},
}
}
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
hiddenSize := opts.hiddenSize
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
// AvgPool2D with kernel=stride=nMerge
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
// Reshape back to [hiddenSize, numMergedPatches]
mergedX := numPatchesX / opts.nMerge
mergedY := numPatchesY / opts.nMerge
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
// Optional vision standardization before projection.
if stdBias != nil && stdScale != nil {
hiddenState = hiddenState.Sub(ctx, stdBias)
hiddenState = hiddenState.Mul(ctx, stdScale)
}
// Project to text embedding dimension
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
return hiddenState
}

View File

@@ -0,0 +1,280 @@
package gemma4
import (
"encoding/binary"
"fmt"
"math"
"math/cmplx"
)
// Audio preprocessing constants.
const (
audioSampleRate = 16000
melBins = 128
frameLengthMs = 20.0
hopLengthMs = 10.0
minFrequency = 0.0
maxFrequency = 8000.0
melFloor = 1e-3
maxAudioSoftTokens = 750
)
// Computed from the above constants.
var (
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
)
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
func decodeWAV(data []byte) ([]float32, error) {
if len(data) < 12 {
return nil, fmt.Errorf("WAV file too short")
}
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
return nil, fmt.Errorf("not a WAV file")
}
var audioFormat uint16
var numChannels, sampleRate, bitsPerSample int
var audioData []byte
foundFmt := false
offset := 12
for offset+8 <= len(data) {
chunkID := string(data[offset : offset+4])
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
switch chunkID {
case "fmt ":
if len(chunkData) < 16 {
return nil, fmt.Errorf("fmt chunk too short")
}
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
}
foundFmt = true
case "data":
audioData = chunkData
}
offset += 8 + chunkSize
if chunkSize%2 != 0 {
offset++
}
}
if !foundFmt {
return nil, fmt.Errorf("no fmt chunk found in WAV file")
}
if audioFormat != 1 && audioFormat != 3 {
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
}
if audioData == nil {
return nil, fmt.Errorf("no data chunk found in WAV file")
}
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
if sampleRate != audioSampleRate {
samples = resampleLinear(samples, sampleRate, audioSampleRate)
}
return samples, nil
}
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
bytesPerSample := bits / 8
totalSamples := len(data) / (bytesPerSample * channels)
mono := make([]float32, totalSamples)
for i := range totalSamples {
var sum float64
for ch := range channels {
off := (i*channels + ch) * bytesPerSample
if off+bytesPerSample > len(data) {
break
}
switch {
case format == 1 && bits == 16:
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
sum += float64(v) / 32768.0
case format == 1 && bits == 32:
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
sum += float64(v) / 2147483648.0
case format == 1 && bits == 24:
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
if v&0x800000 != 0 {
v |= ^0xFFFFFF
}
sum += float64(v) / 8388608.0
case format == 3 && bits == 32:
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
sum += float64(v)
case format == 1 && bits == 8:
sum += (float64(data[off]) - 128.0) / 128.0
}
}
mono[i] = float32(sum / float64(channels))
}
return mono
}
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
out := make([]float32, n)
for i := range n {
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
idx := int(pos)
frac := float32(pos - float64(idx))
if idx+1 < len(samples) {
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
} else {
out[i] = samples[idx]
}
}
return out
}
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
func computeMelSpectrogram(samples []float32) ([]float32, int) {
fftLen := 1
for fftLen < frameLength {
fftLen <<= 1
}
fftLen *= 2 // fft_overdrive=True
// Hanning-nonzero window.
window := make([]float64, frameLength)
arg := math.Pi * 2.0 / float64(frameLength)
for i := range frameLength {
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
}
numFreqBins := fftLen/2 + 1
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
frameSizeForUnfold := frameLength + 1
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
if numFrames <= 0 {
return nil, 0
}
result := make([]float32, numFrames*melBins)
fftInput := make([]complex128, fftLen)
for f := range numFrames {
start := f * hopLength
for i := range frameLength {
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
}
for i := frameLength; i < fftLen; i++ {
fftInput[i] = 0
}
fft(fftInput)
for m := range melBins {
var melVal float64
for k := range numFreqBins {
mag := cmplx.Abs(fftInput[k])
melVal += mag * float64(melFilters[k*melBins+m])
}
if melVal < melFloor {
melVal = melFloor
}
result[f*melBins+m] = float32(math.Log(melVal))
}
}
return result, numFrames
}
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
hzToMel := func(f float64) float64 {
return 2595.0 * math.Log10(1.0+f/700.0)
}
melToHz := func(m float64) float64 {
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
}
melMin := hzToMel(fMin)
melMax := hzToMel(fMax)
melPts := make([]float64, numMels+2)
for i := range melPts {
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
}
filterFreqs := make([]float64, numMels+2)
for i, m := range melPts {
filterFreqs[i] = melToHz(m)
}
fftFreqs := make([]float64, numFreqBins)
for i := range fftFreqs {
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
}
filters := make([]float32, numFreqBins*numMels)
for m := range numMels {
fLeft := filterFreqs[m]
fCenter := filterFreqs[m+1]
fRight := filterFreqs[m+2]
for k := range numFreqBins {
f := fftFreqs[k]
var v float64
if f >= fLeft && f <= fCenter && fCenter > fLeft {
v = (f - fLeft) / (fCenter - fLeft)
} else if f > fCenter && f <= fRight && fRight > fCenter {
v = (fRight - f) / (fRight - fCenter)
}
if v > 0 {
filters[k*numMels+m] = float32(v)
}
}
}
return filters
}
// fft performs an in-place Cooley-Tukey radix-2 FFT.
func fft(x []complex128) {
n := len(x)
if n <= 1 {
return
}
j := 0
for i := 1; i < n; i++ {
bit := n >> 1
for j&bit != 0 {
j ^= bit
bit >>= 1
}
j ^= bit
if i < j {
x[i], x[j] = x[j], x[i]
}
}
for size := 2; size <= n; size <<= 1 {
halfSize := size / 2
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
for start := 0; start < n; start += size {
wn := complex(1, 0)
for k := range halfSize {
t := wn * x[start+k+halfSize]
x[start+k+halfSize] = x[start+k] - t
x[start+k] = x[start+k] + t
wn *= w
}
}
}
}
// isAudioData checks if the data starts with WAV magic bytes.
func isAudioData(data []byte) bool {
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
}

View File

@@ -0,0 +1,103 @@
package gemma4
import (
"image"
"math"
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
)
type ImageProcessor struct {
patchSize int
numChannels int
nMerge int
minPixels int
maxPixels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 16))
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
numChannels := int(c.Uint("vision.num_channels", 3))
// Token limits from reference: min=40, max=280 output tokens after pooling.
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
minTokens := 40
maxTokens := 280
patchArea := patchSize * patchSize * nMerge * nMerge
minPixels := minTokens * patchArea
maxPixels := maxTokens * patchArea
return ImageProcessor{
patchSize: patchSize,
numChannels: numChannels,
nMerge: nMerge,
minPixels: minPixels,
maxPixels: maxPixels,
}
}
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
// Returns the float32 pixel data and the actual output dimensions.
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
// Compute target size preserving aspect ratio
alignSize := p.patchSize * p.nMerge
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
// Resize directly without alpha compositing, matching MLX reference.
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
data := p.pack(dst)
return data, targetW, targetH, nil
}
// smartResize computes target dimensions that preserve aspect ratio and
// align to alignSize boundaries. It scales the image to fill the maximum
// patch budget (maxPixels), matching the MLX reference.
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
totalPx := origW * origH
var targetW, targetH int
if p.maxPixels > 0 && totalPx > 0 {
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
} else {
targetH = max(alignSize, (origH/alignSize)*alignSize)
targetW = max(alignSize, (origW/alignSize)*alignSize)
}
return targetW, targetH
}
// pack extracts RGB values from an image and normalizes to [-1, 1].
// Returns channel-first layout: [R..., G..., B...].
func (p *ImageProcessor) pack(img image.Image) []float32 {
bounds := img.Bounds()
w := bounds.Dx()
h := bounds.Dy()
size := w * h
pixelVals := make([]float32, 3*size)
rOff, gOff, bOff := 0, size, 2*size
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
}
}
return pixelVals
}

View File

@@ -0,0 +1,102 @@
package gemma4
import (
"os"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
func TestTokenizerMatchesHF(t *testing.T) {
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
if modelPath == "" {
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
}
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatalf("Failed to load model: %v", err)
}
defer m.Backend().Close()
tok := m.(tokenizer.Tokenizer)
tests := []struct {
name string
input string
expected []int32
}{
{
name: "simple",
input: "Hello, world!",
expected: []int32{9259, 236764, 1902, 236888},
},
{
name: "special_tokens",
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
},
{
name: "tool_declaration",
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
},
{
name: "tool_call",
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
},
{
name: "thinking",
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
},
{
name: "code",
input: "func main() { fmt.Println(\"hello\") }",
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
},
{
name: "numbers",
input: "The answer is 42, not 43.5 or -1",
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
},
{
name: "mixed_chat_with_tools",
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens, err := tok.Encode(tt.input, false) // no BOS
if err != nil {
t.Fatalf("encode error: %v", err)
}
if len(tokens) != len(tt.expected) {
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
t.Logf("got: %v", tokens)
t.Logf("want: %v", tt.expected)
return
}
mismatches := 0
for i := range tokens {
if tokens[i] != tt.expected[i] {
mismatches++
if mismatches <= 5 {
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
}
}
}
if mismatches > 5 {
t.Errorf("... and %d more mismatches", mismatches-5)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/gemma4"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"

412
model/parsers/gemma4.go Normal file
View File

@@ -0,0 +1,412 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type Gemma4ParserState int
const (
Gemma4CollectingContent Gemma4ParserState = iota
Gemma4CollectingThinking
Gemma4CollectingToolCall
)
const (
gemma4ThinkingOpenTag = "<|channel>"
gemma4ThinkingCloseTag = "<channel|>"
gemma4ToolCallOpenTag = "<|tool_call>"
gemma4ToolCallCloseTag = "<tool_call|>"
)
type Gemma4Parser struct {
state Gemma4ParserState
buffer strings.Builder
hasThinkingSupport bool
thinkingEnabled bool // true when both model supports and user requested thinking
needsChannelNameStrip bool // true when we just entered thinking and need to strip "thought\n"
}
func (p *Gemma4Parser) HasToolSupport() bool {
return true
}
func (p *Gemma4Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *Gemma4Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
p.thinkingEnabled = p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !p.thinkingEnabled {
p.state = Gemma4CollectingContent
return tools
}
if prefill && lastMessage.Content != "" {
p.state = Gemma4CollectingContent
return tools
}
// When thinking is enabled, start in content mode but we'll switch to
// thinking when we see <|channel>. The model typically starts with
// <|channel> immediately when thinking is enabled.
p.state = Gemma4CollectingContent
return tools
}
type gemma4Event interface {
isGemma4Event()
}
type gemma4EventThinkingContent struct {
content string
}
type gemma4EventContent struct {
content string
}
type gemma4EventToolCall struct {
toolCall api.ToolCall
}
func (gemma4EventThinkingContent) isGemma4Event() {}
func (gemma4EventContent) isGemma4Event() {}
func (gemma4EventToolCall) isGemma4Event() {}
func (p *Gemma4Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents(done)
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case gemma4EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case gemma4EventThinkingContent:
if p.thinkingEnabled {
thinkingSb.WriteString(event.content)
}
// When thinking is disabled, silently discard channel content
case gemma4EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Gemma4Parser) parseEvents(done bool) []gemma4Event {
var all []gemma4Event
keepLooping := true
for keepLooping {
var events []gemma4Event
events, keepLooping = p.eat(done)
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
// longestOverlap returns the longest overlap between the suffix of bufStr and
// a prefix of any of the given tags.
func longestOverlap(bufStr string, tags ...string) int {
maxOverlap := 0
for _, tag := range tags {
if o := overlap(bufStr, tag); o > maxOverlap {
maxOverlap = o
}
}
return maxOverlap
}
func (p *Gemma4Parser) eat(done bool) ([]gemma4Event, bool) {
var events []gemma4Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case Gemma4CollectingContent:
// Check for thinking open tag
if idx := strings.Index(bufStr, gemma4ThinkingOpenTag); idx != -1 {
contentBefore := bufStr[:idx]
remaining := bufStr[idx+len(gemma4ThinkingOpenTag):]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = Gemma4CollectingThinking
p.needsChannelNameStrip = true
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
events = append(events, gemma4EventContent{content: contentBefore})
}
return events, true
}
// Check for tool call open tag
if idx := strings.Index(bufStr, gemma4ToolCallOpenTag); idx != -1 {
contentBefore := bufStr[:idx]
remaining := bufStr[idx+len(gemma4ToolCallOpenTag):]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = Gemma4CollectingToolCall
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
events = append(events, gemma4EventContent{content: contentBefore})
}
return events, true
}
// Check for partial tag overlap
if !done {
if overlapLen := longestOverlap(bufStr, gemma4ThinkingOpenTag, gemma4ToolCallOpenTag); overlapLen > 0 {
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, gemma4EventContent{content: unambiguous})
}
return events, false
}
}
// No tags found, emit all content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, gemma4EventContent{content: bufStr})
}
return events, false
case Gemma4CollectingThinking:
// Strip channel name (e.g., "thought\n") after <|channel>.
// Gemma 4 format: <|channel>thought\n...content...<channel|>
// In streaming mode, "thought" and "\n" may arrive in separate chunks.
if p.needsChannelNameStrip {
if strings.HasPrefix(bufStr, "thought\n") {
bufStr = bufStr[len("thought\n"):]
p.buffer.Reset()
p.buffer.WriteString(bufStr)
p.needsChannelNameStrip = false
} else if !done && (bufStr == "thought" || strings.HasPrefix("thought\n", bufStr)) {
// Partial match — wait for more data.
return events, false
} else {
// No match (different channel name or no newline) — don't strip.
p.needsChannelNameStrip = false
}
}
if strings.Contains(bufStr, gemma4ThinkingCloseTag) {
split := strings.SplitN(bufStr, gemma4ThinkingCloseTag, 2)
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = Gemma4CollectingContent
if len(thinking) > 0 {
events = append(events, gemma4EventThinkingContent{content: thinking})
}
return events, true
}
// Check for partial close tag
if !done {
if overlapLen := overlap(bufStr, gemma4ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, gemma4EventThinkingContent{content: unambiguous})
}
return events, false
}
}
// No close tag, emit thinking content (hold back trailing whitespace)
if !done {
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, gemma4EventThinkingContent{content: unambiguous})
}
} else {
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, gemma4EventThinkingContent{content: bufStr})
}
}
return events, false
case Gemma4CollectingToolCall:
if idx := strings.Index(bufStr, gemma4ToolCallCloseTag); idx != -1 {
toolCallContent := bufStr[:idx]
remaining := bufStr[idx+len(gemma4ToolCallCloseTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = Gemma4CollectingContent
if toolCall, err := parseGemma4ToolCall(toolCallContent); err == nil {
events = append(events, gemma4EventToolCall{toolCall: toolCall})
} else {
slog.Warn("gemma4 tool call parsing failed", "error", err, "content", toolCallContent)
}
return events, true
}
// If done, flush any accumulated tool call content even without closing tag.
// The model may hit a stop token before emitting <tool_call|>.
if done && len(bufStr) > 0 {
p.buffer.Reset()
p.state = Gemma4CollectingContent
if toolCall, err := parseGemma4ToolCall(bufStr); err == nil {
events = append(events, gemma4EventToolCall{toolCall: toolCall})
} else {
slog.Warn("gemma4 tool call flush on done failed", "error", err, "content", bufStr)
}
return events, false
}
// Wait for closing tag
return events, false
}
return events, false
}
// parseGemma4ToolCall parses a tool call in Gemma 4 format:
// call:NAME{key:value,key:value}
func parseGemma4ToolCall(content string) (api.ToolCall, error) {
// Expected format: call:NAME{args}
if !strings.HasPrefix(content, "call:") {
return api.ToolCall{}, errors.New("expected 'call:' prefix")
}
content = content[len("call:"):]
// Find the opening brace for args
braceIdx := strings.Index(content, "{")
if braceIdx == -1 {
return api.ToolCall{}, errors.New("expected '{' in tool call")
}
toolName := strings.TrimSpace(content[:braceIdx])
argsStr := content[braceIdx:]
// Convert Gemma 4 argument format to JSON
jsonStr := gemma4ArgsToJSON(argsStr)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(jsonStr), &args); err != nil {
return api.ToolCall{}, err
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: toolName,
Arguments: args,
},
}, nil
}
// gemma4ArgsToJSON converts Gemma 4's custom argument format to valid JSON.
func gemma4ArgsToJSON(s string) string {
s = strings.ReplaceAll(s, `<|"|>`, `"`)
var buf strings.Builder
buf.Grow(len(s) + 32)
inString := false
hex := "0123456789abcdef"
i := 0
for i < len(s) {
ch := s[i]
if ch == '"' {
inString = !inString
buf.WriteByte('"')
i++
continue
}
if inString {
switch ch {
case '\\':
buf.WriteString(`\\`)
case '\n':
buf.WriteString(`\n`)
case '\r':
buf.WriteString(`\r`)
case '\t':
buf.WriteString(`\t`)
case '\b':
buf.WriteString(`\b`)
case '\f':
buf.WriteString(`\f`)
default:
if ch < 0x20 {
buf.WriteString(`\u00`)
buf.WriteByte(hex[ch>>4])
buf.WriteByte(hex[ch&0x0f])
} else {
buf.WriteByte(ch)
}
}
i++
continue
}
if !inString && isIdentStart(ch) {
j := i + 1
for j < len(s) && isIdentPart(s[j]) {
j++
}
word := s[i:j]
if j < len(s) && s[j] == ':' {
buf.WriteByte('"')
buf.WriteString(word)
buf.WriteByte('"')
} else {
buf.WriteString(word)
}
i = j
} else {
buf.WriteByte(ch)
i++
}
}
return buf.String()
}

View File

@@ -0,0 +1,463 @@
package parsers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGemma4Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
thinkingEnabled bool
lastMessage *api.Message
}{
{
name: "simple_content",
input: "This is a simple response.",
expectedContent: "This is a simple response.",
},
{
name: "thinking_then_content",
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
expectedContent: "The answer is 42.",
expectedThinking: "Let me think about this...",
thinkingEnabled: true,
},
{
name: "multiple_thinking_blocks",
input: "<|channel>first thought<channel|><|channel>second thought<channel|>Final answer.",
expectedContent: "Final answer.",
expectedThinking: "first thoughtsecond thought",
thinkingEnabled: true,
},
{
name: "thinking_only_no_content",
input: "<|channel>just thinking<channel|>",
expectedContent: "",
expectedThinking: "just thinking",
thinkingEnabled: true,
},
{
name: "tool_call_simple",
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{
name: "tool_call_with_multiple_args",
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
"units": "metric",
}),
},
},
},
},
{
name: "tool_call_with_number_arg",
input: `<|tool_call>call:set_temp{value:42}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{
"value": 42.0,
}),
},
},
},
},
{
name: "tool_call_with_boolean_arg",
input: `<|tool_call>call:toggle{enabled:true}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "toggle",
Arguments: testArgs(map[string]any{
"enabled": true,
}),
},
},
},
},
{
name: "tool_call_with_nested_object",
input: `<|tool_call>call:process{config:{enabled:true,name:<|"|>test<|"|>}}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{
"config": map[string]any{
"enabled": true,
"name": "test",
},
}),
},
},
},
},
{
name: "tool_call_with_array",
input: `<|tool_call>call:process{items:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{
"items": []any{"a", "b"},
}),
},
},
},
},
{
name: "tool_call_with_multiline_string_arg",
input: `<|tool_call>call:bash{command:<|"|>date
<|"|>}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "bash",
Arguments: testArgs(map[string]any{
"command": "date\n",
}),
},
},
},
},
{
name: "multiple_tool_calls",
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|><|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
{
name: "thinking_then_tool_call",
input: "<|channel>thought\nI need to check the weather<channel|><|tool_call>call:get_weather{location:<|\"|>Paris<|\"|>}<tool_call|>",
expectedThinking: "I need to check the weather",
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
thinkingEnabled: true,
},
{
name: "content_then_tool_call",
input: `Let me check that for you.<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
expectedContent: "Let me check that for you.",
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{
name: "thinking_disabled_channel_tags_as_content",
input: "<|channel>this is not thinking<channel|>actual content",
expectedContent: "actual content",
thinkingEnabled: false,
},
{
name: "prefill_content_only",
input: "Continuing content.",
expectedContent: "Continuing content.",
lastMessage: &api.Message{
Role: "assistant",
Content: "Previous content",
},
thinkingEnabled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Gemma4Parser{hasThinkingSupport: true}
parser.Init(nil, tt.lastMessage, &api.ThinkValue{Value: tt.thinkingEnabled})
content, thinking, toolCalls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestGemma4Parser_Streaming(t *testing.T) {
parser := &Gemma4Parser{hasThinkingSupport: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
chunks := []string{
"<|channel>thought",
"\nLet me think",
"...<channel|>The answer",
" is 42.",
}
var finalContent, finalThinking strings.Builder
for i, chunk := range chunks {
done := i == len(chunks)-1
content, thinking, _, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
}
if finalContent.String() != "The answer is 42." {
t.Errorf("expected content %q, got %q", "The answer is 42.", finalContent.String())
}
if finalThinking.String() != "Let me think..." {
t.Errorf("expected thinking %q, got %q", "Let me think...", finalThinking.String())
}
}
func TestGemma4Parser_StreamingToolCall(t *testing.T) {
parser := &Gemma4Parser{hasThinkingSupport: false}
parser.Init(nil, nil, nil)
chunks := []string{
`<|tool_call>call:get_`,
`weather{location:<|"|>Par`,
`is<|"|>}<tool_call|>`,
}
var finalContent strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range chunks {
done := i == len(chunks)-1
content, _, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
if finalContent.String() != "" {
t.Errorf("expected no content, got %q", finalContent.String())
}
expectedToolCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
}
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
}
func TestGemma4Parser_StreamingSplitThinkingTag(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
}{
{
name: "split_channel_open_tag",
chunks: []string{
"<|chan",
"nel>thinking here<channel|>content",
},
expectedContent: "content",
expectedThinking: "thinking here",
},
{
name: "split_channel_close_tag",
chunks: []string{
"<|channel>thinking here<chan",
"nel|>content",
},
expectedContent: "content",
expectedThinking: "thinking here",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Gemma4Parser{hasThinkingSupport: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
var finalContent, finalThinking strings.Builder
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, _, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
}
if finalContent.String() != tt.expectedContent {
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
})
}
}
func TestGemma4ArgsToJSON(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple_string",
input: `{location:<|"|>Paris<|"|>}`,
expected: `{"location":"Paris"}`,
},
{
name: "multiple_args",
input: `{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}`,
expected: `{"location":"Paris","units":"metric"}`,
},
{
name: "number_value",
input: `{value:42}`,
expected: `{"value":42}`,
},
{
name: "boolean_value",
input: `{enabled:true}`,
expected: `{"enabled":true}`,
},
{
name: "nested_object",
input: `{config:{enabled:true,name:<|"|>test<|"|>}}`,
expected: `{"config":{"enabled":true,"name":"test"}}`,
},
{
name: "array_value",
input: `{items:[<|"|>a<|"|>,<|"|>b<|"|>]}`,
expected: `{"items":["a","b"]}`,
},
{
name: "empty_object",
input: `{}`,
expected: `{}`,
},
{
name: "mixed_types",
input: `{name:<|"|>test<|"|>,count:5,active:true,tags:[<|"|>a<|"|>]}`,
expected: `{"name":"test","count":5,"active":true,"tags":["a"]}`,
},
{
name: "null_value",
input: `{value:null}`,
expected: `{"value":null}`,
},
{
name: "multiline_string_value",
input: `{command:<|"|>date
<|"|>}`,
expected: `{"command":"date\n"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := gemma4ArgsToJSON(tt.input)
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestGemma4Parser_HasToolSupport(t *testing.T) {
parser := &Gemma4Parser{}
if !parser.HasToolSupport() {
t.Error("Gemma4Parser should support tools")
}
}
func TestGemma4Parser_HasThinkingSupport(t *testing.T) {
parser := &Gemma4Parser{hasThinkingSupport: true}
if !parser.HasThinkingSupport() {
t.Error("Gemma4Parser with thinking support should report it")
}
parser2 := &Gemma4Parser{hasThinkingSupport: false}
if parser2.HasThinkingSupport() {
t.Error("Gemma4Parser without thinking support should not report it")
}
}

View File

@@ -77,6 +77,10 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "gemma4":
return &Gemma4Parser{hasThinkingSupport: true}
case "gemma4-no-thinking":
return &Gemma4Parser{hasThinkingSupport: false}
case "glm-ocr":
return &GlmOcrParser{}
case "lfm2":

379
model/renderers/gemma4.go Normal file
View File

@@ -0,0 +1,379 @@
package renderers
import (
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/api"
)
// Gemma4Renderer renders prompts using Gemma 4's chat format with
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
// <|tool_call>/<|tool_response> tags for function calling.
type Gemma4Renderer struct {
useImgTags bool
}
const (
g4Q = `<|"|>` // Gemma 4 string delimiter
)
func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
imageOffset := 0
// BOS token — Gemma 4 models have add_bos_token=false in their tokenizer
// config, so the tokenizer does not auto-prepend BOS. We must emit it
// explicitly in the rendered prompt, matching the HF chat template.
sb.WriteString("<bos>")
// Extract system message if present.
var systemMessage string
var loopMessages []api.Message
hasSystemRole := len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer")
if hasSystemRole {
systemMessage = messages[0].Content
loopMessages = messages[1:]
} else {
loopMessages = messages
}
// Emit system turn if there's a system/developer role, tools, or thinking.
hasThink := thinkValue != nil && thinkValue.Bool()
if hasSystemRole || len(tools) > 0 || hasThink {
sb.WriteString("<|turn>system\n")
if hasThink {
sb.WriteString("<|think|>")
}
if systemMessage != "" {
sb.WriteString(strings.TrimSpace(systemMessage))
}
for _, tool := range tools {
sb.WriteString(r.renderToolDeclaration(tool))
}
sb.WriteString("<turn|>\n")
}
// Each message gets its own <|turn>role\n ... <turn|>\n block,
// matching the HF chat template exactly.
for _, message := range loopMessages {
switch message.Role {
case "user":
sb.WriteString("<|turn>user\n")
r.renderContent(&sb, message, &imageOffset, true)
sb.WriteString("<turn|>\n")
case "assistant":
sb.WriteString("<|turn>model\n")
// Tool calls come before content (matching HF template order)
for _, tc := range message.ToolCalls {
sb.WriteString(r.formatToolCall(tc))
}
// Strip thinking from history (matching HF strip_thinking macro)
if message.Content != "" {
sb.WriteString(stripThinking(message.Content))
}
sb.WriteString("<turn|>\n")
case "tool":
sb.WriteString("<|turn>tool\n")
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<turn|>\n")
default:
sb.WriteString("<|turn>" + message.Role + "\n")
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<turn|>\n")
}
}
// Generation prompt
sb.WriteString("<|turn>model\n")
return sb.String(), nil
}
// stripThinking removes <|channel>...<channel|> thinking blocks from content,
// matching the HF chat template's strip_thinking macro.
func stripThinking(text string) string {
var result strings.Builder
for {
start := strings.Index(text, "<|channel>")
if start == -1 {
result.WriteString(text)
break
}
result.WriteString(text[:start])
end := strings.Index(text[start:], "<channel|>")
if end == -1 {
break
}
text = text[start+end+len("<channel|>"):]
}
return strings.TrimSpace(result.String())
}
// renderContent writes a message's content, interleaving [img-N] tags for images.
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
// template's | trim filter applied to non-model content).
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
if len(msg.Images) > 0 && r.useImgTags {
for range msg.Images {
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
*imageOffset++
}
}
content := msg.Content
if trim {
content = strings.TrimSpace(content)
}
sb.WriteString(content)
}
func (r *Gemma4Renderer) renderToolDeclaration(tool api.Tool) string {
var sb strings.Builder
fn := tool.Function
sb.WriteString("<|tool>declaration:" + fn.Name + "{")
sb.WriteString("description:" + g4Q + fn.Description + g4Q)
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
sb.WriteString(",parameters:{")
needsComma := false
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
sb.WriteString("properties:{")
r.writeProperties(&sb, fn.Parameters.Properties)
sb.WriteString("}")
needsComma = true
}
if len(fn.Parameters.Required) > 0 {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("required:[")
for i, req := range fn.Parameters.Required {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(g4Q + req + g4Q)
}
sb.WriteString("]")
needsComma = true
}
if fn.Parameters.Type != "" {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("type:" + g4Q + strings.ToUpper(fn.Parameters.Type) + g4Q)
}
sb.WriteString("}")
}
sb.WriteString("}<tool|>")
return sb.String()
}
func (r *Gemma4Renderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
keys := make([]string, 0, props.Len())
for k := range props.All() {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, name := range keys {
prop, _ := props.Get(name)
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(name + ":{")
hasContent := false
if prop.Description != "" {
sb.WriteString("description:" + g4Q + prop.Description + g4Q)
hasContent = true
}
if len(prop.Type) > 0 {
typeName := strings.ToUpper(prop.Type[0])
switch typeName {
case "STRING":
if len(prop.Enum) > 0 {
if hasContent {
sb.WriteString(",")
}
sb.WriteString("enum:[")
for j, e := range prop.Enum {
if j > 0 {
sb.WriteString(",")
}
sb.WriteString(g4Q + fmt.Sprintf("%v", e) + g4Q)
}
sb.WriteString("]")
hasContent = true
}
case "OBJECT":
// Render nested properties recursively.
// Note: the leading comma is hardcoded (matching the template),
// and this does NOT set hasContent — the comma before type:
// depends only on whether description was present.
sb.WriteString(",properties:{")
if prop.Properties != nil && prop.Properties.Len() > 0 {
r.writeProperties(sb, prop.Properties)
}
sb.WriteString("}")
if len(prop.Required) > 0 {
sb.WriteString(",required:[")
for j, req := range prop.Required {
if j > 0 {
sb.WriteString(",")
}
sb.WriteString(g4Q + req + g4Q)
}
sb.WriteString("]")
}
case "ARRAY":
// Render items specification.
// Same as OBJECT: leading comma is hardcoded, does NOT set hasContent.
if items, ok := prop.Items.(map[string]any); ok && len(items) > 0 {
sb.WriteString(",items:{")
r.writeItemsSpec(sb, items)
sb.WriteString("}")
}
}
if hasContent {
sb.WriteString(",")
}
sb.WriteString("type:" + g4Q + typeName + g4Q)
}
sb.WriteString("}")
}
}
// writeItemsSpec renders the items specification for array-type properties,
// matching the Jinja2 template's dictsort iteration over items.
func (r *Gemma4Renderer) writeItemsSpec(sb *strings.Builder, items map[string]any) {
keys := make([]string, 0, len(items))
for k := range items {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
value := items[key]
if value == nil {
continue
}
if !first {
sb.WriteString(",")
}
first = false
switch key {
case "type":
if s, ok := value.(string); ok {
sb.WriteString("type:" + g4Q + strings.ToUpper(s) + g4Q)
}
default:
sb.WriteString(key + ":" + r.formatArgValue(value))
}
}
}
func (r *Gemma4Renderer) formatToolCall(tc api.ToolCall) string {
var sb strings.Builder
sb.WriteString("<|tool_call>call:" + tc.Function.Name + "{")
keys := make([]string, 0, tc.Function.Arguments.Len())
for k := range tc.Function.Arguments.All() {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
value, _ := tc.Function.Arguments.Get(key)
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(value))
}
sb.WriteString("}<tool_call|>")
return sb.String()
}
func (r *Gemma4Renderer) formatArgValue(value any) string {
switch v := value.(type) {
case string:
return g4Q + v + g4Q
case bool:
if v {
return "true"
}
return "false"
case float64:
if v == float64(int64(v)) {
return fmt.Sprintf("%d", int64(v))
}
return fmt.Sprintf("%v", v)
case int, int64, int32:
return fmt.Sprintf("%d", v)
case map[string]any:
return r.formatMapValue(v)
case []any:
return r.formatArrayValue(v)
default:
return fmt.Sprintf("%v", v)
}
}
func (r *Gemma4Renderer) formatMapValue(m map[string]any) string {
var sb strings.Builder
sb.WriteString("{")
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
}
sb.WriteString("}")
return sb.String()
}
func (r *Gemma4Renderer) formatArrayValue(arr []any) string {
var sb strings.Builder
sb.WriteString("[")
for i, item := range arr {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(r.formatArgValue(item))
}
sb.WriteString("]")
return sb.String()
}

File diff suppressed because it is too large Load Diff

View File

@@ -81,6 +81,8 @@ func rendererForName(name string) Renderer {
return renderer
case "nemotron-3-nano":
return &Nemotron3NanoRenderer{}
case "gemma4":
return &Gemma4Renderer{useImgTags: RenderImgTags}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":

View File

@@ -0,0 +1,263 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'OBJECT' -%}
,properties:{
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
{%- elif value is mapping -%}
{{- format_parameters(value, value['required'] | default([])) -}}
{%- endif -%}
}
{%- if value['required'] -%}
,required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
,items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{ bos_token }}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{{- '<|turn>' + role + '\n' }}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- if message['tool_responses'] -%}
{#- Tool Response handling -#}
{%- for tool_response in message['tool_responses'] -%}
{{- '<|tool_response>' -}}
{%- if tool_response['response'] is mapping -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
{%- for key, value in tool_response['response'] | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '\n\n<|image|>\n\n' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '\n\n<|video|>\n\n' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if not (message['tool_responses'] and not message['content']) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- endif -%}

View File

@@ -522,6 +522,20 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
case "input_audio":
audioMap, ok := data["input_audio"].(map[string]any)
if !ok {
return nil, errors.New("invalid input_audio format")
}
b64Data, ok := audioMap["data"].(string)
if !ok {
return nil, errors.New("invalid input_audio format: missing data")
}
audioBytes, err := base64.StdEncoding.DecodeString(b64Data)
if err != nil {
return nil, fmt.Errorf("invalid input_audio base64 data: %w", err)
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{audioBytes}})
default:
return nil, errors.New("invalid message format")
}
@@ -824,6 +838,45 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
}
}
// TranscriptionResponse is the response format for /v1/audio/transcriptions.
type TranscriptionResponse struct {
Text string `json:"text"`
}
// TranscriptionRequest holds parsed fields from the multipart form.
type TranscriptionRequest struct {
Model string
AudioData []byte
ResponseFormat string // "json", "text", "verbose_json"
Language string
Prompt string
}
// FromTranscriptionRequest converts a transcription request into a ChatRequest
// by wrapping the audio with a system prompt for transcription.
func FromTranscriptionRequest(r TranscriptionRequest) (*api.ChatRequest, error) {
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
if r.Language != "" {
systemPrompt += " The audio is in " + r.Language + "."
}
if r.Prompt != "" {
systemPrompt += " Context: " + r.Prompt
}
stream := true
return &api.ChatRequest{
Model: r.Model,
Messages: []api.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{r.AudioData}},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
},
}, nil
}
// ImageEditRequest is an OpenAI-compatible image edit request.
type ImageEditRequest struct {
Model string `json:"model"`

View File

@@ -390,3 +390,48 @@ func (t *Terminal) Read() (rune, error) {
}
return r, nil
}
// SetRawModeOn enables raw terminal mode and keeps it on.
// Call SetRawModeOff to restore when done.
func (i *Instance) SetRawModeOn() error {
if i.Terminal.rawmode {
return nil
}
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return err
}
i.Terminal.rawmode = true
i.Terminal.termios = termios
return nil
}
// SetRawModeOff restores the terminal to its previous mode.
func (i *Instance) SetRawModeOff() {
if !i.Terminal.rawmode {
return
}
fd := os.Stdin.Fd()
//nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}
// ReadRaw reads a single rune. If the terminal is already in raw mode
// (via SetRawModeOn), it reads directly. Otherwise it temporarily enters
// raw mode for the read.
func (i *Instance) ReadRaw() (rune, error) {
if !i.Terminal.rawmode {
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return 0, err
}
defer func() {
//nolint:errcheck
UnsetRawMode(fd, termios)
}()
}
return i.Terminal.Read()
}

View File

@@ -1258,6 +1258,12 @@ func (s *Server) loadModel() {
panic(fmt.Errorf("failed to load model: %v", err))
}
if postLoader, ok := s.model.(model.PostLoader); ok {
if err := postLoader.PostLoad(); err != nil {
panic(fmt.Errorf("failed to finalize model initialization: %v", err))
}
}
s.status = llm.ServerStatusReady
s.ready.Done()
}

View File

@@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()}
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) {
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
@@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
if config.Requires == "" {
config.Requires = baseConfig.Requires
}
if len(config.Capabilities) == 0 {
config.Capabilities = baseConfig.Capabilities
}
}
cfgFile.Close()
}
@@ -509,6 +512,24 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
// TODO: abstract this into a registry/lookup table when multiple models
// need architecture-based renderer/parser/stop defaults.
if config.Renderer == "" || config.Parser == "" {
arch := layer.GGML.KV().Architecture()
switch arch {
case "gemma4":
config.Renderer = cmp.Or(config.Renderer, "gemma4")
config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil {
r.Parameters = make(map[string]any)
}
r.Parameters["stop"] = []string{"<turn|>"}
}
}
}
}
layers = append(layers, layer.Layer)
}

View File

@@ -39,6 +39,7 @@ var (
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityAudio = errors.New("audio")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errCapabilityImage = errors.New("image generation")
@@ -93,14 +94,26 @@ func (m *Model) Capabilities() []model.Capability {
if f.KeyValue("vision.block_count").Valid() {
capabilities = append(capabilities, model.CapabilityVision)
}
if f.KeyValue("audio.block_count").Valid() {
capabilities = append(capabilities, model.CapabilityAudio)
}
} else {
slog.Error("couldn't open model file", "error", err)
}
} else if len(m.Config.Capabilities) > 0 {
}
// Also include capabilities from the model config (e.g. vision capability
// set during creation for MLX/safetensors models).
if len(m.Config.Capabilities) > 0 {
for _, c := range m.Config.Capabilities {
capabilities = append(capabilities, model.Capability(c))
cap := model.Capability(c)
if !slices.Contains(capabilities, cap) {
capabilities = append(capabilities, cap)
}
}
} else {
}
if len(capabilities) == 0 {
slog.Warn("unknown capabilities for model", "model", m.Name)
}
@@ -141,6 +154,14 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityThinking)
}
// Temporary workaround — suppress vision/audio for gemma4 MLX models
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == "audio"
})
}
return capabilities
}
@@ -156,6 +177,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityTools: errCapabilityTools,
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityAudio: errCapabilityAudio,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
model.CapabilityImage: errCapabilityImage,

View File

@@ -153,7 +153,16 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
// MLA tensors need higher precision to avoid quality degradation
newType = fsggml.TensorTypeQ8_0
} else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
// expert alphabetically, so dense increments the counter and expert uses counter-1.
var iLayer int
if strings.Contains(name, "_exps") {
iLayer = max(0, qs.iFfnDown-1)
} else {
iLayer = qs.iFfnDown
qs.iFfnDown++
}
n_layer := qs.nFfnDown
if ftype == fsggml.FileTypeQ4_K_M {
if useMoreBits(iLayer, n_layer) {
@@ -162,7 +171,6 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ5_K
}
qs.iFfnDown++
} else if strings.Contains(name, "attn_output.weight") {
if nExperts == 8 {
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
@@ -255,8 +263,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision encoder tensors (named with "v." prefix)
// don't quantize vision or audio encoder tensors
quantize = quantize && !strings.HasPrefix(name, "v.")
quantize = quantize && !strings.HasPrefix(name, "a.")
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)

View File

@@ -1718,6 +1718,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
// OpenAI-compatible audio endpoint
r.POST("/v1/audio/transcriptions", middleware.TranscriptionMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)

View File

@@ -10,6 +10,7 @@ const (
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImage = Capability("image")
CapabilityAudio = Capability("audio")
)
func (c Capability) String() string {