mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
Add support for gemma4 (#15214)
* bench: add prompt calibration, context size flag, and NumCtx reporting Add --num-ctx flag to set context size, and report NumCtx in model info header. Calibrate tokens-per-word ratio during warmup using actual tokenization metrics from the model, replacing the fixed 1.3 heuristic. This produces more accurate prompt token counts for --prompt-tokens. Also add fetchContextLength() to query running model context via /api/ps. * integration: improve vision test robustness and add thinking tests Add skipIfNoVisionOverride() to skip vision tests when OLLAMA_TEST_MODEL is set to a non-vision model. Add Think:false to context exhaustion test to prevent thinking models from using all context before the test can measure it. Add third test image (ollama homepage) and replace OCR test with ImageDescription test using it. Relax match strings for broader model compatibility. Add TestThinkingEnabled and TestThinkingSuppressed to verify thinking output and channel tag handling. * gemma4: add Gemma 4 GGML model support Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense) for the GGML backend including text, vision, converter, parser, and renderer. Text model features: - Sliding window + full attention with per-layer patterns - KV sharing across layers with donor map - Per-layer embeddings (PLE) with learned projections - MoE routing with RMSNorm + learned scale - Proportional RoPE with freq_factors for global attention - Final logit softcapping Vision model features: - SigLIP vision encoder with 2D RoPE - ClippableLinear with input/output clamping via packed v.clamp_data - Adaptive average pooling with nMerge kernel - Multi-modal projection with unweighted RMSNorm Converter: - Safetensors to GGUF with vision tensor renaming - Fused MoE gate_up_proj splitting - Vision patch embedding reshape (HF to Conv2D layout) - Packed clamp data tensor for ClippableLinear bounds - Proportional RoPE freq_factors generation Also includes: - BackendGet() on ml.Tensor for reading weight tensor data - Q6_K CUDA get_rows kernel support - MoE-aware ffn_down quantization layer counting - Gemma4 parser with tool calling and thinking support - Gemma4 renderer with structured tool format - Architecture-based auto-detection of renderer/parser/stop tokens - Integration test gemma4 model list additions * gemma4: add audio support with USM conformer encoder Add audio encoding for Gemma 4 using the USM conformer architecture: - Converter: audio tensor mapping, SSCP/conformer/embedder name replacements, softplus repacker for per_dim_scale, F32 enforcement for conv weights - GGML backend: Conv1DDW and PadExt tensor ops - Audio encoder: SSCP Conv2D, 12 conformer blocks (FFW + block-local attention with relative position embeddings + LightConv1d + FFW), output projection, audio-to-text embedding projector - Audio preprocessing: WAV decode, mel spectrogram, FFT (pure Go) - Model wiring: WAV detection, audio token handling, unified PostTokenize Correctly transcribes "why is the sky blue" from test audio. * integration: add gemma4 audio tests including OpenAI API coverage Test audio transcription and response via the Ollama native API, plus two new tests exercising the OpenAI-compatible endpoints: - /v1/audio/transcriptions (multipart form upload) - /v1/chat/completions with input_audio content type All tests use capability checks and skip models without audio support. * gemma4: add OpenAI audio API support and capability detection - Add CapabilityAudio and detect from audio.block_count in GGUF - Add /v1/audio/transcriptions endpoint with TranscriptionMiddleware - Add input_audio content type support in /v1/chat/completions - Add TranscriptionRequest/Response types in openai package * gemma4: add audio input support for run command - /audio toggle in interactive mode for voice chat - Platform-specific microphone recording (AVFoundation on macOS, PulseAudio/ALSA on Linux, WASAPI on Windows) - Space to start/stop recording, automatic chunking for long audio * gemma4: add transcribe command (ollama transcribe MODEL) - Interactive mode with readline prompt and slash commands - Non-interactive mode for piped audio or record-until-Ctrl+C - Chunked streaming transcription for long recordings - Word-wrapped output matching run command style * gemma4: add parser, renderer, and integration test plumbing * gemma4: fix renderer to emit BOS token * gemma4: add OpenAI audio transcription API and input_audio support * gemma4: update converter for new weight drop naming * gemma4: add per_expert_scale to MoE router and fix moe_intermediate_size config * gemma4: rewrite renderer to match HF Jinja2 template exactly Fix 8 bugs found by building 55 reference tests verified against the HF Jinja2 chat template (VERIFY_JINJA2=1 shells out to Python): - Tool responses use separate <|turn>tool turns (not inline tags) - Tool calls emitted before content in assistant messages - Thinking content stripped from assistant history (strip_thinking) - User, tool, and system content trimmed (template does | trim) - Empty system message still emits system turn (check role, not content) - Nested object properties rendered recursively with required field - Array items specification rendered for array-type properties - OBJECT/ARRAY type-specific rendering comma logic matches template Also adds Required field to api.ToolProperty for nested object schemas, replaces old gemma4_test.go with comprehensive gemma4_reference_test.go, and commits the Jinja2 template as testdata for verification. * gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing - Text MoE: split `ffn_gate_up_exps` into contiguous `[gate|up]` halves instead of stride-2 slices. - Parser: escape control characters in `<|"|>...<|"|>` string literals when converting tool-call args to JSON. - Fixes warnings like `invalid character '\n' in string literal` for multiline tool arguments. - Add Gemma4 parser regressions for multiline tool-call args and `gemma4ArgsToJSON`. * cmd: simplify audio input to dropped file attachments * gemma4: use full SWA memory for better cache reuse * gemma4: initialize clamps after backend load * convert: align gemma4 audio tensor renames with llama.cpp * Remove redundant comments in gemma4 vision model * Format Gemma4 MoE block field alignment * use 4096 kvcache.NewSWAMemCache * convert: support new Gemma4 audio_tower tensor naming (#15221) Co-authored-by: jmorganca <jmorganca@gmail.com> * fix integration test defaults for audio * review comments and lint fixes * remove unused audio/video files --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
@@ -436,6 +436,7 @@ type ToolProperty struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
|
||||
@@ -32,6 +32,7 @@ type flagOptions struct {
|
||||
verbose *bool
|
||||
warmup *int
|
||||
promptTokens *int
|
||||
numCtx *int
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
@@ -48,6 +49,7 @@ type ModelInfo struct {
|
||||
Family string
|
||||
SizeBytes int64
|
||||
VRAMBytes int64
|
||||
NumCtx int64
|
||||
}
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
@@ -64,9 +66,12 @@ var promptWordList = []string{
|
||||
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||
}
|
||||
|
||||
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
|
||||
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
|
||||
var tokensPerWord = 1.3
|
||||
|
||||
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
// ~1.3 tokens per word heuristic
|
||||
targetWords := int(float64(targetTokens) / 1.3)
|
||||
targetWords := int(float64(targetTokens) / tokensPerWord)
|
||||
if targetWords < 1 {
|
||||
targetWords = 1
|
||||
}
|
||||
@@ -81,6 +86,17 @@ func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
|
||||
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
|
||||
if actualTokens <= 0 || wordCount <= 0 {
|
||||
return
|
||||
}
|
||||
tokensPerWord = float64(actualTokens) / float64(wordCount)
|
||||
newWords := int(float64(targetTokens) / tokensPerWord)
|
||||
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
|
||||
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
|
||||
}
|
||||
|
||||
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
@@ -90,6 +106,9 @@ func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData,
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
options["num_ctx"] = *fOpt.numCtx
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
@@ -146,7 +165,6 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
// Try prefix match (model names may include :latest or tags)
|
||||
for _, m := range resp.Models {
|
||||
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return m.Size, m.SizeVRAM
|
||||
@@ -155,6 +173,19 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return int64(m.ContextLength)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
@@ -177,8 +208,12 @@ func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||
if info.SizeBytes > 0 {
|
||||
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
|
||||
info.Name, params, quant, family, memStr)
|
||||
ctxStr := ""
|
||||
if info.NumCtx > 0 {
|
||||
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
|
||||
info.Name, params, quant, family, memStr, ctxStr)
|
||||
}
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
@@ -276,21 +311,38 @@ func BenchmarkModel(fOpt flagOptions) error {
|
||||
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
|
||||
var warmupMetrics *api.Metrics
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
warmupMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||
} else if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
} else {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
}
|
||||
// Calibrate prompt token count on last warmup run
|
||||
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch memory usage once after warmup (model is loaded and stable)
|
||||
// Fetch memory/context info once after warmup (model is loaded and stable)
|
||||
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
info.NumCtx = int64(*fOpt.numCtx)
|
||||
} else {
|
||||
info.NumCtx = fetchContextLength(memCtx, client, model)
|
||||
}
|
||||
memCancel()
|
||||
|
||||
outputModelInfo(out, *fOpt.format, info)
|
||||
@@ -479,6 +531,7 @@ func main() {
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
|
||||
@@ -695,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
@@ -1494,6 +1495,9 @@ type displayResponseState struct {
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if termWidth == 0 {
|
||||
termWidth = 80
|
||||
}
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
|
||||
@@ -47,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -592,7 +592,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
@@ -608,10 +608,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
ext := strings.ToLower(filepath.Ext(nfp))
|
||||
switch ext {
|
||||
case ".wav":
|
||||
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
}
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
@@ -685,9 +691,9 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
return nil, fmt.Errorf("invalid file type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
@@ -695,8 +701,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB
|
||||
if info.Size() > maxSize {
|
||||
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
func TestExtractFileDataWAV(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "sample.wav")
|
||||
data := make([]byte, 600)
|
||||
copy(data[:44], []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||
0x01, 0x00, // PCM
|
||||
0x01, 0x00, // mono
|
||||
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||
0x02, 0x00, // block align
|
||||
0x10, 0x00, // 16-bit
|
||||
'd', 'a', 't', 'a',
|
||||
0x34, 0x02, 0x00, 0x00, // data size
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test audio: %v", err)
|
||||
}
|
||||
|
||||
input := "before " + fp + " after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, "before after", cleaned)
|
||||
}
|
||||
|
||||
@@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
|
||||
conv = &gemma4Model{Architecture: p.Architectures[0]}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
||||
574
convert/convert_gemma4.go
Normal file
574
convert/convert_gemma4.go
Normal file
@@ -0,0 +1,574 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma4Model struct {
|
||||
gemmaModel
|
||||
Architecture string
|
||||
TextModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
GlobalHeadDim uint32 `json:"global_head_dim"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
|
||||
EnableMoeBlock bool `json:"enable_moe_block"`
|
||||
NumExperts *uint32 `json:"num_experts"`
|
||||
TopKExperts *uint32 `json:"top_k_experts"`
|
||||
ExpertIntermediateSize *uint32 `json:"moe_intermediate_size"`
|
||||
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
AttentionKEqV bool `json:"attention_k_eq_v"`
|
||||
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
|
||||
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
|
||||
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
|
||||
RopeParameters map[string]*struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
|
||||
LayerNormEps float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
AudioModel *struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
OutputProjDims uint32 `json:"output_proj_dims"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
ConvKernelSize uint32 `json:"conv_kernel_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
} `json:"audio_config"`
|
||||
}
|
||||
|
||||
func (p *gemma4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma4"
|
||||
kv["tokenizer.ggml.model"] = "llama"
|
||||
kv["tokenizer.ggml.pre"] = "gemma4"
|
||||
|
||||
tc := p.TextModel
|
||||
|
||||
kv["gemma4.block_count"] = tc.NumHiddenLayers
|
||||
kv["gemma4.embedding_length"] = tc.HiddenSize
|
||||
|
||||
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
|
||||
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
|
||||
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
|
||||
ffnWidths := make([]int32, tc.NumHiddenLayers)
|
||||
for i := range ffnWidths {
|
||||
if i >= firstShared {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize * 2)
|
||||
} else {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize)
|
||||
}
|
||||
}
|
||||
kv["gemma4.feed_forward_length"] = ffnWidths
|
||||
} else {
|
||||
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
|
||||
}
|
||||
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
|
||||
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
|
||||
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
|
||||
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
|
||||
kvHeads := make([]int32, len(tc.LayerTypes))
|
||||
for i, lt := range tc.LayerTypes {
|
||||
if lt == "sliding_attention" {
|
||||
kvHeads[i] = int32(tc.NumKeyValueHeads)
|
||||
} else {
|
||||
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
|
||||
}
|
||||
}
|
||||
kv["gemma4.attention.head_count_kv"] = kvHeads
|
||||
} else {
|
||||
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
|
||||
}
|
||||
// key_length = global head dim, key_length_swa = local (SWA) head dim
|
||||
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
|
||||
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
|
||||
|
||||
// Sliding window pattern from layer_types
|
||||
if len(tc.LayerTypes) > 0 {
|
||||
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, lt := range tc.LayerTypes {
|
||||
if !yield(lt == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
|
||||
|
||||
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
|
||||
}
|
||||
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
|
||||
}
|
||||
|
||||
if tc.FinalLogitSoftcapping > 0 {
|
||||
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
|
||||
}
|
||||
|
||||
// MoE
|
||||
if tc.EnableMoeBlock && tc.NumExperts != nil {
|
||||
kv["gemma4.expert_count"] = *tc.NumExperts
|
||||
if tc.TopKExperts != nil {
|
||||
kv["gemma4.expert_used_count"] = *tc.TopKExperts
|
||||
}
|
||||
if tc.ExpertIntermediateSize != nil {
|
||||
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
|
||||
}
|
||||
}
|
||||
|
||||
// PLE — always emit, even when 0
|
||||
pleSize := uint32(0)
|
||||
if tc.HiddenSizePerLayerInput != nil {
|
||||
pleSize = *tc.HiddenSizePerLayerInput
|
||||
}
|
||||
kv["gemma4.embedding_length_per_layer_input"] = pleSize
|
||||
|
||||
// Vision model KV metadata
|
||||
vc := p.VisionModel
|
||||
if vc.NumHiddenLayers > 0 {
|
||||
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
|
||||
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
|
||||
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
|
||||
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
|
||||
kv["gemma4.vision.patch_size"] = vc.PatchSize
|
||||
numCh := vc.NumChannels
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
kv["gemma4.vision.num_channels"] = numCh
|
||||
nMerge := vc.PoolingKernelSize
|
||||
if nMerge == 0 {
|
||||
nMerge = 3
|
||||
}
|
||||
kv["gemma4.vision.projector.scale_factor"] = nMerge
|
||||
eps := vc.LayerNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
|
||||
}
|
||||
|
||||
// Audio model KV metadata
|
||||
if p.AudioModel != nil && p.AudioModel.NumHiddenLayers > 0 {
|
||||
ac := p.AudioModel
|
||||
kv["gemma4.audio.block_count"] = ac.NumHiddenLayers
|
||||
kv["gemma4.audio.embedding_length"] = ac.HiddenSize
|
||||
kv["gemma4.audio.feed_forward_length"] = ac.HiddenSize * 4
|
||||
kv["gemma4.audio.attention.head_count"] = ac.NumAttentionHeads
|
||||
eps := ac.RMSNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.audio.attention.layer_norm_epsilon"] = eps
|
||||
if ac.ConvKernelSize > 0 {
|
||||
kv["gemma4.audio.conv_kernel_size"] = ac.ConvKernelSize
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// First pass: collect vision clamp scalar values into a packed tensor.
|
||||
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
|
||||
// Then 4 values for the projector (mm.input_projection).
|
||||
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
clampMap := make(map[string]float32)
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
for _, sfx := range clampSuffixes {
|
||||
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
data := buf.Bytes()
|
||||
if len(data) >= 4 {
|
||||
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
|
||||
if strings.Contains(name, "embedding_post_projection_norm") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Vision tensor renaming: match published mmproj GGUF names
|
||||
if strings.HasPrefix(name, "v.blk.") {
|
||||
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
|
||||
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
|
||||
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
|
||||
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
|
||||
}
|
||||
|
||||
// per_dim_scale: apply softplus to weight data and add .weight suffix.
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.HasSuffix(name, "per_dim_scale") {
|
||||
name = name + ".weight"
|
||||
t.SetRepacker(softplusRepacker)
|
||||
}
|
||||
|
||||
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
|
||||
t.SetRepacker(squeezeMiddleDim)
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
|
||||
// Convert scalar tensors (input_min/max, output_min/max) to 1D
|
||||
if len(shape) == 0 {
|
||||
shape = []uint64{1}
|
||||
}
|
||||
|
||||
// Depthwise conv1d shape: safetensors [C, 1, K] → GGUF ne[K, C].
|
||||
// Shape array here maps to GGUF ne[] directly, but safetensors reader
|
||||
// stores shape in PyTorch order [C, 1, K] which the GGUF writer inverts.
|
||||
// Published GGUF has ne[0]=K, ne[1]=C → shape array must be [K, C].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
|
||||
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
|
||||
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
|
||||
// (transposeExperts was incorrectly swapping dims — removed)
|
||||
|
||||
// Audio conv weights are forced to F32 via tensorBase.Kind() in reader.go
|
||||
// (im2col doesn't support BF16). No kindOverride needed — the Kind() method
|
||||
// controls both the GGUF header type AND the WriteTo data encoding path.
|
||||
var kindOverride *uint32
|
||||
|
||||
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
|
||||
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
|
||||
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
|
||||
nEmbd := shape[0]
|
||||
patchSize := uint64(p.VisionModel.PatchSize)
|
||||
if patchSize == 0 {
|
||||
patchSize = 16
|
||||
}
|
||||
numCh := uint64(p.VisionModel.NumChannels)
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
t.SetRepacker(p.reshapePatchEmbed)
|
||||
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
|
||||
f16Kind := uint32(1) // tensorKindFP16
|
||||
kindOverride = &f16Kind
|
||||
}
|
||||
|
||||
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
|
||||
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
|
||||
|
||||
kind := t.Kind()
|
||||
if kindOverride != nil {
|
||||
kind = *kindOverride
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
|
||||
// This matches the published GGUF format: one global tensor shared by all layers.
|
||||
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
|
||||
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
|
||||
tc := p.TextModel
|
||||
if tc.GlobalHeadDim > 0 {
|
||||
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
|
||||
|
||||
// Compute number of rotated pairs for global layers
|
||||
partialRotaryFactor := float32(0.25) // default
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
|
||||
partialRotaryFactor = *rp.PartialRotaryFactor
|
||||
}
|
||||
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
|
||||
|
||||
freqs := make(ropeFactor, globalFreqsSize)
|
||||
for j := range freqs {
|
||||
if j < nRotFull {
|
||||
freqs[j] = 1.0
|
||||
} else {
|
||||
freqs[j] = 1e30 // effectively disable rotation
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(len(freqs))},
|
||||
WriterTo: freqs,
|
||||
})
|
||||
}
|
||||
|
||||
// Emit packed vision clamp data as a single F32 tensor.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
|
||||
if len(clampMap) > 0 {
|
||||
numLayers := int(p.VisionModel.NumHiddenLayers)
|
||||
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
|
||||
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
|
||||
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
|
||||
clampData := make([]float32, totalFloats)
|
||||
|
||||
for layer := range numLayers {
|
||||
for li, ln := range linearNames {
|
||||
for si, sfx := range suffixes {
|
||||
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
|
||||
idx := (layer*len(linearNames)+li)*4 + si
|
||||
clampData[idx] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Projector clamp values
|
||||
projIdx := numLayers * len(linearNames) * 4
|
||||
for si, sfx := range suffixes {
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
|
||||
clampData[projIdx+si] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, clampData)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "v.clamp_data",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(totalFloats)},
|
||||
WriterTo: &buf,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
|
||||
// to GGUF layout [n_embd, channels, patch_size, patch_size].
|
||||
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
nEmbd := int(shape[0])
|
||||
ksqC := int(shape[1])
|
||||
nChannels := 3
|
||||
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
|
||||
|
||||
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
|
||||
// Need: [n_embd, channels, patch_size, patch_size]
|
||||
result := make([]float32, len(data))
|
||||
for e := range nEmbd {
|
||||
for c := range nChannels {
|
||||
for h := range patchSize {
|
||||
for w := range patchSize {
|
||||
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
|
||||
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
|
||||
result[dstIdx] = data[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
shape[0] = uint64(nEmbd)
|
||||
shape[1] = uint64(nChannels * patchSize * patchSize)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
|
||||
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
|
||||
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
result := make([]float32, len(data))
|
||||
for i, x := range data {
|
||||
result[i] = float32(math.Log(1 + math.Exp(float64(x))))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// squeezeMiddleDim squeezes the middle dimension from [C, 1, K] → [C, K] for depthwise conv1d weights.
|
||||
// Data layout stays the same since the middle dim is 1 — just a shape change.
|
||||
func squeezeMiddleDim(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Replacements() []string {
|
||||
return []string{
|
||||
// ClippableLinear wraps nn.Linear — strip .linear. from weight path
|
||||
".linear.weight", ".weight",
|
||||
".linear.bias", ".bias",
|
||||
|
||||
// Audio SSCP (Sub-Sample Convolution Projection)
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.layer0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.layer0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.layer1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.layer1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
|
||||
|
||||
// Audio conformer blocks
|
||||
"model.audio_tower.conformer", "a.blk",
|
||||
"model.audio_tower.layers", "a.blk",
|
||||
|
||||
// Audio conformer attention
|
||||
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
|
||||
"self_attn.relative_k_proj", "linear_pos",
|
||||
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
|
||||
"attention.attn.per_dim_scale", "per_dim_scale",
|
||||
"self_attn.per_dim_scale", "per_dim_scale",
|
||||
"attention.attn.q_proj", "attn_q",
|
||||
"attention.attn.k_proj", "attn_k",
|
||||
"attention.attn.v_proj", "attn_v",
|
||||
"attention.pre_attn_norm", "ln1",
|
||||
"attention.post_norm", "ln2",
|
||||
"attention.post", "attn_out",
|
||||
"self_attn.post", "attn_out",
|
||||
"norm_pre_attn", "ln1",
|
||||
"norm_post_attn", "ln2",
|
||||
|
||||
// Audio conformer feedforward
|
||||
"ffw_layer_start.pre_layer_norm", "ffn_norm",
|
||||
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
|
||||
"ffw_layer_start.ffw_layer_1", "ffn_up",
|
||||
"ffw_layer_start.ffw_layer_2", "ffn_down",
|
||||
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
|
||||
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
|
||||
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
|
||||
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
|
||||
"feed_forward1.pre_layer_norm", "ffn_norm",
|
||||
"feed_forward1.post_layer_norm", "ffn_post_norm",
|
||||
"feed_forward1.ffw_layer_1", "ffn_up",
|
||||
"feed_forward1.ffw_layer_2", "ffn_down",
|
||||
"feed_forward2.pre_layer_norm", "ffn_norm_1",
|
||||
"feed_forward2.post_layer_norm", "ffn_post_norm_1",
|
||||
"feed_forward2.ffw_layer_1", "ffn_up_1",
|
||||
"feed_forward2.ffw_layer_2", "ffn_down_1",
|
||||
|
||||
// Audio conformer lightweight conv1d
|
||||
"lconv1d.depthwise_conv1d", "conv_dw",
|
||||
"lconv1d.pre_layer_norm", "conv_norm",
|
||||
"lconv1d.conv_norm", "norm_conv",
|
||||
"lconv1d.linear_start", "conv_pw1",
|
||||
"lconv1d.linear_end", "conv_pw2",
|
||||
|
||||
// Audio block final norm
|
||||
"norm_out", "layer_pre_norm",
|
||||
|
||||
// Audio embedder and output projection
|
||||
"model.embed_audio.embedding_projection", "mm.a.input_projection",
|
||||
"model.audio_tower.output_proj", "mm.a.fc",
|
||||
|
||||
// Vision encoder
|
||||
"model.vision_tower.encoder.layers", "v.blk",
|
||||
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
|
||||
"model.vision_tower.std_bias", "v.std_bias",
|
||||
"model.vision_tower.std_scale", "v.std_scale",
|
||||
|
||||
// Vision multimodal projector
|
||||
"model.embed_vision.embedding_projection", "mm.input_projection",
|
||||
|
||||
// Text model
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
// Shared attention replacements (work for both text and vision tensors)
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
// Post norms
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm_1", "post_ffw_norm_1",
|
||||
"post_feedforward_layernorm_2", "post_ffw_norm_2",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
|
||||
// PLE
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
|
||||
// MoE
|
||||
"router.proj", "ffn_gate_inp",
|
||||
"router.scale", "ffn_gate_inp.scale",
|
||||
"router.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"router.per_expert_scale", "ffn_down_exps.scale",
|
||||
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"experts.down_proj.weight", "ffn_down_exps.weight",
|
||||
"experts.down_proj", "ffn_down_exps.weight",
|
||||
"moe.gate_proj", "ffn_gate_exps.weight",
|
||||
"moe.up_proj", "ffn_up_exps.weight",
|
||||
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"moe.down_proj", "ffn_down_exps.weight",
|
||||
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"moe.per_expert_scale", "ffn_down_exps.scale",
|
||||
|
||||
// Layer scalar
|
||||
"layer_scalar", "layer_output_scale.weight",
|
||||
}
|
||||
}
|
||||
318
convert/convert_gemma4_test.go
Normal file
318
convert/convert_gemma4_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4AudioReplacements(t *testing.T) {
|
||||
p := gemma4Model{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
// SSCP convolution blocks
|
||||
{
|
||||
"sscp conv0 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv.weight",
|
||||
"a.conv1d.0.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv0 norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm.weight",
|
||||
"a.conv1d.0.norm.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv1 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv.weight",
|
||||
"a.conv1d.1.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj weight",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.weight",
|
||||
"a.pre_encode.out.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj bias",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.bias",
|
||||
"a.pre_encode.out.bias",
|
||||
},
|
||||
{
|
||||
"sscp layer0 conv weight (new naming)",
|
||||
"model.audio_tower.subsample_conv_projection.layer0.conv.weight",
|
||||
"a.conv1d.0.weight",
|
||||
},
|
||||
{
|
||||
"sscp layer1 norm weight (new naming)",
|
||||
"model.audio_tower.subsample_conv_projection.layer1.norm.weight",
|
||||
"a.conv1d.1.norm.weight",
|
||||
},
|
||||
|
||||
// Conformer attention
|
||||
{
|
||||
"attn q weight",
|
||||
"model.audio_tower.conformer.0.attention.attn.q_proj.linear.weight",
|
||||
"a.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"attn k weight",
|
||||
"model.audio_tower.conformer.5.attention.attn.k_proj.linear.weight",
|
||||
"a.blk.5.attn_k.weight",
|
||||
},
|
||||
{
|
||||
"attn v clamp input_min",
|
||||
"model.audio_tower.conformer.0.attention.attn.v_proj.input_min",
|
||||
"a.blk.0.attn_v.input_min",
|
||||
},
|
||||
{
|
||||
"attn out weight (ClippableLinear)",
|
||||
"model.audio_tower.conformer.0.attention.post.linear.weight",
|
||||
"a.blk.0.attn_out.weight",
|
||||
},
|
||||
{
|
||||
"attn out clamp output_max",
|
||||
"model.audio_tower.conformer.0.attention.post.output_max",
|
||||
"a.blk.0.attn_out.output_max",
|
||||
},
|
||||
{
|
||||
"attn pre norm",
|
||||
"model.audio_tower.conformer.0.attention.pre_attn_norm.weight",
|
||||
"a.blk.0.ln1.weight",
|
||||
},
|
||||
{
|
||||
"attn post norm",
|
||||
"model.audio_tower.conformer.0.attention.post_norm.weight",
|
||||
"a.blk.0.ln2.weight",
|
||||
},
|
||||
{
|
||||
"linear pos",
|
||||
"model.audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight",
|
||||
"a.blk.0.linear_pos.weight",
|
||||
},
|
||||
{
|
||||
"per dim scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_scale",
|
||||
"a.blk.0.per_dim_scale",
|
||||
},
|
||||
{
|
||||
"per dim key scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_key_scale",
|
||||
"a.blk.0.per_dim_k_scale",
|
||||
},
|
||||
{
|
||||
"attn relative k proj (new naming)",
|
||||
"model.audio_tower.layers.0.self_attn.relative_k_proj.weight",
|
||||
"a.blk.0.linear_pos.weight",
|
||||
},
|
||||
{
|
||||
"attn pre norm (new naming)",
|
||||
"model.audio_tower.layers.0.norm_pre_attn.weight",
|
||||
"a.blk.0.ln1.weight",
|
||||
},
|
||||
{
|
||||
"attn post norm (new naming)",
|
||||
"model.audio_tower.layers.0.norm_post_attn.weight",
|
||||
"a.blk.0.ln2.weight",
|
||||
},
|
||||
{
|
||||
"attn out clamp output_max (new naming)",
|
||||
"model.audio_tower.layers.0.self_attn.post.output_max",
|
||||
"a.blk.0.attn_out.output_max",
|
||||
},
|
||||
{
|
||||
"per dim scale (new naming)",
|
||||
"model.audio_tower.layers.0.self_attn.per_dim_scale",
|
||||
"a.blk.0.per_dim_scale",
|
||||
},
|
||||
|
||||
// Conformer feedforward start
|
||||
{
|
||||
"ffn up weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up.weight",
|
||||
},
|
||||
{
|
||||
"ffn down weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm.weight",
|
||||
},
|
||||
|
||||
// Conformer feedforward end
|
||||
{
|
||||
"ffn up 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn down 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn up output_max (new naming)",
|
||||
"model.audio_tower.layers.10.feed_forward1.ffw_layer_1.output_max",
|
||||
"a.blk.10.ffn_up.output_max",
|
||||
},
|
||||
{
|
||||
"ffn down output_min (new naming)",
|
||||
"model.audio_tower.layers.0.feed_forward1.ffw_layer_2.output_min",
|
||||
"a.blk.0.ffn_down.output_min",
|
||||
},
|
||||
{
|
||||
"ffn up 1 input_max (new naming)",
|
||||
"model.audio_tower.layers.0.feed_forward2.ffw_layer_1.input_max",
|
||||
"a.blk.0.ffn_up_1.input_max",
|
||||
},
|
||||
{
|
||||
"ffn norm 1 (new naming)",
|
||||
"model.audio_tower.layers.0.feed_forward2.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm_1.weight",
|
||||
},
|
||||
|
||||
// Conformer lightweight conv1d
|
||||
{
|
||||
"conv dw weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.depthwise_conv1d.weight",
|
||||
"a.blk.0.conv_dw.weight",
|
||||
},
|
||||
{
|
||||
"conv norm (pre_layer_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.pre_layer_norm.weight",
|
||||
"a.blk.0.conv_norm.weight",
|
||||
},
|
||||
{
|
||||
"norm conv (conv_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.conv_norm.weight",
|
||||
"a.blk.0.norm_conv.weight",
|
||||
},
|
||||
{
|
||||
"conv pw1 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_start.linear.weight",
|
||||
"a.blk.0.conv_pw1.weight",
|
||||
},
|
||||
{
|
||||
"conv pw2 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_end.linear.weight",
|
||||
"a.blk.0.conv_pw2.weight",
|
||||
},
|
||||
|
||||
// Audio embedder
|
||||
{
|
||||
"audio embedder projection weight",
|
||||
"model.embed_audio.embedding_projection.linear.weight",
|
||||
"mm.a.input_projection.weight",
|
||||
},
|
||||
{
|
||||
"audio embedder projection bias",
|
||||
"model.embed_audio.embedding_projection.linear.bias",
|
||||
"mm.a.input_projection.bias",
|
||||
},
|
||||
|
||||
// Audio output projection
|
||||
{
|
||||
"audio output proj weight",
|
||||
"model.audio_tower.output_proj.weight",
|
||||
"mm.a.fc.weight",
|
||||
},
|
||||
{
|
||||
"audio output proj bias",
|
||||
"model.audio_tower.output_proj.bias",
|
||||
"mm.a.fc.bias",
|
||||
},
|
||||
|
||||
// Verify vision tensors still work
|
||||
{
|
||||
"vision q weight",
|
||||
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"vision std bias",
|
||||
"model.vision_tower.std_bias",
|
||||
"v.std_bias",
|
||||
},
|
||||
{
|
||||
"vision std scale",
|
||||
"model.vision_tower.std_scale",
|
||||
"v.std_scale",
|
||||
},
|
||||
{
|
||||
"vision patch embd",
|
||||
"model.vision_tower.patch_embedder.input_proj.weight",
|
||||
"v.patch_embd.weight",
|
||||
},
|
||||
{
|
||||
"vision projector",
|
||||
"model.embed_vision.embedding_projection.linear.weight",
|
||||
"mm.input_projection.weight",
|
||||
},
|
||||
|
||||
// Verify text tensors still work
|
||||
{
|
||||
"text attn q",
|
||||
"model.language_model.layers.0.self_attn.q_proj.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"text token embd",
|
||||
"model.language_model.embed_tokens.weight",
|
||||
"token_embd.weight",
|
||||
},
|
||||
{
|
||||
"text moe gate up fused",
|
||||
"model.language_model.layers.0.experts.gate_up_proj",
|
||||
"blk.0.ffn_gate_up_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down",
|
||||
"model.language_model.layers.0.experts.down_proj",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down with weight suffix",
|
||||
"model.language_model.layers.0.experts.down_proj.weight",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale",
|
||||
"model.language_model.layers.0.router.per_expert_scale",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale with weight suffix",
|
||||
"model.language_model.layers.0.router.per_expert_scale.weight",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := r.Replace(tt.in); got != tt.want {
|
||||
t.Errorf("Replace(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
|
||||
generateSafetensorTestData(t, tempDir, td)
|
||||
|
||||
err = ConvertModel(os.DirFS(tempDir), f)
|
||||
if err == nil || err.Error() != "unsupported safetensors model" {
|
||||
t.Errorf("expected error but didn't get one")
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown data type") {
|
||||
t.Errorf("expected 'unknown data type' error but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,8 +42,11 @@ func (t tensorBase) Kind() uint32 {
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
|
||||
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.position_embd.weight" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
|
||||
for _, key := range keys {
|
||||
if value := headers[key]; value.Type != "" {
|
||||
// bitsandbytes quantized models are unsupported
|
||||
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
|
||||
// Promote them to 1-dim so they can be stored in GGUF.
|
||||
if len(value.Shape) == 0 {
|
||||
return nil, errors.New("unsupported safetensors model")
|
||||
value.Shape = []uint64{1}
|
||||
}
|
||||
ggufName := replacer.Replace(key)
|
||||
if _, ok := names[ggufName]; ok {
|
||||
|
||||
@@ -281,6 +281,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
|
||||
259
integration/audio_test.go
Normal file
259
integration/audio_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var defaultAudioModels = []string{
|
||||
"gemma4:e2b",
|
||||
"gemma4:e4b",
|
||||
}
|
||||
|
||||
// decodeTestAudio returns the test audio clip ("Why is the sky blue?", 16kHz mono WAV).
|
||||
func decodeTestAudio(t *testing.T) api.ImageData {
|
||||
t.Helper()
|
||||
data, err := base64.StdEncoding.DecodeString(audioEncodingPrompt)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode test audio: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// setupAudioModel pulls the model, preloads it, and skips if it doesn't support audio.
|
||||
func setupAudioModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
|
||||
t.Helper()
|
||||
requireCapability(ctx, t, client, model, "audio")
|
||||
pullOrSkip(ctx, t, client, model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioTranscription tests that the model can transcribe audio to text.
|
||||
func TestAudioTranscription(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audio := decodeTestAudio(t)
|
||||
noThink := &api.ThinkValue{Value: false}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Think: noThink,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Transcribe the audio exactly as spoken. Output only the transcription.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Transcribe this audio.",
|
||||
Images: []api.ImageData{audio},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 50,
|
||||
},
|
||||
}
|
||||
|
||||
// The audio says "Why is the sky blue?" — expect key words in transcription.
|
||||
DoChat(ctx, t, client, req, []string{"sky", "blue"}, 60*time.Second, 10*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioResponse tests that the model can respond to a spoken question.
|
||||
func TestAudioResponse(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audio := decodeTestAudio(t)
|
||||
noThink := &api.ThinkValue{Value: false}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Think: noThink,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "",
|
||||
Images: []api.ImageData{audio},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 200,
|
||||
},
|
||||
}
|
||||
|
||||
// The audio asks "Why is the sky blue?" — expect an answer about light/scattering.
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"scatter", "light", "blue", "atmosphere", "wavelength", "rayleigh",
|
||||
}, 60*time.Second, 10*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIAudioTranscription tests the /v1/audio/transcriptions endpoint.
|
||||
func TestOpenAIAudioTranscription(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, endpoint, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audioBytes := decodeTestAudio(t)
|
||||
|
||||
// Build multipart form request.
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
writer.WriteField("model", model)
|
||||
part, err := writer.CreateFormFile("file", "prompt.wav")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
part.Write(audioBytes)
|
||||
writer.Close()
|
||||
|
||||
url := fmt.Sprintf("http://%s/v1/audio/transcriptions", endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
text := strings.ToLower(string(respBody))
|
||||
if !strings.Contains(text, "sky") && !strings.Contains(text, "blue") {
|
||||
t.Errorf("transcription response missing expected words, got: %s", string(respBody))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIChatWithAudio tests /v1/chat/completions with input_audio content.
|
||||
func TestOpenAIChatWithAudio(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, endpoint, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audioB64 := audioEncodingPrompt
|
||||
|
||||
reqBody := fmt.Sprintf(`{
|
||||
"model": %q,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_audio", "input_audio": {"data": %q, "format": "wav"}}
|
||||
]
|
||||
}],
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"max_tokens": 200,
|
||||
"think": false
|
||||
}`, model, strings.TrimSpace(audioB64))
|
||||
|
||||
url := fmt.Sprintf("http://%s/v1/chat/completions", endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(reqBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Choices) == 0 {
|
||||
t.Fatal("no choices in response")
|
||||
}
|
||||
|
||||
text := strings.ToLower(result.Choices[0].Message.Content + " " + result.Choices[0].Message.Reasoning)
|
||||
found := false
|
||||
for _, word := range []string{"sky", "blue", "scatter", "light", "atmosphere"} {
|
||||
if strings.Contains(text, word) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("response missing expected words about sky/blue/light, got: %s", result.Choices[0].Message.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
9
integration/audio_test_data_test.go
Normal file
9
integration/audio_test_data_test.go
Normal file
File diff suppressed because one or more lines are too long
@@ -51,6 +51,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
thinkOff := api.ThinkValue{Value: false}
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
@@ -59,6 +60,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
Content: "Write me a story in english with a lot of emojis",
|
||||
},
|
||||
},
|
||||
Think: &thinkOff,
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
|
||||
@@ -15,6 +15,7 @@ func TestVisionModels(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
|
||||
defaultVisionModels := []string{
|
||||
"gemma4",
|
||||
"qwen2.5vl",
|
||||
"llama3.2-vision",
|
||||
"gemma3",
|
||||
@@ -23,6 +24,8 @@ func TestVisionModels(t *testing.T) {
|
||||
"ministral-3",
|
||||
}
|
||||
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
@@ -30,10 +33,7 @@ func TestVisionModels(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if testModel != "" {
|
||||
requireCapability(ctx, t, client, model, "vision")
|
||||
}
|
||||
|
||||
requireCapability(ctx, t, client, model, "vision")
|
||||
pullOrSkip(ctx, t, client, model)
|
||||
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
|
||||
155
integration/thinking_test.go
Normal file
155
integration/thinking_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestThinkingEnabled verifies that when thinking is requested, the model
|
||||
// produces both thinking and content output without leaking raw channel tags.
|
||||
func TestThinkingEnabled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
models := testModels([]string{smol})
|
||||
for _, modelName := range models {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
requireCapability(ctx, t, client, modelName, "thinking")
|
||||
pullOrSkip(ctx, t, client, modelName)
|
||||
|
||||
think := api.ThinkValue{Value: true}
|
||||
stream := false
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Stream: &stream,
|
||||
Think: &think,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What is 12 * 15? Think step by step."},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 42,
|
||||
"num_predict": 512,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
|
||||
response = cr
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "model requires more system memory") {
|
||||
t.Skip("model too large for test system")
|
||||
}
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
thinking := response.Message.Thinking
|
||||
|
||||
// Thinking should be non-empty when thinking is enabled
|
||||
if thinking == "" {
|
||||
t.Error("expected non-empty thinking output when thinking is enabled")
|
||||
}
|
||||
|
||||
// The answer (180) should appear in thinking, content, or both.
|
||||
// Some models put everything in thinking and leave content empty
|
||||
// if they hit the token limit while still thinking.
|
||||
combined := thinking + " " + content
|
||||
if !strings.Contains(combined, "180") {
|
||||
t.Errorf("expected '180' in thinking or content, got thinking=%q content=%q", thinking, content)
|
||||
}
|
||||
|
||||
// Neither thinking nor content should contain raw channel tags
|
||||
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains raw channel tags: %s", content)
|
||||
}
|
||||
if strings.Contains(thinking, "<|channel>") || strings.Contains(thinking, "<channel|>") {
|
||||
t.Errorf("thinking contains raw channel tags: %s", thinking)
|
||||
}
|
||||
|
||||
t.Logf("thinking (%d chars): %.100s...", len(thinking), thinking)
|
||||
t.Logf("content (%d chars): %s", len(content), content)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestThinkingSuppressed verifies that when thinking is NOT requested,
|
||||
// the model does not leak thinking/channel content into the response.
|
||||
func TestThinkingSuppressed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
models := testModels([]string{smol})
|
||||
for _, modelName := range models {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
requireCapability(ctx, t, client, modelName, "thinking")
|
||||
pullOrSkip(ctx, t, client, modelName)
|
||||
|
||||
stream := false
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Stream: &stream,
|
||||
// Think is nil — thinking not requested
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of Japan? Answer in one word."},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 42,
|
||||
"num_predict": 64,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
|
||||
response = cr
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "model requires more system memory") {
|
||||
t.Skip("model too large for test system")
|
||||
}
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
thinking := response.Message.Thinking
|
||||
|
||||
// The answer should appear in content or thinking
|
||||
combined := content + " " + thinking
|
||||
if !strings.Contains(combined, "Tokyo") {
|
||||
t.Errorf("expected 'Tokyo' in content or thinking, got content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
|
||||
// Content must NOT contain channel/thinking tags
|
||||
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains leaked channel tags when thinking not requested: %s", content)
|
||||
}
|
||||
if strings.Contains(content, "thought") && strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains leaked thinking block: %s", content)
|
||||
}
|
||||
|
||||
// Thinking field should ideally be empty when not requested.
|
||||
// Some small models may still produce thinking output; log but don't fail.
|
||||
if thinking != "" {
|
||||
t.Logf("WARNING: model produced thinking output when not requested (%d chars): %.100s...", len(thinking), thinking)
|
||||
}
|
||||
|
||||
t.Logf("content: %s", content)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
minVRAM := map[string]uint64{
|
||||
"gemma4": 8,
|
||||
"qwen3-vl": 16,
|
||||
"gpt-oss:20b": 16,
|
||||
"gpt-oss:120b": 70,
|
||||
|
||||
@@ -45,6 +45,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"gemma4",
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
@@ -137,6 +138,7 @@ var (
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"glm4",
|
||||
"goliath",
|
||||
"gpt-oss:20b",
|
||||
@@ -272,6 +274,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"gemma4",
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
|
||||
@@ -5,23 +5,26 @@ package integration
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Default set of vision models to test. When OLLAMA_TEST_MODEL is set,
|
||||
// only that model is tested (with a capability check for vision).
|
||||
var defaultVisionModels = []string{
|
||||
"gemma4",
|
||||
"gemma3",
|
||||
"llama3.2-vision",
|
||||
"qwen2.5vl",
|
||||
"qwen3-vl:8b",
|
||||
}
|
||||
|
||||
// decodeTestImages returns the two test images (Abbey Road llamas, docs llamas).
|
||||
func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
|
||||
// decodeTestImages returns the test images.
|
||||
func decodeTestImages(t *testing.T) (abbeyRoad, docs, ollamaHome api.ImageData) {
|
||||
t.Helper()
|
||||
var err error
|
||||
abbeyRoad, err = base64.StdEncoding.DecodeString(imageEncoding)
|
||||
@@ -32,9 +35,35 @@ func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
|
||||
if err != nil {
|
||||
t.Fatalf("decode docs image: %v", err)
|
||||
}
|
||||
ollamaHome, err = base64.StdEncoding.DecodeString(imageEncodingOllamaHome)
|
||||
if err != nil {
|
||||
t.Fatalf("decode ollama home image: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// skipIfNoVisionOverride skips the entire test (at parent level) when
|
||||
// OLLAMA_TEST_MODEL is set to a non-vision model. This prevents the parent
|
||||
// test from reporting PASS when all subtests are skipped.
|
||||
func skipIfNoVisionOverride(t *testing.T) {
|
||||
t.Helper()
|
||||
if testModel == "" {
|
||||
return
|
||||
}
|
||||
// Check actual model capabilities via the API rather than a hardcoded list.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: testModel})
|
||||
if err != nil {
|
||||
return // let the test proceed and fail naturally
|
||||
}
|
||||
if len(resp.Capabilities) > 0 && !slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
t.Skipf("model override %q does not have vision capability (has %v)", testModel, resp.Capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
// setupVisionModel pulls the model, preloads it, and skips if not GPU-loaded.
|
||||
func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
|
||||
t.Helper()
|
||||
@@ -54,6 +83,7 @@ func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, mod
|
||||
// handles cached image tokens across turns.
|
||||
func TestVisionMultiTurn(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Models that fail on multi-turn detail questions (e.g. misidentifying objects).
|
||||
skipModels := map[string]string{
|
||||
@@ -72,7 +102,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, _ := decodeTestImages(t)
|
||||
abbeyRoad, _, _ := decodeTestImages(t)
|
||||
|
||||
// Turn 1: describe the image
|
||||
req := api.ChatRequest{
|
||||
@@ -100,7 +130,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
api.Message{Role: "user", Content: "How many animals are in the image?"},
|
||||
)
|
||||
resp2 := DoChat(ctx, t, client, req, []string{
|
||||
"four", "4",
|
||||
"four", "4", "three", "3",
|
||||
}, 60*time.Second, 30*time.Second)
|
||||
if resp2 == nil {
|
||||
t.Fatal("no response from turn 2")
|
||||
@@ -121,6 +151,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
// TestVisionObjectCounting asks the model to count objects in an image.
|
||||
func TestVisionObjectCounting(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
skipModels := map[string]string{
|
||||
"llama3.2-vision": "consistently miscounts (says 3 instead of 4)",
|
||||
@@ -137,7 +168,7 @@ func TestVisionObjectCounting(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -160,6 +191,7 @@ func TestVisionObjectCounting(t *testing.T) {
|
||||
// cultural references and scene context from an image.
|
||||
func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Models known to be too small or not capable enough for cultural reference detection.
|
||||
skipModels := map[string]string{
|
||||
@@ -178,7 +210,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, _ := decodeTestImages(t)
|
||||
abbeyRoad, _, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -193,7 +225,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
Options: map[string]any{"temperature": 0.0, "seed": 42},
|
||||
}
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"abbey road", "beatles", "abbey",
|
||||
"abbey road", "beatles", "abbey", "llama",
|
||||
}, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
@@ -203,6 +235,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
// objects based on their spatial position in the image.
|
||||
func TestVisionSpatialReasoning(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -212,7 +245,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
// The docs image has: leftmost llama on laptop with glasses,
|
||||
// rightmost llama sleeping.
|
||||
@@ -239,6 +272,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
|
||||
// small details like accessories in an image.
|
||||
func TestVisionDetailRecognition(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -248,7 +282,7 @@ func TestVisionDetailRecognition(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -274,6 +308,7 @@ func TestVisionDetailRecognition(t *testing.T) {
|
||||
// encoding and cross-image reasoning.
|
||||
func TestVisionMultiImage(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Multi-image support varies across models.
|
||||
skipModels := map[string]string{
|
||||
@@ -291,7 +326,7 @@ func TestVisionMultiImage(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, docs := decodeTestImages(t)
|
||||
abbeyRoad, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -314,10 +349,12 @@ func TestVisionMultiImage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestVisionOCR tests text extraction from an image. The docs image
|
||||
// contains the text "Ollama's documentation" in a header.
|
||||
func TestVisionOCR(t *testing.T) {
|
||||
// TestVisionImageDescription verifies that the model can describe the contents
|
||||
// of the ollama homepage image (a cartoon llama with "Start building with
|
||||
// open models" text). Basic sanity check that the vision pipeline works.
|
||||
func TestVisionImageDescription(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -327,22 +364,22 @@ func TestVisionOCR(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, _, ollamaHome := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What text appears in this image? Read all visible text.",
|
||||
Images: []api.ImageData{docs},
|
||||
Content: "Describe what you see in this image briefly.",
|
||||
Images: []api.ImageData{ollamaHome},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{"temperature": 0.0, "seed": 42},
|
||||
}
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"ollama", "documentation",
|
||||
"llama", "animal", "build", "model", "open", "cartoon", "character",
|
||||
}, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -383,3 +383,162 @@ yEUu0pztbKtys2RR9bUiUBGoCFQE5oTAL3/5y+ab3/xmc9JJJzWf+cxnmq9+9atzKXmuDGQuNaqFVAQq
|
||||
VBGoCFQElgKBykCWoptqJSsCFYGKwOIhUBnI4vVJrVFFoCJQEVgKBCoDWYpuqpWsCFQEKgKLh0BlIIvXJ7VGFYGKQEVgKRDYOWr5q6Woaa1kRaAiUBGoCCwU
|
||||
Av8fgwPy24mbuF8AAAAASUVORK5CYII=
|
||||
`
|
||||
// imageEncodingOllamaHome is a 415x293 JPEG of the ollama.com homepage.
|
||||
// Shows a cartoon llama character with text "Start building with open models".
|
||||
const imageEncodingOllamaHome = `/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA0JCgsKCA0LCgsODg0PEyAVExISEyccHhcgLikxMC4pLSwzOko+MzZGNywtQFdBRkxO
|
||||
UlNSMj5aYVpQYEpRUk//2wBDAQ4ODhMREyYVFSZPNS01T09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09P
|
||||
T09PT09PT0//wAARCAElAZ8DASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUF
|
||||
BAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVW
|
||||
V1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi
|
||||
4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAEC
|
||||
AxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVm
|
||||
Z2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq
|
||||
8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD06iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiq
|
||||
2o39rpllLeXsqxQRDLMf5e5oAs0V5XffEXXL6WeXQdOC2dsNzu8ZchfVuwrufCOvDxFocd80YjlDFJUHQMPT270AbdFFFABRRRQA
|
||||
UUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
|
||||
AFFFFABRRRQAUUUUAFeUeI7u68b+L00PT5CLG2chnHTj7zn+Q/8Ar12XjTxLZ6LpNzD9pUX8sRWGIcsCRjJ9BXmPg/xPJ4djuDa6
|
||||
V9rnnI3SliMKO3A9eaAO/wDFyWHhfwDPYWSLGJlECDu5PUn1OM1V+HuoaVovhWFb7UbWGa4kaUo0oyAeBkduBXMXc2t/EfV44obc
|
||||
W1vbLyGJKR56knHJPpXT2fwr0mOMfa7y6mkxyVwg/Ac0AdpZ6lY3wzZ3kE//AFzkDfyq1XmupfDF7YfafD2oypcJyqSnBP0YdKzU
|
||||
+IWuabp1xpV/bltUjPlpM45X13DufQ96APQtf8U6T4fTF9PmYjKwxjc5/Dt+Nc7pnxP0291GO1ns5rZJWCrKzhgCemR2qn4V8Atd
|
||||
v/a/ikvNPMd4t3Y9+7n19qzfHsVrfeLNM0PSoIkaHCMIkAwWI449AM/jQB61RSKNqgegxS0AFFFFABRRRQAUUUUAFFFFABRRRQAU
|
||||
UUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFBIAyTgCsifxBbCYw2UU17KvUQLkD8e
|
||||
lAGvRSIxZFYqVJGSD2paACub8b+JB4c0fzItrXk5KQKegPdj7D/Cukry7xWn9s/FLT9LnOYItgK9iMbz+fSgCfwh4I/tEDW/Exe4
|
||||
luD5iQyE8g/xP/hXosFtb20Qit4I4kHRUUAD8qkAAAAGBXC6d4pvtL8XXGieIZ45Yp5M2064AXJ4U47duehoA7pUVc7VAzycDrS1
|
||||
FPc29uM3E8cQ9XcL/OiG5t7gZt54pR/sOG/lQBLXn/xH8OXt3dWesaNbNLdQnEojGWOOVOO+OlegUUAeXr8ULuCxuLfUNMMeoouI
|
||||
yMhd3+0p5HrV34ceHbgzSeJNWDNc3GTCH64PVz9e3tW94z8LW3iDTJGWNVv4lJhlA5J/un1BrK+F2tzXumTaXeMTPYkBN3XYeMfg
|
||||
ePyoA7qiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKK
|
||||
KACiimyIJInjJIDKRkdRmgDAmeTXrmWJZTDpduSJXBwZiOoz6ClsLq4uGWLQ7OGGwjbBmkBG/wBdoHX61dk0dP7DOl20rRJtC78Z
|
||||
JGcnP1qle+IbDSsWNvG0hiXZ8mAF9s+tAHK6/wCKdc1nxDJofhTKCElXmXGWI6nJ+6oPFVk8ReKvCGoxReJQ13Zyn7+Q3Hcqw7j0
|
||||
NSfCJkbUdXZv9aQpyeuMnP613+u6Pa65pcthdr8rjKt3RuzCgC1Z3UF7aRXVrIJIZVDIw7g15p468zQvH2na8ULQPtLY9V4Yf98k
|
||||
VF4X1u58FaxLoGvZW0LZSTshP8Q/2T+n51FbxXHxF8XySTM6aTaHgA9FzwB/tNjJNAHpd9Aut6G8VpevCl1GCk8J5APORXC63oWk
|
||||
eCdIW/hha+1SSQJBLcfMFfruC9OP8K9Gt4Ira3jggRY4o1Coq9FA6CuT+Jem3F74fS6tFLS2Mon2gZyuOfy60AYkfhCxAt7rxnqs
|
||||
0t/ek7IzLtXdjO3d/wDqFY+k6XpOrXwttEfUtK1VC/RvMiUr0y4wRmug8RahB4s8F295ZR/aHtpo5Lq3UZkUDhgO/wCPpVXQtK/t
|
||||
PxBJP4b/ALQ0jRgi+a24r5zjoFB/Xr+tAGx4V8SajHqz+HPEqhb9B+5m7TD+vHQ96u6z480PSpXt/Oe6uUO0xQLuwfQnpWN43eKb
|
||||
xx4dgsyDfRygvt6qm4EZ/JjS6h4avtE8XQa1oNot1b3MmLi3IB8sk8kE9B3z2+lAHdWdwt3Zw3Ko6LKgcK4wwyM4I9a828G4i+KO
|
||||
sxQcRHzuB0++K7/XdWg0XSJ7+5YARr8o7u3YD8a4j4U2E0smoa7cg7rhiiEj73OWP54FAHo9FFFABRRRQAUUUUAFFFFABRRRQAUU
|
||||
UUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFADJiwhcp94KcfXFcl4NgguJLuS4RZ
|
||||
JRj74zgHOf1rsK4u4EnhvxCZ1Um0nJOB3U9R9QeaAOdsivhL4pSwP+7s70kL2AV+V/JuK9WriviBoS+INCj1LTsSXNqpdCv/AC0T
|
||||
uPqOtWPh94mXXdJFvcv/AKdagLICeXXs3+PvQA34mWFnP4VuLueBWuLfHkydCuWA/L2pPhdaR2/hCKZQN9xI7ufXBwP0Fb/iHTf7
|
||||
Y0K80/IDTRkKT2bqP1ArB+G9lq2m6LNZarbGBYpj5O48kHr+Gen1oA6+ggEYNFFAHF6t4Bie+OoeH76TS7snJEedhP0HT+VVjonj
|
||||
9h5B8QWwj6eYBhv/AEHNd7RQBzHhfwdb6HO99c3D3uoyZ3Tyds9cf4109FFAHlfxYiu11ewmupXfTGGFjU42sD834kdDXpWlwWtv
|
||||
pltFYIEtljXygP7uMiuT+LMaN4TV2A3JcptP1Brd8Hu0nhLS2fkm2T+VAGzRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
|
||||
AFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAVWv7GDULVre4XKnoR1U+oqzRQBxSSaj4WuSki+dZ
|
||||
O3H90/T0PtXK69GNF1mPxP4bfEJfM8GMGJj1BH91v5/hXrsscc0bRyorowwVYZBrl9W8HxTK5sGChgQ0Mn3SPQHtQBpaV4k07UtB
|
||||
OrrMscEa5mDHmIjqD/nmrOjavY63Yi806XzItxU5UggjsRXiOv6VqXh2WSzJmitrz+DPD4OcH1we9e0eGNKi0bQLSyjXDKgaQ/3n
|
||||
PJNAGrRRRQAUUUUAFFFFAHmvxW1H7XNY6BaZknaQSOq9ieFH6k13+lWY0/SrWzHSCJY/yFcJ8U9GEKQeIrMmO5idUlZT1/ut9QeP
|
||||
yrtfD2o/2toNlfnAaeIFsf3uh/UGgDRooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA
|
||||
KKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigDzD4rZGuaKW/1fP/oQzXpy4IBHTtXC/FnTHutBgvolJazky+OytwT+YFdH
|
||||
4T1P+1/DdleEEM0e18/3l4P8qANiiiigAooooArajfW+m2E17dvshhUsx/w968zOveL/ABldSJoKNZ2SHG5Ttx/vP6+wrqfibb3F
|
||||
x4On+zgt5ciSSAd0B5/ofwqD4d65pEvh610+KaKC6hXbJE5Clmzyw9c0Ac3efD3xPPaO0+sJcSEZ8lpnIb8TxWr8M9ddQ/hq/i8m
|
||||
4tN3l5GCQD8yn3BP5V6CzKqFmYBQMkk8AV5Xpk0er/F57zTObeMszyL0YBNpP4mgD1WiiigAooooAKKKKACiiigAooooAKKKKACi
|
||||
iigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAGyIkiMkiK6MMFWGQR9KSKK
|
||||
OGJYoY1jjUYVVGAPoKfRQAUUUUAFFFFACMoZSrAEEYIPeuJ1r4aaTfyvPYSyWMrHO1BuTP07fga7eigDy9vhtrxUwHX1NueCpaTG
|
||||
P93pXaeF/DFj4aszFbZknkwZZmHL+3sPatyigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
|
||||
gAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAoo
|
||||
ooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAK
|
||||
KKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKA
|
||||
CiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
|
||||
gAooooAKZPNHbwSTTOEjjUszHoAOSafWR4tt5rrwrqcNsC0rW7bQOp4zj8qAPPdQ8a+IvEeqNY+GIpIoudvlqPMYf3mY8KPy+tJN
|
||||
p/xIsIzdfabuXb8xVbhZSP8AgPf8Ki+F2vaZpVxeW2oSJbtc7DHM/C8Z+Unt1zXrcM0U8YkgkSRD0ZGBH5igDiPAvjmTWbn+zNWV
|
||||
EvcExyKNolx1BHZu/wCddJ4n16Pw7pQv5YHmXzFj2owB5zzz9Kz/APhB9LXxH/bkU11Hced52xWUJu78Yzg/XvXH/EvUtdke5sJ7
|
||||
DZpUcyGK48phuO3+9nB5J/KgD0Lw3rcfiDSU1CKF4VZ2XYxBPB9q1a8j8Cax4lt4LKysdM83THuQHn8hmwCw3fMDjiu+8W+J7bwz
|
||||
p6zSoZriUkQwg43EdST2AoA3qK8qi8W+OtRiN5Y6Xm2PK+XallI9iTk/hWz4P8ftq2oLper26W92xKo6AhWYfwkHoaAO8orkvH/i
|
||||
W+8N2tnLYJAzTOyt5qk8AA8YIrnLjx9r1/b28OhWHnzrCrXMkUDSAOR90DsB70Aej6neLp+mXV66F1t4mlKg4JwM4rF8JeLYPFBu
|
||||
hDaSQfZtmd7A53Z9PpWB4w1jxGmgW0cOn+ZDdaduvZDA37pivzd/lxz1ri/Bmq+INMN5/YGn/a/M2eb+5aTbjOOh46mgD3aq2oyP
|
||||
Fpt1JG210hdlI7EKafZPLLYwSXCbJnjVpFxjaxHIx9ai1b/kE3n/AFwf/wBBNAHC/DLXtV1i/vo9SvZLhY4lZQwHBJ9hXoteH+Bv
|
||||
ENv4cGpXUqGWZ4kSGEHBds/yrYu/Gfje2T7dPpYgtTyN9owQD3JOaAPWKK5nwZ4vg8TW8iNGIL2EZkiByCP7y+38qxvGXjPVND8S
|
||||
RafZx2zQvGjEyIS2SSD0I9KAO/oorgPHPjPVPD+vQ2VjHbNE8CyEyoSclmHYj0oA3/G76zH4ekbQBIbneu7yhlwnOdvv0/Wl8Evr
|
||||
L+HYm14SC63tt8wYcp23D16/pSeNdautC8Ptf2SxNKJEXEikjB+hFL4K1m617w+l9erEsrSOpEYIGAfcmgDkPAviLWNS8YTWd9fy
|
||||
TW6pKQjAYBBGOgr06vHfht/yPtx/1zm/9CFeg+MPFVt4ZskZk866mz5UOcZx1JPYfzoA6GuA+KGuapoz6aNMvHt/NEm/aB82NuOo
|
||||
9zWNaeM/G18DeWemCa2B6R2jMh9s5z+tZXjnxJD4jstLlEZguYDKk8JP3T8uCPY4P5UAenm41SfwNHc2DeZqUlijoxAyzlQSfTPX
|
||||
8azfh5L4jktLv/hIRcbA6+QbhcP33e+OlXY72XTfh5Be24UywaajqHGRkIOtU/h/4mv/ABJb3sl+kCmB0C+UpHUHrkn0oA6+ivNr
|
||||
r4hXlh4vubG9S3Gn280iMyxkyEAHAHOMk4FVb3xn4ynja+stHe3sfvK32Zn+X1LH+YGKAPU6K4nwN45OvznT9QijivApZGjyFkA6
|
||||
8Hoe9b3ifxDa+HNMN3cgu7HbFEpwXb+g9TQBsUV5Nb+N/GWqyvNpenK8KHlYrZnA9i3rVu++Jd5HpC7LSK21WKcJPBMjEbcH5gMg
|
||||
jkDg9KAPTqKyfCupT6v4cs9QuggmnUlggwvDEcflWpLv8pvKID4O0sOM9qAHUVwngjxlqOta5c6bqsVvG8cZZfKUg7lYAg5J9f0r
|
||||
b8ba9L4e0Bry2EbXDyLHGJBkZPJ4HsDQB0FFcl4A8U3PiS1u/tywrcW7rxECAVI44JPcGmfEDxXdeG47JLBYXmnLFhKpICjHoR3P
|
||||
6UAXvHL61H4eZtAEhuPMXf5Qy4TnO33zj8Kk8FvrD+HYW14OLrc2PMGH2dtw9f8A61Z+v+IdU0jwTaaqUt/t0vl+YrIdg3AkjGf6
|
||||
1DZeJNbv/AX9r2dpFPqJlKrFHEzAgNg8Zz096ALEfji3fxWdA+wyiQTmHzd4xkd8V1MzFYHZTghSR+VeBR3+rr4yN8lnnVftBf7P
|
||||
5Z+/zkbc5r17wxqGsajodzNrtn9kuFdlVPLKZXaOcE+pNAHL/DPxDq+r61dQ6lfSXEaW5dVYDg7gM8D3r0qvC/Auuw+H7y+vJUMs
|
||||
jW/lwxL1kcuuBW5f+M/G1ov2y50wW1sTwHtWCj0BJOaAPWKK5zwb4rh8TWTsYxDdwECWIHI56MPb+VdHQAUUUUAFFFZHiqbUrbw7
|
||||
d3Gjvtu4VDr8gbIB+YYPtmgDC8QfDjS9Vne5s5XsZ3JLbFDIx9dvb8DXJ3Hw88TaUxn0q7SUryPIlMT/AJHH863vBPj+O8SS18Q3
|
||||
kcdzvzFM4CIy+mRwCD6+tdnPrWlW8Jmm1G0SMDO4zL/jQB554O8canDrEejeIC0m+TyVkkXEkb5wA3qM8c81vfFf/kUB/wBfKfyN
|
||||
cNdTJ4o+JUc2lxnypbiMhsYJVAMv7cKTXc/FZSfB+QOlyhP60ASfC3/kTYf+u0n86t+L9P8ADMqRX3iVgqxgpGTKy574CqeTVD4V
|
||||
3EL+ElhWRDJHM4dc8jJyOK5D4nSNJ41hhvXdbVI49uOyE/MR79fyoA6pviZ4bto1ighvHRAFUJEAAB0AyRXCXGq22rfEW21KwieG
|
||||
Oa8gIVwAc5UE8epFer2ln4XstNWa3h0xLVVyJSEII9Sx615Tf6ha6n8R4LqxQLbNeQrHhdoIUqM498ZoA634xf8AIP0z/rq/8hXS
|
||||
eAbWG18Haf5KBTLH5rkdWYnqf5fhXN/GL/kH6Z/12f8AkK6rwV/yJ+lf9e60AT+Kv+RV1b/r0l/9BNcL8Gvvav8ASH/2eu78UKW8
|
||||
L6qqjJNpLx/wE15/8HbiGO51OB5FWSRY2RScFgN2cfmKAPVKqat/yCbz/rg//oJq3VTVv+QTef8AXB//AEE0AeS/CrTYL3xHJc3C
|
||||
BxaRb0BGfnJwD+HNexyIkkbRyKGRgQysMgj0NeH/AA912HQtfL3h22twnlSPjhDnKk+3H617Be6/pNlYteT6hb+SFyCsgYt7ADqa
|
||||
APK9Aj/sP4qfY7c4hFy8AH+wwOB/L8qk+J3/ACPFt/1wi/8AQjTPBizeIPiK+qFCsaSPcv8A7IOQo/Mj8jUvxYikg8U2t1j5Ht12
|
||||
ntlWOR+o/OgD2CvHvix/yN1r/wBeqf8AobV6bp/iDStQ0+O9hvrcRsoZg0gBT1DA9CK8f8eazb634qM9m2+CFFhR+z4JJI9sk0Ae
|
||||
hfFL/kTW/wCu8f8AWl+Fv/InRf8AXaT+dJ8Uv+RNb/rvH/Wl+Fv/ACJ0X/XaT+dAHHfDf/kfbj/rnN/6EKZ41B1f4mJp8jHy/Mht
|
||||
h7KcE/qxp/w2/wCR9uP+uc3/AKEKPiJBPo/jqHVkQlZTHPGexZMAj9B+dAHr0EMVvAkECLHFGoVFUYCgdBXkvxb02C11m1vYUCNd
|
||||
xt5mB1ZSOfrgj8q9K03xDpOpWCXdvfQBCuWDyBWT2YHpXk/xJ1+31vWYksW8y1tEKCUdHYnLEe3QUAegXv8AySs/9gpf/RYrD+Dn
|
||||
/Hnqn/XSP+TVuXv/ACSs/wDYKX/0WKw/g5/x56p/10j/AJNQBzwtIb74tSW1wgeJr9yynocZOD+Ve0dBXj1j/wAlkb/r+l/k1exd
|
||||
qAPGtPhSy+LoitwERb1wqjgAEHj9am+LNxJceJbSzB+SK3UqP9pmOT+gpsf/ACWM/wDX8f8A0Grfxd02WPULLVUU+W8fksw/hYEk
|
||||
fmCfyoA9M0uwg0vToLK2QLFCgUYHX1P1PWvPfjBp0CxWOpogWZnMLsP4hjIz9MH866vw74t0vV9Lime9ghuAg86KSQKVbv16j3rg
|
||||
vif4jtdWmt7DTpBNBasWklTlS54AB74Gfz9qAO7+H3/IkaZ/uN/6G1dHXOfD7/kSNM/3G/8AQ2ro6APJbhf+Ef8Ai/G4+WG5nDex
|
||||
Eowf/HifyrQ+JrvqfiDRtBhPLsGYD1Zto/IA0nxds2ifTdWh4ZGMLN6H7y/+zVF4YnHif4lz6wATBbQ7kyOh2hQPzLGgA8Mxr4d+
|
||||
KN7pSjZb3SsIl7YI3r+mRVbxoDrvxKs9KBykflxMPQH52P5H9Kv/ABJjOl+JdF1+MYCuFkI/2Gz+oJH4VW8BL/bfj7VNbOTHGXdC
|
||||
R0LnC/8AjoNAG/8AFUAeD8AYH2iP+tTfC/8A5Eu3/wCusn/oVRfFb/kUP+3mP+tSfC//AJEu3/66yf8AoVAHFW//ACWI/wDX+/8A
|
||||
I17Bcf8AHtJ/uH+VeO+bHa/F1pLh1jQX5yzHAGen8xXsMxDWshBBGw9PpQB498KLKG68UvLMgY21u0keR0bIGfyJr2G5t4ru2ltr
|
||||
hA8UqlHU9CCMGvJvg/8A8jDef9eh/wDQ1r16gDx34WlrfxlcQKx2mCRT74Yf4V7FXjvw2/5Hy4/65Tf+hCvYqACiiigAooooA4zX
|
||||
Phxo+qXD3Ns8ljM5y3lAFCfXaen4EVjx/CWMPmXWXZPRbcA/nur0uigDF8O+FtL8OxsLCJjM4w80hy7D09h7Cr+qadbatp01jepv
|
||||
gmXDAHBHcEe4PNW6KAPPrH4YQ2Or217Dq0hSCZZRG0IydpBxkH29K6XxN4W07xJAi3geOaP/AFc0f3l9vce1blFAHndp8KLGO4D3
|
||||
epTTxA58tYwmfqcmtS68AWE2vQanDcSQLA0RSBEG0BMYHr2rsKKAMDxX4Xh8TQW8U9zJAIGLAooOcjHetPSNPTStKtrCORpFt4wg
|
||||
ZhgnFXKKAEdFdGR1DKwwQehFee3/AMKrKa5aSx1KW2jY5EbRiTb7A5HFeh0UAQ2cH2Wygtg2/wAmNU3YxnAxmluoRc2ssBYqJUZC
|
||||
R2yMVLRQBxulfDvTLG3vLe4nlu4rpFUh1ClCDkMCOhrKb4TWpnJXV5hD2Uwgt+ecfpXo9FAGXoGg6f4fsvs2nxEBjl5GOXc+pNJ4
|
||||
h8P2HiKx+y36N8p3RyIcMh9R/hWrRQB5xH8JrUTgy6vM0WeVWEK355P8q0tU+HGmX0tsbe4ltI7eERKiKDnBJ3EnqSTXa0UAZPiT
|
||||
Q4vEGknT5p3hUur7kAJ4+tL4b0SPw/pK2EMzzKrs+5wAefpWrRQBynh/wRbaHrb6nFezSu6suxlAHzHPatrW9EsNdsTaajFvTOVZ
|
||||
ThkPqD2rRooA83/4VNa+fkavN5Ofu+SN355x+la2pfDrSruws7O2mltY7XecqAzSM2Mlie/y12VFAGZNo0cvhr+xDM4j+zC38zA3
|
||||
YC4zjpVLwn4Wg8MRXMcFzJOLhlJ3qBjGfT610FFAHKQ+CLaHxYdfF7MZTM0vlFRtyQeM9e9dX2oooA5RfBFsviv+3/ts3m+cZfK2
|
||||
jbnGMZ61Y8Ya1o2m2iWmu2001vdqQAse5TjHfIweQa6OqOr6TY61YtZ6hCJImOR2Kn1B7GgDgrP4feG9ZiW90vVbg20nzbAVYp7H
|
||||
IyD9ayvH0Gh6NpVpoejlWmWbzp2Dbm4Ugbj689O1a1z8J4/NJstYkjQ/wyQ7jj6gjP5VpaF8NNL025S5vp3v5EOVVlCJn1I5z+Jx
|
||||
QBueCbaSz8IaZDMpVxDuIPUbiW/rW5RRQBl+I9Eg8QaS+n3EjRhmVw6gEqQff8R+NU/CfhS28MR3K288k7XBUszqAQBnA4+proKK
|
||||
AMfxP4ft/EemCyuJGiCyCRXQAkEZHf2JqLwp4YtvDNrPDbzPMZnDs7gA8DAHH4/nW7RQBkeJtCi8RaV9gmneFfMV9yAE8Z9frT/D
|
||||
mix6BpCafDM8yIzNucAHk57VqUUAcj4o8BWHiC9N8tw9pdMAHZVDK+OASOOfxq/4V8NDw7pE1h9rNz5shkL7NuMqBjGT6Vv0UAct
|
||||
4V8FW3hq/lu4LyadpYvLKuoAHIOePpXU0UUAcp4f8EW2ha0+pxXs0rurLsZQB8xz2rq6KKACiiigAooooAKKKKACiiigAooooAKK
|
||||
hvLuCxtJbq6kEcMKl3Y9gK80uviHrmrX7W/hrTMoM4JjMkhHqQOFoA9Rory23+IWvaRfLB4l0zCN1xGY3A9Rng/55r0qyvYL+xiv
|
||||
LSQSQzJvRh3FAFiivLdG+J10Zbp9Yit/KihLRJCpVpJNwAXJJ4wSfwqK78ZeNljN+NK8iz6jNqxUD1JPP48UAer0VyXgrxrD4kD2
|
||||
1xEtvfRruKKcrIvqv+Fa/iTxBaeHdNN3d5dmO2KJTzI3p7D1NAGtRXlUPjDxtq+660rSx9mB48u3Lj6bieT9K1/CvxAe+1FdK122
|
||||
W1u2bYjgFQW/usp5U0Ad9RXN+Otdu/D2hpe2KxNI06xkSqSMEE9iPSuWi+Ier32m29vpenLd6q4ZpvLiYpENxAwM8nGD1xQB6bRX
|
||||
kqfELxLpOoLHrlguw8tE8JifHqp//XXqWn3sGpWEF7atuhnQOh9j6+9AFiiivOdT8f3um+MptOuFtl0+GYK7eWS+3GT36/hQB6NR
|
||||
XleoeNfF8kbahaaS1tp33lZrdnG31LH+YwK0rD4mwSaBLPdWw/tKNhGlvGTiUnoR3A4569vWgD0KivJr7xr41sdt5eaatvbMeBJa
|
||||
sF+mSc/rXeeEfEsPiXSzcJH5U8TbJos52nsQfQ0AbtFcn4y8bW/hsrawRC5vnXdsJwqD1Y/0rlB4t8dm3+3jS/8ARsbs/ZG249eu
|
||||
ce9AHq9Fcj4N8cW/iNzaXEQtr5VLBAcrIB1K+/tWz4j1608PaW17d5Y52xxqeZG9P/r0Aatcl8SdUvtJ8PQ3GnXLQStcqhZQORtY
|
||||
45+grkoPHPjDVpnk0nTleJDysVuZAPYt6/lVfxZ4sOu+GBY39sbTU7a6QyREEBhtbkA8jqOD60AeheBL661LwlZ3d9M008hfc7Yy
|
||||
cOQOnsK6CuW+Gv8AyI9h9ZP/AEY1Y3iX4hTQ6k2leHLVbq4VijSFS4Ldwqjr9aAPQqD0ryuTxn4z0YpPrGlg27HB8yAp+G4dD9a9
|
||||
B8P65aeINLS+syQD8ro33o27g0AcKviLWD8T/wCzDfyfYvtZTycDG3HTpmvTR0rx5P8Aksf/AG/H/wBBr0Lxb4nt/DOnJNJGZriY
|
||||
lYYgcbiOpJ7AUAb9FeXQ+JvH2owfbrLS0+zNyuy3yGHtk5P4VueDfHX9t3h0zU7dba/AO3bkK5HUYPII9PrQB2tFFFABRRRQAUUU
|
||||
UAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAcH8XLuSHw9bWqHAuLj5/cKM4/PH5Vo/DbTobLwjbSoo826zLI2OTyQB+AH86rfFLTJ
|
||||
b7wwLiBSzWcolYAZ+TBBP4cH8Kr/AA08SWU2gxaVcXEcV1a5VVdgvmITkEZ64zjHtQBqfEXT4b7whdvIgMlsPOjbHKkHn8xmsf4R
|
||||
3ckug3tq7ZW3mynsGHT8wfzqf4k+JLK30CbTILiOW7ugE2IwbYucknHTpj8ad8LNMlsvDMl1MpVryTegIx8gGAfx5oA4f4babDqP
|
||||
i5PtCB0to2n2sMgkEAfqc/hXtxAKkEZB6g14N4J1qPQvE0d3cZ+zuGimYDO1T3/AgV7TLrukw2RvH1K1Fvt3bxKCCPbHX6UAeUPC
|
||||
nh/4sRw2g8uIXiBVHQLIBkfTDEVN8VrszeKoLWVm8i3hXhevzHJP1xj8qh0d38VfE5b6KNhELgXByPuxpjbn8lH41ofFaxmtNfs9
|
||||
XRN0UiKpJGQHQ5wfqP5GgDWtvib4ftLaO2t9PvkhiUKihEwAP+BVxnjbX9O17VLfUNMgngmVNsrSBQWIPyngnn/61esaNfeH9YsI
|
||||
7q1Sy+ZQXQqgaM9wRWDrnjDQtN1NLCy0uDUZW4byQmAxOAucHJ+lAEPxHuDd/D/T7liCZpYZDj3jY1f+FlpDB4RjuEQCW5kdpG7n
|
||||
BKgfp+tV/iqMeDrcbBHi5j+QdF+VuK0Phn/yJFl/vSf+hmgDP+LdtFJ4ZhuGUeZFcqFbuAQcj9B+VXvhi7N4JtAxzteQD6bzVf4r
|
||||
/wDIoD/r5T+Rqb4X/wDIlW3/AF0k/wDQjQB1x6V4xqVtFd/FtredQ8T3qBlPQjA4r2c9K8euP+SyD/r+T+QoA9fZVZCjKCpGCCOC
|
||||
K8a+H9pA3xAZGjBW381owecEHA/LNez9q8f+Hv8AyUO5/wB2f/0KgD07xHDHceHNSilUMhtZOD6hSQfzFeffBtj9p1Vc8FIjj8Wr
|
||||
0XXf+QDqP/XrL/6Ca85+Dn/H3qv/AFzi/m1AGZoUS+IPifJJegSR+fJKVIyCEztH04X8q9nrxdpG8HfEt57pW+zGZmzjrFJnkeuM
|
||||
/pXrQ1jTDZ/bBqFr9n27vM81cYoA8m8TRJoHxLjnsgI1MsU4UDgbvvD6Hn86ufGC6d9asbPPyRW5kA92Yj+SiqU0p8ZfEmN7NWNt
|
||||
5qYbHSJMZY+mefzFavxg0+QXVjqaqTGUMDn+6QSw/PJ/KgD0XRNOh0rSLayt0CpFGAcfxHHJPuTXC/F/TYfsNnqioBMJfIdh1ZSC
|
||||
Rn6YP510/hfxRp2saPBIbqGO5RAs0TuFZWA5OD2PXNcR8UvEVpqIt9LsJVnSB/MmkQ5UNjAUHv1NAG54au3sPhI11EcSRQTlD6He
|
||||
2P1rivAviLSvDlzc3WoW1xNcSKEiaNVOwfxdSOvH5V3XhCyOpfCxbEHDTwzop9CXbH61y3w41Gw03UrzS9bjhiaVgEadR8jrkFST
|
||||
0z/SgDb1D4k+H9QsJ7O4sL5op0KMCid/+BVl/B+6ddWv7Pd8kkAlx7qwGf8Ax6u+1S78P6VYvd3a2Soq5UBEJc+gHc1n+DPENt4g
|
||||
e5ktNHFmkICmUFfmJ/h4A9M/lQBxKf8AJY/+34/+g1vfFfRb2+trPULSJ5ktgyyooyVBwQ2PTjn8KwU/5LH/ANvx/wDQa7fxZ4yH
|
||||
hi6t4ZdOknSdCyyLIFGQcEdPp+dAHP8Ah74mWEGn21nqlpNE8Max+ZCAykAYzjgj9a3tItvCOt6yda0xklv1bzWIkdWBxjJQ/wCF
|
||||
XTpvhnxHZLetZ2dxHKu7zQArD6kYINeX20EOm/Ey3t/D87SwJdoisrbvlON657gfMPwoA9vooooAKKKKACiiigAooooAKKKKACii
|
||||
igAooooAKKKKACiiigBGVXUqwBUjBBHBrhtX+GGlXtw01jcS2Jc5MaqHQH2HBH513VFAHB6T8L9Ks7hZr+5lvdpBEZUIhPuOSfzr
|
||||
ugqpFsRQqqMAAYAFOoIyCKAPFfhpY22p67f2V7EJYJbJwyn/AH059j710k3wntGuC0GqzRw54RogzD8cj+VbPhXwPD4b1SS+jv5L
|
||||
gyRGLa0YXGSDnr7V1tAGP4c8Nad4ctmisUYySY8yaQ5d/wDAe1XtS0601Wyks7+FZoJByp/mD2PvVqigDzm6+E9m8pa01WaKMnhZ
|
||||
Ig5H45FbnhvwHpWg3C3eXu7tfuySgAJ7qo6H35rqqKAMfxRoEXiPTEsZp3hVZRJuQAngEY5+tS+HtHj0HR4dOileVIixDsACcknt
|
||||
9a06KAMjxNoUXiLSxYTTvCvmCTcgBPGfX60/w5osegaRHp8MzzIjMwZwAeTntWpRQAVykngi2fxYNfN7MJfOEvlbRtyB0z1rq6KA
|
||||
CuV0LwTbaLrsmqxXs0ruHBRlAHzHPauqooAhvbcXdjcWrMVE0bRlh1GRjP61geE/CFv4YluZILuWc3CqCHUDGM+n1rpaKAMjxD4b
|
||||
03xFbrFqER3p/q5UOHT6H09jXG/8Klt/Nz/bEvl/3fIGfzz/AEr0migDH8O+GdM8OwNHYRkyP/rJpDl3/HsPYVf1GwtdTspLO+hW
|
||||
WCUYZT/Meh96s0UAec3Hwns2nLW2qzRxE/ceIOQPrkfyrUl+HWlf2ENMglmiJlWWSfAZ5CAQAewHzHgV2VFAGb4f0iPQtGh02KVp
|
||||
Uh3YdgATlie31rI8SeBdK1+c3TF7W7b70sQGH/3gev1rqaKAPObb4T2iTBrnVZpY8/dSIISPqSf5V3emabZ6TYpZ2EKwwp0A7n1J
|
||||
7n3q3RQByg8EWw8V/wBv/bZvN87zfK2jbnHTPWtrW9EsNdsTaajFvTOVYHDIfUHtWjRQB5vJ8J4PMPk6zMkZ/haEE4+oI/lXSeGf
|
||||
Bel+HXM8Iee7Ix50uMqO4UDgfzrpKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiig
|
||||
AooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooo
|
||||
oAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKK
|
||||
KKACiiigAooooAKKKKACiiigAooooAKKKKAP/9k=`
|
||||
|
||||
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
@@ -0,0 +1,121 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 20 Mar 2026 18:50:38 -0700
|
||||
Subject: [PATCH] CUDA get_rows q6_k support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/getrows.cu | 80 ++++++++++++++++++++++++++++++++-
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 1 +
|
||||
2 files changed, 80 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
|
||||
index 2fab33243..dc5c4f57a 100644
|
||||
--- a/ggml/src/ggml-cuda/getrows.cu
|
||||
+++ b/ggml/src/ggml-cuda/getrows.cu
|
||||
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
||||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
+// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
+// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
+// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
+template<typename dst_t>
|
||||
+static __global__ void k_get_rows_q6_K(
|
||||
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
+ const int64_t ne00,
|
||||
+ const int64_t ne11, const int64_t ne12,
|
||||
+ const size_t s1, const size_t s2, const size_t s3,
|
||||
+ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const size_t s10, const size_t s11, const size_t s12) {
|
||||
+
|
||||
+ const int64_t i10 = blockIdx.x; // row index into src1
|
||||
+ const int64_t z = blockIdx.z;
|
||||
+ const int64_t i11 = z / ne12;
|
||||
+ const int64_t i12 = z % ne12;
|
||||
+
|
||||
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
+
|
||||
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
+ const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
+
|
||||
+ const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
+
|
||||
+ // blockIdx.y iterates over Q6_K blocks within the row
|
||||
+ for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
+ const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
+
|
||||
+ // Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
+ const int64_t tid = threadIdx.x;
|
||||
+ const int64_t ip = tid / 32; // 0 or 1
|
||||
+ const int64_t il = tid - 32*ip; // 0..31
|
||||
+ const int64_t is = 8*ip + il/16;
|
||||
+
|
||||
+ const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
+
|
||||
+ const float d = x->d;
|
||||
+ const uint8_t * ql = x->ql + 64*ip + il;
|
||||
+ const uint8_t qh = x->qh[32*ip + il];
|
||||
+ const int8_t * sc = x->scales + is;
|
||||
+
|
||||
+ if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+template<typename dst_t>
|
||||
+static void get_rows_cuda_q6_K(
|
||||
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
+ const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
+ cudaStream_t stream) {
|
||||
+ const int64_t nb_blocks = ne00 / QK_K;
|
||||
+ const dim3 block_dims(64, 1, 1);
|
||||
+ const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
+
|
||||
+ const size_t s1 = nb1 / sizeof(dst_t);
|
||||
+ const size_t s2 = nb2 / sizeof(dst_t);
|
||||
+ const size_t s3 = nb3 / sizeof(dst_t);
|
||||
+
|
||||
+ const size_t s10 = nb10 / sizeof(int32_t);
|
||||
+ const size_t s11 = nb11 / sizeof(int32_t);
|
||||
+ const size_t s12 = nb12 / sizeof(int32_t);
|
||||
+
|
||||
+ k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
+ src0_d, src1_d, dst_d,
|
||||
+ ne00, ne11, ne12,
|
||||
+ s1, s2, s3,
|
||||
+ nb01, nb02, nb03,
|
||||
+ s10, s11, s12);
|
||||
+}
|
||||
+
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
+ get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
+ break;
|
||||
default:
|
||||
- // TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 5c9dfd032..b8ed3709b 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -678,3 +678,113 @@ func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionWriter collects streamed chat responses and outputs a transcription response.
|
||||
type TranscriptionWriter struct {
|
||||
BaseWriter
|
||||
responseFormat string
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func (w *TranscriptionWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.text.WriteString(chatResponse.Message.Content)
|
||||
|
||||
if chatResponse.Done {
|
||||
text := strings.TrimSpace(w.text.String())
|
||||
|
||||
if w.responseFormat == "text" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/plain")
|
||||
_, err := w.ResponseWriter.Write([]byte(text))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
resp := openai.TranscriptionResponse{Text: text}
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(resp); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// TranscriptionMiddleware handles /v1/audio/transcriptions requests.
|
||||
// It accepts multipart/form-data with an audio file and converts it to a chat request.
|
||||
func TranscriptionMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Parse multipart form (limit 25MB).
|
||||
if err := c.Request.ParseMultipartForm(25 << 20); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to parse multipart form: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
model := c.Request.FormValue("model")
|
||||
if model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "file is required: "+err.Error()))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
audioData, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, "failed to read audio file"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(audioData) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "audio file is empty"))
|
||||
return
|
||||
}
|
||||
|
||||
req := openai.TranscriptionRequest{
|
||||
Model: model,
|
||||
AudioData: audioData,
|
||||
ResponseFormat: c.Request.FormValue("response_format"),
|
||||
Language: c.Request.FormValue("language"),
|
||||
Prompt: c.Request.FormValue("prompt"),
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromTranscriptionRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
c.Request.ContentLength = int64(b.Len())
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := &TranscriptionWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
responseFormat: req.ResponseFormat,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +137,7 @@ type Tensor interface {
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
BackendGet() []float32
|
||||
|
||||
FromBytes([]byte)
|
||||
FromFloats([]float32)
|
||||
@@ -162,6 +163,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
Conv1DDW(ctx Context, weight Tensor, s, p, d int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
|
||||
|
||||
@@ -187,6 +189,9 @@ type Tensor interface {
|
||||
Contiguous(ctx Context, shape ...int) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
// PadExt pads with independent left/right amounts per dimension.
|
||||
// Arguments: lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 for dims 0-3.
|
||||
PadExt(ctx Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) Tensor
|
||||
|
||||
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
|
||||
@@ -1069,6 +1069,21 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendGet() []float32 {
|
||||
n := int(C.ggml_nelements(t.t))
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.sync != nil {
|
||||
t.sync()
|
||||
}
|
||||
|
||||
data := make([]float32, n)
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
return data
|
||||
}
|
||||
|
||||
func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
@@ -1313,6 +1328,13 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) PadExt(ctx ml.Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad_ext(ctx.(*Context).ctx, t.t, C.int(lp0), C.int(rp0), C.int(lp1), C.int(rp1), C.int(lp2), C.int(rp2), C.int(lp3), C.int(rp3)),
|
||||
}
|
||||
}
|
||||
|
||||
// Permute permutes t according to order. Permute panics if the number of dimensions
|
||||
// in order does not match the number of dimensions in t.
|
||||
func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
@@ -1660,6 +1682,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv1DDW(ctx ml.Context, weight ml.Tensor, s, p, d int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_1d_dw(ctx.(*Context).ctx, weight.(*Tensor).t, t.t, C.int(s), C.int(p), C.int(d)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
|
||||
var tt ml.Tensor = &Tensor{
|
||||
b: t.b,
|
||||
|
||||
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
||||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
template<typename dst_t>
|
||||
static __global__ void k_get_rows_q6_K(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00,
|
||||
const int64_t ne11, const int64_t ne12,
|
||||
const size_t s1, const size_t s2, const size_t s3,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12) {
|
||||
|
||||
const int64_t i10 = blockIdx.x; // row index into src1
|
||||
const int64_t z = blockIdx.z;
|
||||
const int64_t i11 = z / ne12;
|
||||
const int64_t i12 = z % ne12;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
|
||||
// blockIdx.y iterates over Q6_K blocks within the row
|
||||
for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
|
||||
// Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid / 32; // 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0..31
|
||||
const int64_t is = 8*ip + il/16;
|
||||
|
||||
const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
|
||||
const float d = x->d;
|
||||
const uint8_t * ql = x->ql + 64*ip + il;
|
||||
const uint8_t qh = x->qh[32*ip + il];
|
||||
const int8_t * sc = x->scales + is;
|
||||
|
||||
if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void get_rows_cuda_q6_K(
|
||||
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const int64_t nb_blocks = ne00 / QK_K;
|
||||
const dim3 block_dims(64, 1, 1);
|
||||
const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
|
||||
const size_t s1 = nb1 / sizeof(dst_t);
|
||||
const size_t s2 = nb2 / sizeof(dst_t);
|
||||
const size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
const size_t s10 = nb10 / sizeof(int32_t);
|
||||
const size_t s11 = nb11 / sizeof(int32_t);
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
|
||||
k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne11, ne12,
|
||||
s1, s2, s3,
|
||||
nb01, nb02, nb03,
|
||||
s10, s11, s12);
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -47,6 +47,12 @@ type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// PostLoader is an optional interface that models can implement to run
|
||||
// initialization steps after backend weights have been loaded.
|
||||
type PostLoader interface {
|
||||
PostLoad() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
|
||||
@@ -68,6 +68,8 @@ func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Conv1DDW(ctx ml.Context, _ ml.Tensor, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) PadExt(ctx ml.Context, _, _, _, _, _, _, _, _ int) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
265
model/models/gemma4/model.go
Normal file
265
model/models/gemma4/model.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
*AudioModel `gguf:"a"`
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*AudioMultimodalProjector `gguf:"mm.a"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageEndTokenID int32
|
||||
audioTokenID int32
|
||||
audioEndTokenID int32
|
||||
|
||||
audioOpts *AudioModelOptions
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Projection *ClippableLinear `gguf:"input_projection"`
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
|
||||
// Post-projection RMSNorm without learned weight
|
||||
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
|
||||
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
|
||||
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
|
||||
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
|
||||
tokenizer.WithSentencePieceNormalizer())
|
||||
|
||||
// Look up special token IDs for vision and audio
|
||||
imageTokenID := int32(-1)
|
||||
imageEndTokenID := int32(-1)
|
||||
audioTokenID := int32(-1)
|
||||
audioEndTokenID := int32(-1)
|
||||
for i, tok := range vocabulary.Values {
|
||||
switch tok {
|
||||
case "<|image>":
|
||||
imageTokenID = int32(i)
|
||||
case "<image|>":
|
||||
imageEndTokenID = int32(i)
|
||||
case "<|audio>":
|
||||
audioTokenID = int32(i)
|
||||
case "<audio|>":
|
||||
audioEndTokenID = int32(i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
AudioMultimodalProjector: &AudioMultimodalProjector{},
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
audioTokenID: audioTokenID,
|
||||
audioEndTokenID: audioEndTokenID,
|
||||
audioOpts: newAudioModelOptions(c),
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
// Audio input: detect WAV format and route to audio encoder.
|
||||
if isAudioData(multimodalData) {
|
||||
return m.encodeAudioMultimodal(ctx, multimodalData)
|
||||
}
|
||||
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
|
||||
|
||||
t1 := time.Now()
|
||||
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
|
||||
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
|
||||
|
||||
numPatchesX := imgW / m.ImageProcessor.patchSize
|
||||
numPatchesY := imgH / m.ImageProcessor.patchSize
|
||||
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
|
||||
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostLoad() error {
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
|
||||
if m.AudioModel == nil || m.audioOpts == nil {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
samples, err := decodeWAV(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("audio: decode", "elapsed", time.Since(t0), "samples", len(samples), "duration_s", float64(len(samples))/audioSampleRate)
|
||||
|
||||
// Pad waveform to next multiple of 128.
|
||||
if rem := len(samples) % 128; rem != 0 {
|
||||
samples = append(samples, make([]float32, 128-rem)...)
|
||||
}
|
||||
|
||||
// Compute mel spectrogram.
|
||||
melData, numFrames := computeMelSpectrogram(samples)
|
||||
if numFrames == 0 {
|
||||
return nil, fmt.Errorf("audio too short to encode")
|
||||
}
|
||||
slog.Info("audio: mel", "frames", numFrames, "elapsed", time.Since(t0))
|
||||
|
||||
// Create input tensor [melBins, numFrames] (GGML ne order). FromFloats creates F32.
|
||||
melTensor := ctx.Input().FromFloats(melData, melBins, numFrames)
|
||||
|
||||
// Run audio encoder.
|
||||
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, m.AudioMultimodalProjector, m.audioOpts)
|
||||
slog.Info("audio: encoded", "elapsed", time.Since(t0), "shape", audioOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
|
||||
}
|
||||
|
||||
// audioTag marks multimodal data as audio (vs vision) for PostTokenize.
|
||||
type audioTag struct{}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
numTokens := inputMultimodal.Dim(1)
|
||||
|
||||
// Determine if this is audio or vision based on the tag.
|
||||
_, isAudio := inp.Multimodal[0].Data.(audioTag)
|
||||
|
||||
var beginToken, endToken int32
|
||||
if isAudio {
|
||||
beginToken = m.audioTokenID
|
||||
endToken = m.audioEndTokenID
|
||||
} else {
|
||||
beginToken = m.imageTokenID
|
||||
endToken = m.imageEndTokenID
|
||||
}
|
||||
|
||||
if beginToken >= 0 {
|
||||
result = append(result, &input.Input{Token: beginToken, SameBatch: numTokens + 2})
|
||||
}
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
|
||||
)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numTokens-1)...)
|
||||
|
||||
if endToken >= 0 {
|
||||
result = append(result, &input.Input{Token: endToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
|
||||
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma4", New)
|
||||
}
|
||||
611
model/models/gemma4/model_audio.go
Normal file
611
model/models/gemma4/model_audio.go
Normal file
@@ -0,0 +1,611 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// AudioModel holds the audio encoder and configuration.
|
||||
type AudioModel struct {
|
||||
// SSCP: Sub-Sample Convolution Projection.
|
||||
SSCPConv0 *AudioConvBlock `gguf:"conv1d.0"`
|
||||
SSCPConv1 *AudioConvBlock `gguf:"conv1d.1"`
|
||||
|
||||
// SSCP output projection (linear).
|
||||
SSCPInputProj *nn.Linear `gguf:"pre_encode.out"`
|
||||
|
||||
// Conformer blocks.
|
||||
Layers []AudioConformerBlock `gguf:"blk"`
|
||||
|
||||
// Output projection to embedder dimension.
|
||||
OutputProj *AudioOutputProj `gguf:"output_proj"`
|
||||
|
||||
AudioModelOptions
|
||||
}
|
||||
|
||||
type AudioOutputProj struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
// AudioModelOptions holds audio model hyperparameters.
|
||||
type AudioModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
ffnSize int
|
||||
numLayers int
|
||||
melBins int
|
||||
chunkSize int
|
||||
maxPast int
|
||||
maxFuture int
|
||||
contextSize int
|
||||
logitCap float32
|
||||
residualWeight float32
|
||||
gradClip float32
|
||||
convKernelSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
// AudioConvBlock is a single 2D convolution block for the SSCP.
|
||||
type AudioConvBlock struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Norm *nn.LayerNorm `gguf:"norm"`
|
||||
}
|
||||
|
||||
// AudioConformerBlock is a single conformer layer.
|
||||
// All tensors are flat at the block level (a.blk.N.<name>) using underscore naming.
|
||||
type AudioConformerBlock struct {
|
||||
// Block-level norm
|
||||
Norm *nn.RMSNorm `gguf:"layer_pre_norm"`
|
||||
|
||||
// FFW start
|
||||
FFWNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
FFWUp *AudioClippableLinear `gguf:"ffn_up"`
|
||||
FFWDown *AudioClippableLinear `gguf:"ffn_down"`
|
||||
FFWPostNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
// FFW end
|
||||
FFWNorm1 *nn.RMSNorm `gguf:"ffn_norm_1"`
|
||||
FFWUp1 *AudioClippableLinear `gguf:"ffn_up_1"`
|
||||
FFWDown1 *AudioClippableLinear `gguf:"ffn_down_1"`
|
||||
FFWPostNorm1 *nn.RMSNorm `gguf:"ffn_post_norm_1"`
|
||||
|
||||
// Attention
|
||||
AttnQ *AudioClippableLinear `gguf:"attn_q"`
|
||||
AttnK *AudioClippableLinear `gguf:"attn_k"`
|
||||
AttnV *AudioClippableLinear `gguf:"attn_v"`
|
||||
AttnOut *AudioClippableLinear `gguf:"attn_out"`
|
||||
AttnPreNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
AttnPostNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
LinearPos ml.Tensor `gguf:"linear_pos.weight"`
|
||||
PerDimScale ml.Tensor `gguf:"per_dim_scale.weight"`
|
||||
|
||||
// LightConv1d
|
||||
ConvPW1 *AudioClippableLinear `gguf:"conv_pw1"`
|
||||
ConvPW2 *AudioClippableLinear `gguf:"conv_pw2"`
|
||||
ConvDW ml.Tensor `gguf:"conv_dw.weight"`
|
||||
ConvNorm *nn.RMSNorm `gguf:"conv_norm"`
|
||||
NormConv *nn.RMSNorm `gguf:"norm_conv"`
|
||||
}
|
||||
|
||||
// AudioClippableLinear is a linear layer with optional input/output clamping.
|
||||
type AudioClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
// Cached scalar clamp values (populated on first forward).
|
||||
inMin, inMax, outMin, outMax float32
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) loadClamps() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
if l.InputMin != nil {
|
||||
vals := l.InputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.InputMax != nil {
|
||||
vals := l.InputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMax = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMin != nil {
|
||||
vals := l.OutputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMax != nil {
|
||||
vals := l.OutputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMax = vals[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
l.loadClamps()
|
||||
if l.inMax != 0 {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.Bias != nil {
|
||||
out = out.Add(ctx, l.Bias)
|
||||
}
|
||||
if l.outMax != 0 {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AudioMultimodalProjector is the audio-to-text embedding projector.
|
||||
type AudioMultimodalProjector struct {
|
||||
Projection *AudioClippableLinear `gguf:"input_projection"`
|
||||
FC *AudioFC `gguf:"fc"`
|
||||
}
|
||||
|
||||
type AudioFC struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (p *AudioMultimodalProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
// FC: output projection from conformer to embedder dimension.
|
||||
x = p.FC.Weight.Mulmat(ctx, x)
|
||||
if p.FC.Bias != nil {
|
||||
x = x.Add(ctx, p.FC.Bias)
|
||||
}
|
||||
// Pre-projection RMSNorm (without learned weight) — matches Python's embedding_pre_projection_norm.
|
||||
x = x.RMSNorm(ctx, nil, eps)
|
||||
// Embedding projection to text hidden size.
|
||||
x = p.Projection.Forward(ctx, x)
|
||||
return x
|
||||
}
|
||||
|
||||
// ForwardAudio encodes mel spectrogram features into soft tokens.
|
||||
// melFeatures: float32 tensor with ne[0]=melBins, ne[1]=numFrames.
|
||||
// Returns: [hiddenSize, numTokens] tensor.
|
||||
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, proj *AudioMultimodalProjector, opts *AudioModelOptions) ml.Tensor {
|
||||
// SSCP Conv2D input: ne[0]=F (freq/width), ne[1]=T (time/height), ne[2]=C_in, ne[3]=B
|
||||
// melFeatures is [melBins, numFrames], add channel and batch dims.
|
||||
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
|
||||
|
||||
// SSCP Conv block 0: [F, T, 1, 1] → [F', T', C0, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv0, x, opts)
|
||||
|
||||
// SSCP Conv block 1: [F', T', C0, 1] → [F'', T'', C1, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv1, x, opts)
|
||||
|
||||
// After conv blocks, layout is [F'', T'', C_out, B].
|
||||
// Permute to [C_out*F'', T'', B] for linear projection (channels+freq in ne[0]).
|
||||
fOut := x.Dim(0)
|
||||
tOut := x.Dim(1)
|
||||
cOut := x.Dim(2)
|
||||
// Permute [F'', T'', C, B] → [C, F'', T'', B]
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = x.Reshape(ctx, cOut*fOut, tOut)
|
||||
|
||||
// Linear projection to hidden size.
|
||||
x = m.SSCPInputProj.Forward(ctx, x)
|
||||
|
||||
// Build causal-valid mask for conformer attention.
|
||||
causalMask := buildCausalValidMaskF32(opts.chunkSize, opts.maxPast, opts.maxFuture)
|
||||
|
||||
// Run conformer blocks.
|
||||
for i := range m.Layers {
|
||||
x = m.Layers[i].Forward(ctx, x, causalMask, opts, i)
|
||||
}
|
||||
|
||||
// Output projection.
|
||||
if m.OutputProj != nil {
|
||||
x = m.OutputProj.Weight.Mulmat(ctx, x)
|
||||
if m.OutputProj.Bias != nil {
|
||||
x = x.Add(ctx, m.OutputProj.Bias)
|
||||
}
|
||||
}
|
||||
|
||||
// Audio embedder: project to text embedding space.
|
||||
if proj != nil {
|
||||
x = proj.Forward(ctx, x, opts.eps)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardConvBlock runs a single SSCP Conv2D block.
|
||||
// Conv2D receiver is the kernel, argument is the input data.
|
||||
// Input: [F, T, C_in, B]. Output: [F', T', C_out, B].
|
||||
func forwardConvBlock(ctx ml.Context, block *AudioConvBlock, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
// Conv2D: kernel.Conv2D(ctx, input, s0, s1, p0, p1, d0, d1)
|
||||
// Kernel is 3x3, stride 2x2, padding 1x1 (matching SSCP config).
|
||||
// Output layout: [F', T', C_out, B]
|
||||
// Make weight contiguous — the shape reversal in the converter creates
|
||||
// a tensor where the physical data order doesn't match ne[]/stride[].
|
||||
weight := block.Weight.Contiguous(ctx)
|
||||
x = weight.Conv2D(ctx, x, 2, 2, 1, 1, 1, 1)
|
||||
|
||||
// LayerNorm needs channels in ne[0]. Permute [F', T', C_out, B] → [C_out, F', T', B],
|
||||
// norm, then permute back.
|
||||
// GGML permute: axis i says where old axis i goes.
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0 → [C_out, F', T', B]
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = block.Norm.Forward(ctx, x, opts.eps)
|
||||
// (2,0,1,3): old[0]→pos2, old[1]→pos0, old[2]→pos1 → [F', T', C_out, B]
|
||||
x = x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
x = x.RELU(ctx)
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward runs a single conformer block.
|
||||
func (cb *AudioConformerBlock) Forward(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
// FFW start (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm, cb.FFWUp, cb.FFWDown, cb.FFWPostNorm, x, opts)
|
||||
|
||||
// Self-attention.
|
||||
x = cb.forwardAttention(ctx, x, causalMask, opts, blockIdx)
|
||||
|
||||
// Lightweight Conv1d.
|
||||
x = cb.forwardLightConv(ctx, x, opts, blockIdx)
|
||||
|
||||
// FFW end (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm1, cb.FFWUp1, cb.FFWDown1, cb.FFWPostNorm1, x, opts)
|
||||
|
||||
// Gradient clipping + final norm.
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.Norm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardFFW runs a feedforward module with half-residual connection.
|
||||
func (cb *AudioConformerBlock) forwardFFW(ctx ml.Context, preNorm *nn.RMSNorm, up, down *AudioClippableLinear, postNorm *nn.RMSNorm, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = preNorm.Forward(ctx, x, opts.eps)
|
||||
x = up.Forward(ctx, x)
|
||||
x = x.SILU(ctx)
|
||||
x = down.Forward(ctx, x)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = postNorm.Forward(ctx, x, opts.eps)
|
||||
x = x.Scale(ctx, float64(opts.residualWeight))
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// forwardAttention runs the conformer block-local attention with relative position embeddings.
|
||||
func (cb *AudioConformerBlock) forwardAttention(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPreNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
hiddenSize := x.Dim(0)
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// QKV projections: [hiddenSize, seqLen] → [headDim, numHeads, seqLen]
|
||||
q := cb.AttnQ.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
k := cb.AttnK.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
v := cb.AttnV.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
|
||||
// Per-dim scaling for queries: (headDim^-0.5 / log(2)) * softplus(per_dim_scale)
|
||||
// per_dim_scale is already softplus'd from the converter.
|
||||
qScale := float64(math.Pow(float64(opts.headDim), -0.5)) / math.Log(2)
|
||||
q = q.Scale(ctx, qScale)
|
||||
if cb.PerDimScale != nil {
|
||||
q = q.Mul(ctx, cb.PerDimScale)
|
||||
}
|
||||
|
||||
// Key scaling: softplus(1) / log(2) — matches the query base scaling convention.
|
||||
kScale := math.Log(1+math.E) / math.Log(2)
|
||||
k = k.Scale(ctx, kScale)
|
||||
|
||||
// Build sinusoidal position embeddings for the block-local context.
|
||||
maxSpan := opts.maxPast + opts.maxFuture + 1 // 13 unique relative positions
|
||||
posEmb := cb.buildPositionEmbeddings(ctx, maxSpan, opts)
|
||||
// posEmb: [headDim, numHeads, maxSpan]
|
||||
|
||||
// Block-local attention: process chunks of size chunkSize.
|
||||
chunkSize := opts.chunkSize
|
||||
numChunks := (seqLen + chunkSize - 1) / chunkSize
|
||||
contextSize := opts.contextSize
|
||||
|
||||
// Pad q/k/v to multiple of chunkSize on the time dimension (dim 2).
|
||||
padT := numChunks*chunkSize - seqLen
|
||||
if padT > 0 {
|
||||
q = q.Pad(ctx, 0, 0, padT, 0)
|
||||
k = k.Pad(ctx, 0, 0, padT, 0)
|
||||
v = v.Pad(ctx, 0, 0, padT, 0)
|
||||
}
|
||||
paddedLen := numChunks * chunkSize
|
||||
|
||||
// Pad k/v for context extraction: add maxPast on left, (maxFuture+chunkSize-1) on right.
|
||||
// Use Pad (right) + PadExt (left) workaround since PadExt+Slice has issues.
|
||||
// Actually use Concat with zero tensors for reliable left-padding.
|
||||
padLeft := opts.maxPast
|
||||
padRight := opts.maxFuture + chunkSize - 1
|
||||
zeroLeft := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padLeft), opts.headDim, opts.numHeads, padLeft)
|
||||
zeroRight := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padRight), opts.headDim, opts.numHeads, padRight)
|
||||
kPadded := zeroLeft.Concat(ctx, k, 2).Concat(ctx, zeroRight, 2)
|
||||
vPadded := zeroLeft.Concat(ctx, v, 2).Concat(ctx, zeroRight, 2)
|
||||
|
||||
// Reshape q into chunks: [headDim, numHeads, numChunks, chunkSize]
|
||||
qChunked := q.Reshape(ctx, opts.headDim, opts.numHeads, numChunks, chunkSize)
|
||||
|
||||
// Process each chunk and collect results.
|
||||
chunkOutputs := make([]ml.Tensor, numChunks)
|
||||
for u := range numChunks {
|
||||
// Extract query block: [headDim, numHeads, 1, chunkSize] → [headDim, numHeads, chunkSize]
|
||||
qBlock := qChunked.Slice(ctx, 2, u, u+1, 1).Reshape(ctx, opts.headDim, opts.numHeads, chunkSize)
|
||||
|
||||
// Extract key/value context: [headDim, numHeads, contextSize]
|
||||
cStart := u * chunkSize // offset in kPadded (padLeft already accounts for left context)
|
||||
kCtx := kPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
vCtx := vPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
|
||||
// Content-content logits: qBlock^T @ kCtx → [chunkSize, contextSize] per head.
|
||||
// Mulmat(a, b) = a^T @ b. We want Q^T K, so: kCtx.Mulmat(qBlock) but that gives
|
||||
// [numHeads, chunkSize, contextSize] with wrong batching.
|
||||
// Instead: permute to [headDim, chunkSize, numHeads] and [headDim, contextSize, numHeads]
|
||||
// then Mulmat batches over numHeads.
|
||||
// GGML permute(0,2,1,3): old[0]→0, old[1]→2, old[2]→1
|
||||
qP := qBlock.Permute(ctx, 0, 2, 1, 3) // [headDim, chunkSize, numHeads]
|
||||
kP := kCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
|
||||
termAC := kP.MulmatFullPrec(ctx, qP) // [contextSize, chunkSize, numHeads]
|
||||
|
||||
// Content-position logits: qBlock^T @ posEmb → [chunkSize, maxSpan] per head.
|
||||
pP := posEmb.Permute(ctx, 0, 2, 1, 3) // [headDim, maxSpan, numHeads]
|
||||
termBDRaw := pP.MulmatFullPrec(ctx, qP) // [maxSpan, chunkSize, numHeads]
|
||||
|
||||
// Relative shift: [maxSpan, chunkSize, numHeads] → [contextSize, chunkSize, numHeads]
|
||||
termBD := cb.relativeShiftGGML(ctx, termBDRaw, maxSpan, chunkSize, contextSize, opts.numHeads)
|
||||
|
||||
// Combined logits.
|
||||
logits := termAC.Add(ctx, termBD)
|
||||
|
||||
// Logit softcap: tanh(logits / cap) * cap
|
||||
logits = logits.Scale(ctx, 1.0/float64(opts.logitCap))
|
||||
logits = logits.Tanh(ctx)
|
||||
logits = logits.Scale(ctx, float64(opts.logitCap))
|
||||
|
||||
// Apply combined causal + validity mask.
|
||||
// causalMask [chunkSize * contextSize]: 1=causal-allowed, 0=masked.
|
||||
// Validity: context positions before the actual sequence start are invalid.
|
||||
// For chunk u, context position c corresponds to actual time: u*chunkSize + c - padLeft.
|
||||
// Valid if 0 <= actual_time < seqLen.
|
||||
// Mask tensor layout: [contextSize, chunkSize, 1] with ne[0]=contextSize contiguous.
|
||||
// Element at (context=j, chunk=i) is at flat index: i*contextSize + j.
|
||||
maskData := make([]float32, contextSize*chunkSize)
|
||||
for i := range chunkSize {
|
||||
for j := range contextSize {
|
||||
actualTime := u*chunkSize + j - padLeft
|
||||
causalOK := causalMask[i*contextSize+j] > 0
|
||||
validOK := actualTime >= 0 && actualTime < seqLen
|
||||
if causalOK && validOK {
|
||||
maskData[i*contextSize+j] = 0
|
||||
} else {
|
||||
maskData[i*contextSize+j] = -1e9
|
||||
}
|
||||
}
|
||||
}
|
||||
mask := ctx.Input().FromFloats(maskData, contextSize, chunkSize, 1) // 3D for broadcasting over numHeads
|
||||
logits = logits.Add(ctx, mask)
|
||||
|
||||
// Softmax over context dimension (dim 0 = contextSize).
|
||||
logits = logits.Softmax(ctx) // softmax over ne[0]=contextSize
|
||||
|
||||
// Weighted sum: logits^T @ vCtx.
|
||||
// logits: [contextSize, chunkSize, numHeads], vCtx: [headDim, numHeads, contextSize]
|
||||
// vCtx permuted: [headDim, contextSize, numHeads]
|
||||
vP := vCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
// Weighted sum: for each head, value[headDim, contextSize] @ weights[contextSize, chunkSize]
|
||||
// = [headDim, chunkSize].
|
||||
// Mulmat(a, b) = a^T @ b. Need a=[contextSize, headDim, numHeads], b=[contextSize, chunkSize, numHeads].
|
||||
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [contextSize, headDim, numHeads]
|
||||
chunkOut := vPT.Mulmat(ctx, logits) // [headDim, chunkSize, numHeads]
|
||||
|
||||
// Permute back to [headDim, numHeads, chunkSize]
|
||||
chunkOut = chunkOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
chunkOutputs[u] = chunkOut
|
||||
}
|
||||
|
||||
// Concatenate chunk outputs along time dimension.
|
||||
var attnOut ml.Tensor
|
||||
if numChunks == 1 {
|
||||
attnOut = chunkOutputs[0]
|
||||
} else {
|
||||
attnOut = chunkOutputs[0]
|
||||
for _, co := range chunkOutputs[1:] {
|
||||
attnOut = attnOut.Concat(ctx, co, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to original sequence length if we padded.
|
||||
if paddedLen > seqLen {
|
||||
attnOut = attnOut.Slice(ctx, 2, 0, seqLen, 1).Contiguous(ctx)
|
||||
}
|
||||
|
||||
// Reshape to [hiddenSize, seqLen] and project.
|
||||
attnOut = attnOut.Reshape(ctx, hiddenSize, seqLen)
|
||||
x = cb.AttnOut.Forward(ctx, attnOut)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPostNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// buildPositionEmbeddings builds sinusoidal position embeddings and projects through linear_pos.
|
||||
// Returns [headDim, numHeads, maxSpan] tensor.
|
||||
func (cb *AudioConformerBlock) buildPositionEmbeddings(ctx ml.Context, maxSpan int, opts *AudioModelOptions) ml.Tensor {
|
||||
halfDim := opts.hiddenSize / 2
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// inv_timescales: exp(-i * log(10000) / max(D/2-1, 1))
|
||||
logInc := math.Log(10000.0) / math.Max(float64(halfDim-1), 1)
|
||||
|
||||
// Sinusoidal embeddings for relative positions [maxPast, maxPast-1, ..., -maxFuture].
|
||||
posData := make([]float32, hiddenSize*maxSpan)
|
||||
for p := range maxSpan {
|
||||
relPos := float64(opts.maxPast - p)
|
||||
for d := range halfDim {
|
||||
angle := relPos * math.Exp(float64(-d)*logInc)
|
||||
posData[p*hiddenSize+d] = float32(math.Sin(angle))
|
||||
posData[p*hiddenSize+halfDim+d] = float32(math.Cos(angle))
|
||||
}
|
||||
}
|
||||
|
||||
// Create [hiddenSize, maxSpan] input tensor.
|
||||
posEmb := ctx.Input().FromFloats(posData, hiddenSize, maxSpan)
|
||||
|
||||
// Project through linear_pos: [hiddenSize, maxSpan] → Mulmat → [numHeads*headDim, maxSpan]
|
||||
projPos := cb.LinearPos.Mulmat(ctx, posEmb)
|
||||
|
||||
// Reshape to [headDim, numHeads, maxSpan].
|
||||
return projPos.Reshape(ctx, opts.headDim, opts.numHeads, maxSpan)
|
||||
}
|
||||
|
||||
// relativeShiftGGML performs the relative shift to extract correct position logits.
|
||||
// Input: [maxSpan, chunkSize, numHeads]. Output: [contextSize, chunkSize, numHeads].
|
||||
func (cb *AudioConformerBlock) relativeShiftGGML(ctx ml.Context, x ml.Tensor, maxSpan, chunkSize, contextSize, numHeads int) ml.Tensor {
|
||||
// The shift trick: pad ne[0] to contextSize+1, reshape to flatten first two dims,
|
||||
// skip first (contextSize+1-maxSpan) elements, take contextSize*chunkSize elements, reshape back.
|
||||
padAmt := contextSize + 1 - maxSpan
|
||||
if padAmt > 0 {
|
||||
x = x.Pad(ctx, padAmt, 0, 0, 0) // [maxSpan+padAmt, chunkSize, numHeads] = [contextSize+1, chunkSize, numHeads]
|
||||
}
|
||||
// Reshape to [(contextSize+1)*chunkSize, numHeads]
|
||||
x = x.Reshape(ctx, (contextSize+1)*chunkSize, numHeads)
|
||||
// Take the first contextSize*chunkSize elements (the standard relative shift trick).
|
||||
x = x.Slice(ctx, 0, 0, contextSize*chunkSize, 1).Contiguous(ctx)
|
||||
// Reshape to [contextSize, chunkSize, numHeads]
|
||||
return x.Reshape(ctx, contextSize, chunkSize, numHeads)
|
||||
}
|
||||
|
||||
// forwardLightConv runs the lightweight depthwise convolution module.
|
||||
func (cb *AudioConformerBlock) forwardLightConv(ctx ml.Context, x ml.Tensor, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
|
||||
x = cb.ConvNorm.Forward(ctx, x, opts.eps)
|
||||
x = cb.ConvPW1.Forward(ctx, x) // [2*D, T, B]
|
||||
|
||||
// GLU: split in half along dim 0, sigmoid gate, multiply.
|
||||
d := x.Dim(0) / 2
|
||||
data := x.Slice(ctx, 0, 0, d, 1).Contiguous(ctx)
|
||||
gate := x.Slice(ctx, 0, d, d*2, 1).Contiguous(ctx).Sigmoid(ctx)
|
||||
x = data.Mul(ctx, gate) // [D, T, B]
|
||||
|
||||
// Depthwise Conv1d: manual implementation using model weight tensor slices.
|
||||
// Kernel cb.ConvDW shape: [K=5, D=1024] (ne[0]=K, ne[1]=D) after shape reversal.
|
||||
// Actually in GGML, ne[0]=K=5 contiguous, ne[1]=D=1024.
|
||||
// We need per-tap weights [D] and shifted input copies.
|
||||
kernelSize := cb.ConvDW.Dim(0) // K=5
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// Transpose kernel to [D, K] for per-tap slicing.
|
||||
// GGML permute(1,0,2,3): old[0]→pos1, old[1]→pos0 → swap ne[0] and ne[1]
|
||||
kernelT := cb.ConvDW.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [D, K]
|
||||
|
||||
var convOut ml.Tensor
|
||||
for k := range kernelSize {
|
||||
shift := kernelSize - 1 - k
|
||||
var shifted ml.Tensor
|
||||
if shift == 0 {
|
||||
shifted = x
|
||||
} else {
|
||||
trimmed := x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
|
||||
shifted = trimmed.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
|
||||
}
|
||||
|
||||
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) // [D, 1]
|
||||
term := shifted.Mul(ctx, wk)
|
||||
if convOut == nil {
|
||||
convOut = term
|
||||
} else {
|
||||
convOut = convOut.Add(ctx, term)
|
||||
}
|
||||
}
|
||||
x = convOut
|
||||
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.NormConv.Forward(ctx, x, opts.eps)
|
||||
x = x.SILU(ctx)
|
||||
x = cb.ConvPW2.Forward(ctx, x)
|
||||
|
||||
return x.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func newAudioModel(c fs.Config) *AudioModel {
|
||||
numLayers := int(c.Uint("audio.block_count", 0))
|
||||
if numLayers == 0 {
|
||||
return nil
|
||||
}
|
||||
return &AudioModel{
|
||||
Layers: make([]AudioConformerBlock, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
func newAudioModelOptions(c fs.Config) *AudioModelOptions {
|
||||
hiddenSize := int(c.Uint("audio.embedding_length", 0))
|
||||
if hiddenSize == 0 {
|
||||
return nil
|
||||
}
|
||||
numHeads := int(c.Uint("audio.attention.head_count", 8))
|
||||
headDim := hiddenSize / numHeads
|
||||
chunkSize := 12 // default conformer chunk size
|
||||
maxPast := 12 // conf_attention_context_left - 1
|
||||
maxFuture := 0 // conf_attention_context_right
|
||||
convKernel := int(c.Uint("audio.conv_kernel_size", 5))
|
||||
|
||||
eps := c.Float("audio.attention.layer_norm_epsilon", 1e-6)
|
||||
|
||||
return &AudioModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
ffnSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
|
||||
numLayers: int(c.Uint("audio.block_count", 12)),
|
||||
melBins: int(c.Uint("audio.num_mel_bins", 128)),
|
||||
chunkSize: chunkSize,
|
||||
maxPast: maxPast,
|
||||
maxFuture: maxFuture,
|
||||
contextSize: chunkSize + maxPast + maxFuture,
|
||||
logitCap: 50.0,
|
||||
residualWeight: 0.5,
|
||||
gradClip: 1e10,
|
||||
convKernelSize: convKernel,
|
||||
eps: float32(eps),
|
||||
}
|
||||
}
|
||||
|
||||
// buildCausalValidMaskF32 creates the causal-valid mask for block-local attention.
|
||||
// Returns flat [chunkSize * contextSize] float32 data (1.0 = allowed, 0.0 = masked).
|
||||
func buildCausalValidMaskF32(chunkSize, maxPast, maxFuture int) []float32 {
|
||||
contextSize := chunkSize + maxPast + maxFuture
|
||||
upperDiag := maxPast + maxFuture
|
||||
|
||||
result := make([]float32, chunkSize*contextSize)
|
||||
for r := range chunkSize {
|
||||
for c := range contextSize {
|
||||
lower := (r <= c) // tril(contextSize, chunkSize) transposed
|
||||
upper := (c <= r+upperDiag) // tril(chunkSize, contextSize, diag=upperDiag)
|
||||
if lower && upper {
|
||||
result[r*contextSize+c] = 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
475
model/models/gemma4/model_text.go
Normal file
475
model/models/gemma4/model_text.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads int
|
||||
numGlobalKVHeads int
|
||||
headDim, globalHeadDim int
|
||||
hiddenLayers int
|
||||
hiddenSizePerLayerInput int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeLocalBase float32
|
||||
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
||||
|
||||
slidingWindowPattern []bool
|
||||
// kvDonorMap maps shared layer index -> donor layer index.
|
||||
// Donor is the last non-shared layer of the same type (sliding/full).
|
||||
kvDonorMap map[int]int
|
||||
|
||||
finalLogitSoftcap float32
|
||||
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(layer int) bool {
|
||||
if layer < len(o.slidingWindowPattern) {
|
||||
return o.slidingWindowPattern[layer]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
||||
if o.isLocal(layer) {
|
||||
return o.ropeLocalBase, o.headDim
|
||||
}
|
||||
return o.ropeBase, o.partialRotaryDims
|
||||
}
|
||||
|
||||
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.numKVHeads
|
||||
}
|
||||
if o.numGlobalKVHeads > 0 {
|
||||
return o.numGlobalKVHeads
|
||||
}
|
||||
return o.numKVHeads
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDimForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.headDim
|
||||
}
|
||||
return o.globalHeadDim
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
*PerLayerProjector
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
|
||||
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
||||
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
||||
headDim := int(c.Uint("attention.key_length_swa", 256))
|
||||
|
||||
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
||||
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
||||
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
||||
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
||||
if partialRotaryDims == 0 {
|
||||
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
||||
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
||||
}
|
||||
|
||||
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
||||
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
||||
if ropeLocalBase == 0 {
|
||||
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
||||
}
|
||||
|
||||
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
||||
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
||||
numKVHeads := 0
|
||||
kvHeadsArray := c.Ints("attention.head_count_kv")
|
||||
if len(kvHeadsArray) > 0 {
|
||||
numKVHeads = int(kvHeadsArray[0])
|
||||
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
||||
for i, isLocal := range slidingPattern {
|
||||
if !isLocal && i < len(kvHeadsArray) {
|
||||
numGlobalKVHeads = int(kvHeadsArray[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
||||
}
|
||||
|
||||
// Compute KV sharing donor map (same logic as MLX)
|
||||
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
||||
kvDonorMap := make(map[int]int)
|
||||
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
||||
firstShared := numLayers - sharedLayers
|
||||
for i := firstShared; i < numLayers; i++ {
|
||||
isLocal := slidingPattern[i]
|
||||
// Find last non-shared layer of same type
|
||||
for j := firstShared - 1; j >= 0; j-- {
|
||||
if slidingPattern[j] == isLocal {
|
||||
kvDonorMap[i] = j
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextLayer, numLayers),
|
||||
TextOptions: TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: numKVHeads,
|
||||
numGlobalKVHeads: numGlobalKVHeads,
|
||||
headDim: headDim,
|
||||
globalHeadDim: globalHeadDim,
|
||||
hiddenLayers: numLayers,
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: ropeBase,
|
||||
ropeLocalBase: ropeLocalBase,
|
||||
partialRotaryDims: partialRotaryDims,
|
||||
slidingWindowPattern: slidingPattern,
|
||||
kvDonorMap: kvDonorMap,
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
numExperts: int(c.Uint("expert_count", 0)),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// Inject vision embeddings into the hidden state
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
// PLE
|
||||
var perLayerInputs ml.Tensor
|
||||
if m.PerLayerProjector != nil {
|
||||
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
||||
}
|
||||
|
||||
for i := range len(m.Layers) {
|
||||
layer := m.Layers[i]
|
||||
if cache != nil {
|
||||
cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
if !m.isLocal(i) {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
var perLayerInput ml.Tensor
|
||||
if perLayerInputs != nil {
|
||||
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
||||
}
|
||||
|
||||
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
||||
isShared := false
|
||||
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
||||
// Set cache layer to donor so Get() reads donor's K/V
|
||||
cache.SetLayer(donorLayer)
|
||||
isShared = true
|
||||
}
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
// PerLayerProjector implements PLE.
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
hd := opts.headDimForLayer(layer)
|
||||
kvHeads := opts.kvHeadsForLayer(layer)
|
||||
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
|
||||
var k, v ml.Tensor
|
||||
if !sharedKV {
|
||||
k = sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
|
||||
if sa.Value != nil {
|
||||
v = sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
} else {
|
||||
// K=V: use raw K projection (before K norm) as V
|
||||
v = k
|
||||
}
|
||||
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
||||
}
|
||||
|
||||
// RoPE with proportional freq_factors on global layers
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
||||
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
if k != nil {
|
||||
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
// RMSNorm without learned weight
|
||||
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
||||
// Scale by 1/sqrt(hidden_size)
|
||||
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
// Multiply by learned scale parameter
|
||||
x = x.Mul(ctx, r.Scale)
|
||||
// Project to expert logits
|
||||
expertScores := r.Proj.Forward(ctx, x)
|
||||
// Softmax over experts
|
||||
routingWeights = expertScores.Softmax(ctx)
|
||||
// TopK expert selection
|
||||
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
return routingWeights, selectedExperts
|
||||
}
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
var gateOut, upOut ml.Tensor
|
||||
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
||||
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
||||
nFF := gateUp.Dim(0) / 2
|
||||
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
||||
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
||||
} else {
|
||||
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
}
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply per-expert down projection scale when present.
|
||||
if moe.DownScale != nil {
|
||||
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
||||
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
||||
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
experts = experts.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum across experts
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
||||
|
||||
// MoE (present only for models with enable_moe_block=true)
|
||||
Router *TextRouter
|
||||
MoE *TextMoEBlock
|
||||
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
||||
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
||||
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
if perLayerInput != nil {
|
||||
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
||||
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
combined := mlpState.Add(ctx, moeState)
|
||||
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
||||
hiddenState = combined.Add(ctx, residual)
|
||||
} else {
|
||||
// Dense layers: MLP only
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// PLE injection (after MLP residual)
|
||||
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
||||
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
||||
pleState = pleState.GELU(ctx, perLayerInput)
|
||||
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
||||
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, pleState)
|
||||
}
|
||||
|
||||
// Layer scalar applied at end of layer (full-attention layers only)
|
||||
if l.LayerScalar != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
384
model/models/gemma4/model_vision.go
Normal file
384
model/models/gemma4/model_vision.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
const batchSize = 1
|
||||
|
||||
// ClippableLinear is a linear layer with optional input/output clamping.
|
||||
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
|
||||
type ClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
inMin, inMax, outMin, outMax float32
|
||||
hasClamp bool
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func scalarValue(t ml.Tensor) (float32, bool) {
|
||||
if t == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
data := t.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return data[0], true
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) loadClampFromScalars() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
|
||||
const (
|
||||
defaultMin = -math.MaxFloat32
|
||||
defaultMax = math.MaxFloat32
|
||||
)
|
||||
|
||||
inMin, hasInMin := scalarValue(l.InputMin)
|
||||
inMax, hasInMax := scalarValue(l.InputMax)
|
||||
outMin, hasOutMin := scalarValue(l.OutputMin)
|
||||
outMax, hasOutMax := scalarValue(l.OutputMax)
|
||||
|
||||
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
|
||||
return
|
||||
}
|
||||
|
||||
l.hasClamp = true
|
||||
l.inMin = defaultMin
|
||||
l.inMax = defaultMax
|
||||
l.outMin = defaultMin
|
||||
l.outMax = defaultMax
|
||||
|
||||
if hasInMin {
|
||||
l.inMin = inMin
|
||||
}
|
||||
if hasInMax {
|
||||
l.inMax = inMax
|
||||
}
|
||||
if hasOutMin {
|
||||
l.outMin = outMin
|
||||
}
|
||||
if hasOutMax {
|
||||
l.outMax = outMax
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if l.hasClamp {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.hasClamp {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
|
||||
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector.
|
||||
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
if m.clampInitDone {
|
||||
return
|
||||
}
|
||||
m.clampInitDone = true
|
||||
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
for _, cl := range linears(&m.Layers[i]) {
|
||||
if cl != nil {
|
||||
cl.loadClampFromScalars()
|
||||
}
|
||||
}
|
||||
}
|
||||
if proj != nil && proj.Projection != nil {
|
||||
proj.Projection.loadClampFromScalars()
|
||||
}
|
||||
|
||||
// Load packed clamp data when present (legacy Ollama format).
|
||||
if m.ClampData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read all clamp values from packed F32 tensor
|
||||
data := m.ClampData.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute to layer linears: 7 per layer × 4 values each
|
||||
for i := range m.Layers {
|
||||
for li, cl := range linears(&m.Layers[i]) {
|
||||
if cl == nil {
|
||||
continue
|
||||
}
|
||||
idx := (i*7 + li) * 4
|
||||
if idx+3 < len(data) {
|
||||
cl.inMin = data[idx]
|
||||
cl.inMax = data[idx+1]
|
||||
cl.outMin = data[idx+2]
|
||||
cl.outMax = data[idx+3]
|
||||
cl.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Projector clamp values (last 4 floats)
|
||||
if proj != nil && proj.Projection != nil {
|
||||
projIdx := len(m.Layers) * 7 * 4
|
||||
if projIdx+3 < len(data) {
|
||||
proj.Projection.inMin = data[projIdx]
|
||||
proj.Projection.inMax = data[projIdx+1]
|
||||
proj.Projection.outMin = data[projIdx+2]
|
||||
proj.Projection.outMax = data[projIdx+3]
|
||||
proj.Projection.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *ClippableLinear `gguf:"attn_q"`
|
||||
Key *ClippableLinear `gguf:"attn_k"`
|
||||
Value *ClippableLinear `gguf:"attn_v"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *ClippableLinear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
numPatches := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
|
||||
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// V norm (RMSNorm without learned weights)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
|
||||
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
|
||||
// y positions to second half, then concatenate.
|
||||
halfDim := headDim / 2
|
||||
ropeOpts := rope.WithTypeNeoX()
|
||||
|
||||
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffset := halfDim * query.Stride(0)
|
||||
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffsetK := halfDim * key.Stride(0)
|
||||
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
query = qFirst.Concat(ctx, qSecond, 0)
|
||||
key = kFirst.Concat(ctx, kSecond, 0)
|
||||
|
||||
// Use flash attention for numerical stability (handles large attention scores
|
||||
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *ClippableLinear `gguf:"ffn_gate"`
|
||||
Up *ClippableLinear `gguf:"ffn_up"`
|
||||
Down *ClippableLinear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
gate := mlp.Gate.Forward(ctx, hiddenState)
|
||||
up := mlp.Up.Forward(ctx, hiddenState)
|
||||
hiddenState = gate.QuickGELU(ctx, up)
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// Pre-attention norm -> self attention -> post-attention norm
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
|
||||
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// Pre-FFN norm -> FFN -> post-FFN norm
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
// Per-layer output scale
|
||||
if e.LayerOutputScale != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
patchSize int
|
||||
nMerge int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
|
||||
ClampData ml.Tensor `gguf:"clamp_data"`
|
||||
StdBias ml.Tensor `gguf:"std_bias"`
|
||||
StdScale ml.Tensor `gguf:"std_scale"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
clampInitDone bool
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
|
||||
numPatches := numPatchesX * numPatchesY
|
||||
|
||||
// Patch embedding via Conv2D
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
|
||||
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
|
||||
posSize := m.PositionEmbedding.Dim(1)
|
||||
nb1 := m.PositionEmbedding.Stride(1)
|
||||
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
|
||||
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
|
||||
|
||||
// Position indices for patches
|
||||
posXData := make([]int32, numPatches)
|
||||
posYData := make([]int32, numPatches)
|
||||
for i := range numPatches {
|
||||
posXData[i] = int32(i % numPatchesX)
|
||||
posYData[i] = int32(i / numPatchesX)
|
||||
}
|
||||
|
||||
posXEmb := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYEmb := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
|
||||
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
|
||||
|
||||
// No attention mask — all positions are real patches
|
||||
var attnMask ml.Tensor
|
||||
|
||||
// RoPE positions
|
||||
posXRope := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYRope := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
// Vision transformer layers
|
||||
for i := range m.Layers {
|
||||
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: 100.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
|
||||
|
||||
// AvgPool2D with kernel=stride=nMerge
|
||||
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
|
||||
|
||||
// Reshape back to [hiddenSize, numMergedPatches]
|
||||
mergedX := numPatchesX / opts.nMerge
|
||||
mergedY := numPatchesY / opts.nMerge
|
||||
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
|
||||
|
||||
// Optional vision standardization before projection.
|
||||
if stdBias != nil && stdScale != nil {
|
||||
hiddenState = hiddenState.Sub(ctx, stdBias)
|
||||
hiddenState = hiddenState.Mul(ctx, stdScale)
|
||||
}
|
||||
|
||||
// Project to text embedding dimension
|
||||
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
280
model/models/gemma4/process_audio.go
Normal file
280
model/models/gemma4/process_audio.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/cmplx"
|
||||
)
|
||||
|
||||
// Audio preprocessing constants.
|
||||
const (
|
||||
audioSampleRate = 16000
|
||||
melBins = 128
|
||||
frameLengthMs = 20.0
|
||||
hopLengthMs = 10.0
|
||||
minFrequency = 0.0
|
||||
maxFrequency = 8000.0
|
||||
melFloor = 1e-3
|
||||
maxAudioSoftTokens = 750
|
||||
)
|
||||
|
||||
// Computed from the above constants.
|
||||
var (
|
||||
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
|
||||
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
|
||||
)
|
||||
|
||||
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
|
||||
func decodeWAV(data []byte) ([]float32, error) {
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("WAV file too short")
|
||||
}
|
||||
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
|
||||
return nil, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
|
||||
var audioFormat uint16
|
||||
var numChannels, sampleRate, bitsPerSample int
|
||||
var audioData []byte
|
||||
foundFmt := false
|
||||
|
||||
offset := 12
|
||||
for offset+8 <= len(data) {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
|
||||
|
||||
switch chunkID {
|
||||
case "fmt ":
|
||||
if len(chunkData) < 16 {
|
||||
return nil, fmt.Errorf("fmt chunk too short")
|
||||
}
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
|
||||
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
|
||||
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
|
||||
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
|
||||
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
|
||||
}
|
||||
foundFmt = true
|
||||
case "data":
|
||||
audioData = chunkData
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return nil, fmt.Errorf("no fmt chunk found in WAV file")
|
||||
}
|
||||
if audioFormat != 1 && audioFormat != 3 {
|
||||
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
|
||||
}
|
||||
if audioData == nil {
|
||||
return nil, fmt.Errorf("no data chunk found in WAV file")
|
||||
}
|
||||
|
||||
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
|
||||
if sampleRate != audioSampleRate {
|
||||
samples = resampleLinear(samples, sampleRate, audioSampleRate)
|
||||
}
|
||||
return samples, nil
|
||||
}
|
||||
|
||||
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
|
||||
bytesPerSample := bits / 8
|
||||
totalSamples := len(data) / (bytesPerSample * channels)
|
||||
mono := make([]float32, totalSamples)
|
||||
|
||||
for i := range totalSamples {
|
||||
var sum float64
|
||||
for ch := range channels {
|
||||
off := (i*channels + ch) * bytesPerSample
|
||||
if off+bytesPerSample > len(data) {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case format == 1 && bits == 16:
|
||||
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
|
||||
sum += float64(v) / 32768.0
|
||||
case format == 1 && bits == 32:
|
||||
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v) / 2147483648.0
|
||||
case format == 1 && bits == 24:
|
||||
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
|
||||
if v&0x800000 != 0 {
|
||||
v |= ^0xFFFFFF
|
||||
}
|
||||
sum += float64(v) / 8388608.0
|
||||
case format == 3 && bits == 32:
|
||||
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v)
|
||||
case format == 1 && bits == 8:
|
||||
sum += (float64(data[off]) - 128.0) / 128.0
|
||||
}
|
||||
}
|
||||
mono[i] = float32(sum / float64(channels))
|
||||
}
|
||||
return mono
|
||||
}
|
||||
|
||||
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
|
||||
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
|
||||
idx := int(pos)
|
||||
frac := float32(pos - float64(idx))
|
||||
if idx+1 < len(samples) {
|
||||
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
|
||||
} else {
|
||||
out[i] = samples[idx]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
|
||||
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
|
||||
func computeMelSpectrogram(samples []float32) ([]float32, int) {
|
||||
fftLen := 1
|
||||
for fftLen < frameLength {
|
||||
fftLen <<= 1
|
||||
}
|
||||
fftLen *= 2 // fft_overdrive=True
|
||||
|
||||
// Hanning-nonzero window.
|
||||
window := make([]float64, frameLength)
|
||||
arg := math.Pi * 2.0 / float64(frameLength)
|
||||
for i := range frameLength {
|
||||
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
|
||||
}
|
||||
|
||||
numFreqBins := fftLen/2 + 1
|
||||
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
|
||||
|
||||
frameSizeForUnfold := frameLength + 1
|
||||
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
|
||||
if numFrames <= 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
result := make([]float32, numFrames*melBins)
|
||||
fftInput := make([]complex128, fftLen)
|
||||
|
||||
for f := range numFrames {
|
||||
start := f * hopLength
|
||||
for i := range frameLength {
|
||||
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
|
||||
}
|
||||
for i := frameLength; i < fftLen; i++ {
|
||||
fftInput[i] = 0
|
||||
}
|
||||
|
||||
fft(fftInput)
|
||||
|
||||
for m := range melBins {
|
||||
var melVal float64
|
||||
for k := range numFreqBins {
|
||||
mag := cmplx.Abs(fftInput[k])
|
||||
melVal += mag * float64(melFilters[k*melBins+m])
|
||||
}
|
||||
if melVal < melFloor {
|
||||
melVal = melFloor
|
||||
}
|
||||
result[f*melBins+m] = float32(math.Log(melVal))
|
||||
}
|
||||
}
|
||||
|
||||
return result, numFrames
|
||||
}
|
||||
|
||||
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
|
||||
hzToMel := func(f float64) float64 {
|
||||
return 2595.0 * math.Log10(1.0+f/700.0)
|
||||
}
|
||||
melToHz := func(m float64) float64 {
|
||||
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
|
||||
}
|
||||
|
||||
melMin := hzToMel(fMin)
|
||||
melMax := hzToMel(fMax)
|
||||
|
||||
melPts := make([]float64, numMels+2)
|
||||
for i := range melPts {
|
||||
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
|
||||
}
|
||||
filterFreqs := make([]float64, numMels+2)
|
||||
for i, m := range melPts {
|
||||
filterFreqs[i] = melToHz(m)
|
||||
}
|
||||
|
||||
fftFreqs := make([]float64, numFreqBins)
|
||||
for i := range fftFreqs {
|
||||
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
|
||||
}
|
||||
|
||||
filters := make([]float32, numFreqBins*numMels)
|
||||
for m := range numMels {
|
||||
fLeft := filterFreqs[m]
|
||||
fCenter := filterFreqs[m+1]
|
||||
fRight := filterFreqs[m+2]
|
||||
for k := range numFreqBins {
|
||||
f := fftFreqs[k]
|
||||
var v float64
|
||||
if f >= fLeft && f <= fCenter && fCenter > fLeft {
|
||||
v = (f - fLeft) / (fCenter - fLeft)
|
||||
} else if f > fCenter && f <= fRight && fRight > fCenter {
|
||||
v = (fRight - f) / (fRight - fCenter)
|
||||
}
|
||||
if v > 0 {
|
||||
filters[k*numMels+m] = float32(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
// fft performs an in-place Cooley-Tukey radix-2 FFT.
|
||||
func fft(x []complex128) {
|
||||
n := len(x)
|
||||
if n <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
j := 0
|
||||
for i := 1; i < n; i++ {
|
||||
bit := n >> 1
|
||||
for j&bit != 0 {
|
||||
j ^= bit
|
||||
bit >>= 1
|
||||
}
|
||||
j ^= bit
|
||||
if i < j {
|
||||
x[i], x[j] = x[j], x[i]
|
||||
}
|
||||
}
|
||||
|
||||
for size := 2; size <= n; size <<= 1 {
|
||||
halfSize := size / 2
|
||||
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
|
||||
for start := 0; start < n; start += size {
|
||||
wn := complex(1, 0)
|
||||
for k := range halfSize {
|
||||
t := wn * x[start+k+halfSize]
|
||||
x[start+k+halfSize] = x[start+k] - t
|
||||
x[start+k] = x[start+k] + t
|
||||
wn *= w
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAudioData checks if the data starts with WAV magic bytes.
|
||||
func isAudioData(data []byte) bool {
|
||||
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
|
||||
}
|
||||
103
model/models/gemma4/process_image.go
Normal file
103
model/models/gemma4/process_image.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
patchSize int
|
||||
numChannels int
|
||||
nMerge int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 16))
|
||||
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
|
||||
// Token limits from reference: min=40, max=280 output tokens after pooling.
|
||||
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
|
||||
minTokens := 40
|
||||
maxTokens := 280
|
||||
patchArea := patchSize * patchSize * nMerge * nMerge
|
||||
minPixels := minTokens * patchArea
|
||||
maxPixels := maxTokens * patchArea
|
||||
|
||||
return ImageProcessor{
|
||||
patchSize: patchSize,
|
||||
numChannels: numChannels,
|
||||
nMerge: nMerge,
|
||||
minPixels: minPixels,
|
||||
maxPixels: maxPixels,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
|
||||
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
|
||||
// Returns the float32 pixel data and the actual output dimensions.
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
|
||||
// Compute target size preserving aspect ratio
|
||||
alignSize := p.patchSize * p.nMerge
|
||||
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
|
||||
|
||||
// Resize directly without alpha compositing, matching MLX reference.
|
||||
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
|
||||
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
|
||||
data := p.pack(dst)
|
||||
return data, targetW, targetH, nil
|
||||
}
|
||||
|
||||
// smartResize computes target dimensions that preserve aspect ratio and
|
||||
// align to alignSize boundaries. It scales the image to fill the maximum
|
||||
// patch budget (maxPixels), matching the MLX reference.
|
||||
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
|
||||
totalPx := origW * origH
|
||||
|
||||
var targetW, targetH int
|
||||
if p.maxPixels > 0 && totalPx > 0 {
|
||||
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
|
||||
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
|
||||
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
|
||||
} else {
|
||||
targetH = max(alignSize, (origH/alignSize)*alignSize)
|
||||
targetW = max(alignSize, (origW/alignSize)*alignSize)
|
||||
}
|
||||
|
||||
return targetW, targetH
|
||||
}
|
||||
|
||||
// pack extracts RGB values from an image and normalizes to [-1, 1].
|
||||
// Returns channel-first layout: [R..., G..., B...].
|
||||
func (p *ImageProcessor) pack(img image.Image) []float32 {
|
||||
bounds := img.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
size := w * h
|
||||
|
||||
pixelVals := make([]float32, 3*size)
|
||||
rOff, gOff, bOff := 0, size, 2*size
|
||||
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
|
||||
|
||||
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
|
||||
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
|
||||
func TestTokenizerMatchesHF(t *testing.T) {
|
||||
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
|
||||
}
|
||||
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load model: %v", err)
|
||||
}
|
||||
defer m.Backend().Close()
|
||||
|
||||
tok := m.(tokenizer.Tokenizer)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []int32
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
input: "Hello, world!",
|
||||
expected: []int32{9259, 236764, 1902, 236888},
|
||||
},
|
||||
{
|
||||
name: "special_tokens",
|
||||
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
|
||||
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
|
||||
},
|
||||
{
|
||||
name: "tool_declaration",
|
||||
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
|
||||
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
|
||||
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
|
||||
},
|
||||
{
|
||||
name: "thinking",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
|
||||
},
|
||||
{
|
||||
name: "code",
|
||||
input: "func main() { fmt.Println(\"hello\") }",
|
||||
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "The answer is 42, not 43.5 or -1",
|
||||
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
|
||||
},
|
||||
{
|
||||
name: "mixed_chat_with_tools",
|
||||
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tok.Encode(tt.input, false) // no BOS
|
||||
if err != nil {
|
||||
t.Fatalf("encode error: %v", err)
|
||||
}
|
||||
|
||||
if len(tokens) != len(tt.expected) {
|
||||
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
|
||||
t.Logf("got: %v", tokens)
|
||||
t.Logf("want: %v", tt.expected)
|
||||
return
|
||||
}
|
||||
|
||||
mismatches := 0
|
||||
for i := range tokens {
|
||||
if tokens[i] != tt.expected[i] {
|
||||
mismatches++
|
||||
if mismatches <= 5 {
|
||||
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if mismatches > 5 {
|
||||
t.Errorf("... and %d more mismatches", mismatches-5)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/gemma4"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
|
||||
412
model/parsers/gemma4.go
Normal file
412
model/parsers/gemma4.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type Gemma4ParserState int
|
||||
|
||||
const (
|
||||
Gemma4CollectingContent Gemma4ParserState = iota
|
||||
Gemma4CollectingThinking
|
||||
Gemma4CollectingToolCall
|
||||
)
|
||||
|
||||
const (
|
||||
gemma4ThinkingOpenTag = "<|channel>"
|
||||
gemma4ThinkingCloseTag = "<channel|>"
|
||||
gemma4ToolCallOpenTag = "<|tool_call>"
|
||||
gemma4ToolCallCloseTag = "<tool_call|>"
|
||||
)
|
||||
|
||||
type Gemma4Parser struct {
|
||||
state Gemma4ParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
thinkingEnabled bool // true when both model supports and user requested thinking
|
||||
needsChannelNameStrip bool // true when we just entered thinking and need to strip "thought\n"
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
p.thinkingEnabled = p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
if !p.thinkingEnabled {
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
// When thinking is enabled, start in content mode but we'll switch to
|
||||
// thinking when we see <|channel>. The model typically starts with
|
||||
// <|channel> immediately when thinking is enabled.
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type gemma4Event interface {
|
||||
isGemma4Event()
|
||||
}
|
||||
|
||||
type gemma4EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type gemma4EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type gemma4EventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (gemma4EventThinkingContent) isGemma4Event() {}
|
||||
func (gemma4EventContent) isGemma4Event() {}
|
||||
func (gemma4EventToolCall) isGemma4Event() {}
|
||||
|
||||
func (p *Gemma4Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents(done)
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case gemma4EventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case gemma4EventThinkingContent:
|
||||
if p.thinkingEnabled {
|
||||
thinkingSb.WriteString(event.content)
|
||||
}
|
||||
// When thinking is disabled, silently discard channel content
|
||||
case gemma4EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) parseEvents(done bool) []gemma4Event {
|
||||
var all []gemma4Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []gemma4Event
|
||||
events, keepLooping = p.eat(done)
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// longestOverlap returns the longest overlap between the suffix of bufStr and
|
||||
// a prefix of any of the given tags.
|
||||
func longestOverlap(bufStr string, tags ...string) int {
|
||||
maxOverlap := 0
|
||||
for _, tag := range tags {
|
||||
if o := overlap(bufStr, tag); o > maxOverlap {
|
||||
maxOverlap = o
|
||||
}
|
||||
}
|
||||
return maxOverlap
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) eat(done bool) ([]gemma4Event, bool) {
|
||||
var events []gemma4Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Gemma4CollectingContent:
|
||||
// Check for thinking open tag
|
||||
if idx := strings.Index(bufStr, gemma4ThinkingOpenTag); idx != -1 {
|
||||
contentBefore := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ThinkingOpenTag):]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingThinking
|
||||
p.needsChannelNameStrip = true
|
||||
|
||||
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
|
||||
events = append(events, gemma4EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for tool call open tag
|
||||
if idx := strings.Index(bufStr, gemma4ToolCallOpenTag); idx != -1 {
|
||||
contentBefore := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ToolCallOpenTag):]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingToolCall
|
||||
|
||||
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
|
||||
events = append(events, gemma4EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap
|
||||
if !done {
|
||||
if overlapLen := longestOverlap(bufStr, gemma4ThinkingOpenTag, gemma4ToolCallOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
|
||||
// No tags found, emit all content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, gemma4EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case Gemma4CollectingThinking:
|
||||
// Strip channel name (e.g., "thought\n") after <|channel>.
|
||||
// Gemma 4 format: <|channel>thought\n...content...<channel|>
|
||||
// In streaming mode, "thought" and "\n" may arrive in separate chunks.
|
||||
if p.needsChannelNameStrip {
|
||||
if strings.HasPrefix(bufStr, "thought\n") {
|
||||
bufStr = bufStr[len("thought\n"):]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
p.needsChannelNameStrip = false
|
||||
} else if !done && (bufStr == "thought" || strings.HasPrefix("thought\n", bufStr)) {
|
||||
// Partial match — wait for more data.
|
||||
return events, false
|
||||
} else {
|
||||
// No match (different channel name or no newline) — don't strip.
|
||||
p.needsChannelNameStrip = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, gemma4ThinkingCloseTag) {
|
||||
split := strings.SplitN(bufStr, gemma4ThinkingCloseTag, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingContent
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial close tag
|
||||
if !done {
|
||||
if overlapLen := overlap(bufStr, gemma4ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
|
||||
// No close tag, emit thinking content (hold back trailing whitespace)
|
||||
if !done {
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: bufStr})
|
||||
}
|
||||
}
|
||||
return events, false
|
||||
|
||||
case Gemma4CollectingToolCall:
|
||||
if idx := strings.Index(bufStr, gemma4ToolCallCloseTag); idx != -1 {
|
||||
toolCallContent := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ToolCallCloseTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingContent
|
||||
|
||||
if toolCall, err := parseGemma4ToolCall(toolCallContent); err == nil {
|
||||
events = append(events, gemma4EventToolCall{toolCall: toolCall})
|
||||
} else {
|
||||
slog.Warn("gemma4 tool call parsing failed", "error", err, "content", toolCallContent)
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// If done, flush any accumulated tool call content even without closing tag.
|
||||
// The model may hit a stop token before emitting <tool_call|>.
|
||||
if done && len(bufStr) > 0 {
|
||||
p.buffer.Reset()
|
||||
p.state = Gemma4CollectingContent
|
||||
if toolCall, err := parseGemma4ToolCall(bufStr); err == nil {
|
||||
events = append(events, gemma4EventToolCall{toolCall: toolCall})
|
||||
} else {
|
||||
slog.Warn("gemma4 tool call flush on done failed", "error", err, "content", bufStr)
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// Wait for closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseGemma4ToolCall parses a tool call in Gemma 4 format:
|
||||
// call:NAME{key:value,key:value}
|
||||
func parseGemma4ToolCall(content string) (api.ToolCall, error) {
|
||||
// Expected format: call:NAME{args}
|
||||
if !strings.HasPrefix(content, "call:") {
|
||||
return api.ToolCall{}, errors.New("expected 'call:' prefix")
|
||||
}
|
||||
content = content[len("call:"):]
|
||||
|
||||
// Find the opening brace for args
|
||||
braceIdx := strings.Index(content, "{")
|
||||
if braceIdx == -1 {
|
||||
return api.ToolCall{}, errors.New("expected '{' in tool call")
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(content[:braceIdx])
|
||||
argsStr := content[braceIdx:]
|
||||
|
||||
// Convert Gemma 4 argument format to JSON
|
||||
jsonStr := gemma4ArgsToJSON(argsStr)
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(jsonStr), &args); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// gemma4ArgsToJSON converts Gemma 4's custom argument format to valid JSON.
|
||||
func gemma4ArgsToJSON(s string) string {
|
||||
s = strings.ReplaceAll(s, `<|"|>`, `"`)
|
||||
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(s) + 32)
|
||||
inString := false
|
||||
hex := "0123456789abcdef"
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
buf.WriteByte('"')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
switch ch {
|
||||
case '\\':
|
||||
buf.WriteString(`\\`)
|
||||
case '\n':
|
||||
buf.WriteString(`\n`)
|
||||
case '\r':
|
||||
buf.WriteString(`\r`)
|
||||
case '\t':
|
||||
buf.WriteString(`\t`)
|
||||
case '\b':
|
||||
buf.WriteString(`\b`)
|
||||
case '\f':
|
||||
buf.WriteString(`\f`)
|
||||
default:
|
||||
if ch < 0x20 {
|
||||
buf.WriteString(`\u00`)
|
||||
buf.WriteByte(hex[ch>>4])
|
||||
buf.WriteByte(hex[ch&0x0f])
|
||||
} else {
|
||||
buf.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString && isIdentStart(ch) {
|
||||
j := i + 1
|
||||
for j < len(s) && isIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
word := s[i:j]
|
||||
if j < len(s) && s[j] == ':' {
|
||||
buf.WriteByte('"')
|
||||
buf.WriteString(word)
|
||||
buf.WriteByte('"')
|
||||
} else {
|
||||
buf.WriteString(word)
|
||||
}
|
||||
i = j
|
||||
} else {
|
||||
buf.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
463
model/parsers/gemma4_test.go
Normal file
463
model/parsers/gemma4_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGemma4Parser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedToolCalls []api.ToolCall
|
||||
thinkingEnabled bool
|
||||
lastMessage *api.Message
|
||||
}{
|
||||
{
|
||||
name: "simple_content",
|
||||
input: "This is a simple response.",
|
||||
expectedContent: "This is a simple response.",
|
||||
},
|
||||
{
|
||||
name: "thinking_then_content",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expectedContent: "The answer is 42.",
|
||||
expectedThinking: "Let me think about this...",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "multiple_thinking_blocks",
|
||||
input: "<|channel>first thought<channel|><|channel>second thought<channel|>Final answer.",
|
||||
expectedContent: "Final answer.",
|
||||
expectedThinking: "first thoughtsecond thought",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_only_no_content",
|
||||
input: "<|channel>just thinking<channel|>",
|
||||
expectedContent: "",
|
||||
expectedThinking: "just thinking",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_multiple_args",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
"units": "metric",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_number_arg",
|
||||
input: `<|tool_call>call:set_temp{value:42}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"value": 42.0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_boolean_arg",
|
||||
input: `<|tool_call>call:toggle{enabled:true}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"enabled": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_nested_object",
|
||||
input: `<|tool_call>call:process{config:{enabled:true,name:<|"|>test<|"|>}}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"config": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_array",
|
||||
input: `<|tool_call>call:process{items:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"a", "b"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_multiline_string_arg",
|
||||
input: `<|tool_call>call:bash{command:<|"|>date
|
||||
<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "date\n",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|><|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_then_tool_call",
|
||||
input: "<|channel>thought\nI need to check the weather<channel|><|tool_call>call:get_weather{location:<|\"|>Paris<|\"|>}<tool_call|>",
|
||||
expectedThinking: "I need to check the weather",
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "content_then_tool_call",
|
||||
input: `Let me check that for you.<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
|
||||
expectedContent: "Let me check that for you.",
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled_channel_tags_as_content",
|
||||
input: "<|channel>this is not thinking<channel|>actual content",
|
||||
expectedContent: "actual content",
|
||||
thinkingEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "prefill_content_only",
|
||||
input: "Continuing content.",
|
||||
expectedContent: "Continuing content.",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Previous content",
|
||||
},
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, tt.lastMessage, &api.ThinkValue{Value: tt.thinkingEnabled})
|
||||
|
||||
content, thinking, toolCalls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_Streaming(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
chunks := []string{
|
||||
"<|channel>thought",
|
||||
"\nLet me think",
|
||||
"...<channel|>The answer",
|
||||
" is 42.",
|
||||
}
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalContent.String() != "The answer is 42." {
|
||||
t.Errorf("expected content %q, got %q", "The answer is 42.", finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != "Let me think..." {
|
||||
t.Errorf("expected thinking %q, got %q", "Let me think...", finalThinking.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_StreamingToolCall(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
chunks := []string{
|
||||
`<|tool_call>call:get_`,
|
||||
`weather{location:<|"|>Par`,
|
||||
`is<|"|>}<tool_call|>`,
|
||||
}
|
||||
|
||||
var finalContent strings.Builder
|
||||
var finalToolCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, _, toolCalls, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalToolCalls = append(finalToolCalls, toolCalls...)
|
||||
}
|
||||
|
||||
if finalContent.String() != "" {
|
||||
t.Errorf("expected no content, got %q", finalContent.String())
|
||||
}
|
||||
|
||||
expectedToolCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_StreamingSplitThinkingTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
}{
|
||||
{
|
||||
name: "split_channel_open_tag",
|
||||
chunks: []string{
|
||||
"<|chan",
|
||||
"nel>thinking here<channel|>content",
|
||||
},
|
||||
expectedContent: "content",
|
||||
expectedThinking: "thinking here",
|
||||
},
|
||||
{
|
||||
name: "split_channel_close_tag",
|
||||
chunks: []string{
|
||||
"<|channel>thinking here<chan",
|
||||
"nel|>content",
|
||||
},
|
||||
expectedContent: "content",
|
||||
expectedThinking: "thinking here",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalContent.String() != tt.expectedContent {
|
||||
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
|
||||
}
|
||||
if finalThinking.String() != tt.expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4ArgsToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple_string",
|
||||
input: `{location:<|"|>Paris<|"|>}`,
|
||||
expected: `{"location":"Paris"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple_args",
|
||||
input: `{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}`,
|
||||
expected: `{"location":"Paris","units":"metric"}`,
|
||||
},
|
||||
{
|
||||
name: "number_value",
|
||||
input: `{value:42}`,
|
||||
expected: `{"value":42}`,
|
||||
},
|
||||
{
|
||||
name: "boolean_value",
|
||||
input: `{enabled:true}`,
|
||||
expected: `{"enabled":true}`,
|
||||
},
|
||||
{
|
||||
name: "nested_object",
|
||||
input: `{config:{enabled:true,name:<|"|>test<|"|>}}`,
|
||||
expected: `{"config":{"enabled":true,"name":"test"}}`,
|
||||
},
|
||||
{
|
||||
name: "array_value",
|
||||
input: `{items:[<|"|>a<|"|>,<|"|>b<|"|>]}`,
|
||||
expected: `{"items":["a","b"]}`,
|
||||
},
|
||||
{
|
||||
name: "empty_object",
|
||||
input: `{}`,
|
||||
expected: `{}`,
|
||||
},
|
||||
{
|
||||
name: "mixed_types",
|
||||
input: `{name:<|"|>test<|"|>,count:5,active:true,tags:[<|"|>a<|"|>]}`,
|
||||
expected: `{"name":"test","count":5,"active":true,"tags":["a"]}`,
|
||||
},
|
||||
{
|
||||
name: "null_value",
|
||||
input: `{value:null}`,
|
||||
expected: `{"value":null}`,
|
||||
},
|
||||
{
|
||||
name: "multiline_string_value",
|
||||
input: `{command:<|"|>date
|
||||
<|"|>}`,
|
||||
expected: `{"command":"date\n"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := gemma4ArgsToJSON(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_HasToolSupport(t *testing.T) {
|
||||
parser := &Gemma4Parser{}
|
||||
if !parser.HasToolSupport() {
|
||||
t.Error("Gemma4Parser should support tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_HasThinkingSupport(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
if !parser.HasThinkingSupport() {
|
||||
t.Error("Gemma4Parser with thinking support should report it")
|
||||
}
|
||||
|
||||
parser2 := &Gemma4Parser{hasThinkingSupport: false}
|
||||
if parser2.HasThinkingSupport() {
|
||||
t.Error("Gemma4Parser without thinking support should not report it")
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,10 @@ func ParserForName(name string) Parser {
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "gemma4":
|
||||
return &Gemma4Parser{hasThinkingSupport: true}
|
||||
case "gemma4-no-thinking":
|
||||
return &Gemma4Parser{hasThinkingSupport: false}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrParser{}
|
||||
case "lfm2":
|
||||
|
||||
379
model/renderers/gemma4.go
Normal file
379
model/renderers/gemma4.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Gemma4Renderer renders prompts using Gemma 4's chat format with
|
||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||
// <|tool_call>/<|tool_response> tags for function calling.
|
||||
type Gemma4Renderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
const (
|
||||
g4Q = `<|"|>` // Gemma 4 string delimiter
|
||||
)
|
||||
|
||||
func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
imageOffset := 0
|
||||
|
||||
// BOS token — Gemma 4 models have add_bos_token=false in their tokenizer
|
||||
// config, so the tokenizer does not auto-prepend BOS. We must emit it
|
||||
// explicitly in the rendered prompt, matching the HF chat template.
|
||||
sb.WriteString("<bos>")
|
||||
// Extract system message if present.
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
hasSystemRole := len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer")
|
||||
if hasSystemRole {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
// Emit system turn if there's a system/developer role, tools, or thinking.
|
||||
hasThink := thinkValue != nil && thinkValue.Bool()
|
||||
if hasSystemRole || len(tools) > 0 || hasThink {
|
||||
sb.WriteString("<|turn>system\n")
|
||||
if hasThink {
|
||||
sb.WriteString("<|think|>")
|
||||
}
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(strings.TrimSpace(systemMessage))
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
|
||||
// Each message gets its own <|turn>role\n ... <turn|>\n block,
|
||||
// matching the HF chat template exactly.
|
||||
for _, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|turn>user\n")
|
||||
r.renderContent(&sb, message, &imageOffset, true)
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|turn>model\n")
|
||||
// Tool calls come before content (matching HF template order)
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
// Strip thinking from history (matching HF strip_thinking macro)
|
||||
if message.Content != "" {
|
||||
sb.WriteString(stripThinking(message.Content))
|
||||
}
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|turn>tool\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
default:
|
||||
sb.WriteString("<|turn>" + message.Role + "\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Generation prompt
|
||||
sb.WriteString("<|turn>model\n")
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// stripThinking removes <|channel>...<channel|> thinking blocks from content,
|
||||
// matching the HF chat template's strip_thinking macro.
|
||||
func stripThinking(text string) string {
|
||||
var result strings.Builder
|
||||
for {
|
||||
start := strings.Index(text, "<|channel>")
|
||||
if start == -1 {
|
||||
result.WriteString(text)
|
||||
break
|
||||
}
|
||||
result.WriteString(text[:start])
|
||||
end := strings.Index(text[start:], "<channel|>")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
text = text[start+end+len("<channel|>"):]
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// renderContent writes a message's content, interleaving [img-N] tags for images.
|
||||
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
|
||||
// template's | trim filter applied to non-model content).
|
||||
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
|
||||
if len(msg.Images) > 0 && r.useImgTags {
|
||||
for range msg.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
|
||||
*imageOffset++
|
||||
}
|
||||
}
|
||||
content := msg.Content
|
||||
if trim {
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
fn := tool.Function
|
||||
|
||||
sb.WriteString("<|tool>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:" + g4Q + fn.Description + g4Q)
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + req + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(fn.Parameters.Type) + g4Q)
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||
keys := make([]string, 0, props.Len())
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop, _ := props.Get(name)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{")
|
||||
|
||||
hasContent := false
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("description:" + g4Q + prop.Description + g4Q)
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
typeName := strings.ToUpper(prop.Type[0])
|
||||
|
||||
switch typeName {
|
||||
case "STRING":
|
||||
if len(prop.Enum) > 0 {
|
||||
if hasContent {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("enum:[")
|
||||
for j, e := range prop.Enum {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + fmt.Sprintf("%v", e) + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
case "OBJECT":
|
||||
// Render nested properties recursively.
|
||||
// Note: the leading comma is hardcoded (matching the template),
|
||||
// and this does NOT set hasContent — the comma before type:
|
||||
// depends only on whether description was present.
|
||||
sb.WriteString(",properties:{")
|
||||
if prop.Properties != nil && prop.Properties.Len() > 0 {
|
||||
r.writeProperties(sb, prop.Properties)
|
||||
}
|
||||
sb.WriteString("}")
|
||||
if len(prop.Required) > 0 {
|
||||
sb.WriteString(",required:[")
|
||||
for j, req := range prop.Required {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + req + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
}
|
||||
|
||||
case "ARRAY":
|
||||
// Render items specification.
|
||||
// Same as OBJECT: leading comma is hardcoded, does NOT set hasContent.
|
||||
if items, ok := prop.Items.(map[string]any); ok && len(items) > 0 {
|
||||
sb.WriteString(",items:{")
|
||||
r.writeItemsSpec(sb, items)
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + typeName + g4Q)
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
// writeItemsSpec renders the items specification for array-type properties,
|
||||
// matching the Jinja2 template's dictsort iteration over items.
|
||||
func (r *Gemma4Renderer) writeItemsSpec(sb *strings.Builder, items map[string]any) {
|
||||
keys := make([]string, 0, len(items))
|
||||
for k := range items {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := items[key]
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
switch key {
|
||||
case "type":
|
||||
if s, ok := value.(string); ok {
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(s) + g4Q)
|
||||
}
|
||||
default:
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<|tool_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool_call|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return g4Q + v + g4Q
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
1274
model/renderers/gemma4_reference_test.go
Normal file
1274
model/renderers/gemma4_reference_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,6 +81,8 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "gemma4":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
|
||||
263
model/renderers/testdata/gemma4_chat_template.jinja2
vendored
Normal file
263
model/renderers/testdata/gemma4_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,263 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -522,6 +522,20 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
case "input_audio":
|
||||
audioMap, ok := data["input_audio"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format")
|
||||
}
|
||||
b64Data, ok := audioMap["data"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format: missing data")
|
||||
}
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(b64Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input_audio base64 data: %w", err)
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{audioBytes}})
|
||||
default:
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
@@ -824,6 +838,45 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionResponse is the response format for /v1/audio/transcriptions.
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TranscriptionRequest holds parsed fields from the multipart form.
|
||||
type TranscriptionRequest struct {
|
||||
Model string
|
||||
AudioData []byte
|
||||
ResponseFormat string // "json", "text", "verbose_json"
|
||||
Language string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// FromTranscriptionRequest converts a transcription request into a ChatRequest
|
||||
// by wrapping the audio with a system prompt for transcription.
|
||||
func FromTranscriptionRequest(r TranscriptionRequest) (*api.ChatRequest, error) {
|
||||
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
|
||||
if r.Language != "" {
|
||||
systemPrompt += " The audio is in " + r.Language + "."
|
||||
}
|
||||
if r.Prompt != "" {
|
||||
systemPrompt += " Context: " + r.Prompt
|
||||
}
|
||||
|
||||
stream := true
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{r.AudioData}},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
@@ -390,3 +390,48 @@ func (t *Terminal) Read() (rune, error) {
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// SetRawModeOn enables raw terminal mode and keeps it on.
|
||||
// Call SetRawModeOff to restore when done.
|
||||
func (i *Instance) SetRawModeOn() error {
|
||||
if i.Terminal.rawmode {
|
||||
return nil
|
||||
}
|
||||
fd := os.Stdin.Fd()
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.Terminal.rawmode = true
|
||||
i.Terminal.termios = termios
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRawModeOff restores the terminal to its previous mode.
|
||||
func (i *Instance) SetRawModeOff() {
|
||||
if !i.Terminal.rawmode {
|
||||
return
|
||||
}
|
||||
fd := os.Stdin.Fd()
|
||||
//nolint:errcheck
|
||||
UnsetRawMode(fd, i.Terminal.termios)
|
||||
i.Terminal.rawmode = false
|
||||
}
|
||||
|
||||
// ReadRaw reads a single rune. If the terminal is already in raw mode
|
||||
// (via SetRawModeOn), it reads directly. Otherwise it temporarily enters
|
||||
// raw mode for the read.
|
||||
func (i *Instance) ReadRaw() (rune, error) {
|
||||
if !i.Terminal.rawmode {
|
||||
fd := os.Stdin.Fd()
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
//nolint:errcheck
|
||||
UnsetRawMode(fd, termios)
|
||||
}()
|
||||
}
|
||||
return i.Terminal.Read()
|
||||
}
|
||||
|
||||
@@ -1258,6 +1258,12 @@ func (s *Server) loadModel() {
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
}
|
||||
|
||||
if postLoader, ok := s.model.(model.PostLoader); ok {
|
||||
if err := postLoader.PostLoad(); err != nil {
|
||||
panic(fmt.Errorf("failed to finalize model initialization: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) {
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
@@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
if len(config.Capabilities) == 0 {
|
||||
config.Capabilities = baseConfig.Capabilities
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
@@ -509,6 +512,24 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||
|
||||
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
|
||||
// TODO: abstract this into a registry/lookup table when multiple models
|
||||
// need architecture-based renderer/parser/stop defaults.
|
||||
if config.Renderer == "" || config.Parser == "" {
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
r.Parameters = make(map[string]any)
|
||||
}
|
||||
r.Parameters["stop"] = []string{"<turn|>"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
layers = append(layers, layer.Layer)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ var (
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityAudio = errors.New("audio")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
@@ -93,14 +94,26 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
if f.KeyValue("audio.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityAudio)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
} else if len(m.Config.Capabilities) > 0 {
|
||||
}
|
||||
|
||||
// Also include capabilities from the model config (e.g. vision capability
|
||||
// set during creation for MLX/safetensors models).
|
||||
if len(m.Config.Capabilities) > 0 {
|
||||
for _, c := range m.Config.Capabilities {
|
||||
capabilities = append(capabilities, model.Capability(c))
|
||||
cap := model.Capability(c)
|
||||
if !slices.Contains(capabilities, cap) {
|
||||
capabilities = append(capabilities, cap)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
||||
if len(capabilities) == 0 {
|
||||
slog.Warn("unknown capabilities for model", "model", m.Name)
|
||||
}
|
||||
|
||||
@@ -141,6 +154,14 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||
return c == model.CapabilityVision || c == "audio"
|
||||
})
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@@ -156,6 +177,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityTools: errCapabilityTools,
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityAudio: errCapabilityAudio,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
|
||||
@@ -153,7 +153,16 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
|
||||
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
|
||||
// expert alphabetically, so dense increments the counter and expert uses counter-1.
|
||||
var iLayer int
|
||||
if strings.Contains(name, "_exps") {
|
||||
iLayer = max(0, qs.iFfnDown-1)
|
||||
} else {
|
||||
iLayer = qs.iFfnDown
|
||||
qs.iFfnDown++
|
||||
}
|
||||
n_layer := qs.nFfnDown
|
||||
if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
@@ -162,7 +171,6 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
qs.iFfnDown++
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
@@ -255,8 +263,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision encoder tensors (named with "v." prefix)
|
||||
// don't quantize vision or audio encoder tensors
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
quantize = quantize && !strings.HasPrefix(name, "a.")
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
|
||||
@@ -1718,6 +1718,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
// OpenAI-compatible audio endpoint
|
||||
r.POST("/v1/audio/transcriptions", middleware.TranscriptionMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
CapabilityAudio = Capability("audio")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
Reference in New Issue
Block a user