mirror of
https://github.com/ollama/ollama.git
synced 2026-04-25 02:06:11 +02:00
Compare commits
23 Commits
jessegross
...
v0.20.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d846fdbc0 | ||
|
|
f3536a356e | ||
|
|
c89280fb0c | ||
|
|
eb5434d7fb | ||
|
|
2b949a11d9 | ||
|
|
6b013002fc | ||
|
|
5e622289c5 | ||
|
|
9c8bcecdb2 | ||
|
|
1cbe7950d6 | ||
|
|
95073400fc | ||
|
|
c29932c631 | ||
|
|
1ce101c9a0 | ||
|
|
5a7928ed38 | ||
|
|
7fdc051091 | ||
|
|
5bad871241 | ||
|
|
82437d620a | ||
|
|
570c53859d | ||
|
|
ebd70f73b7 | ||
|
|
eb5df80733 | ||
|
|
356c0b8e34 | ||
|
|
ea3c6a3cbe | ||
|
|
f6b69f3f28 | ||
|
|
e38b606e8b |
@@ -436,6 +436,7 @@ type ToolProperty struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
|
||||
216
cmd/audio.go
Normal file
216
cmd/audio.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
audioSampleRate = 16000
|
||||
audioChannels = 1
|
||||
audioFrameSize = 1024 // samples per callback
|
||||
)
|
||||
|
||||
// AudioRecorder captures audio from the default microphone.
|
||||
// Platform-specific capture is provided by audioStream (audio_darwin.go, etc.).
|
||||
type AudioRecorder struct {
|
||||
stream audioStream
|
||||
mu sync.Mutex
|
||||
samples []float32
|
||||
started time.Time
|
||||
MaxChunkSeconds int // hard split limit in seconds; 0 means use default
|
||||
}
|
||||
|
||||
// audioStream is the platform-specific audio capture interface.
|
||||
type audioStream interface {
|
||||
// Start begins capturing. Samples are delivered via the callback.
|
||||
Start(callback func(samples []float32)) error
|
||||
// Stop ends capturing and releases resources.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// NewAudioRecorder creates a recorder ready to capture from the default mic.
|
||||
func NewAudioRecorder() (*AudioRecorder, error) {
|
||||
stream, err := newAudioStream(audioSampleRate, audioChannels, audioFrameSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AudioRecorder{stream: stream}, nil
|
||||
}
|
||||
|
||||
// Start begins capturing audio from the microphone.
|
||||
func (r *AudioRecorder) Start() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.samples = make([]float32, 0, audioSampleRate*30) // preallocate ~30s
|
||||
r.started = time.Now()
|
||||
|
||||
return r.stream.Start(func(samples []float32) {
|
||||
r.mu.Lock()
|
||||
r.samples = append(r.samples, samples...)
|
||||
r.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop ends the recording and returns the duration.
|
||||
func (r *AudioRecorder) Stop() (time.Duration, error) {
|
||||
r.mu.Lock()
|
||||
dur := time.Since(r.started)
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.stream != nil {
|
||||
r.stream.Stop()
|
||||
}
|
||||
|
||||
return dur, nil
|
||||
}
|
||||
|
||||
// Duration returns how long the current recording has been running.
|
||||
func (r *AudioRecorder) Duration() time.Duration {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.started.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return time.Since(r.started)
|
||||
}
|
||||
|
||||
// Chunking constants for live transcription.
|
||||
const (
|
||||
chunkTargetSamples = 8 * audioSampleRate // 8s — start yielding when silence found
|
||||
chunkMinSamples = 5 * audioSampleRate // start scanning for silence at 5s
|
||||
defaultMaxAudioSeconds = 28 // default hard split (just under typical 30s model cap)
|
||||
silenceWindow = 800 // 50ms RMS window
|
||||
)
|
||||
|
||||
func (r *AudioRecorder) maxChunk() int {
|
||||
if r.MaxChunkSeconds > 0 {
|
||||
return r.MaxChunkSeconds * audioSampleRate
|
||||
}
|
||||
return defaultMaxAudioSeconds * audioSampleRate
|
||||
}
|
||||
|
||||
// TakeChunk checks if there are enough accumulated samples to yield a chunk.
|
||||
// If so, it splits at the best silence boundary, removes the consumed samples
|
||||
// from the buffer, and returns the chunk as WAV bytes. Returns nil if not enough
|
||||
// audio has accumulated yet.
|
||||
func (r *AudioRecorder) TakeChunk() []byte {
|
||||
r.mu.Lock()
|
||||
n := len(r.samples)
|
||||
if n < chunkMinSamples {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
maxSamples := r.maxChunk()
|
||||
|
||||
if n < chunkTargetSamples && n < maxSamples {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
limit := n
|
||||
if limit > maxSamples {
|
||||
limit = maxSamples
|
||||
}
|
||||
|
||||
splitAt := limit
|
||||
bestEnergy := float64(1e30)
|
||||
|
||||
scanStart := limit - silenceWindow
|
||||
scanEnd := chunkMinSamples
|
||||
for pos := scanStart; pos >= scanEnd; pos -= silenceWindow / 2 {
|
||||
end := pos + silenceWindow
|
||||
if end > n {
|
||||
end = n
|
||||
}
|
||||
var sumSq float64
|
||||
for _, s := range r.samples[pos:end] {
|
||||
sumSq += float64(s) * float64(s)
|
||||
}
|
||||
rms := sumSq / float64(end-pos)
|
||||
if rms < bestEnergy {
|
||||
bestEnergy = rms
|
||||
splitAt = pos + silenceWindow/2
|
||||
}
|
||||
}
|
||||
|
||||
chunk := make([]float32, splitAt)
|
||||
copy(chunk, r.samples[:splitAt])
|
||||
remaining := make([]float32, n-splitAt)
|
||||
copy(remaining, r.samples[splitAt:])
|
||||
r.samples = remaining
|
||||
r.mu.Unlock()
|
||||
|
||||
return encodeWAV(chunk, audioSampleRate, audioChannels)
|
||||
}
|
||||
|
||||
// FlushWAV returns any remaining samples as WAV, clearing the buffer.
|
||||
func (r *AudioRecorder) FlushWAV() []byte {
|
||||
r.mu.Lock()
|
||||
samples := r.samples
|
||||
r.samples = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
if len(samples) == 0 {
|
||||
return nil
|
||||
}
|
||||
return encodeWAV(samples, audioSampleRate, audioChannels)
|
||||
}
|
||||
|
||||
// WAV encodes the captured samples as a WAV file in memory.
|
||||
func (r *AudioRecorder) WAV() ([]byte, error) {
|
||||
r.mu.Lock()
|
||||
samples := make([]float32, len(r.samples))
|
||||
copy(samples, r.samples)
|
||||
r.mu.Unlock()
|
||||
|
||||
if len(samples) == 0 {
|
||||
return nil, errNoAudio
|
||||
}
|
||||
|
||||
return encodeWAV(samples, audioSampleRate, audioChannels), nil
|
||||
}
|
||||
|
||||
// encodeWAV produces a 16-bit PCM WAV file from float32 samples.
|
||||
func encodeWAV(samples []float32, sampleRate, channels int) []byte {
|
||||
numSamples := len(samples)
|
||||
bitsPerSample := 16
|
||||
byteRate := sampleRate * channels * bitsPerSample / 8
|
||||
blockAlign := channels * bitsPerSample / 8
|
||||
dataSize := numSamples * blockAlign
|
||||
|
||||
buf := make([]byte, 44+dataSize)
|
||||
|
||||
copy(buf[0:4], "RIFF")
|
||||
binary.LittleEndian.PutUint32(buf[4:8], uint32(36+dataSize))
|
||||
copy(buf[8:12], "WAVE")
|
||||
|
||||
copy(buf[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(buf[16:20], 16)
|
||||
binary.LittleEndian.PutUint16(buf[20:22], 1)
|
||||
binary.LittleEndian.PutUint16(buf[22:24], uint16(channels))
|
||||
binary.LittleEndian.PutUint32(buf[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(buf[28:32], uint32(byteRate))
|
||||
binary.LittleEndian.PutUint16(buf[32:34], uint16(blockAlign))
|
||||
binary.LittleEndian.PutUint16(buf[34:36], uint16(bitsPerSample))
|
||||
|
||||
copy(buf[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(buf[40:44], uint32(dataSize))
|
||||
|
||||
offset := 44
|
||||
for _, s := range samples {
|
||||
if s > 1.0 {
|
||||
s = 1.0
|
||||
} else if s < -1.0 {
|
||||
s = -1.0
|
||||
}
|
||||
val := int16(s * 32767)
|
||||
binary.LittleEndian.PutUint16(buf[offset:offset+2], uint16(val))
|
||||
offset += 2
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
180
cmd/audio_darwin.go
Normal file
180
cmd/audio_darwin.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package cmd
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework CoreAudio -framework AudioToolbox
|
||||
#include <AudioToolbox/AudioQueue.h>
|
||||
#include <string.h>
|
||||
|
||||
// Callback context passed to AudioQueue.
|
||||
typedef struct {
|
||||
int ready; // set to 1 when a buffer is filled
|
||||
} AQContext;
|
||||
|
||||
// C callback — re-enqueues the buffer so recording continues.
|
||||
// Not static — must be visible to the linker for Go's function pointer.
|
||||
void aqInputCallback(
|
||||
void *inUserData,
|
||||
AudioQueueRef inAQ,
|
||||
AudioQueueBufferRef inBuffer,
|
||||
const AudioTimeStamp *inStartTime,
|
||||
UInt32 inNumberPacketDescriptions,
|
||||
const AudioStreamPacketDescription *inPacketDescs)
|
||||
{
|
||||
// Re-enqueue the buffer immediately so recording continues.
|
||||
AudioQueueEnqueueBuffer(inAQ, inBuffer, 0, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
const numAQBuffers = 3
|
||||
|
||||
type coreAudioStream struct {
|
||||
queue C.AudioQueueRef
|
||||
buffers [numAQBuffers]C.AudioQueueBufferRef
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
pollDone chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
return &coreAudioStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
// Set up audio format: 16-bit signed integer PCM, mono, 16kHz.
|
||||
var format C.AudioStreamBasicDescription
|
||||
format.mSampleRate = C.Float64(s.sampleRate)
|
||||
format.mFormatID = C.kAudioFormatLinearPCM
|
||||
format.mFormatFlags = C.kLinearPCMFormatFlagIsSignedInteger | C.kLinearPCMFormatFlagIsPacked
|
||||
format.mBitsPerChannel = 16
|
||||
format.mChannelsPerFrame = C.UInt32(s.channels)
|
||||
format.mBytesPerFrame = 2 * C.UInt32(s.channels)
|
||||
format.mFramesPerPacket = 1
|
||||
format.mBytesPerPacket = format.mBytesPerFrame
|
||||
|
||||
// Create the audio queue.
|
||||
var status C.OSStatus
|
||||
status = C.AudioQueueNewInput(
|
||||
&format,
|
||||
C.AudioQueueInputCallback(C.aqInputCallback),
|
||||
nil, // user data
|
||||
C.CFRunLoopRef(0), // NULL run loop — use internal thread
|
||||
C.CFStringRef(0), // NULL run loop mode
|
||||
0, // flags
|
||||
&s.queue,
|
||||
)
|
||||
if status != 0 {
|
||||
return fmt.Errorf("AudioQueueNewInput failed: %d", status)
|
||||
}
|
||||
|
||||
// Allocate and enqueue buffers.
|
||||
bufferBytes := C.UInt32(s.frameSize * int(format.mBytesPerFrame))
|
||||
for i := range s.buffers {
|
||||
status = C.AudioQueueAllocateBuffer(s.queue, bufferBytes, &s.buffers[i])
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueAllocateBuffer failed: %d", status)
|
||||
}
|
||||
status = C.AudioQueueEnqueueBuffer(s.queue, s.buffers[i], 0, nil)
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueEnqueueBuffer failed: %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
// Start recording.
|
||||
status = C.AudioQueueStart(s.queue, nil)
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueStart failed: %d", status)
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.pollDone = make(chan struct{})
|
||||
|
||||
// Poll buffers for data. AudioQueue re-enqueues in the C callback,
|
||||
// so we read the data out periodically.
|
||||
go s.pollLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) pollLoop() {
|
||||
defer close(s.pollDone)
|
||||
|
||||
// Read at roughly frameSize intervals.
|
||||
interval := time.Duration(float64(s.frameSize) / float64(s.sampleRate) * float64(time.Second))
|
||||
if interval < 10*time.Millisecond {
|
||||
interval = 10 * time.Millisecond
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Read available data from each buffer.
|
||||
for i := range s.buffers {
|
||||
buf := s.buffers[i]
|
||||
if buf.mAudioDataByteSize > 0 {
|
||||
numSamples := int(buf.mAudioDataByteSize) / 2 // 16-bit samples
|
||||
if numSamples > 0 {
|
||||
raw := (*[1 << 28]int16)(buf.mAudioData)[:numSamples:numSamples]
|
||||
floats := make([]float32, numSamples)
|
||||
for j, v := range raw {
|
||||
floats[j] = float32(v) / float32(math.MaxInt16)
|
||||
}
|
||||
s.callback(floats)
|
||||
}
|
||||
buf.mAudioDataByteSize = 0
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
queue := s.queue
|
||||
s.mu.Unlock()
|
||||
|
||||
if queue != nil {
|
||||
C.AudioQueueStop(queue, C.true)
|
||||
C.AudioQueueDispose(queue, C.true)
|
||||
}
|
||||
|
||||
if s.pollDone != nil {
|
||||
<-s.pollDone
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
275
cmd/audio_linux.go
Normal file
275
cmd/audio_linux.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package cmd
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -ldl
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Function pointer types for ALSA functions loaded at runtime.
|
||||
typedef int (*pcm_open_fn)(void**, const char*, int, int);
|
||||
typedef int (*pcm_simple_fn)(void*);
|
||||
typedef long (*pcm_readi_fn)(void*, void*, unsigned long);
|
||||
typedef int (*hw_malloc_fn)(void**);
|
||||
typedef void (*hw_free_fn)(void*);
|
||||
typedef int (*hw_any_fn)(void*, void*);
|
||||
typedef int (*hw_set_int_fn)(void*, void*, int);
|
||||
typedef int (*hw_set_uint_fn)(void*, void*, unsigned int);
|
||||
typedef int (*hw_set_rate_fn)(void*, void*, unsigned int*, int*);
|
||||
typedef int (*hw_set_period_fn)(void*, void*, unsigned long*, int*);
|
||||
typedef int (*hw_apply_fn)(void*, void*);
|
||||
typedef const char* (*strerror_fn)(int);
|
||||
|
||||
// Trampoline functions — call dynamically loaded ALSA symbols.
|
||||
static int alsa_pcm_open(void* fn, void** h, const char* name, int stream, int mode) {
|
||||
return ((pcm_open_fn)fn)(h, name, stream, mode);
|
||||
}
|
||||
static int alsa_pcm_close(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static int alsa_pcm_prepare(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static int alsa_pcm_drop(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static long alsa_pcm_readi(void* fn, void* h, void* buf, unsigned long frames) {
|
||||
return ((pcm_readi_fn)fn)(h, buf, frames);
|
||||
}
|
||||
static int alsa_hw_malloc(void* fn, void** p) { return ((hw_malloc_fn)fn)(p); }
|
||||
static void alsa_hw_free(void* fn, void* p) { ((hw_free_fn)fn)(p); }
|
||||
static int alsa_hw_any(void* fn, void* h, void* p) { return ((hw_any_fn)fn)(h, p); }
|
||||
static int alsa_hw_set_access(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_format(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_channels(void* fn, void* h, void* p, unsigned int v) { return ((hw_set_uint_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_rate(void* fn, void* h, void* p, unsigned int* v, int* d) { return ((hw_set_rate_fn)fn)(h, p, v, d); }
|
||||
static int alsa_hw_set_period(void* fn, void* h, void* p, unsigned long* v, int* d) { return ((hw_set_period_fn)fn)(h, p, v, d); }
|
||||
static int alsa_hw_apply(void* fn, void* h, void* p) { return ((hw_apply_fn)fn)(h, p); }
|
||||
static const char* alsa_strerror(void* fn, int e) { return ((strerror_fn)fn)(e); }
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
const (
|
||||
sndPCMStreamCapture = 1
|
||||
sndPCMAccessRWInterleaved = 3
|
||||
sndPCMFormatS16LE = 2
|
||||
)
|
||||
|
||||
var (
|
||||
alsaLoadErr error
|
||||
alsaOnce sync.Once
|
||||
alsa alsaFuncs
|
||||
)
|
||||
|
||||
type alsaFuncs struct {
|
||||
pcmOpen, pcmClose, pcmPrepare, pcmDrop, pcmReadi unsafe.Pointer
|
||||
hwMalloc, hwFree, hwAny unsafe.Pointer
|
||||
hwSetAccess, hwSetFormat, hwSetChannels unsafe.Pointer
|
||||
hwSetRate, hwSetPeriod, hwApply unsafe.Pointer
|
||||
strerror unsafe.Pointer
|
||||
}
|
||||
|
||||
func loadALSA() {
|
||||
var lib unsafe.Pointer
|
||||
for _, name := range []string{"libasound.so.2", "libasound.so"} {
|
||||
cName := C.CString(name)
|
||||
lib = C.dlopen(cName, C.RTLD_NOW)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
if lib != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if lib == nil {
|
||||
alsaLoadErr = fmt.Errorf("audio capture unavailable: libasound.so not found")
|
||||
return
|
||||
}
|
||||
|
||||
sym := func(name string) unsafe.Pointer {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
return C.dlsym(lib, cName)
|
||||
}
|
||||
|
||||
syms := []struct {
|
||||
ptr *unsafe.Pointer
|
||||
name string
|
||||
}{
|
||||
{&alsa.pcmOpen, "snd_pcm_open"},
|
||||
{&alsa.pcmClose, "snd_pcm_close"},
|
||||
{&alsa.pcmPrepare, "snd_pcm_prepare"},
|
||||
{&alsa.pcmDrop, "snd_pcm_drop"},
|
||||
{&alsa.pcmReadi, "snd_pcm_readi"},
|
||||
{&alsa.hwMalloc, "snd_pcm_hw_params_malloc"},
|
||||
{&alsa.hwFree, "snd_pcm_hw_params_free"},
|
||||
{&alsa.hwAny, "snd_pcm_hw_params_any"},
|
||||
{&alsa.hwSetAccess, "snd_pcm_hw_params_set_access"},
|
||||
{&alsa.hwSetFormat, "snd_pcm_hw_params_set_format"},
|
||||
{&alsa.hwSetChannels, "snd_pcm_hw_params_set_channels"},
|
||||
{&alsa.hwSetRate, "snd_pcm_hw_params_set_rate_near"},
|
||||
{&alsa.hwSetPeriod, "snd_pcm_hw_params_set_period_size_near"},
|
||||
{&alsa.hwApply, "snd_pcm_hw_params"},
|
||||
{&alsa.strerror, "snd_strerror"},
|
||||
}
|
||||
|
||||
for _, s := range syms {
|
||||
*s.ptr = sym(s.name)
|
||||
if *s.ptr == nil {
|
||||
alsaLoadErr = fmt.Errorf("audio capture unavailable: missing %s in libasound", s.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func alsaError(code C.int) string {
|
||||
if alsa.strerror == nil {
|
||||
return fmt.Sprintf("error %d", code)
|
||||
}
|
||||
return C.GoString(C.alsa_strerror(alsa.strerror, code))
|
||||
}
|
||||
|
||||
type alsaStream struct {
|
||||
handle unsafe.Pointer
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
done chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
alsaOnce.Do(loadALSA)
|
||||
if alsaLoadErr != nil {
|
||||
return nil, alsaLoadErr
|
||||
}
|
||||
return &alsaStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *alsaStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
cName := C.CString("default")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
rc := C.alsa_pcm_open(alsa.pcmOpen, (*unsafe.Pointer)(unsafe.Pointer(&s.handle)), cName, C.int(sndPCMStreamCapture), 0)
|
||||
if rc < 0 {
|
||||
return fmt.Errorf("snd_pcm_open: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
var hwParams unsafe.Pointer
|
||||
C.alsa_hw_malloc(alsa.hwMalloc, (*unsafe.Pointer)(unsafe.Pointer(&hwParams)))
|
||||
defer C.alsa_hw_free(alsa.hwFree, hwParams)
|
||||
|
||||
C.alsa_hw_any(alsa.hwAny, s.handle, hwParams)
|
||||
|
||||
if rc = C.alsa_hw_set_access(alsa.hwSetAccess, s.handle, hwParams, C.int(sndPCMAccessRWInterleaved)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set access: %s", alsaError(rc))
|
||||
}
|
||||
if rc = C.alsa_hw_set_format(alsa.hwSetFormat, s.handle, hwParams, C.int(sndPCMFormatS16LE)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set format: %s", alsaError(rc))
|
||||
}
|
||||
if rc = C.alsa_hw_set_channels(alsa.hwSetChannels, s.handle, hwParams, C.uint(s.channels)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set channels: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
rate := C.uint(s.sampleRate)
|
||||
if rc = C.alsa_hw_set_rate(alsa.hwSetRate, s.handle, hwParams, &rate, nil); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set rate: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
periodSize := C.ulong(s.frameSize)
|
||||
if rc = C.alsa_hw_set_period(alsa.hwSetPeriod, s.handle, hwParams, &periodSize, nil); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set period: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
if rc = C.alsa_hw_apply(alsa.hwApply, s.handle, hwParams); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("apply hw params: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
if rc = C.alsa_pcm_prepare(alsa.pcmPrepare, s.handle); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("prepare: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.done = make(chan struct{})
|
||||
go s.captureLoop(int(periodSize))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *alsaStream) captureLoop(periodSize int) {
|
||||
defer close(s.done)
|
||||
|
||||
buf := make([]int16, periodSize*s.channels)
|
||||
|
||||
for {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
handle := s.handle
|
||||
s.mu.Unlock()
|
||||
|
||||
frames := C.alsa_pcm_readi(alsa.pcmReadi, handle, unsafe.Pointer(&buf[0]), C.ulong(periodSize))
|
||||
if frames < 0 {
|
||||
C.alsa_pcm_prepare(alsa.pcmPrepare, handle)
|
||||
continue
|
||||
}
|
||||
if frames == 0 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
numSamples := int(frames) * s.channels
|
||||
floats := make([]float32, numSamples)
|
||||
for i := 0; i < numSamples; i++ {
|
||||
floats[i] = float32(buf[i]) / float32(math.MaxInt16)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.callback != nil {
|
||||
s.callback(floats)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *alsaStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
handle := s.handle
|
||||
s.handle = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.done != nil {
|
||||
<-s.done
|
||||
}
|
||||
|
||||
if handle != nil {
|
||||
C.alsa_pcm_drop(alsa.pcmDrop, handle)
|
||||
C.alsa_pcm_close(alsa.pcmClose, handle)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
288
cmd/audio_windows.go
Normal file
288
cmd/audio_windows.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
// WASAPI COM GUIDs
|
||||
var (
|
||||
iidIMMDeviceEnumerator = guid{0xA95664D2, 0x9614, 0x4F35, [8]byte{0xA7, 0x46, 0xDE, 0x8D, 0xB6, 0x36, 0x17, 0xE6}}
|
||||
clsidMMDeviceEnumerator = guid{0xBCDE0395, 0xE52F, 0x467C, [8]byte{0x8E, 0x3D, 0xC4, 0x57, 0x92, 0x91, 0x69, 0x2E}}
|
||||
iidIAudioClient = guid{0x1CB9AD4C, 0xDBFA, 0x4C32, [8]byte{0xB1, 0x78, 0xC2, 0xF5, 0x68, 0xA7, 0x03, 0xB2}}
|
||||
iidIAudioCaptureClient = guid{0xC8ADBD64, 0xE71E, 0x48A0, [8]byte{0xA4, 0xDE, 0x18, 0x5C, 0x39, 0x5C, 0xD3, 0x17}}
|
||||
)
|
||||
|
||||
type guid struct {
|
||||
Data1 uint32
|
||||
Data2 uint16
|
||||
Data3 uint16
|
||||
Data4 [8]byte
|
||||
}
|
||||
|
||||
// WAVEFORMATEX structure
|
||||
type waveFormatEx struct {
|
||||
FormatTag uint16
|
||||
Channels uint16
|
||||
SamplesPerSec uint32
|
||||
AvgBytesPerSec uint32
|
||||
BlockAlign uint16
|
||||
BitsPerSample uint16
|
||||
CbSize uint16
|
||||
}
|
||||
|
||||
const (
|
||||
wavePCM = 1
|
||||
eCapture = 1
|
||||
eConsole = 0
|
||||
audclntSharemode = 0 // AUDCLNT_SHAREMODE_SHARED
|
||||
audclntStreamflagsEventcallback = 0x00040000
|
||||
|
||||
coinitMultithreaded = 0x0
|
||||
clsctxAll = 0x17
|
||||
|
||||
reftimesPerSec = 10000000 // 100ns units per second
|
||||
reftimesPerMillis = 10000
|
||||
)
|
||||
|
||||
var (
|
||||
ole32 = syscall.NewLazyDLL("ole32.dll")
|
||||
coInit = ole32.NewProc("CoInitializeEx")
|
||||
coCreate = ole32.NewProc("CoCreateInstance")
|
||||
)
|
||||
|
||||
type wasapiStream struct {
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
done chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
|
||||
// COM interfaces (stored as uintptr for syscall)
|
||||
enumerator uintptr
|
||||
device uintptr
|
||||
client uintptr
|
||||
capture uintptr
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
return &wasapiStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *wasapiStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
// Initialize COM
|
||||
hr, _, _ := coInit.Call(0, uintptr(coinitMultithreaded))
|
||||
// S_OK or S_FALSE (already initialized) are both fine
|
||||
if hr != 0 && hr != 1 {
|
||||
return fmt.Errorf("CoInitializeEx failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Create device enumerator
|
||||
hr, _, _ = coCreate.Call(
|
||||
uintptr(unsafe.Pointer(&clsidMMDeviceEnumerator)),
|
||||
0,
|
||||
uintptr(clsctxAll),
|
||||
uintptr(unsafe.Pointer(&iidIMMDeviceEnumerator)),
|
||||
uintptr(unsafe.Pointer(&s.enumerator)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("CoCreateInstance(MMDeviceEnumerator) failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Get default capture device
|
||||
// IMMDeviceEnumerator::GetDefaultAudioEndpoint is vtable index 4
|
||||
hr = comCall(s.enumerator, 4, uintptr(eCapture), uintptr(eConsole), uintptr(unsafe.Pointer(&s.device)))
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("GetDefaultAudioEndpoint failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Activate IAudioClient
|
||||
// IMMDevice::Activate is vtable index 3
|
||||
hr = comCall(s.device, 3,
|
||||
uintptr(unsafe.Pointer(&iidIAudioClient)),
|
||||
uintptr(clsctxAll),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&s.client)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IMMDevice::Activate failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Set up format: 16-bit PCM mono 16kHz
|
||||
format := waveFormatEx{
|
||||
FormatTag: wavePCM,
|
||||
Channels: uint16(s.channels),
|
||||
SamplesPerSec: uint32(s.sampleRate),
|
||||
BitsPerSample: 16,
|
||||
BlockAlign: uint16(2 * s.channels),
|
||||
AvgBytesPerSec: uint32(s.sampleRate * 2 * s.channels),
|
||||
CbSize: 0,
|
||||
}
|
||||
|
||||
// Initialize audio client
|
||||
// IAudioClient::Initialize is vtable index 3
|
||||
bufferDuration := int64(reftimesPerSec) // 1 second buffer
|
||||
hr = comCall(s.client, 3,
|
||||
uintptr(audclntSharemode),
|
||||
0, // stream flags
|
||||
uintptr(bufferDuration),
|
||||
0, // periodicity (0 = use default)
|
||||
uintptr(unsafe.Pointer(&format)),
|
||||
0, // audio session GUID (NULL = default)
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::Initialize failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Get capture client
|
||||
// IAudioClient::GetService is vtable index 8
|
||||
hr = comCall(s.client, 8,
|
||||
uintptr(unsafe.Pointer(&iidIAudioCaptureClient)),
|
||||
uintptr(unsafe.Pointer(&s.capture)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::GetService failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Start capture
|
||||
// IAudioClient::Start is vtable index 6
|
||||
hr = comCall(s.client, 6)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::Start failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.done = make(chan struct{})
|
||||
go s.captureLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wasapiStream) captureLoop() {
|
||||
defer close(s.done)
|
||||
|
||||
ticker := time.NewTicker(20 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Read available packets
|
||||
for {
|
||||
var data uintptr
|
||||
var numFrames uint32
|
||||
var flags uint32
|
||||
|
||||
// IAudioCaptureClient::GetBuffer is vtable index 3
|
||||
hr := comCall(s.capture, 3,
|
||||
uintptr(unsafe.Pointer(&data)),
|
||||
uintptr(unsafe.Pointer(&numFrames)),
|
||||
uintptr(unsafe.Pointer(&flags)),
|
||||
0, // device position (not needed)
|
||||
0, // QPC position (not needed)
|
||||
)
|
||||
if hr != 0 || numFrames == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Convert int16 samples to float32
|
||||
samples := make([]float32, numFrames*uint32(s.channels))
|
||||
raw := (*[1 << 28]int16)(unsafe.Pointer(data))[:len(samples):len(samples)]
|
||||
for i, v := range raw {
|
||||
samples[i] = float32(v) / float32(math.MaxInt16)
|
||||
}
|
||||
|
||||
s.callback(samples)
|
||||
|
||||
// IAudioCaptureClient::ReleaseBuffer is vtable index 4
|
||||
comCall(s.capture, 4, uintptr(numFrames))
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *wasapiStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.done != nil {
|
||||
<-s.done
|
||||
}
|
||||
|
||||
// IAudioClient::Stop is vtable index 7
|
||||
if s.client != 0 {
|
||||
comCall(s.client, 7)
|
||||
}
|
||||
|
||||
// Release COM interfaces (IUnknown::Release is vtable index 2)
|
||||
if s.capture != 0 {
|
||||
comCall(s.capture, 2)
|
||||
}
|
||||
if s.client != 0 {
|
||||
comCall(s.client, 2)
|
||||
}
|
||||
if s.device != 0 {
|
||||
comCall(s.device, 2)
|
||||
}
|
||||
if s.enumerator != 0 {
|
||||
comCall(s.enumerator, 2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// comCall invokes a COM method by vtable index.
|
||||
func comCall(obj uintptr, method uintptr, args ...uintptr) uintptr {
|
||||
vtable := *(*uintptr)(unsafe.Pointer(obj))
|
||||
fn := *(*uintptr)(unsafe.Pointer(vtable + method*unsafe.Sizeof(uintptr(0))))
|
||||
|
||||
// Build syscall args: first arg is always 'this' pointer
|
||||
callArgs := make([]uintptr, 1+len(args))
|
||||
callArgs[0] = obj
|
||||
copy(callArgs[1:], args)
|
||||
|
||||
var hr uintptr
|
||||
switch len(callArgs) {
|
||||
case 1:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0])
|
||||
case 2:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1])
|
||||
case 3:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2])
|
||||
case 4:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3])
|
||||
case 5:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4])
|
||||
case 6:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5])
|
||||
case 7:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5], callArgs[6])
|
||||
default:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs...)
|
||||
}
|
||||
return hr
|
||||
}
|
||||
@@ -32,6 +32,7 @@ type flagOptions struct {
|
||||
verbose *bool
|
||||
warmup *int
|
||||
promptTokens *int
|
||||
numCtx *int
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
@@ -48,6 +49,7 @@ type ModelInfo struct {
|
||||
Family string
|
||||
SizeBytes int64
|
||||
VRAMBytes int64
|
||||
NumCtx int64
|
||||
}
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
@@ -64,9 +66,12 @@ var promptWordList = []string{
|
||||
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||
}
|
||||
|
||||
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
|
||||
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
|
||||
var tokensPerWord = 1.3
|
||||
|
||||
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
// ~1.3 tokens per word heuristic
|
||||
targetWords := int(float64(targetTokens) / 1.3)
|
||||
targetWords := int(float64(targetTokens) / tokensPerWord)
|
||||
if targetWords < 1 {
|
||||
targetWords = 1
|
||||
}
|
||||
@@ -81,6 +86,17 @@ func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
|
||||
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
|
||||
if actualTokens <= 0 || wordCount <= 0 {
|
||||
return
|
||||
}
|
||||
tokensPerWord = float64(actualTokens) / float64(wordCount)
|
||||
newWords := int(float64(targetTokens) / tokensPerWord)
|
||||
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
|
||||
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
|
||||
}
|
||||
|
||||
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
@@ -90,6 +106,9 @@ func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData,
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
options["num_ctx"] = *fOpt.numCtx
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
@@ -146,7 +165,6 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
// Try prefix match (model names may include :latest or tags)
|
||||
for _, m := range resp.Models {
|
||||
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return m.Size, m.SizeVRAM
|
||||
@@ -155,6 +173,19 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return int64(m.ContextLength)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
@@ -177,8 +208,12 @@ func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||
if info.SizeBytes > 0 {
|
||||
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
|
||||
info.Name, params, quant, family, memStr)
|
||||
ctxStr := ""
|
||||
if info.NumCtx > 0 {
|
||||
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
|
||||
info.Name, params, quant, family, memStr, ctxStr)
|
||||
}
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
@@ -276,21 +311,38 @@ func BenchmarkModel(fOpt flagOptions) error {
|
||||
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
|
||||
var warmupMetrics *api.Metrics
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
warmupMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||
} else if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
} else {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
}
|
||||
// Calibrate prompt token count on last warmup run
|
||||
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch memory usage once after warmup (model is loaded and stable)
|
||||
// Fetch memory/context info once after warmup (model is loaded and stable)
|
||||
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
info.NumCtx = int64(*fOpt.numCtx)
|
||||
} else {
|
||||
info.NumCtx = fetchContextLength(memCtx, client, model)
|
||||
}
|
||||
memCancel()
|
||||
|
||||
outputModelInfo(out, *fOpt.format, info)
|
||||
@@ -479,6 +531,7 @@ func main() {
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
|
||||
@@ -695,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
@@ -1494,6 +1495,9 @@ type displayResponseState struct {
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if termWidth == 0 {
|
||||
termWidth = 80
|
||||
}
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
|
||||
@@ -47,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -592,7 +592,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav|mp4|webm|mov|avi|mkv|m4v)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
@@ -608,10 +608,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
ext := strings.ToLower(filepath.Ext(nfp))
|
||||
switch ext {
|
||||
case ".wav":
|
||||
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
}
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
@@ -685,9 +691,9 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
return nil, fmt.Errorf("invalid file type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
@@ -695,8 +701,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB
|
||||
if info.Size() > maxSize {
|
||||
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
func TestExtractFileDataWAV(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "sample.wav")
|
||||
data := make([]byte, 600)
|
||||
copy(data[:44], []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||
0x01, 0x00, // PCM
|
||||
0x01, 0x00, // mono
|
||||
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||
0x02, 0x00, // block align
|
||||
0x10, 0x00, // 16-bit
|
||||
'd', 'a', 't', 'a',
|
||||
0x34, 0x02, 0x00, 0x00, // data size
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test audio: %v", err)
|
||||
}
|
||||
|
||||
input := "before " + fp + " after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, "before after", cleaned)
|
||||
}
|
||||
|
||||
@@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
|
||||
conv = &gemma4Model{Architecture: p.Architectures[0]}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
||||
556
convert/convert_gemma4.go
Normal file
556
convert/convert_gemma4.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma4Model struct {
|
||||
gemmaModel
|
||||
Architecture string
|
||||
TextModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
GlobalHeadDim uint32 `json:"global_head_dim"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
|
||||
EnableMoeBlock bool `json:"enable_moe_block"`
|
||||
NumExperts *uint32 `json:"num_experts"`
|
||||
TopKExperts *uint32 `json:"top_k_experts"`
|
||||
ExpertIntermediateSize *uint32 `json:"moe_intermediate_size"`
|
||||
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
AttentionKEqV bool `json:"attention_k_eq_v"`
|
||||
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
|
||||
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
|
||||
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
|
||||
RopeParameters map[string]*struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
|
||||
LayerNormEps float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
AudioModel *struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
OutputProjDims uint32 `json:"output_proj_dims"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
ConvKernelSize uint32 `json:"conv_kernel_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
} `json:"audio_config"`
|
||||
}
|
||||
|
||||
func (p *gemma4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma4"
|
||||
kv["tokenizer.ggml.model"] = "llama"
|
||||
kv["tokenizer.ggml.pre"] = "gemma4"
|
||||
|
||||
tc := p.TextModel
|
||||
|
||||
kv["gemma4.block_count"] = tc.NumHiddenLayers
|
||||
kv["gemma4.embedding_length"] = tc.HiddenSize
|
||||
|
||||
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
|
||||
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
|
||||
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
|
||||
ffnWidths := make([]int32, tc.NumHiddenLayers)
|
||||
for i := range ffnWidths {
|
||||
if i >= firstShared {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize * 2)
|
||||
} else {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize)
|
||||
}
|
||||
}
|
||||
kv["gemma4.feed_forward_length"] = ffnWidths
|
||||
} else {
|
||||
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
|
||||
}
|
||||
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
|
||||
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
|
||||
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
|
||||
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
|
||||
kvHeads := make([]int32, len(tc.LayerTypes))
|
||||
for i, lt := range tc.LayerTypes {
|
||||
if lt == "sliding_attention" {
|
||||
kvHeads[i] = int32(tc.NumKeyValueHeads)
|
||||
} else {
|
||||
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
|
||||
}
|
||||
}
|
||||
kv["gemma4.attention.head_count_kv"] = kvHeads
|
||||
} else {
|
||||
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
|
||||
}
|
||||
// key_length = global head dim, key_length_swa = local (SWA) head dim
|
||||
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
|
||||
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
|
||||
|
||||
// Sliding window pattern from layer_types
|
||||
if len(tc.LayerTypes) > 0 {
|
||||
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, lt := range tc.LayerTypes {
|
||||
if !yield(lt == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
|
||||
|
||||
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
|
||||
}
|
||||
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
|
||||
}
|
||||
|
||||
if tc.FinalLogitSoftcapping > 0 {
|
||||
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
|
||||
}
|
||||
|
||||
// MoE
|
||||
if tc.EnableMoeBlock && tc.NumExperts != nil {
|
||||
kv["gemma4.expert_count"] = *tc.NumExperts
|
||||
if tc.TopKExperts != nil {
|
||||
kv["gemma4.expert_used_count"] = *tc.TopKExperts
|
||||
}
|
||||
if tc.ExpertIntermediateSize != nil {
|
||||
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
|
||||
}
|
||||
}
|
||||
|
||||
// PLE — always emit, even when 0
|
||||
pleSize := uint32(0)
|
||||
if tc.HiddenSizePerLayerInput != nil {
|
||||
pleSize = *tc.HiddenSizePerLayerInput
|
||||
}
|
||||
kv["gemma4.embedding_length_per_layer_input"] = pleSize
|
||||
|
||||
// Vision model KV metadata
|
||||
vc := p.VisionModel
|
||||
if vc.NumHiddenLayers > 0 {
|
||||
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
|
||||
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
|
||||
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
|
||||
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
|
||||
kv["gemma4.vision.patch_size"] = vc.PatchSize
|
||||
numCh := vc.NumChannels
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
kv["gemma4.vision.num_channels"] = numCh
|
||||
nMerge := vc.PoolingKernelSize
|
||||
if nMerge == 0 {
|
||||
nMerge = 3
|
||||
}
|
||||
kv["gemma4.vision.projector.scale_factor"] = nMerge
|
||||
eps := vc.LayerNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
|
||||
}
|
||||
|
||||
// Audio model KV metadata
|
||||
if p.AudioModel != nil && p.AudioModel.NumHiddenLayers > 0 {
|
||||
ac := p.AudioModel
|
||||
kv["gemma4.audio.block_count"] = ac.NumHiddenLayers
|
||||
kv["gemma4.audio.embedding_length"] = ac.HiddenSize
|
||||
kv["gemma4.audio.feed_forward_length"] = ac.HiddenSize * 4
|
||||
kv["gemma4.audio.attention.head_count"] = ac.NumAttentionHeads
|
||||
eps := ac.RMSNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.audio.attention.layer_norm_epsilon"] = eps
|
||||
if ac.ConvKernelSize > 0 {
|
||||
kv["gemma4.audio.conv_kernel_size"] = ac.ConvKernelSize
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// First pass: collect vision clamp scalar values into a packed tensor.
|
||||
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
|
||||
// Then 4 values for the projector (mm.input_projection).
|
||||
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
clampMap := make(map[string]float32)
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
for _, sfx := range clampSuffixes {
|
||||
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
data := buf.Bytes()
|
||||
if len(data) >= 4 {
|
||||
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
|
||||
if strings.Contains(name, "embedding_post_projection_norm") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Vision tensor renaming: match published mmproj GGUF names
|
||||
if strings.HasPrefix(name, "v.blk.") {
|
||||
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
|
||||
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
|
||||
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
|
||||
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
|
||||
}
|
||||
|
||||
// per_dim_scale: apply softplus to weight data and add .weight suffix.
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.HasSuffix(name, "per_dim_scale") {
|
||||
name = name + ".weight"
|
||||
t.SetRepacker(softplusRepacker)
|
||||
}
|
||||
|
||||
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
|
||||
t.SetRepacker(squeezeMiddleDim)
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
|
||||
// Convert scalar tensors (input_min/max, output_min/max) to 1D
|
||||
if len(shape) == 0 {
|
||||
shape = []uint64{1}
|
||||
}
|
||||
|
||||
// Depthwise conv1d shape: safetensors [C, 1, K] → GGUF ne[K, C].
|
||||
// Shape array here maps to GGUF ne[] directly, but safetensors reader
|
||||
// stores shape in PyTorch order [C, 1, K] which the GGUF writer inverts.
|
||||
// Published GGUF has ne[0]=K, ne[1]=C → shape array must be [K, C].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
|
||||
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
|
||||
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
|
||||
// (transposeExperts was incorrectly swapping dims — removed)
|
||||
|
||||
// Audio conv weights are forced to F32 via tensorBase.Kind() in reader.go
|
||||
// (im2col doesn't support BF16). No kindOverride needed — the Kind() method
|
||||
// controls both the GGUF header type AND the WriteTo data encoding path.
|
||||
var kindOverride *uint32
|
||||
|
||||
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
|
||||
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
|
||||
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
|
||||
nEmbd := shape[0]
|
||||
patchSize := uint64(p.VisionModel.PatchSize)
|
||||
if patchSize == 0 {
|
||||
patchSize = 16
|
||||
}
|
||||
numCh := uint64(p.VisionModel.NumChannels)
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
t.SetRepacker(p.reshapePatchEmbed)
|
||||
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
|
||||
f16Kind := uint32(1) // tensorKindFP16
|
||||
kindOverride = &f16Kind
|
||||
}
|
||||
|
||||
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
|
||||
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
|
||||
|
||||
kind := t.Kind()
|
||||
if kindOverride != nil {
|
||||
kind = *kindOverride
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
|
||||
// This matches the published GGUF format: one global tensor shared by all layers.
|
||||
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
|
||||
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
|
||||
tc := p.TextModel
|
||||
if tc.GlobalHeadDim > 0 {
|
||||
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
|
||||
|
||||
// Compute number of rotated pairs for global layers
|
||||
partialRotaryFactor := float32(0.25) // default
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
|
||||
partialRotaryFactor = *rp.PartialRotaryFactor
|
||||
}
|
||||
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
|
||||
|
||||
freqs := make(ropeFactor, globalFreqsSize)
|
||||
for j := range freqs {
|
||||
if j < nRotFull {
|
||||
freqs[j] = 1.0
|
||||
} else {
|
||||
freqs[j] = 1e30 // effectively disable rotation
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(len(freqs))},
|
||||
WriterTo: freqs,
|
||||
})
|
||||
}
|
||||
|
||||
// Emit packed vision clamp data as a single F32 tensor.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
|
||||
if len(clampMap) > 0 {
|
||||
numLayers := int(p.VisionModel.NumHiddenLayers)
|
||||
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
|
||||
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
|
||||
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
|
||||
clampData := make([]float32, totalFloats)
|
||||
|
||||
for layer := range numLayers {
|
||||
for li, ln := range linearNames {
|
||||
for si, sfx := range suffixes {
|
||||
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
|
||||
idx := (layer*len(linearNames)+li)*4 + si
|
||||
clampData[idx] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Projector clamp values
|
||||
projIdx := numLayers * len(linearNames) * 4
|
||||
for si, sfx := range suffixes {
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
|
||||
clampData[projIdx+si] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, clampData)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "v.clamp_data",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(totalFloats)},
|
||||
WriterTo: &buf,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
|
||||
// to GGUF layout [n_embd, channels, patch_size, patch_size].
|
||||
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
nEmbd := int(shape[0])
|
||||
ksqC := int(shape[1])
|
||||
nChannels := 3
|
||||
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
|
||||
|
||||
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
|
||||
// Need: [n_embd, channels, patch_size, patch_size]
|
||||
result := make([]float32, len(data))
|
||||
for e := range nEmbd {
|
||||
for c := range nChannels {
|
||||
for h := range patchSize {
|
||||
for w := range patchSize {
|
||||
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
|
||||
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
|
||||
result[dstIdx] = data[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
shape[0] = uint64(nEmbd)
|
||||
shape[1] = uint64(nChannels * patchSize * patchSize)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
|
||||
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
|
||||
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
result := make([]float32, len(data))
|
||||
for i, x := range data {
|
||||
result[i] = float32(math.Log(1 + math.Exp(float64(x))))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// squeezeMiddleDim squeezes the middle dimension from [C, 1, K] → [C, K] for depthwise conv1d weights.
|
||||
// Data layout stays the same since the middle dim is 1 — just a shape change.
|
||||
func squeezeMiddleDim(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Replacements() []string {
|
||||
return []string{
|
||||
// ClippableLinear wraps nn.Linear — strip .linear. from weight path
|
||||
".linear.weight", ".weight",
|
||||
".linear.bias", ".bias",
|
||||
|
||||
// Audio SSCP (Sub-Sample Convolution Projection)
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
|
||||
|
||||
// Audio conformer blocks
|
||||
"model.audio_tower.conformer", "a.blk",
|
||||
|
||||
// Audio conformer attention
|
||||
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
|
||||
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
|
||||
"attention.attn.per_dim_scale", "per_dim_scale",
|
||||
"attention.attn.q_proj", "attn_q",
|
||||
"attention.attn.k_proj", "attn_k",
|
||||
"attention.attn.v_proj", "attn_v",
|
||||
"attention.pre_attn_norm", "ln1",
|
||||
"attention.post_norm", "ln2",
|
||||
"attention.post", "attn_out",
|
||||
|
||||
// Audio conformer feedforward
|
||||
"ffw_layer_start.pre_layer_norm", "ffn_norm",
|
||||
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
|
||||
"ffw_layer_start.ffw_layer_1", "ffn_up",
|
||||
"ffw_layer_start.ffw_layer_2", "ffn_down",
|
||||
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
|
||||
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
|
||||
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
|
||||
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
|
||||
|
||||
// Audio conformer lightweight conv1d
|
||||
"lconv1d.depthwise_conv1d", "conv_dw",
|
||||
"lconv1d.pre_layer_norm", "conv_norm",
|
||||
"lconv1d.conv_norm", "norm_conv",
|
||||
"lconv1d.linear_start", "conv_pw1",
|
||||
"lconv1d.linear_end", "conv_pw2",
|
||||
|
||||
// Audio block final norm
|
||||
"norm_out", "layer_pre_norm",
|
||||
|
||||
// Audio embedder and output projection
|
||||
"model.embed_audio.embedding_projection", "mm.a.input_projection",
|
||||
"model.audio_tower.output_proj", "mm.a.fc",
|
||||
|
||||
// Vision encoder
|
||||
"model.vision_tower.encoder.layers", "v.blk",
|
||||
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
|
||||
"model.vision_tower.std_bias", "v.std_bias",
|
||||
"model.vision_tower.std_scale", "v.std_scale",
|
||||
|
||||
// Vision multimodal projector
|
||||
"model.embed_vision.embedding_projection", "mm.input_projection",
|
||||
|
||||
// Text model
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
// Shared attention replacements (work for both text and vision tensors)
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
// Post norms
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm_1", "post_ffw_norm_1",
|
||||
"post_feedforward_layernorm_2", "post_ffw_norm_2",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
|
||||
// PLE
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
|
||||
// MoE
|
||||
"router.proj", "ffn_gate_inp",
|
||||
"router.scale", "ffn_gate_inp.scale",
|
||||
"router.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"router.per_expert_scale", "ffn_down_exps.scale",
|
||||
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"experts.down_proj.weight", "ffn_down_exps.weight",
|
||||
"experts.down_proj", "ffn_down_exps.weight",
|
||||
"moe.gate_proj", "ffn_gate_exps.weight",
|
||||
"moe.up_proj", "ffn_up_exps.weight",
|
||||
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"moe.down_proj", "ffn_down_exps.weight",
|
||||
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"moe.per_expert_scale", "ffn_down_exps.scale",
|
||||
|
||||
// Layer scalar
|
||||
"layer_scalar", "layer_output_scale.weight",
|
||||
}
|
||||
}
|
||||
263
convert/convert_gemma4_test.go
Normal file
263
convert/convert_gemma4_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4AudioReplacements(t *testing.T) {
|
||||
p := gemma4Model{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
// SSCP convolution blocks
|
||||
{
|
||||
"sscp conv0 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv.weight",
|
||||
"a.conv1d.0.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv0 norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm.weight",
|
||||
"a.conv1d.0.norm.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv1 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv.weight",
|
||||
"a.conv1d.1.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj weight",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.weight",
|
||||
"a.pre_encode.out.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj bias",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.bias",
|
||||
"a.pre_encode.out.bias",
|
||||
},
|
||||
|
||||
// Conformer attention
|
||||
{
|
||||
"attn q weight",
|
||||
"model.audio_tower.conformer.0.attention.attn.q_proj.linear.weight",
|
||||
"a.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"attn k weight",
|
||||
"model.audio_tower.conformer.5.attention.attn.k_proj.linear.weight",
|
||||
"a.blk.5.attn_k.weight",
|
||||
},
|
||||
{
|
||||
"attn v clamp input_min",
|
||||
"model.audio_tower.conformer.0.attention.attn.v_proj.input_min",
|
||||
"a.blk.0.attn_v.input_min",
|
||||
},
|
||||
{
|
||||
"attn out weight (ClippableLinear)",
|
||||
"model.audio_tower.conformer.0.attention.post.linear.weight",
|
||||
"a.blk.0.attn_out.weight",
|
||||
},
|
||||
{
|
||||
"attn out clamp output_max",
|
||||
"model.audio_tower.conformer.0.attention.post.output_max",
|
||||
"a.blk.0.attn_out.output_max",
|
||||
},
|
||||
{
|
||||
"attn pre norm",
|
||||
"model.audio_tower.conformer.0.attention.pre_attn_norm.weight",
|
||||
"a.blk.0.ln1.weight",
|
||||
},
|
||||
{
|
||||
"attn post norm",
|
||||
"model.audio_tower.conformer.0.attention.post_norm.weight",
|
||||
"a.blk.0.ln2.weight",
|
||||
},
|
||||
{
|
||||
"linear pos",
|
||||
"model.audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight",
|
||||
"a.blk.0.linear_pos.weight",
|
||||
},
|
||||
{
|
||||
"per dim scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_scale",
|
||||
"a.blk.0.per_dim_scale",
|
||||
},
|
||||
{
|
||||
"per dim key scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_key_scale",
|
||||
"a.blk.0.per_dim_k_scale",
|
||||
},
|
||||
|
||||
// Conformer feedforward start
|
||||
{
|
||||
"ffn up weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up.weight",
|
||||
},
|
||||
{
|
||||
"ffn down weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm.weight",
|
||||
},
|
||||
|
||||
// Conformer feedforward end
|
||||
{
|
||||
"ffn up 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn down 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm_1.weight",
|
||||
},
|
||||
|
||||
// Conformer lightweight conv1d
|
||||
{
|
||||
"conv dw weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.depthwise_conv1d.weight",
|
||||
"a.blk.0.conv_dw.weight",
|
||||
},
|
||||
{
|
||||
"conv norm (pre_layer_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.pre_layer_norm.weight",
|
||||
"a.blk.0.conv_norm.weight",
|
||||
},
|
||||
{
|
||||
"norm conv (conv_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.conv_norm.weight",
|
||||
"a.blk.0.norm_conv.weight",
|
||||
},
|
||||
{
|
||||
"conv pw1 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_start.linear.weight",
|
||||
"a.blk.0.conv_pw1.weight",
|
||||
},
|
||||
{
|
||||
"conv pw2 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_end.linear.weight",
|
||||
"a.blk.0.conv_pw2.weight",
|
||||
},
|
||||
|
||||
// Audio embedder
|
||||
{
|
||||
"audio embedder projection weight",
|
||||
"model.embed_audio.embedding_projection.linear.weight",
|
||||
"mm.a.input_projection.weight",
|
||||
},
|
||||
{
|
||||
"audio embedder projection bias",
|
||||
"model.embed_audio.embedding_projection.linear.bias",
|
||||
"mm.a.input_projection.bias",
|
||||
},
|
||||
|
||||
// Audio output projection
|
||||
{
|
||||
"audio output proj weight",
|
||||
"model.audio_tower.output_proj.weight",
|
||||
"mm.a.fc.weight",
|
||||
},
|
||||
{
|
||||
"audio output proj bias",
|
||||
"model.audio_tower.output_proj.bias",
|
||||
"mm.a.fc.bias",
|
||||
},
|
||||
|
||||
// Verify vision tensors still work
|
||||
{
|
||||
"vision q weight",
|
||||
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"vision std bias",
|
||||
"model.vision_tower.std_bias",
|
||||
"v.std_bias",
|
||||
},
|
||||
{
|
||||
"vision std scale",
|
||||
"model.vision_tower.std_scale",
|
||||
"v.std_scale",
|
||||
},
|
||||
{
|
||||
"vision patch embd",
|
||||
"model.vision_tower.patch_embedder.input_proj.weight",
|
||||
"v.patch_embd.weight",
|
||||
},
|
||||
{
|
||||
"vision projector",
|
||||
"model.embed_vision.embedding_projection.linear.weight",
|
||||
"mm.input_projection.weight",
|
||||
},
|
||||
|
||||
// Verify text tensors still work
|
||||
{
|
||||
"text attn q",
|
||||
"model.language_model.layers.0.self_attn.q_proj.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"text token embd",
|
||||
"model.language_model.embed_tokens.weight",
|
||||
"token_embd.weight",
|
||||
},
|
||||
{
|
||||
"text moe gate up fused",
|
||||
"model.language_model.layers.0.experts.gate_up_proj",
|
||||
"blk.0.ffn_gate_up_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down",
|
||||
"model.language_model.layers.0.experts.down_proj",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down with weight suffix",
|
||||
"model.language_model.layers.0.experts.down_proj.weight",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale",
|
||||
"model.language_model.layers.0.router.per_expert_scale",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale with weight suffix",
|
||||
"model.language_model.layers.0.router.per_expert_scale.weight",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := r.Replace(tt.in); got != tt.want {
|
||||
t.Errorf("Replace(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
|
||||
generateSafetensorTestData(t, tempDir, td)
|
||||
|
||||
err = ConvertModel(os.DirFS(tempDir), f)
|
||||
if err == nil || err.Error() != "unsupported safetensors model" {
|
||||
t.Errorf("expected error but didn't get one")
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown data type") {
|
||||
t.Errorf("expected 'unknown data type' error but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,8 +42,11 @@ func (t tensorBase) Kind() uint32 {
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
|
||||
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.position_embd.weight" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
|
||||
for _, key := range keys {
|
||||
if value := headers[key]; value.Type != "" {
|
||||
// bitsandbytes quantized models are unsupported
|
||||
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
|
||||
// Promote them to 1-dim so they can be stored in GGUF.
|
||||
if len(value.Shape) == 0 {
|
||||
return nil, errors.New("unsupported safetensors model")
|
||||
value.Shape = []uint64{1}
|
||||
}
|
||||
ggufName := replacer.Replace(key)
|
||||
if _, ok := names[ggufName]; ok {
|
||||
|
||||
@@ -281,6 +281,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
|
||||
259
integration/audio_test.go
Normal file
259
integration/audio_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var defaultAudioModels = []string{
|
||||
"gemma4-e2b",
|
||||
"gemma4-e4b",
|
||||
}
|
||||
|
||||
// decodeTestAudio returns the test audio clip ("Why is the sky blue?", 16kHz mono WAV).
|
||||
func decodeTestAudio(t *testing.T) api.ImageData {
|
||||
t.Helper()
|
||||
data, err := base64.StdEncoding.DecodeString(audioEncodingPrompt)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode test audio: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// setupAudioModel pulls the model, preloads it, and skips if it doesn't support audio.
|
||||
func setupAudioModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
|
||||
t.Helper()
|
||||
requireCapability(ctx, t, client, model, "audio")
|
||||
pullOrSkip(ctx, t, client, model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioTranscription tests that the model can transcribe audio to text.
|
||||
func TestAudioTranscription(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audio := decodeTestAudio(t)
|
||||
noThink := &api.ThinkValue{Value: false}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Think: noThink,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Transcribe the audio exactly as spoken. Output only the transcription.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Transcribe this audio.",
|
||||
Images: []api.ImageData{audio},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 50,
|
||||
},
|
||||
}
|
||||
|
||||
// The audio says "Why is the sky blue?" — expect key words in transcription.
|
||||
DoChat(ctx, t, client, req, []string{"sky", "blue"}, 60*time.Second, 10*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioResponse tests that the model can respond to a spoken question.
|
||||
func TestAudioResponse(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audio := decodeTestAudio(t)
|
||||
noThink := &api.ThinkValue{Value: false}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Think: noThink,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "",
|
||||
Images: []api.ImageData{audio},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 200,
|
||||
},
|
||||
}
|
||||
|
||||
// The audio asks "Why is the sky blue?" — expect an answer about light/scattering.
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"scatter", "light", "blue", "atmosphere", "wavelength", "rayleigh",
|
||||
}, 60*time.Second, 10*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIAudioTranscription tests the /v1/audio/transcriptions endpoint.
|
||||
func TestOpenAIAudioTranscription(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, endpoint, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audioBytes := decodeTestAudio(t)
|
||||
|
||||
// Build multipart form request.
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
writer.WriteField("model", model)
|
||||
part, err := writer.CreateFormFile("file", "prompt.wav")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
part.Write(audioBytes)
|
||||
writer.Close()
|
||||
|
||||
url := fmt.Sprintf("http://%s/v1/audio/transcriptions", endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
text := strings.ToLower(string(respBody))
|
||||
if !strings.Contains(text, "sky") && !strings.Contains(text, "blue") {
|
||||
t.Errorf("transcription response missing expected words, got: %s", string(respBody))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIChatWithAudio tests /v1/chat/completions with input_audio content.
|
||||
func TestOpenAIChatWithAudio(t *testing.T) {
|
||||
for _, model := range testModels(defaultAudioModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, endpoint, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
setupAudioModel(ctx, t, client, model)
|
||||
audioB64 := audioEncodingPrompt
|
||||
|
||||
reqBody := fmt.Sprintf(`{
|
||||
"model": %q,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_audio", "input_audio": {"data": %q, "format": "wav"}}
|
||||
]
|
||||
}],
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"max_tokens": 200,
|
||||
"think": false
|
||||
}`, model, strings.TrimSpace(audioB64))
|
||||
|
||||
url := fmt.Sprintf("http://%s/v1/chat/completions", endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(reqBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Choices) == 0 {
|
||||
t.Fatal("no choices in response")
|
||||
}
|
||||
|
||||
text := strings.ToLower(result.Choices[0].Message.Content + " " + result.Choices[0].Message.Reasoning)
|
||||
found := false
|
||||
for _, word := range []string{"sky", "blue", "scatter", "light", "atmosphere"} {
|
||||
if strings.Contains(text, word) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("response missing expected words about sky/blue/light, got: %s", result.Choices[0].Message.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
9
integration/audio_test_data_test.go
Normal file
9
integration/audio_test_data_test.go
Normal file
File diff suppressed because one or more lines are too long
@@ -51,6 +51,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
thinkOff := api.ThinkValue{Value: false}
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
@@ -59,6 +60,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
Content: "Write me a story in english with a lot of emojis",
|
||||
},
|
||||
},
|
||||
Think: &thinkOff,
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
|
||||
@@ -15,6 +15,7 @@ func TestVisionModels(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
|
||||
defaultVisionModels := []string{
|
||||
"gemma4",
|
||||
"qwen2.5vl",
|
||||
"llama3.2-vision",
|
||||
"gemma3",
|
||||
@@ -23,6 +24,8 @@ func TestVisionModels(t *testing.T) {
|
||||
"ministral-3",
|
||||
}
|
||||
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
@@ -30,10 +33,7 @@ func TestVisionModels(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if testModel != "" {
|
||||
requireCapability(ctx, t, client, model, "vision")
|
||||
}
|
||||
|
||||
requireCapability(ctx, t, client, model, "vision")
|
||||
pullOrSkip(ctx, t, client, model)
|
||||
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
|
||||
155
integration/thinking_test.go
Normal file
155
integration/thinking_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestThinkingEnabled verifies that when thinking is requested, the model
|
||||
// produces both thinking and content output without leaking raw channel tags.
|
||||
func TestThinkingEnabled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
models := testModels([]string{smol})
|
||||
for _, modelName := range models {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
requireCapability(ctx, t, client, modelName, "thinking")
|
||||
pullOrSkip(ctx, t, client, modelName)
|
||||
|
||||
think := api.ThinkValue{Value: true}
|
||||
stream := false
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Stream: &stream,
|
||||
Think: &think,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What is 12 * 15? Think step by step."},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 42,
|
||||
"num_predict": 512,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
|
||||
response = cr
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "model requires more system memory") {
|
||||
t.Skip("model too large for test system")
|
||||
}
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
thinking := response.Message.Thinking
|
||||
|
||||
// Thinking should be non-empty when thinking is enabled
|
||||
if thinking == "" {
|
||||
t.Error("expected non-empty thinking output when thinking is enabled")
|
||||
}
|
||||
|
||||
// The answer (180) should appear in thinking, content, or both.
|
||||
// Some models put everything in thinking and leave content empty
|
||||
// if they hit the token limit while still thinking.
|
||||
combined := thinking + " " + content
|
||||
if !strings.Contains(combined, "180") {
|
||||
t.Errorf("expected '180' in thinking or content, got thinking=%q content=%q", thinking, content)
|
||||
}
|
||||
|
||||
// Neither thinking nor content should contain raw channel tags
|
||||
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains raw channel tags: %s", content)
|
||||
}
|
||||
if strings.Contains(thinking, "<|channel>") || strings.Contains(thinking, "<channel|>") {
|
||||
t.Errorf("thinking contains raw channel tags: %s", thinking)
|
||||
}
|
||||
|
||||
t.Logf("thinking (%d chars): %.100s...", len(thinking), thinking)
|
||||
t.Logf("content (%d chars): %s", len(content), content)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestThinkingSuppressed verifies that when thinking is NOT requested,
|
||||
// the model does not leak thinking/channel content into the response.
|
||||
func TestThinkingSuppressed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
models := testModels([]string{smol})
|
||||
for _, modelName := range models {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
requireCapability(ctx, t, client, modelName, "thinking")
|
||||
pullOrSkip(ctx, t, client, modelName)
|
||||
|
||||
stream := false
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Stream: &stream,
|
||||
// Think is nil — thinking not requested
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of Japan? Answer in one word."},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 42,
|
||||
"num_predict": 64,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(cr api.ChatResponse) error {
|
||||
response = cr
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "model requires more system memory") {
|
||||
t.Skip("model too large for test system")
|
||||
}
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
thinking := response.Message.Thinking
|
||||
|
||||
// The answer should appear in content or thinking
|
||||
combined := content + " " + thinking
|
||||
if !strings.Contains(combined, "Tokyo") {
|
||||
t.Errorf("expected 'Tokyo' in content or thinking, got content=%q thinking=%q", content, thinking)
|
||||
}
|
||||
|
||||
// Content must NOT contain channel/thinking tags
|
||||
if strings.Contains(content, "<|channel>") || strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains leaked channel tags when thinking not requested: %s", content)
|
||||
}
|
||||
if strings.Contains(content, "thought") && strings.Contains(content, "<channel|>") {
|
||||
t.Errorf("content contains leaked thinking block: %s", content)
|
||||
}
|
||||
|
||||
// Thinking field should ideally be empty when not requested.
|
||||
// Some small models may still produce thinking output; log but don't fail.
|
||||
if thinking != "" {
|
||||
t.Logf("WARNING: model produced thinking output when not requested (%d chars): %.100s...", len(thinking), thinking)
|
||||
}
|
||||
|
||||
t.Logf("content: %s", content)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
minVRAM := map[string]uint64{
|
||||
"gemma4": 8,
|
||||
"qwen3-vl": 16,
|
||||
"gpt-oss:20b": 16,
|
||||
"gpt-oss:120b": 70,
|
||||
|
||||
@@ -45,6 +45,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"gemma4",
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
@@ -137,6 +138,7 @@ var (
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"glm4",
|
||||
"goliath",
|
||||
"gpt-oss:20b",
|
||||
@@ -272,6 +274,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"gemma4",
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
|
||||
@@ -5,23 +5,26 @@ package integration
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Default set of vision models to test. When OLLAMA_TEST_MODEL is set,
|
||||
// only that model is tested (with a capability check for vision).
|
||||
var defaultVisionModels = []string{
|
||||
"gemma4",
|
||||
"gemma3",
|
||||
"llama3.2-vision",
|
||||
"qwen2.5vl",
|
||||
"qwen3-vl:8b",
|
||||
}
|
||||
|
||||
// decodeTestImages returns the two test images (Abbey Road llamas, docs llamas).
|
||||
func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
|
||||
// decodeTestImages returns the test images.
|
||||
func decodeTestImages(t *testing.T) (abbeyRoad, docs, ollamaHome api.ImageData) {
|
||||
t.Helper()
|
||||
var err error
|
||||
abbeyRoad, err = base64.StdEncoding.DecodeString(imageEncoding)
|
||||
@@ -32,9 +35,35 @@ func decodeTestImages(t *testing.T) (abbeyRoad, docs api.ImageData) {
|
||||
if err != nil {
|
||||
t.Fatalf("decode docs image: %v", err)
|
||||
}
|
||||
ollamaHome, err = base64.StdEncoding.DecodeString(imageEncodingOllamaHome)
|
||||
if err != nil {
|
||||
t.Fatalf("decode ollama home image: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// skipIfNoVisionOverride skips the entire test (at parent level) when
|
||||
// OLLAMA_TEST_MODEL is set to a non-vision model. This prevents the parent
|
||||
// test from reporting PASS when all subtests are skipped.
|
||||
func skipIfNoVisionOverride(t *testing.T) {
|
||||
t.Helper()
|
||||
if testModel == "" {
|
||||
return
|
||||
}
|
||||
// Check actual model capabilities via the API rather than a hardcoded list.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: testModel})
|
||||
if err != nil {
|
||||
return // let the test proceed and fail naturally
|
||||
}
|
||||
if len(resp.Capabilities) > 0 && !slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
t.Skipf("model override %q does not have vision capability (has %v)", testModel, resp.Capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
// setupVisionModel pulls the model, preloads it, and skips if not GPU-loaded.
|
||||
func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, model string) {
|
||||
t.Helper()
|
||||
@@ -54,6 +83,7 @@ func setupVisionModel(ctx context.Context, t *testing.T, client *api.Client, mod
|
||||
// handles cached image tokens across turns.
|
||||
func TestVisionMultiTurn(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Models that fail on multi-turn detail questions (e.g. misidentifying objects).
|
||||
skipModels := map[string]string{
|
||||
@@ -72,7 +102,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, _ := decodeTestImages(t)
|
||||
abbeyRoad, _, _ := decodeTestImages(t)
|
||||
|
||||
// Turn 1: describe the image
|
||||
req := api.ChatRequest{
|
||||
@@ -100,7 +130,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
api.Message{Role: "user", Content: "How many animals are in the image?"},
|
||||
)
|
||||
resp2 := DoChat(ctx, t, client, req, []string{
|
||||
"four", "4",
|
||||
"four", "4", "three", "3",
|
||||
}, 60*time.Second, 30*time.Second)
|
||||
if resp2 == nil {
|
||||
t.Fatal("no response from turn 2")
|
||||
@@ -121,6 +151,7 @@ func TestVisionMultiTurn(t *testing.T) {
|
||||
// TestVisionObjectCounting asks the model to count objects in an image.
|
||||
func TestVisionObjectCounting(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
skipModels := map[string]string{
|
||||
"llama3.2-vision": "consistently miscounts (says 3 instead of 4)",
|
||||
@@ -137,7 +168,7 @@ func TestVisionObjectCounting(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -160,6 +191,7 @@ func TestVisionObjectCounting(t *testing.T) {
|
||||
// cultural references and scene context from an image.
|
||||
func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Models known to be too small or not capable enough for cultural reference detection.
|
||||
skipModels := map[string]string{
|
||||
@@ -178,7 +210,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, _ := decodeTestImages(t)
|
||||
abbeyRoad, _, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -193,7 +225,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
Options: map[string]any{"temperature": 0.0, "seed": 42},
|
||||
}
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"abbey road", "beatles", "abbey",
|
||||
"abbey road", "beatles", "abbey", "llama",
|
||||
}, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
@@ -203,6 +235,7 @@ func TestVisionSceneUnderstanding(t *testing.T) {
|
||||
// objects based on their spatial position in the image.
|
||||
func TestVisionSpatialReasoning(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -212,7 +245,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
// The docs image has: leftmost llama on laptop with glasses,
|
||||
// rightmost llama sleeping.
|
||||
@@ -239,6 +272,7 @@ func TestVisionSpatialReasoning(t *testing.T) {
|
||||
// small details like accessories in an image.
|
||||
func TestVisionDetailRecognition(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -248,7 +282,7 @@ func TestVisionDetailRecognition(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -274,6 +308,7 @@ func TestVisionDetailRecognition(t *testing.T) {
|
||||
// encoding and cross-image reasoning.
|
||||
func TestVisionMultiImage(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
// Multi-image support varies across models.
|
||||
skipModels := map[string]string{
|
||||
@@ -291,7 +326,7 @@ func TestVisionMultiImage(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
abbeyRoad, docs := decodeTestImages(t)
|
||||
abbeyRoad, docs, _ := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -314,10 +349,12 @@ func TestVisionMultiImage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestVisionOCR tests text extraction from an image. The docs image
|
||||
// contains the text "Ollama's documentation" in a header.
|
||||
func TestVisionOCR(t *testing.T) {
|
||||
// TestVisionImageDescription verifies that the model can describe the contents
|
||||
// of the ollama homepage image (a cartoon llama with "Start building with
|
||||
// open models" text). Basic sanity check that the vision pipeline works.
|
||||
func TestVisionImageDescription(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
skipIfNoVisionOverride(t)
|
||||
|
||||
for _, model := range testModels(defaultVisionModels) {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
@@ -327,22 +364,22 @@ func TestVisionOCR(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
setupVisionModel(ctx, t, client, model)
|
||||
_, docs := decodeTestImages(t)
|
||||
_, _, ollamaHome := decodeTestImages(t)
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What text appears in this image? Read all visible text.",
|
||||
Images: []api.ImageData{docs},
|
||||
Content: "Describe what you see in this image briefly.",
|
||||
Images: []api.ImageData{ollamaHome},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{"temperature": 0.0, "seed": 42},
|
||||
}
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"ollama", "documentation",
|
||||
"llama", "animal", "build", "model", "open", "cartoon", "character",
|
||||
}, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -383,3 +383,162 @@ yEUu0pztbKtys2RR9bUiUBGoCFQE5oTAL3/5y+ab3/xmc9JJJzWf+cxnmq9+9atzKXmuDGQuNaqFVAQq
|
||||
VBGoCFQElgKBykCWoptqJSsCFYGKwOIhUBnI4vVJrVFFoCJQEVgKBCoDWYpuqpWsCFQEKgKLh0BlIIvXJ7VGFYGKQEVgKRDYOWr5q6Woaa1kRaAiUBGoCCwU
|
||||
Av8fgwPy24mbuF8AAAAASUVORK5CYII=
|
||||
`
|
||||
// imageEncodingOllamaHome is a 415x293 JPEG of the ollama.com homepage.
|
||||
// Shows a cartoon llama character with text "Start building with open models".
|
||||
const imageEncodingOllamaHome = `/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA0JCgsKCA0LCgsODg0PEyAVExISEyccHhcgLikxMC4pLSwzOko+MzZGNywtQFdBRkxO
|
||||
UlNSMj5aYVpQYEpRUk//2wBDAQ4ODhMREyYVFSZPNS01T09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09P
|
||||
T09PT09PT0//wAARCAElAZ8DASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUF
|
||||
BAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVW
|
||||
V1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi
|
||||
4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAEC
|
||||
AxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVm
|
||||
Z2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq
|
||||
8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD06iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiq
|
||||
2o39rpllLeXsqxQRDLMf5e5oAs0V5XffEXXL6WeXQdOC2dsNzu8ZchfVuwrufCOvDxFocd80YjlDFJUHQMPT270AbdFFFABRRRQA
|
||||
UUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
|
||||
AFFFFABRRRQAUUUUAFeUeI7u68b+L00PT5CLG2chnHTj7zn+Q/8Ar12XjTxLZ6LpNzD9pUX8sRWGIcsCRjJ9BXmPg/xPJ4djuDa6
|
||||
V9rnnI3SliMKO3A9eaAO/wDFyWHhfwDPYWSLGJlECDu5PUn1OM1V+HuoaVovhWFb7UbWGa4kaUo0oyAeBkduBXMXc2t/EfV44obc
|
||||
W1vbLyGJKR56knHJPpXT2fwr0mOMfa7y6mkxyVwg/Ac0AdpZ6lY3wzZ3kE//AFzkDfyq1XmupfDF7YfafD2oypcJyqSnBP0YdKzU
|
||||
+IWuabp1xpV/bltUjPlpM45X13DufQ96APQtf8U6T4fTF9PmYjKwxjc5/Dt+Nc7pnxP0291GO1ns5rZJWCrKzhgCemR2qn4V8Atd
|
||||
v/a/ikvNPMd4t3Y9+7n19qzfHsVrfeLNM0PSoIkaHCMIkAwWI449AM/jQB61RSKNqgegxS0AFFFFABRRRQAUUUUAFFFFABRRRQAU
|
||||
UUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFBIAyTgCsifxBbCYw2UU17KvUQLkD8e
|
||||
lAGvRSIxZFYqVJGSD2paACub8b+JB4c0fzItrXk5KQKegPdj7D/Cukry7xWn9s/FLT9LnOYItgK9iMbz+fSgCfwh4I/tEDW/Exe4
|
||||
luD5iQyE8g/xP/hXosFtb20Qit4I4kHRUUAD8qkAAAAGBXC6d4pvtL8XXGieIZ45Yp5M2064AXJ4U47duehoA7pUVc7VAzycDrS1
|
||||
FPc29uM3E8cQ9XcL/OiG5t7gZt54pR/sOG/lQBLXn/xH8OXt3dWesaNbNLdQnEojGWOOVOO+OlegUUAeXr8ULuCxuLfUNMMeoouI
|
||||
yMhd3+0p5HrV34ceHbgzSeJNWDNc3GTCH64PVz9e3tW94z8LW3iDTJGWNVv4lJhlA5J/un1BrK+F2tzXumTaXeMTPYkBN3XYeMfg
|
||||
ePyoA7qiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKK
|
||||
KACiimyIJInjJIDKRkdRmgDAmeTXrmWJZTDpduSJXBwZiOoz6ClsLq4uGWLQ7OGGwjbBmkBG/wBdoHX61dk0dP7DOl20rRJtC78Z
|
||||
JGcnP1qle+IbDSsWNvG0hiXZ8mAF9s+tAHK6/wCKdc1nxDJofhTKCElXmXGWI6nJ+6oPFVk8ReKvCGoxReJQ13Zyn7+Q3Hcqw7j0
|
||||
NSfCJkbUdXZv9aQpyeuMnP613+u6Pa65pcthdr8rjKt3RuzCgC1Z3UF7aRXVrIJIZVDIw7g15p468zQvH2na8ULQPtLY9V4Yf98k
|
||||
VF4X1u58FaxLoGvZW0LZSTshP8Q/2T+n51FbxXHxF8XySTM6aTaHgA9FzwB/tNjJNAHpd9Aut6G8VpevCl1GCk8J5APORXC63oWk
|
||||
eCdIW/hha+1SSQJBLcfMFfruC9OP8K9Gt4Ira3jggRY4o1Coq9FA6CuT+Jem3F74fS6tFLS2Mon2gZyuOfy60AYkfhCxAt7rxnqs
|
||||
0t/ek7IzLtXdjO3d/wDqFY+k6XpOrXwttEfUtK1VC/RvMiUr0y4wRmug8RahB4s8F295ZR/aHtpo5Lq3UZkUDhgO/wCPpVXQtK/t
|
||||
PxBJP4b/ALQ0jRgi+a24r5zjoFB/Xr+tAGx4V8SajHqz+HPEqhb9B+5m7TD+vHQ96u6z480PSpXt/Oe6uUO0xQLuwfQnpWN43eKb
|
||||
xx4dgsyDfRygvt6qm4EZ/JjS6h4avtE8XQa1oNot1b3MmLi3IB8sk8kE9B3z2+lAHdWdwt3Zw3Ko6LKgcK4wwyM4I9a828G4i+KO
|
||||
sxQcRHzuB0++K7/XdWg0XSJ7+5YARr8o7u3YD8a4j4U2E0smoa7cg7rhiiEj73OWP54FAHo9FFFABRRRQAUUUUAFFFFABRRRQAUU
|
||||
UUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFADJiwhcp94KcfXFcl4NgguJLuS4RZ
|
||||
JRj74zgHOf1rsK4u4EnhvxCZ1Um0nJOB3U9R9QeaAOdsivhL4pSwP+7s70kL2AV+V/JuK9WriviBoS+INCj1LTsSXNqpdCv/AC0T
|
||||
uPqOtWPh94mXXdJFvcv/AKdagLICeXXs3+PvQA34mWFnP4VuLueBWuLfHkydCuWA/L2pPhdaR2/hCKZQN9xI7ufXBwP0Fb/iHTf7
|
||||
Y0K80/IDTRkKT2bqP1ArB+G9lq2m6LNZarbGBYpj5O48kHr+Gen1oA6+ggEYNFFAHF6t4Bie+OoeH76TS7snJEedhP0HT+VVjonj
|
||||
9h5B8QWwj6eYBhv/AEHNd7RQBzHhfwdb6HO99c3D3uoyZ3Tyds9cf4109FFAHlfxYiu11ewmupXfTGGFjU42sD834kdDXpWlwWtv
|
||||
pltFYIEtljXygP7uMiuT+LMaN4TV2A3JcptP1Brd8Hu0nhLS2fkm2T+VAGzRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUU
|
||||
AFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAVWv7GDULVre4XKnoR1U+oqzRQBxSSaj4WuSki+dZ
|
||||
O3H90/T0PtXK69GNF1mPxP4bfEJfM8GMGJj1BH91v5/hXrsscc0bRyorowwVYZBrl9W8HxTK5sGChgQ0Mn3SPQHtQBpaV4k07UtB
|
||||
OrrMscEa5mDHmIjqD/nmrOjavY63Yi806XzItxU5UggjsRXiOv6VqXh2WSzJmitrz+DPD4OcH1we9e0eGNKi0bQLSyjXDKgaQ/3n
|
||||
PJNAGrRRRQAUUUUAFFFFAHmvxW1H7XNY6BaZknaQSOq9ieFH6k13+lWY0/SrWzHSCJY/yFcJ8U9GEKQeIrMmO5idUlZT1/ut9QeP
|
||||
yrtfD2o/2toNlfnAaeIFsf3uh/UGgDRooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA
|
||||
KKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigDzD4rZGuaKW/1fP/oQzXpy4IBHTtXC/FnTHutBgvolJazky+OytwT+YFdH
|
||||
4T1P+1/DdleEEM0e18/3l4P8qANiiiigAooooArajfW+m2E17dvshhUsx/w968zOveL/ABldSJoKNZ2SHG5Ttx/vP6+wrqfibb3F
|
||||
x4On+zgt5ciSSAd0B5/ofwqD4d65pEvh610+KaKC6hXbJE5Clmzyw9c0Ac3efD3xPPaO0+sJcSEZ8lpnIb8TxWr8M9ddQ/hq/i8m
|
||||
4tN3l5GCQD8yn3BP5V6CzKqFmYBQMkk8AV5Xpk0er/F57zTObeMszyL0YBNpP4mgD1WiiigAooooAKKKKACiiigAooooAKKKKACi
|
||||
iigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAGyIkiMkiK6MMFWGQR9KSKK
|
||||
OGJYoY1jjUYVVGAPoKfRQAUUUUAFFFFACMoZSrAEEYIPeuJ1r4aaTfyvPYSyWMrHO1BuTP07fga7eigDy9vhtrxUwHX1NueCpaTG
|
||||
P93pXaeF/DFj4aszFbZknkwZZmHL+3sPatyigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
|
||||
gAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAoo
|
||||
ooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAK
|
||||
KKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKA
|
||||
CiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiii
|
||||
gAooooAKZPNHbwSTTOEjjUszHoAOSafWR4tt5rrwrqcNsC0rW7bQOp4zj8qAPPdQ8a+IvEeqNY+GIpIoudvlqPMYf3mY8KPy+tJN
|
||||
p/xIsIzdfabuXb8xVbhZSP8AgPf8Ki+F2vaZpVxeW2oSJbtc7DHM/C8Z+Unt1zXrcM0U8YkgkSRD0ZGBH5igDiPAvjmTWbn+zNWV
|
||||
EvcExyKNolx1BHZu/wCddJ4n16Pw7pQv5YHmXzFj2owB5zzz9Kz/APhB9LXxH/bkU11Hced52xWUJu78Yzg/XvXH/EvUtdke5sJ7
|
||||
DZpUcyGK48phuO3+9nB5J/KgD0Lw3rcfiDSU1CKF4VZ2XYxBPB9q1a8j8Cax4lt4LKysdM83THuQHn8hmwCw3fMDjiu+8W+J7bwz
|
||||
p6zSoZriUkQwg43EdST2AoA3qK8qi8W+OtRiN5Y6Xm2PK+XallI9iTk/hWz4P8ftq2oLper26W92xKo6AhWYfwkHoaAO8orkvH/i
|
||||
W+8N2tnLYJAzTOyt5qk8AA8YIrnLjx9r1/b28OhWHnzrCrXMkUDSAOR90DsB70Aej6neLp+mXV66F1t4mlKg4JwM4rF8JeLYPFBu
|
||||
hDaSQfZtmd7A53Z9PpWB4w1jxGmgW0cOn+ZDdaduvZDA37pivzd/lxz1ri/Bmq+INMN5/YGn/a/M2eb+5aTbjOOh46mgD3aq2oyP
|
||||
Fpt1JG210hdlI7EKafZPLLYwSXCbJnjVpFxjaxHIx9ai1b/kE3n/AFwf/wBBNAHC/DLXtV1i/vo9SvZLhY4lZQwHBJ9hXoteH+Bv
|
||||
ENv4cGpXUqGWZ4kSGEHBds/yrYu/Gfje2T7dPpYgtTyN9owQD3JOaAPWKK5nwZ4vg8TW8iNGIL2EZkiByCP7y+38qxvGXjPVND8S
|
||||
RafZx2zQvGjEyIS2SSD0I9KAO/oorgPHPjPVPD+vQ2VjHbNE8CyEyoSclmHYj0oA3/G76zH4ekbQBIbneu7yhlwnOdvv0/Wl8Evr
|
||||
L+HYm14SC63tt8wYcp23D16/pSeNdautC8Ptf2SxNKJEXEikjB+hFL4K1m617w+l9erEsrSOpEYIGAfcmgDkPAviLWNS8YTWd9fy
|
||||
TW6pKQjAYBBGOgr06vHfht/yPtx/1zm/9CFeg+MPFVt4ZskZk866mz5UOcZx1JPYfzoA6GuA+KGuapoz6aNMvHt/NEm/aB82NuOo
|
||||
9zWNaeM/G18DeWemCa2B6R2jMh9s5z+tZXjnxJD4jstLlEZguYDKk8JP3T8uCPY4P5UAenm41SfwNHc2DeZqUlijoxAyzlQSfTPX
|
||||
8azfh5L4jktLv/hIRcbA6+QbhcP33e+OlXY72XTfh5Be24UywaajqHGRkIOtU/h/4mv/ABJb3sl+kCmB0C+UpHUHrkn0oA6+ivNr
|
||||
r4hXlh4vubG9S3Gn280iMyxkyEAHAHOMk4FVb3xn4ynja+stHe3sfvK32Zn+X1LH+YGKAPU6K4nwN45OvznT9QijivApZGjyFkA6
|
||||
8Hoe9b3ifxDa+HNMN3cgu7HbFEpwXb+g9TQBsUV5Nb+N/GWqyvNpenK8KHlYrZnA9i3rVu++Jd5HpC7LSK21WKcJPBMjEbcH5gMg
|
||||
jkDg9KAPTqKyfCupT6v4cs9QuggmnUlggwvDEcflWpLv8pvKID4O0sOM9qAHUVwngjxlqOta5c6bqsVvG8cZZfKUg7lYAg5J9f0r
|
||||
b8ba9L4e0Bry2EbXDyLHGJBkZPJ4HsDQB0FFcl4A8U3PiS1u/tywrcW7rxECAVI44JPcGmfEDxXdeG47JLBYXmnLFhKpICjHoR3P
|
||||
6UAXvHL61H4eZtAEhuPMXf5Qy4TnO33zj8Kk8FvrD+HYW14OLrc2PMGH2dtw9f8A61Z+v+IdU0jwTaaqUt/t0vl+YrIdg3AkjGf6
|
||||
1DZeJNbv/AX9r2dpFPqJlKrFHEzAgNg8Zz096ALEfji3fxWdA+wyiQTmHzd4xkd8V1MzFYHZTghSR+VeBR3+rr4yN8lnnVftBf7P
|
||||
5Z+/zkbc5r17wxqGsajodzNrtn9kuFdlVPLKZXaOcE+pNAHL/DPxDq+r61dQ6lfSXEaW5dVYDg7gM8D3r0qvC/Auuw+H7y+vJUMs
|
||||
jW/lwxL1kcuuBW5f+M/G1ov2y50wW1sTwHtWCj0BJOaAPWKK5zwb4rh8TWTsYxDdwECWIHI56MPb+VdHQAUUUUAFFFZHiqbUrbw7
|
||||
d3Gjvtu4VDr8gbIB+YYPtmgDC8QfDjS9Vne5s5XsZ3JLbFDIx9dvb8DXJ3Hw88TaUxn0q7SUryPIlMT/AJHH863vBPj+O8SS18Q3
|
||||
kcdzvzFM4CIy+mRwCD6+tdnPrWlW8Jmm1G0SMDO4zL/jQB554O8canDrEejeIC0m+TyVkkXEkb5wA3qM8c81vfFf/kUB/wBfKfyN
|
||||
cNdTJ4o+JUc2lxnypbiMhsYJVAMv7cKTXc/FZSfB+QOlyhP60ASfC3/kTYf+u0n86t+L9P8ADMqRX3iVgqxgpGTKy574CqeTVD4V
|
||||
3EL+ElhWRDJHM4dc8jJyOK5D4nSNJ41hhvXdbVI49uOyE/MR79fyoA6pviZ4bto1ighvHRAFUJEAAB0AyRXCXGq22rfEW21KwieG
|
||||
Oa8gIVwAc5UE8epFer2ln4XstNWa3h0xLVVyJSEII9Sx615Tf6ha6n8R4LqxQLbNeQrHhdoIUqM498ZoA634xf8AIP0z/rq/8hXS
|
||||
eAbWG18Haf5KBTLH5rkdWYnqf5fhXN/GL/kH6Z/12f8AkK6rwV/yJ+lf9e60AT+Kv+RV1b/r0l/9BNcL8Gvvav8ASH/2eu78UKW8
|
||||
L6qqjJNpLx/wE15/8HbiGO51OB5FWSRY2RScFgN2cfmKAPVKqat/yCbz/rg//oJq3VTVv+QTef8AXB//AEE0AeS/CrTYL3xHJc3C
|
||||
BxaRb0BGfnJwD+HNexyIkkbRyKGRgQysMgj0NeH/AA912HQtfL3h22twnlSPjhDnKk+3H617Be6/pNlYteT6hb+SFyCsgYt7ADqa
|
||||
APK9Aj/sP4qfY7c4hFy8AH+wwOB/L8qk+J3/ACPFt/1wi/8AQjTPBizeIPiK+qFCsaSPcv8A7IOQo/Mj8jUvxYikg8U2t1j5Ht12
|
||||
ntlWOR+o/OgD2CvHvix/yN1r/wBeqf8AobV6bp/iDStQ0+O9hvrcRsoZg0gBT1DA9CK8f8eazb634qM9m2+CFFhR+z4JJI9sk0Ae
|
||||
hfFL/kTW/wCu8f8AWl+Fv/InRf8AXaT+dJ8Uv+RNb/rvH/Wl+Fv/ACJ0X/XaT+dAHHfDf/kfbj/rnN/6EKZ41B1f4mJp8jHy/Mht
|
||||
h7KcE/qxp/w2/wCR9uP+uc3/AKEKPiJBPo/jqHVkQlZTHPGexZMAj9B+dAHr0EMVvAkECLHFGoVFUYCgdBXkvxb02C11m1vYUCNd
|
||||
xt5mB1ZSOfrgj8q9K03xDpOpWCXdvfQBCuWDyBWT2YHpXk/xJ1+31vWYksW8y1tEKCUdHYnLEe3QUAegXv8AySs/9gpf/RYrD+Dn
|
||||
/Hnqn/XSP+TVuXv/ACSs/wDYKX/0WKw/g5/x56p/10j/AJNQBzwtIb74tSW1wgeJr9yynocZOD+Ve0dBXj1j/wAlkb/r+l/k1exd
|
||||
qAPGtPhSy+LoitwERb1wqjgAEHj9am+LNxJceJbSzB+SK3UqP9pmOT+gpsf/ACWM/wDX8f8A0Grfxd02WPULLVUU+W8fksw/hYEk
|
||||
fmCfyoA9M0uwg0vToLK2QLFCgUYHX1P1PWvPfjBp0CxWOpogWZnMLsP4hjIz9MH866vw74t0vV9Lime9ghuAg86KSQKVbv16j3rg
|
||||
vif4jtdWmt7DTpBNBasWklTlS54AB74Gfz9qAO7+H3/IkaZ/uN/6G1dHXOfD7/kSNM/3G/8AQ2ro6APJbhf+Ef8Ai/G4+WG5nDex
|
||||
Eowf/HifyrQ+JrvqfiDRtBhPLsGYD1Zto/IA0nxds2ifTdWh4ZGMLN6H7y/+zVF4YnHif4lz6wATBbQ7kyOh2hQPzLGgA8Mxr4d+
|
||||
KN7pSjZb3SsIl7YI3r+mRVbxoDrvxKs9KBykflxMPQH52P5H9Kv/ABJjOl+JdF1+MYCuFkI/2Gz+oJH4VW8BL/bfj7VNbOTHGXdC
|
||||
R0LnC/8AjoNAG/8AFUAeD8AYH2iP+tTfC/8A5Eu3/wCusn/oVRfFb/kUP+3mP+tSfC//AJEu3/66yf8AoVAHFW//ACWI/wDX+/8A
|
||||
I17Bcf8AHtJ/uH+VeO+bHa/F1pLh1jQX5yzHAGen8xXsMxDWshBBGw9PpQB498KLKG68UvLMgY21u0keR0bIGfyJr2G5t4ru2ltr
|
||||
hA8UqlHU9CCMGvJvg/8A8jDef9eh/wDQ1r16gDx34WlrfxlcQKx2mCRT74Yf4V7FXjvw2/5Hy4/65Tf+hCvYqACiiigAooooA4zX
|
||||
Phxo+qXD3Ns8ljM5y3lAFCfXaen4EVjx/CWMPmXWXZPRbcA/nur0uigDF8O+FtL8OxsLCJjM4w80hy7D09h7Cr+qadbatp01jepv
|
||||
gmXDAHBHcEe4PNW6KAPPrH4YQ2Or217Dq0hSCZZRG0IydpBxkH29K6XxN4W07xJAi3geOaP/AFc0f3l9vce1blFAHndp8KLGO4D3
|
||||
epTTxA58tYwmfqcmtS68AWE2vQanDcSQLA0RSBEG0BMYHr2rsKKAMDxX4Xh8TQW8U9zJAIGLAooOcjHetPSNPTStKtrCORpFt4wg
|
||||
ZhgnFXKKAEdFdGR1DKwwQehFee3/AMKrKa5aSx1KW2jY5EbRiTb7A5HFeh0UAQ2cH2Wygtg2/wAmNU3YxnAxmluoRc2ssBYqJUZC
|
||||
R2yMVLRQBxulfDvTLG3vLe4nlu4rpFUh1ClCDkMCOhrKb4TWpnJXV5hD2Uwgt+ecfpXo9FAGXoGg6f4fsvs2nxEBjl5GOXc+pNJ4
|
||||
h8P2HiKx+y36N8p3RyIcMh9R/hWrRQB5xH8JrUTgy6vM0WeVWEK355P8q0tU+HGmX0tsbe4ltI7eERKiKDnBJ3EnqSTXa0UAZPiT
|
||||
Q4vEGknT5p3hUur7kAJ4+tL4b0SPw/pK2EMzzKrs+5wAefpWrRQBynh/wRbaHrb6nFezSu6suxlAHzHPatrW9EsNdsTaajFvTOVZ
|
||||
ThkPqD2rRooA83/4VNa+fkavN5Ofu+SN355x+la2pfDrSruws7O2mltY7XecqAzSM2Mlie/y12VFAGZNo0cvhr+xDM4j+zC38zA3
|
||||
YC4zjpVLwn4Wg8MRXMcFzJOLhlJ3qBjGfT610FFAHKQ+CLaHxYdfF7MZTM0vlFRtyQeM9e9dX2oooA5RfBFsviv+3/ts3m+cZfK2
|
||||
jbnGMZ61Y8Ya1o2m2iWmu2001vdqQAse5TjHfIweQa6OqOr6TY61YtZ6hCJImOR2Kn1B7GgDgrP4feG9ZiW90vVbg20nzbAVYp7H
|
||||
IyD9ayvH0Gh6NpVpoejlWmWbzp2Dbm4Ugbj689O1a1z8J4/NJstYkjQ/wyQ7jj6gjP5VpaF8NNL025S5vp3v5EOVVlCJn1I5z+Jx
|
||||
QBueCbaSz8IaZDMpVxDuIPUbiW/rW5RRQBl+I9Eg8QaS+n3EjRhmVw6gEqQff8R+NU/CfhS28MR3K288k7XBUszqAQBnA4+proKK
|
||||
AMfxP4ft/EemCyuJGiCyCRXQAkEZHf2JqLwp4YtvDNrPDbzPMZnDs7gA8DAHH4/nW7RQBkeJtCi8RaV9gmneFfMV9yAE8Z9frT/D
|
||||
mix6BpCafDM8yIzNucAHk57VqUUAcj4o8BWHiC9N8tw9pdMAHZVDK+OASOOfxq/4V8NDw7pE1h9rNz5shkL7NuMqBjGT6Vv0UAct
|
||||
4V8FW3hq/lu4LyadpYvLKuoAHIOePpXU0UUAcp4f8EW2ha0+pxXs0rurLsZQB8xz2rq6KKACiiigAooooAKKKKACiiigAooooAKK
|
||||
hvLuCxtJbq6kEcMKl3Y9gK80uviHrmrX7W/hrTMoM4JjMkhHqQOFoA9Rory23+IWvaRfLB4l0zCN1xGY3A9Rng/55r0qyvYL+xiv
|
||||
LSQSQzJvRh3FAFiivLdG+J10Zbp9Yit/KihLRJCpVpJNwAXJJ4wSfwqK78ZeNljN+NK8iz6jNqxUD1JPP48UAer0VyXgrxrD4kD2
|
||||
1xEtvfRruKKcrIvqv+Fa/iTxBaeHdNN3d5dmO2KJTzI3p7D1NAGtRXlUPjDxtq+660rSx9mB48u3Lj6bieT9K1/CvxAe+1FdK122
|
||||
W1u2bYjgFQW/usp5U0Ad9RXN+Otdu/D2hpe2KxNI06xkSqSMEE9iPSuWi+Ier32m29vpenLd6q4ZpvLiYpENxAwM8nGD1xQB6bRX
|
||||
kqfELxLpOoLHrlguw8tE8JifHqp//XXqWn3sGpWEF7atuhnQOh9j6+9AFiiivOdT8f3um+MptOuFtl0+GYK7eWS+3GT36/hQB6NR
|
||||
XleoeNfF8kbahaaS1tp33lZrdnG31LH+YwK0rD4mwSaBLPdWw/tKNhGlvGTiUnoR3A4569vWgD0KivJr7xr41sdt5eaatvbMeBJa
|
||||
sF+mSc/rXeeEfEsPiXSzcJH5U8TbJos52nsQfQ0AbtFcn4y8bW/hsrawRC5vnXdsJwqD1Y/0rlB4t8dm3+3jS/8ARsbs/ZG249eu
|
||||
ce9AHq9Fcj4N8cW/iNzaXEQtr5VLBAcrIB1K+/tWz4j1608PaW17d5Y52xxqeZG9P/r0Aatcl8SdUvtJ8PQ3GnXLQStcqhZQORtY
|
||||
45+grkoPHPjDVpnk0nTleJDysVuZAPYt6/lVfxZ4sOu+GBY39sbTU7a6QyREEBhtbkA8jqOD60AeheBL661LwlZ3d9M008hfc7Yy
|
||||
cOQOnsK6CuW+Gv8AyI9h9ZP/AEY1Y3iX4hTQ6k2leHLVbq4VijSFS4Ldwqjr9aAPQqD0ryuTxn4z0YpPrGlg27HB8yAp+G4dD9a9
|
||||
B8P65aeINLS+syQD8ro33o27g0AcKviLWD8T/wCzDfyfYvtZTycDG3HTpmvTR0rx5P8Aksf/AG/H/wBBr0Lxb4nt/DOnJNJGZriY
|
||||
lYYgcbiOpJ7AUAb9FeXQ+JvH2owfbrLS0+zNyuy3yGHtk5P4VueDfHX9t3h0zU7dba/AO3bkK5HUYPII9PrQB2tFFFABRRRQAUUU
|
||||
UAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAcH8XLuSHw9bWqHAuLj5/cKM4/PH5Vo/DbTobLwjbSoo826zLI2OTyQB+AH86rfFLTJ
|
||||
b7wwLiBSzWcolYAZ+TBBP4cH8Kr/AA08SWU2gxaVcXEcV1a5VVdgvmITkEZ64zjHtQBqfEXT4b7whdvIgMlsPOjbHKkHn8xmsf4R
|
||||
3ckug3tq7ZW3mynsGHT8wfzqf4k+JLK30CbTILiOW7ugE2IwbYucknHTpj8ad8LNMlsvDMl1MpVryTegIx8gGAfx5oA4f4babDqP
|
||||
i5PtCB0to2n2sMgkEAfqc/hXtxAKkEZB6g14N4J1qPQvE0d3cZ+zuGimYDO1T3/AgV7TLrukw2RvH1K1Fvt3bxKCCPbHX6UAeUPC
|
||||
nh/4sRw2g8uIXiBVHQLIBkfTDEVN8VrszeKoLWVm8i3hXhevzHJP1xj8qh0d38VfE5b6KNhELgXByPuxpjbn8lH41ofFaxmtNfs9
|
||||
XRN0UiKpJGQHQ5wfqP5GgDWtvib4ftLaO2t9PvkhiUKihEwAP+BVxnjbX9O17VLfUNMgngmVNsrSBQWIPyngnn/61esaNfeH9YsI
|
||||
7q1Sy+ZQXQqgaM9wRWDrnjDQtN1NLCy0uDUZW4byQmAxOAucHJ+lAEPxHuDd/D/T7liCZpYZDj3jY1f+FlpDB4RjuEQCW5kdpG7n
|
||||
BKgfp+tV/iqMeDrcbBHi5j+QdF+VuK0Phn/yJFl/vSf+hmgDP+LdtFJ4ZhuGUeZFcqFbuAQcj9B+VXvhi7N4JtAxzteQD6bzVf4r
|
||||
/wDIoD/r5T+Rqb4X/wDIlW3/AF0k/wDQjQB1x6V4xqVtFd/FtredQ8T3qBlPQjA4r2c9K8euP+SyD/r+T+QoA9fZVZCjKCpGCCOC
|
||||
K8a+H9pA3xAZGjBW381owecEHA/LNez9q8f+Hv8AyUO5/wB2f/0KgD07xHDHceHNSilUMhtZOD6hSQfzFeffBtj9p1Vc8FIjj8Wr
|
||||
0XXf+QDqP/XrL/6Ca85+Dn/H3qv/AFzi/m1AGZoUS+IPifJJegSR+fJKVIyCEztH04X8q9nrxdpG8HfEt57pW+zGZmzjrFJnkeuM
|
||||
/pXrQ1jTDZ/bBqFr9n27vM81cYoA8m8TRJoHxLjnsgI1MsU4UDgbvvD6Hn86ufGC6d9asbPPyRW5kA92Yj+SiqU0p8ZfEmN7NWNt
|
||||
5qYbHSJMZY+mefzFavxg0+QXVjqaqTGUMDn+6QSw/PJ/KgD0XRNOh0rSLayt0CpFGAcfxHHJPuTXC/F/TYfsNnqioBMJfIdh1ZSC
|
||||
Rn6YP510/hfxRp2saPBIbqGO5RAs0TuFZWA5OD2PXNcR8UvEVpqIt9LsJVnSB/MmkQ5UNjAUHv1NAG54au3sPhI11EcSRQTlD6He
|
||||
2P1rivAviLSvDlzc3WoW1xNcSKEiaNVOwfxdSOvH5V3XhCyOpfCxbEHDTwzop9CXbH61y3w41Gw03UrzS9bjhiaVgEadR8jrkFST
|
||||
0z/SgDb1D4k+H9QsJ7O4sL5op0KMCid/+BVl/B+6ddWv7Pd8kkAlx7qwGf8Ax6u+1S78P6VYvd3a2Soq5UBEJc+gHc1n+DPENt4g
|
||||
e5ktNHFmkICmUFfmJ/h4A9M/lQBxKf8AJY/+34/+g1vfFfRb2+trPULSJ5ktgyyooyVBwQ2PTjn8KwU/5LH/ANvx/wDQa7fxZ4yH
|
||||
hi6t4ZdOknSdCyyLIFGQcEdPp+dAHP8Ah74mWEGn21nqlpNE8Max+ZCAykAYzjgj9a3tItvCOt6yda0xklv1bzWIkdWBxjJQ/wCF
|
||||
XTpvhnxHZLetZ2dxHKu7zQArD6kYINeX20EOm/Ey3t/D87SwJdoisrbvlON657gfMPwoA9vooooAKKKKACiiigAooooAKKKKACii
|
||||
igAooooAKKKKACiiigBGVXUqwBUjBBHBrhtX+GGlXtw01jcS2Jc5MaqHQH2HBH513VFAHB6T8L9Ks7hZr+5lvdpBEZUIhPuOSfzr
|
||||
ugqpFsRQqqMAAYAFOoIyCKAPFfhpY22p67f2V7EJYJbJwyn/AH059j710k3wntGuC0GqzRw54RogzD8cj+VbPhXwPD4b1SS+jv5L
|
||||
gyRGLa0YXGSDnr7V1tAGP4c8Nad4ctmisUYySY8yaQ5d/wDAe1XtS0601Wyks7+FZoJByp/mD2PvVqigDzm6+E9m8pa01WaKMnhZ
|
||||
Ig5H45FbnhvwHpWg3C3eXu7tfuySgAJ7qo6H35rqqKAMfxRoEXiPTEsZp3hVZRJuQAngEY5+tS+HtHj0HR4dOileVIixDsACcknt
|
||||
9a06KAMjxNoUXiLSxYTTvCvmCTcgBPGfX60/w5osegaRHp8MzzIjMwZwAeTntWpRQAVykngi2fxYNfN7MJfOEvlbRtyB0z1rq6KA
|
||||
CuV0LwTbaLrsmqxXs0ruHBRlAHzHPauqooAhvbcXdjcWrMVE0bRlh1GRjP61geE/CFv4YluZILuWc3CqCHUDGM+n1rpaKAMjxD4b
|
||||
03xFbrFqER3p/q5UOHT6H09jXG/8Klt/Nz/bEvl/3fIGfzz/AEr0migDH8O+GdM8OwNHYRkyP/rJpDl3/HsPYVf1GwtdTspLO+hW
|
||||
WCUYZT/Meh96s0UAec3Hwns2nLW2qzRxE/ceIOQPrkfyrUl+HWlf2ENMglmiJlWWSfAZ5CAQAewHzHgV2VFAGb4f0iPQtGh02KVp
|
||||
Uh3YdgATlie31rI8SeBdK1+c3TF7W7b70sQGH/3gev1rqaKAPObb4T2iTBrnVZpY8/dSIISPqSf5V3emabZ6TYpZ2EKwwp0A7n1J
|
||||
7n3q3RQByg8EWw8V/wBv/bZvN87zfK2jbnHTPWtrW9EsNdsTaajFvTOVYHDIfUHtWjRQB5vJ8J4PMPk6zMkZ/haEE4+oI/lXSeGf
|
||||
Bel+HXM8Iee7Ix50uMqO4UDgfzrpKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiig
|
||||
AooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooo
|
||||
oAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKK
|
||||
KKACiiigAooooAKKKKACiiigAooooAKKKKAP/9k=`
|
||||
|
||||
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
@@ -0,0 +1,121 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 20 Mar 2026 18:50:38 -0700
|
||||
Subject: [PATCH] CUDA get_rows q6_k support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/getrows.cu | 80 ++++++++++++++++++++++++++++++++-
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 1 +
|
||||
2 files changed, 80 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
|
||||
index 2fab33243..dc5c4f57a 100644
|
||||
--- a/ggml/src/ggml-cuda/getrows.cu
|
||||
+++ b/ggml/src/ggml-cuda/getrows.cu
|
||||
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
||||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
+// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
+// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
+// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
+template<typename dst_t>
|
||||
+static __global__ void k_get_rows_q6_K(
|
||||
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
+ const int64_t ne00,
|
||||
+ const int64_t ne11, const int64_t ne12,
|
||||
+ const size_t s1, const size_t s2, const size_t s3,
|
||||
+ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const size_t s10, const size_t s11, const size_t s12) {
|
||||
+
|
||||
+ const int64_t i10 = blockIdx.x; // row index into src1
|
||||
+ const int64_t z = blockIdx.z;
|
||||
+ const int64_t i11 = z / ne12;
|
||||
+ const int64_t i12 = z % ne12;
|
||||
+
|
||||
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
+
|
||||
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
+ const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
+
|
||||
+ const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
+
|
||||
+ // blockIdx.y iterates over Q6_K blocks within the row
|
||||
+ for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
+ const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
+
|
||||
+ // Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
+ const int64_t tid = threadIdx.x;
|
||||
+ const int64_t ip = tid / 32; // 0 or 1
|
||||
+ const int64_t il = tid - 32*ip; // 0..31
|
||||
+ const int64_t is = 8*ip + il/16;
|
||||
+
|
||||
+ const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
+
|
||||
+ const float d = x->d;
|
||||
+ const uint8_t * ql = x->ql + 64*ip + il;
|
||||
+ const uint8_t qh = x->qh[32*ip + il];
|
||||
+ const int8_t * sc = x->scales + is;
|
||||
+
|
||||
+ if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+template<typename dst_t>
|
||||
+static void get_rows_cuda_q6_K(
|
||||
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
+ const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
+ cudaStream_t stream) {
|
||||
+ const int64_t nb_blocks = ne00 / QK_K;
|
||||
+ const dim3 block_dims(64, 1, 1);
|
||||
+ const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
+
|
||||
+ const size_t s1 = nb1 / sizeof(dst_t);
|
||||
+ const size_t s2 = nb2 / sizeof(dst_t);
|
||||
+ const size_t s3 = nb3 / sizeof(dst_t);
|
||||
+
|
||||
+ const size_t s10 = nb10 / sizeof(int32_t);
|
||||
+ const size_t s11 = nb11 / sizeof(int32_t);
|
||||
+ const size_t s12 = nb12 / sizeof(int32_t);
|
||||
+
|
||||
+ k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
+ src0_d, src1_d, dst_d,
|
||||
+ ne00, ne11, ne12,
|
||||
+ s1, s2, s3,
|
||||
+ nb01, nb02, nb03,
|
||||
+ s10, s11, s12);
|
||||
+}
|
||||
+
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
+ get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
+ break;
|
||||
default:
|
||||
- // TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 5c9dfd032..b8ed3709b 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -678,3 +678,113 @@ func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionWriter collects streamed chat responses and outputs a transcription response.
|
||||
type TranscriptionWriter struct {
|
||||
BaseWriter
|
||||
responseFormat string
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func (w *TranscriptionWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.text.WriteString(chatResponse.Message.Content)
|
||||
|
||||
if chatResponse.Done {
|
||||
text := strings.TrimSpace(w.text.String())
|
||||
|
||||
if w.responseFormat == "text" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/plain")
|
||||
_, err := w.ResponseWriter.Write([]byte(text))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
resp := openai.TranscriptionResponse{Text: text}
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(resp); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// TranscriptionMiddleware handles /v1/audio/transcriptions requests.
|
||||
// It accepts multipart/form-data with an audio file and converts it to a chat request.
|
||||
func TranscriptionMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Parse multipart form (limit 25MB).
|
||||
if err := c.Request.ParseMultipartForm(25 << 20); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to parse multipart form: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
model := c.Request.FormValue("model")
|
||||
if model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "file is required: "+err.Error()))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
audioData, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, "failed to read audio file"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(audioData) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "audio file is empty"))
|
||||
return
|
||||
}
|
||||
|
||||
req := openai.TranscriptionRequest{
|
||||
Model: model,
|
||||
AudioData: audioData,
|
||||
ResponseFormat: c.Request.FormValue("response_format"),
|
||||
Language: c.Request.FormValue("language"),
|
||||
Prompt: c.Request.FormValue("prompt"),
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromTranscriptionRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
c.Request.ContentLength = int64(b.Len())
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := &TranscriptionWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
responseFormat: req.ResponseFormat,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +137,7 @@ type Tensor interface {
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
BackendGet() []float32
|
||||
|
||||
FromBytes([]byte)
|
||||
FromFloats([]float32)
|
||||
@@ -162,6 +163,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
Conv1DDW(ctx Context, weight Tensor, s, p, d int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
|
||||
|
||||
@@ -187,6 +189,9 @@ type Tensor interface {
|
||||
Contiguous(ctx Context, shape ...int) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
// PadExt pads with independent left/right amounts per dimension.
|
||||
// Arguments: lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 for dims 0-3.
|
||||
PadExt(ctx Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) Tensor
|
||||
|
||||
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
|
||||
@@ -1069,6 +1069,21 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendGet() []float32 {
|
||||
n := int(C.ggml_nelements(t.t))
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.sync != nil {
|
||||
t.sync()
|
||||
}
|
||||
|
||||
data := make([]float32, n)
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
return data
|
||||
}
|
||||
|
||||
func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
@@ -1313,6 +1328,13 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) PadExt(ctx ml.Context, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad_ext(ctx.(*Context).ctx, t.t, C.int(lp0), C.int(rp0), C.int(lp1), C.int(rp1), C.int(lp2), C.int(rp2), C.int(lp3), C.int(rp3)),
|
||||
}
|
||||
}
|
||||
|
||||
// Permute permutes t according to order. Permute panics if the number of dimensions
|
||||
// in order does not match the number of dimensions in t.
|
||||
func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
@@ -1660,6 +1682,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv1DDW(ctx ml.Context, weight ml.Tensor, s, p, d int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_1d_dw(ctx.(*Context).ctx, weight.(*Tensor).t, t.t, C.int(s), C.int(p), C.int(d)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
|
||||
var tt ml.Tensor = &Tensor{
|
||||
b: t.b,
|
||||
|
||||
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
||||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
template<typename dst_t>
|
||||
static __global__ void k_get_rows_q6_K(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00,
|
||||
const int64_t ne11, const int64_t ne12,
|
||||
const size_t s1, const size_t s2, const size_t s3,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12) {
|
||||
|
||||
const int64_t i10 = blockIdx.x; // row index into src1
|
||||
const int64_t z = blockIdx.z;
|
||||
const int64_t i11 = z / ne12;
|
||||
const int64_t i12 = z % ne12;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
|
||||
// blockIdx.y iterates over Q6_K blocks within the row
|
||||
for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
|
||||
// Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid / 32; // 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0..31
|
||||
const int64_t is = 8*ip + il/16;
|
||||
|
||||
const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
|
||||
const float d = x->d;
|
||||
const uint8_t * ql = x->ql + 64*ip + il;
|
||||
const uint8_t qh = x->qh[32*ip + il];
|
||||
const int8_t * sc = x->scales + is;
|
||||
|
||||
if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void get_rows_cuda_q6_K(
|
||||
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const int64_t nb_blocks = ne00 / QK_K;
|
||||
const dim3 block_dims(64, 1, 1);
|
||||
const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
|
||||
const size_t s1 = nb1 / sizeof(dst_t);
|
||||
const size_t s2 = nb2 / sizeof(dst_t);
|
||||
const size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
const size_t s10 = nb10 / sizeof(int32_t);
|
||||
const size_t s11 = nb11 / sizeof(int32_t);
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
|
||||
k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne11, ne12,
|
||||
s1, s2, s3,
|
||||
nb01, nb02, nb03,
|
||||
s10, s11, s12);
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -47,6 +47,12 @@ type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// PostLoader is an optional interface that models can implement to run
|
||||
// initialization steps after backend weights have been loaded.
|
||||
type PostLoader interface {
|
||||
PostLoad() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
|
||||
@@ -68,6 +68,8 @@ func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Conv1DDW(ctx ml.Context, _ ml.Tensor, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) PadExt(ctx ml.Context, _, _, _, _, _, _, _, _ int) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
265
model/models/gemma4/model.go
Normal file
265
model/models/gemma4/model.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
*AudioModel `gguf:"a"`
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*AudioMultimodalProjector `gguf:"mm.a"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageEndTokenID int32
|
||||
audioTokenID int32
|
||||
audioEndTokenID int32
|
||||
|
||||
audioOpts *AudioModelOptions
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Projection *ClippableLinear `gguf:"input_projection"`
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
|
||||
// Post-projection RMSNorm without learned weight
|
||||
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
|
||||
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
|
||||
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
|
||||
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
|
||||
tokenizer.WithSentencePieceNormalizer())
|
||||
|
||||
// Look up special token IDs for vision and audio
|
||||
imageTokenID := int32(-1)
|
||||
imageEndTokenID := int32(-1)
|
||||
audioTokenID := int32(-1)
|
||||
audioEndTokenID := int32(-1)
|
||||
for i, tok := range vocabulary.Values {
|
||||
switch tok {
|
||||
case "<|image>":
|
||||
imageTokenID = int32(i)
|
||||
case "<image|>":
|
||||
imageEndTokenID = int32(i)
|
||||
case "<|audio>":
|
||||
audioTokenID = int32(i)
|
||||
case "<audio|>":
|
||||
audioEndTokenID = int32(i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
AudioMultimodalProjector: &AudioMultimodalProjector{},
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
audioTokenID: audioTokenID,
|
||||
audioEndTokenID: audioEndTokenID,
|
||||
audioOpts: newAudioModelOptions(c),
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
// Audio input: detect WAV format and route to audio encoder.
|
||||
if isAudioData(multimodalData) {
|
||||
return m.encodeAudioMultimodal(ctx, multimodalData)
|
||||
}
|
||||
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
|
||||
|
||||
t1 := time.Now()
|
||||
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
|
||||
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
|
||||
|
||||
numPatchesX := imgW / m.ImageProcessor.patchSize
|
||||
numPatchesY := imgH / m.ImageProcessor.patchSize
|
||||
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
|
||||
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostLoad() error {
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
|
||||
if m.AudioModel == nil || m.audioOpts == nil {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
samples, err := decodeWAV(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("audio: decode", "elapsed", time.Since(t0), "samples", len(samples), "duration_s", float64(len(samples))/audioSampleRate)
|
||||
|
||||
// Pad waveform to next multiple of 128.
|
||||
if rem := len(samples) % 128; rem != 0 {
|
||||
samples = append(samples, make([]float32, 128-rem)...)
|
||||
}
|
||||
|
||||
// Compute mel spectrogram.
|
||||
melData, numFrames := computeMelSpectrogram(samples)
|
||||
if numFrames == 0 {
|
||||
return nil, fmt.Errorf("audio too short to encode")
|
||||
}
|
||||
slog.Info("audio: mel", "frames", numFrames, "elapsed", time.Since(t0))
|
||||
|
||||
// Create input tensor [melBins, numFrames] (GGML ne order). FromFloats creates F32.
|
||||
melTensor := ctx.Input().FromFloats(melData, melBins, numFrames)
|
||||
|
||||
// Run audio encoder.
|
||||
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, m.AudioMultimodalProjector, m.audioOpts)
|
||||
slog.Info("audio: encoded", "elapsed", time.Since(t0), "shape", audioOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
|
||||
}
|
||||
|
||||
// audioTag marks multimodal data as audio (vs vision) for PostTokenize.
|
||||
type audioTag struct{}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
numTokens := inputMultimodal.Dim(1)
|
||||
|
||||
// Determine if this is audio or vision based on the tag.
|
||||
_, isAudio := inp.Multimodal[0].Data.(audioTag)
|
||||
|
||||
var beginToken, endToken int32
|
||||
if isAudio {
|
||||
beginToken = m.audioTokenID
|
||||
endToken = m.audioEndTokenID
|
||||
} else {
|
||||
beginToken = m.imageTokenID
|
||||
endToken = m.imageEndTokenID
|
||||
}
|
||||
|
||||
if beginToken >= 0 {
|
||||
result = append(result, &input.Input{Token: beginToken, SameBatch: numTokens + 2})
|
||||
}
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
|
||||
)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numTokens-1)...)
|
||||
|
||||
if endToken >= 0 {
|
||||
result = append(result, &input.Input{Token: endToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
|
||||
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma4", New)
|
||||
}
|
||||
612
model/models/gemma4/model_audio.go
Normal file
612
model/models/gemma4/model_audio.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// AudioModel holds the audio encoder and configuration.
|
||||
type AudioModel struct {
|
||||
// SSCP: Sub-Sample Convolution Projection.
|
||||
SSCPConv0 *AudioConvBlock `gguf:"conv1d.0"`
|
||||
SSCPConv1 *AudioConvBlock `gguf:"conv1d.1"`
|
||||
|
||||
// SSCP output projection (linear).
|
||||
SSCPInputProj *nn.Linear `gguf:"pre_encode.out"`
|
||||
|
||||
// Conformer blocks.
|
||||
Layers []AudioConformerBlock `gguf:"blk"`
|
||||
|
||||
// Output projection to embedder dimension.
|
||||
OutputProj *AudioOutputProj `gguf:"output_proj"`
|
||||
|
||||
AudioModelOptions
|
||||
}
|
||||
|
||||
type AudioOutputProj struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
// AudioModelOptions holds audio model hyperparameters.
|
||||
type AudioModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
ffnSize int
|
||||
numLayers int
|
||||
melBins int
|
||||
chunkSize int
|
||||
maxPast int
|
||||
maxFuture int
|
||||
contextSize int
|
||||
logitCap float32
|
||||
residualWeight float32
|
||||
gradClip float32
|
||||
convKernelSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
// AudioConvBlock is a single 2D convolution block for the SSCP.
|
||||
type AudioConvBlock struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Norm *nn.LayerNorm `gguf:"norm"`
|
||||
}
|
||||
|
||||
// AudioConformerBlock is a single conformer layer.
|
||||
// All tensors are flat at the block level (a.blk.N.<name>) using underscore naming.
|
||||
type AudioConformerBlock struct {
|
||||
// Block-level norm
|
||||
Norm *nn.RMSNorm `gguf:"layer_pre_norm"`
|
||||
|
||||
// FFW start
|
||||
FFWNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
FFWUp *AudioClippableLinear `gguf:"ffn_up"`
|
||||
FFWDown *AudioClippableLinear `gguf:"ffn_down"`
|
||||
FFWPostNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
// FFW end
|
||||
FFWNorm1 *nn.RMSNorm `gguf:"ffn_norm_1"`
|
||||
FFWUp1 *AudioClippableLinear `gguf:"ffn_up_1"`
|
||||
FFWDown1 *AudioClippableLinear `gguf:"ffn_down_1"`
|
||||
FFWPostNorm1 *nn.RMSNorm `gguf:"ffn_post_norm_1"`
|
||||
|
||||
// Attention
|
||||
AttnQ *AudioClippableLinear `gguf:"attn_q"`
|
||||
AttnK *AudioClippableLinear `gguf:"attn_k"`
|
||||
AttnV *AudioClippableLinear `gguf:"attn_v"`
|
||||
AttnOut *AudioClippableLinear `gguf:"attn_out"`
|
||||
AttnPreNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
AttnPostNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
LinearPos ml.Tensor `gguf:"linear_pos.weight"`
|
||||
PerDimScale ml.Tensor `gguf:"per_dim_scale.weight"`
|
||||
|
||||
// LightConv1d
|
||||
ConvPW1 *AudioClippableLinear `gguf:"conv_pw1"`
|
||||
ConvPW2 *AudioClippableLinear `gguf:"conv_pw2"`
|
||||
ConvDW ml.Tensor `gguf:"conv_dw.weight"`
|
||||
ConvNorm *nn.RMSNorm `gguf:"conv_norm"`
|
||||
NormConv *nn.RMSNorm `gguf:"norm_conv"`
|
||||
}
|
||||
|
||||
// AudioClippableLinear is a linear layer with optional input/output clamping.
|
||||
type AudioClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
// Cached scalar clamp values (populated on first forward).
|
||||
inMin, inMax, outMin, outMax float32
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) loadClamps() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
if l.InputMin != nil {
|
||||
vals := l.InputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.InputMax != nil {
|
||||
vals := l.InputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMax = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMin != nil {
|
||||
vals := l.OutputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMax != nil {
|
||||
vals := l.OutputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMax = vals[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
l.loadClamps()
|
||||
if l.inMax != 0 {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.Bias != nil {
|
||||
out = out.Add(ctx, l.Bias)
|
||||
}
|
||||
if l.outMax != 0 {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AudioMultimodalProjector is the audio-to-text embedding projector.
|
||||
type AudioMultimodalProjector struct {
|
||||
Projection *AudioClippableLinear `gguf:"input_projection"`
|
||||
FC *AudioFC `gguf:"fc"`
|
||||
}
|
||||
|
||||
type AudioFC struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (p *AudioMultimodalProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
// FC: output projection from conformer to embedder dimension.
|
||||
x = p.FC.Weight.Mulmat(ctx, x)
|
||||
if p.FC.Bias != nil {
|
||||
x = x.Add(ctx, p.FC.Bias)
|
||||
}
|
||||
// Pre-projection RMSNorm (without learned weight) — matches Python's embedding_pre_projection_norm.
|
||||
x = x.RMSNorm(ctx, nil, eps)
|
||||
// Embedding projection to text hidden size.
|
||||
x = p.Projection.Forward(ctx, x)
|
||||
return x
|
||||
}
|
||||
|
||||
// ForwardAudio encodes mel spectrogram features into soft tokens.
|
||||
// melFeatures: float32 tensor with ne[0]=melBins, ne[1]=numFrames.
|
||||
// Returns: [hiddenSize, numTokens] tensor.
|
||||
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, proj *AudioMultimodalProjector, opts *AudioModelOptions) ml.Tensor {
|
||||
// SSCP Conv2D input: ne[0]=F (freq/width), ne[1]=T (time/height), ne[2]=C_in, ne[3]=B
|
||||
// melFeatures is [melBins, numFrames], add channel and batch dims.
|
||||
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
|
||||
|
||||
// SSCP Conv block 0: [F, T, 1, 1] → [F', T', C0, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv0, x, opts)
|
||||
|
||||
// SSCP Conv block 1: [F', T', C0, 1] → [F'', T'', C1, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv1, x, opts)
|
||||
|
||||
// After conv blocks, layout is [F'', T'', C_out, B].
|
||||
// Permute to [C_out*F'', T'', B] for linear projection (channels+freq in ne[0]).
|
||||
fOut := x.Dim(0)
|
||||
tOut := x.Dim(1)
|
||||
cOut := x.Dim(2)
|
||||
// Permute [F'', T'', C, B] → [C, F'', T'', B]
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = x.Reshape(ctx, cOut*fOut, tOut)
|
||||
|
||||
// Linear projection to hidden size.
|
||||
x = m.SSCPInputProj.Forward(ctx, x)
|
||||
|
||||
// Build causal-valid mask for conformer attention.
|
||||
causalMask := buildCausalValidMaskF32(int(opts.chunkSize), opts.maxPast, opts.maxFuture)
|
||||
|
||||
// Run conformer blocks.
|
||||
for i := range m.Layers {
|
||||
x = m.Layers[i].Forward(ctx, x, causalMask, opts, i)
|
||||
}
|
||||
|
||||
// Output projection.
|
||||
if m.OutputProj != nil {
|
||||
x = m.OutputProj.Weight.Mulmat(ctx, x)
|
||||
if m.OutputProj.Bias != nil {
|
||||
x = x.Add(ctx, m.OutputProj.Bias)
|
||||
}
|
||||
}
|
||||
|
||||
// Audio embedder: project to text embedding space.
|
||||
if proj != nil {
|
||||
x = proj.Forward(ctx, x, opts.eps)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardConvBlock runs a single SSCP Conv2D block.
|
||||
// Conv2D receiver is the kernel, argument is the input data.
|
||||
// Input: [F, T, C_in, B]. Output: [F', T', C_out, B].
|
||||
func forwardConvBlock(ctx ml.Context, block *AudioConvBlock, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
// Conv2D: kernel.Conv2D(ctx, input, s0, s1, p0, p1, d0, d1)
|
||||
// Kernel is 3x3, stride 2x2, padding 1x1 (matching SSCP config).
|
||||
// Output layout: [F', T', C_out, B]
|
||||
// Make weight contiguous — the shape reversal in the converter creates
|
||||
// a tensor where the physical data order doesn't match ne[]/stride[].
|
||||
weight := block.Weight.Contiguous(ctx)
|
||||
x = weight.Conv2D(ctx, x, 2, 2, 1, 1, 1, 1)
|
||||
|
||||
// LayerNorm needs channels in ne[0]. Permute [F', T', C_out, B] → [C_out, F', T', B],
|
||||
// norm, then permute back.
|
||||
// GGML permute: axis i says where old axis i goes.
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0 → [C_out, F', T', B]
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = block.Norm.Forward(ctx, x, opts.eps)
|
||||
// (2,0,1,3): old[0]→pos2, old[1]→pos0, old[2]→pos1 → [F', T', C_out, B]
|
||||
x = x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
x = x.RELU(ctx)
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward runs a single conformer block.
|
||||
func (cb *AudioConformerBlock) Forward(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
|
||||
// FFW start (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm, cb.FFWUp, cb.FFWDown, cb.FFWPostNorm, x, opts)
|
||||
|
||||
// Self-attention.
|
||||
x = cb.forwardAttention(ctx, x, causalMask, opts, blockIdx)
|
||||
|
||||
// Lightweight Conv1d.
|
||||
x = cb.forwardLightConv(ctx, x, opts, blockIdx)
|
||||
|
||||
// FFW end (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm1, cb.FFWUp1, cb.FFWDown1, cb.FFWPostNorm1, x, opts)
|
||||
|
||||
// Gradient clipping + final norm.
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.Norm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardFFW runs a feedforward module with half-residual connection.
|
||||
func (cb *AudioConformerBlock) forwardFFW(ctx ml.Context, preNorm *nn.RMSNorm, up, down *AudioClippableLinear, postNorm *nn.RMSNorm, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = preNorm.Forward(ctx, x, opts.eps)
|
||||
x = up.Forward(ctx, x)
|
||||
x = x.SILU(ctx)
|
||||
x = down.Forward(ctx, x)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = postNorm.Forward(ctx, x, opts.eps)
|
||||
x = x.Scale(ctx, float64(opts.residualWeight))
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// forwardAttention runs the conformer block-local attention with relative position embeddings.
|
||||
func (cb *AudioConformerBlock) forwardAttention(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPreNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
hiddenSize := x.Dim(0)
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// QKV projections: [hiddenSize, seqLen] → [headDim, numHeads, seqLen]
|
||||
q := cb.AttnQ.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
k := cb.AttnK.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
v := cb.AttnV.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
|
||||
// Per-dim scaling for queries: (headDim^-0.5 / log(2)) * softplus(per_dim_scale)
|
||||
// per_dim_scale is already softplus'd from the converter.
|
||||
qScale := float64(math.Pow(float64(opts.headDim), -0.5)) / math.Log(2)
|
||||
q = q.Scale(ctx, qScale)
|
||||
if cb.PerDimScale != nil {
|
||||
q = q.Mul(ctx, cb.PerDimScale)
|
||||
}
|
||||
|
||||
// Key scaling: softplus(1) / log(2) — matches the query base scaling convention.
|
||||
kScale := math.Log(1+math.E) / math.Log(2)
|
||||
k = k.Scale(ctx, kScale)
|
||||
|
||||
// Build sinusoidal position embeddings for the block-local context.
|
||||
maxSpan := opts.maxPast + opts.maxFuture + 1 // 13 unique relative positions
|
||||
posEmb := cb.buildPositionEmbeddings(ctx, maxSpan, opts)
|
||||
// posEmb: [headDim, numHeads, maxSpan]
|
||||
|
||||
// Block-local attention: process chunks of size chunkSize.
|
||||
chunkSize := opts.chunkSize
|
||||
numChunks := (seqLen + chunkSize - 1) / chunkSize
|
||||
contextSize := opts.contextSize
|
||||
|
||||
// Pad q/k/v to multiple of chunkSize on the time dimension (dim 2).
|
||||
padT := numChunks*chunkSize - seqLen
|
||||
if padT > 0 {
|
||||
q = q.Pad(ctx, 0, 0, padT, 0)
|
||||
k = k.Pad(ctx, 0, 0, padT, 0)
|
||||
v = v.Pad(ctx, 0, 0, padT, 0)
|
||||
}
|
||||
paddedLen := numChunks * chunkSize
|
||||
|
||||
// Pad k/v for context extraction: add maxPast on left, (maxFuture+chunkSize-1) on right.
|
||||
// Use Pad (right) + PadExt (left) workaround since PadExt+Slice has issues.
|
||||
// Actually use Concat with zero tensors for reliable left-padding.
|
||||
padLeft := opts.maxPast
|
||||
padRight := opts.maxFuture + chunkSize - 1
|
||||
zeroLeft := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padLeft), opts.headDim, opts.numHeads, padLeft)
|
||||
zeroRight := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padRight), opts.headDim, opts.numHeads, padRight)
|
||||
kPadded := zeroLeft.Concat(ctx, k, 2).Concat(ctx, zeroRight, 2)
|
||||
vPadded := zeroLeft.Concat(ctx, v, 2).Concat(ctx, zeroRight, 2)
|
||||
|
||||
// Reshape q into chunks: [headDim, numHeads, numChunks, chunkSize]
|
||||
qChunked := q.Reshape(ctx, opts.headDim, opts.numHeads, numChunks, chunkSize)
|
||||
|
||||
// Process each chunk and collect results.
|
||||
chunkOutputs := make([]ml.Tensor, numChunks)
|
||||
for u := range numChunks {
|
||||
// Extract query block: [headDim, numHeads, 1, chunkSize] → [headDim, numHeads, chunkSize]
|
||||
qBlock := qChunked.Slice(ctx, 2, u, u+1, 1).Reshape(ctx, opts.headDim, opts.numHeads, chunkSize)
|
||||
|
||||
// Extract key/value context: [headDim, numHeads, contextSize]
|
||||
cStart := u * chunkSize // offset in kPadded (padLeft already accounts for left context)
|
||||
kCtx := kPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
vCtx := vPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
|
||||
// Content-content logits: qBlock^T @ kCtx → [chunkSize, contextSize] per head.
|
||||
// Mulmat(a, b) = a^T @ b. We want Q^T K, so: kCtx.Mulmat(qBlock) but that gives
|
||||
// [numHeads, chunkSize, contextSize] with wrong batching.
|
||||
// Instead: permute to [headDim, chunkSize, numHeads] and [headDim, contextSize, numHeads]
|
||||
// then Mulmat batches over numHeads.
|
||||
// GGML permute(0,2,1,3): old[0]→0, old[1]→2, old[2]→1
|
||||
qP := qBlock.Permute(ctx, 0, 2, 1, 3) // [headDim, chunkSize, numHeads]
|
||||
kP := kCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
|
||||
termAC := kP.MulmatFullPrec(ctx, qP) // [contextSize, chunkSize, numHeads]
|
||||
|
||||
// Content-position logits: qBlock^T @ posEmb → [chunkSize, maxSpan] per head.
|
||||
pP := posEmb.Permute(ctx, 0, 2, 1, 3) // [headDim, maxSpan, numHeads]
|
||||
termBDRaw := pP.MulmatFullPrec(ctx, qP) // [maxSpan, chunkSize, numHeads]
|
||||
|
||||
// Relative shift: [maxSpan, chunkSize, numHeads] → [contextSize, chunkSize, numHeads]
|
||||
termBD := cb.relativeShiftGGML(ctx, termBDRaw, maxSpan, chunkSize, contextSize, opts.numHeads)
|
||||
|
||||
// Combined logits.
|
||||
logits := termAC.Add(ctx, termBD)
|
||||
|
||||
// Logit softcap: tanh(logits / cap) * cap
|
||||
logits = logits.Scale(ctx, 1.0/float64(opts.logitCap))
|
||||
logits = logits.Tanh(ctx)
|
||||
logits = logits.Scale(ctx, float64(opts.logitCap))
|
||||
|
||||
// Apply combined causal + validity mask.
|
||||
// causalMask [chunkSize * contextSize]: 1=causal-allowed, 0=masked.
|
||||
// Validity: context positions before the actual sequence start are invalid.
|
||||
// For chunk u, context position c corresponds to actual time: u*chunkSize + c - padLeft.
|
||||
// Valid if 0 <= actual_time < seqLen.
|
||||
// Mask tensor layout: [contextSize, chunkSize, 1] with ne[0]=contextSize contiguous.
|
||||
// Element at (context=j, chunk=i) is at flat index: i*contextSize + j.
|
||||
maskData := make([]float32, contextSize*chunkSize)
|
||||
for i := range chunkSize {
|
||||
for j := range contextSize {
|
||||
actualTime := u*chunkSize + j - padLeft
|
||||
causalOK := causalMask[i*contextSize+j] > 0
|
||||
validOK := actualTime >= 0 && actualTime < seqLen
|
||||
if causalOK && validOK {
|
||||
maskData[i*contextSize+j] = 0
|
||||
} else {
|
||||
maskData[i*contextSize+j] = -1e9
|
||||
}
|
||||
}
|
||||
}
|
||||
mask := ctx.Input().FromFloats(maskData, contextSize, chunkSize, 1) // 3D for broadcasting over numHeads
|
||||
logits = logits.Add(ctx, mask)
|
||||
|
||||
// Softmax over context dimension (dim 0 = contextSize).
|
||||
logits = logits.Softmax(ctx) // softmax over ne[0]=contextSize
|
||||
|
||||
// Weighted sum: logits^T @ vCtx.
|
||||
// logits: [contextSize, chunkSize, numHeads], vCtx: [headDim, numHeads, contextSize]
|
||||
// vCtx permuted: [headDim, contextSize, numHeads]
|
||||
vP := vCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
// Weighted sum: for each head, value[headDim, contextSize] @ weights[contextSize, chunkSize]
|
||||
// = [headDim, chunkSize].
|
||||
// Mulmat(a, b) = a^T @ b. Need a=[contextSize, headDim, numHeads], b=[contextSize, chunkSize, numHeads].
|
||||
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [contextSize, headDim, numHeads]
|
||||
chunkOut := vPT.Mulmat(ctx, logits) // [headDim, chunkSize, numHeads]
|
||||
|
||||
// Permute back to [headDim, numHeads, chunkSize]
|
||||
chunkOut = chunkOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
chunkOutputs[u] = chunkOut
|
||||
}
|
||||
|
||||
// Concatenate chunk outputs along time dimension.
|
||||
var attnOut ml.Tensor
|
||||
if numChunks == 1 {
|
||||
attnOut = chunkOutputs[0]
|
||||
} else {
|
||||
attnOut = chunkOutputs[0]
|
||||
for _, co := range chunkOutputs[1:] {
|
||||
attnOut = attnOut.Concat(ctx, co, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to original sequence length if we padded.
|
||||
if paddedLen > seqLen {
|
||||
attnOut = attnOut.Slice(ctx, 2, 0, seqLen, 1).Contiguous(ctx)
|
||||
}
|
||||
|
||||
// Reshape to [hiddenSize, seqLen] and project.
|
||||
attnOut = attnOut.Reshape(ctx, hiddenSize, seqLen)
|
||||
x = cb.AttnOut.Forward(ctx, attnOut)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPostNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// buildPositionEmbeddings builds sinusoidal position embeddings and projects through linear_pos.
|
||||
// Returns [headDim, numHeads, maxSpan] tensor.
|
||||
func (cb *AudioConformerBlock) buildPositionEmbeddings(ctx ml.Context, maxSpan int, opts *AudioModelOptions) ml.Tensor {
|
||||
halfDim := opts.hiddenSize / 2
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// inv_timescales: exp(-i * log(10000) / max(D/2-1, 1))
|
||||
logInc := math.Log(10000.0) / math.Max(float64(halfDim-1), 1)
|
||||
|
||||
// Sinusoidal embeddings for relative positions [maxPast, maxPast-1, ..., -maxFuture].
|
||||
posData := make([]float32, hiddenSize*maxSpan)
|
||||
for p := range maxSpan {
|
||||
relPos := float64(opts.maxPast - p)
|
||||
for d := range halfDim {
|
||||
angle := relPos * math.Exp(float64(-d)*logInc)
|
||||
posData[p*hiddenSize+d] = float32(math.Sin(angle))
|
||||
posData[p*hiddenSize+halfDim+d] = float32(math.Cos(angle))
|
||||
}
|
||||
}
|
||||
|
||||
// Create [hiddenSize, maxSpan] input tensor.
|
||||
posEmb := ctx.Input().FromFloats(posData, hiddenSize, maxSpan)
|
||||
|
||||
// Project through linear_pos: [hiddenSize, maxSpan] → Mulmat → [numHeads*headDim, maxSpan]
|
||||
projPos := cb.LinearPos.Mulmat(ctx, posEmb)
|
||||
|
||||
// Reshape to [headDim, numHeads, maxSpan].
|
||||
return projPos.Reshape(ctx, opts.headDim, opts.numHeads, maxSpan)
|
||||
}
|
||||
|
||||
// relativeShiftGGML performs the relative shift to extract correct position logits.
|
||||
// Input: [maxSpan, chunkSize, numHeads]. Output: [contextSize, chunkSize, numHeads].
|
||||
func (cb *AudioConformerBlock) relativeShiftGGML(ctx ml.Context, x ml.Tensor, maxSpan, chunkSize, contextSize, numHeads int) ml.Tensor {
|
||||
// The shift trick: pad ne[0] to contextSize+1, reshape to flatten first two dims,
|
||||
// skip first (contextSize+1-maxSpan) elements, take contextSize*chunkSize elements, reshape back.
|
||||
padAmt := contextSize + 1 - maxSpan
|
||||
if padAmt > 0 {
|
||||
x = x.Pad(ctx, padAmt, 0, 0, 0) // [maxSpan+padAmt, chunkSize, numHeads] = [contextSize+1, chunkSize, numHeads]
|
||||
}
|
||||
// Reshape to [(contextSize+1)*chunkSize, numHeads]
|
||||
x = x.Reshape(ctx, (contextSize+1)*chunkSize, numHeads)
|
||||
// Take the first contextSize*chunkSize elements (the standard relative shift trick).
|
||||
x = x.Slice(ctx, 0, 0, contextSize*chunkSize, 1).Contiguous(ctx)
|
||||
// Reshape to [contextSize, chunkSize, numHeads]
|
||||
return x.Reshape(ctx, contextSize, chunkSize, numHeads)
|
||||
}
|
||||
|
||||
// forwardLightConv runs the lightweight depthwise convolution module.
|
||||
func (cb *AudioConformerBlock) forwardLightConv(ctx ml.Context, x ml.Tensor, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
|
||||
x = cb.ConvNorm.Forward(ctx, x, opts.eps)
|
||||
x = cb.ConvPW1.Forward(ctx, x) // [2*D, T, B]
|
||||
|
||||
// GLU: split in half along dim 0, sigmoid gate, multiply.
|
||||
d := x.Dim(0) / 2
|
||||
data := x.Slice(ctx, 0, 0, d, 1).Contiguous(ctx)
|
||||
gate := x.Slice(ctx, 0, d, d*2, 1).Contiguous(ctx).Sigmoid(ctx)
|
||||
x = data.Mul(ctx, gate) // [D, T, B]
|
||||
|
||||
// Depthwise Conv1d: manual implementation using model weight tensor slices.
|
||||
// Kernel cb.ConvDW shape: [K=5, D=1024] (ne[0]=K, ne[1]=D) after shape reversal.
|
||||
// Actually in GGML, ne[0]=K=5 contiguous, ne[1]=D=1024.
|
||||
// We need per-tap weights [D] and shifted input copies.
|
||||
kernelSize := cb.ConvDW.Dim(0) // K=5
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// Transpose kernel to [D, K] for per-tap slicing.
|
||||
// GGML permute(1,0,2,3): old[0]→pos1, old[1]→pos0 → swap ne[0] and ne[1]
|
||||
kernelT := cb.ConvDW.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [D, K]
|
||||
|
||||
var convOut ml.Tensor
|
||||
for k := range kernelSize {
|
||||
shift := kernelSize - 1 - k
|
||||
var shifted ml.Tensor
|
||||
if shift == 0 {
|
||||
shifted = x
|
||||
} else {
|
||||
trimmed := x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
|
||||
shifted = trimmed.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
|
||||
}
|
||||
|
||||
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) // [D, 1]
|
||||
term := shifted.Mul(ctx, wk)
|
||||
if convOut == nil {
|
||||
convOut = term
|
||||
} else {
|
||||
convOut = convOut.Add(ctx, term)
|
||||
}
|
||||
}
|
||||
x = convOut
|
||||
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.NormConv.Forward(ctx, x, opts.eps)
|
||||
x = x.SILU(ctx)
|
||||
x = cb.ConvPW2.Forward(ctx, x)
|
||||
|
||||
return x.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func newAudioModel(c fs.Config) *AudioModel {
|
||||
numLayers := int(c.Uint("audio.block_count", 0))
|
||||
if numLayers == 0 {
|
||||
return nil
|
||||
}
|
||||
return &AudioModel{
|
||||
Layers: make([]AudioConformerBlock, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
func newAudioModelOptions(c fs.Config) *AudioModelOptions {
|
||||
hiddenSize := int(c.Uint("audio.embedding_length", 0))
|
||||
if hiddenSize == 0 {
|
||||
return nil
|
||||
}
|
||||
numHeads := int(c.Uint("audio.attention.head_count", 8))
|
||||
headDim := hiddenSize / numHeads
|
||||
chunkSize := 12 // default conformer chunk size
|
||||
maxPast := 12 // conf_attention_context_left - 1
|
||||
maxFuture := 0 // conf_attention_context_right
|
||||
convKernel := int(c.Uint("audio.conv_kernel_size", 5))
|
||||
|
||||
eps := c.Float("audio.attention.layer_norm_epsilon", 1e-6)
|
||||
|
||||
return &AudioModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
ffnSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
|
||||
numLayers: int(c.Uint("audio.block_count", 12)),
|
||||
melBins: int(c.Uint("audio.num_mel_bins", 128)),
|
||||
chunkSize: chunkSize,
|
||||
maxPast: maxPast,
|
||||
maxFuture: maxFuture,
|
||||
contextSize: chunkSize + maxPast + maxFuture,
|
||||
logitCap: 50.0,
|
||||
residualWeight: 0.5,
|
||||
gradClip: 1e10,
|
||||
convKernelSize: convKernel,
|
||||
eps: float32(eps),
|
||||
}
|
||||
}
|
||||
|
||||
// buildCausalValidMaskF32 creates the causal-valid mask for block-local attention.
|
||||
// Returns flat [chunkSize * contextSize] float32 data (1.0 = allowed, 0.0 = masked).
|
||||
func buildCausalValidMaskF32(chunkSize, maxPast, maxFuture int) []float32 {
|
||||
contextSize := chunkSize + maxPast + maxFuture
|
||||
upperDiag := maxPast + maxFuture
|
||||
|
||||
result := make([]float32, chunkSize*contextSize)
|
||||
for r := range chunkSize {
|
||||
for c := range contextSize {
|
||||
lower := (r <= c) // tril(contextSize, chunkSize) transposed
|
||||
upper := (c <= r+int(upperDiag)) // tril(chunkSize, contextSize, diag=upperDiag)
|
||||
if lower && upper {
|
||||
result[r*contextSize+c] = 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
475
model/models/gemma4/model_text.go
Normal file
475
model/models/gemma4/model_text.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads int
|
||||
numGlobalKVHeads int
|
||||
headDim, globalHeadDim int
|
||||
hiddenLayers int
|
||||
hiddenSizePerLayerInput int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeLocalBase float32
|
||||
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
||||
|
||||
slidingWindowPattern []bool
|
||||
// kvDonorMap maps shared layer index -> donor layer index.
|
||||
// Donor is the last non-shared layer of the same type (sliding/full).
|
||||
kvDonorMap map[int]int
|
||||
|
||||
finalLogitSoftcap float32
|
||||
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(layer int) bool {
|
||||
if layer < len(o.slidingWindowPattern) {
|
||||
return o.slidingWindowPattern[layer]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
||||
if o.isLocal(layer) {
|
||||
return o.ropeLocalBase, o.headDim
|
||||
}
|
||||
return o.ropeBase, o.partialRotaryDims
|
||||
}
|
||||
|
||||
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.numKVHeads
|
||||
}
|
||||
if o.numGlobalKVHeads > 0 {
|
||||
return o.numGlobalKVHeads
|
||||
}
|
||||
return o.numKVHeads
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDimForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.headDim
|
||||
}
|
||||
return o.globalHeadDim
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
*PerLayerProjector
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
|
||||
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
||||
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
||||
headDim := int(c.Uint("attention.key_length_swa", 256))
|
||||
|
||||
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
||||
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
||||
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
||||
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
||||
if partialRotaryDims == 0 {
|
||||
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
||||
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
||||
}
|
||||
|
||||
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
||||
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
||||
if ropeLocalBase == 0 {
|
||||
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
||||
}
|
||||
|
||||
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
||||
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
||||
numKVHeads := 0
|
||||
kvHeadsArray := c.Ints("attention.head_count_kv")
|
||||
if len(kvHeadsArray) > 0 {
|
||||
numKVHeads = int(kvHeadsArray[0])
|
||||
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
||||
for i, isLocal := range slidingPattern {
|
||||
if !isLocal && i < len(kvHeadsArray) {
|
||||
numGlobalKVHeads = int(kvHeadsArray[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
||||
}
|
||||
|
||||
// Compute KV sharing donor map (same logic as MLX)
|
||||
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
||||
kvDonorMap := make(map[int]int)
|
||||
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
||||
firstShared := numLayers - sharedLayers
|
||||
for i := firstShared; i < numLayers; i++ {
|
||||
isLocal := slidingPattern[i]
|
||||
// Find last non-shared layer of same type
|
||||
for j := firstShared - 1; j >= 0; j-- {
|
||||
if slidingPattern[j] == isLocal {
|
||||
kvDonorMap[i] = j
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextLayer, numLayers),
|
||||
TextOptions: TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: numKVHeads,
|
||||
numGlobalKVHeads: numGlobalKVHeads,
|
||||
headDim: headDim,
|
||||
globalHeadDim: globalHeadDim,
|
||||
hiddenLayers: numLayers,
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: ropeBase,
|
||||
ropeLocalBase: ropeLocalBase,
|
||||
partialRotaryDims: partialRotaryDims,
|
||||
slidingWindowPattern: slidingPattern,
|
||||
kvDonorMap: kvDonorMap,
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
numExperts: int(c.Uint("expert_count", 0)),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// Inject vision embeddings into the hidden state
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
// PLE
|
||||
var perLayerInputs ml.Tensor
|
||||
if m.PerLayerProjector != nil {
|
||||
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
||||
}
|
||||
|
||||
for i := range len(m.Layers) {
|
||||
layer := m.Layers[i]
|
||||
if cache != nil {
|
||||
cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
if !m.isLocal(i) {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
var perLayerInput ml.Tensor
|
||||
if perLayerInputs != nil {
|
||||
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
||||
}
|
||||
|
||||
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
||||
isShared := false
|
||||
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
||||
// Set cache layer to donor so Get() reads donor's K/V
|
||||
cache.SetLayer(donorLayer)
|
||||
isShared = true
|
||||
}
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
// PerLayerProjector implements PLE.
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
hd := opts.headDimForLayer(layer)
|
||||
kvHeads := opts.kvHeadsForLayer(layer)
|
||||
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
|
||||
var k, v ml.Tensor
|
||||
if !sharedKV {
|
||||
k = sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
|
||||
if sa.Value != nil {
|
||||
v = sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
} else {
|
||||
// K=V: use raw K projection (before K norm) as V
|
||||
v = k
|
||||
}
|
||||
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
||||
}
|
||||
|
||||
// RoPE with proportional freq_factors on global layers
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
||||
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
if k != nil {
|
||||
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
// RMSNorm without learned weight
|
||||
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
||||
// Scale by 1/sqrt(hidden_size)
|
||||
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
// Multiply by learned scale parameter
|
||||
x = x.Mul(ctx, r.Scale)
|
||||
// Project to expert logits
|
||||
expertScores := r.Proj.Forward(ctx, x)
|
||||
// Softmax over experts
|
||||
routingWeights = expertScores.Softmax(ctx)
|
||||
// TopK expert selection
|
||||
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
return routingWeights, selectedExperts
|
||||
}
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
var gateOut, upOut ml.Tensor
|
||||
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
||||
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
||||
nFF := gateUp.Dim(0) / 2
|
||||
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
||||
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
||||
} else {
|
||||
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
}
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply per-expert down projection scale when present.
|
||||
if moe.DownScale != nil {
|
||||
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
||||
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
||||
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
experts = experts.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum across experts
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
||||
|
||||
// MoE (present only for models with enable_moe_block=true)
|
||||
Router *TextRouter
|
||||
MoE *TextMoEBlock
|
||||
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
||||
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
||||
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
if perLayerInput != nil {
|
||||
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
||||
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
combined := mlpState.Add(ctx, moeState)
|
||||
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
||||
hiddenState = combined.Add(ctx, residual)
|
||||
} else {
|
||||
// Dense layers: MLP only
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// PLE injection (after MLP residual)
|
||||
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
||||
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
||||
pleState = pleState.GELU(ctx, perLayerInput)
|
||||
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
||||
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, pleState)
|
||||
}
|
||||
|
||||
// Layer scalar applied at end of layer (full-attention layers only)
|
||||
if l.LayerScalar != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
392
model/models/gemma4/model_vision.go
Normal file
392
model/models/gemma4/model_vision.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
const batchSize = 1
|
||||
|
||||
// ClippableLinear is a linear layer with optional input/output clamping.
|
||||
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
|
||||
type ClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
inMin, inMax, outMin, outMax float32
|
||||
hasClamp bool
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func scalarValue(t ml.Tensor) (float32, bool) {
|
||||
if t == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
data := t.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return data[0], true
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) loadClampFromScalars() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
|
||||
const (
|
||||
defaultMin = -math.MaxFloat32
|
||||
defaultMax = math.MaxFloat32
|
||||
)
|
||||
|
||||
inMin, hasInMin := scalarValue(l.InputMin)
|
||||
inMax, hasInMax := scalarValue(l.InputMax)
|
||||
outMin, hasOutMin := scalarValue(l.OutputMin)
|
||||
outMax, hasOutMax := scalarValue(l.OutputMax)
|
||||
|
||||
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
|
||||
return
|
||||
}
|
||||
|
||||
l.hasClamp = true
|
||||
l.inMin = defaultMin
|
||||
l.inMax = defaultMax
|
||||
l.outMin = defaultMin
|
||||
l.outMax = defaultMax
|
||||
|
||||
if hasInMin {
|
||||
l.inMin = inMin
|
||||
}
|
||||
if hasInMax {
|
||||
l.inMax = inMax
|
||||
}
|
||||
if hasOutMin {
|
||||
l.outMin = outMin
|
||||
}
|
||||
if hasOutMax {
|
||||
l.outMax = outMax
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if l.hasClamp {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.hasClamp {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
|
||||
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector.
|
||||
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
if m.clampInitDone {
|
||||
return
|
||||
}
|
||||
m.clampInitDone = true
|
||||
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
for _, cl := range linears(&m.Layers[i]) {
|
||||
if cl != nil {
|
||||
cl.loadClampFromScalars()
|
||||
}
|
||||
}
|
||||
}
|
||||
if proj != nil && proj.Projection != nil {
|
||||
proj.Projection.loadClampFromScalars()
|
||||
}
|
||||
|
||||
// Load packed clamp data when present (legacy Ollama format).
|
||||
if m.ClampData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read all clamp values from packed F32 tensor
|
||||
data := m.ClampData.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute to layer linears: 7 per layer × 4 values each
|
||||
for i := range m.Layers {
|
||||
for li, cl := range linears(&m.Layers[i]) {
|
||||
if cl == nil {
|
||||
continue
|
||||
}
|
||||
idx := (i*7 + li) * 4
|
||||
if idx+3 < len(data) {
|
||||
cl.inMin = data[idx]
|
||||
cl.inMax = data[idx+1]
|
||||
cl.outMin = data[idx+2]
|
||||
cl.outMax = data[idx+3]
|
||||
cl.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Projector clamp values (last 4 floats)
|
||||
if proj != nil && proj.Projection != nil {
|
||||
projIdx := len(m.Layers) * 7 * 4
|
||||
if projIdx+3 < len(data) {
|
||||
proj.Projection.inMin = data[projIdx]
|
||||
proj.Projection.inMax = data[projIdx+1]
|
||||
proj.Projection.outMin = data[projIdx+2]
|
||||
proj.Projection.outMax = data[projIdx+3]
|
||||
proj.Projection.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *ClippableLinear `gguf:"attn_q"`
|
||||
Key *ClippableLinear `gguf:"attn_k"`
|
||||
Value *ClippableLinear `gguf:"attn_v"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *ClippableLinear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
numPatches := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
|
||||
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// V norm (RMSNorm without learned weights)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
|
||||
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
|
||||
// y positions to second half, then concatenate.
|
||||
halfDim := headDim / 2
|
||||
ropeOpts := rope.WithTypeNeoX()
|
||||
|
||||
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffset := halfDim * query.Stride(0)
|
||||
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffsetK := halfDim * key.Stride(0)
|
||||
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
query = qFirst.Concat(ctx, qSecond, 0)
|
||||
key = kFirst.Concat(ctx, kSecond, 0)
|
||||
|
||||
// Use flash attention for numerical stability (handles large attention scores
|
||||
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *ClippableLinear `gguf:"ffn_gate"`
|
||||
Up *ClippableLinear `gguf:"ffn_up"`
|
||||
Down *ClippableLinear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
gate := mlp.Gate.Forward(ctx, hiddenState)
|
||||
up := mlp.Up.Forward(ctx, hiddenState)
|
||||
hiddenState = gate.QuickGELU(ctx, up)
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// Pre-attention norm -> self attention -> post-attention norm
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
|
||||
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// Pre-FFN norm -> FFN -> post-FFN norm
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
// Per-layer output scale
|
||||
if e.LayerOutputScale != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
patchSize int
|
||||
nMerge int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
|
||||
ClampData ml.Tensor `gguf:"clamp_data"`
|
||||
StdBias ml.Tensor `gguf:"std_bias"`
|
||||
StdScale ml.Tensor `gguf:"std_scale"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
clampInitDone bool
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
|
||||
numPatches := numPatchesX * numPatchesY
|
||||
|
||||
// Patch embedding via Conv2D
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
|
||||
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
|
||||
posSize := m.PositionEmbedding.Dim(1)
|
||||
nb1 := m.PositionEmbedding.Stride(1)
|
||||
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
|
||||
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
|
||||
|
||||
// Position indices for patches
|
||||
posXData := make([]int32, numPatches)
|
||||
posYData := make([]int32, numPatches)
|
||||
for i := range numPatches {
|
||||
posXData[i] = int32(i % numPatchesX)
|
||||
posYData[i] = int32(i / numPatchesX)
|
||||
}
|
||||
|
||||
posXEmb := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYEmb := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
|
||||
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
|
||||
|
||||
// No attention mask — all positions are real patches
|
||||
var attnMask ml.Tensor
|
||||
|
||||
// RoPE positions
|
||||
posXRope := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYRope := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
// Vision transformer layers
|
||||
for i := range m.Layers {
|
||||
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: 100.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func visionTokenCount(imageWidth, imageHeight, patchSize, nMerge int) int {
|
||||
patchesX := imageWidth / patchSize
|
||||
patchesY := imageHeight / patchSize
|
||||
mergedX := patchesX / nMerge
|
||||
mergedY := patchesY / nMerge
|
||||
return mergedX * mergedY
|
||||
}
|
||||
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
|
||||
|
||||
// AvgPool2D with kernel=stride=nMerge
|
||||
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
|
||||
|
||||
// Reshape back to [hiddenSize, numMergedPatches]
|
||||
mergedX := numPatchesX / opts.nMerge
|
||||
mergedY := numPatchesY / opts.nMerge
|
||||
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
|
||||
|
||||
// Optional vision standardization before projection.
|
||||
if stdBias != nil && stdScale != nil {
|
||||
hiddenState = hiddenState.Sub(ctx, stdBias)
|
||||
hiddenState = hiddenState.Mul(ctx, stdScale)
|
||||
}
|
||||
|
||||
// Project to text embedding dimension
|
||||
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
331
model/models/gemma4/process_audio.go
Normal file
331
model/models/gemma4/process_audio.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/cmplx"
|
||||
)
|
||||
|
||||
// Audio preprocessing constants.
|
||||
const (
|
||||
audioSampleRate = 16000
|
||||
melBins = 128
|
||||
frameLengthMs = 20.0
|
||||
hopLengthMs = 10.0
|
||||
minFrequency = 0.0
|
||||
maxFrequency = 8000.0
|
||||
melFloor = 1e-3
|
||||
maxAudioSoftTokens = 750
|
||||
|
||||
// Chunking parameters for long audio.
|
||||
maxChunkSamples = 28 * audioSampleRate // 28s target (headroom below 30s cap)
|
||||
minChunkSamples = 20 * audioSampleRate // don't scan for silence before 20s
|
||||
silenceWindowSize = 800 // 50ms at 16kHz for RMS window
|
||||
)
|
||||
|
||||
// Computed from the above constants.
|
||||
var (
|
||||
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
|
||||
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
|
||||
)
|
||||
|
||||
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
|
||||
func decodeWAV(data []byte) ([]float32, error) {
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("WAV file too short")
|
||||
}
|
||||
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
|
||||
return nil, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
|
||||
var audioFormat uint16
|
||||
var numChannels, sampleRate, bitsPerSample int
|
||||
var audioData []byte
|
||||
foundFmt := false
|
||||
|
||||
offset := 12
|
||||
for offset+8 <= len(data) {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
|
||||
|
||||
switch chunkID {
|
||||
case "fmt ":
|
||||
if len(chunkData) < 16 {
|
||||
return nil, fmt.Errorf("fmt chunk too short")
|
||||
}
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
|
||||
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
|
||||
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
|
||||
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
|
||||
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
|
||||
}
|
||||
foundFmt = true
|
||||
case "data":
|
||||
audioData = chunkData
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return nil, fmt.Errorf("no fmt chunk found in WAV file")
|
||||
}
|
||||
if audioFormat != 1 && audioFormat != 3 {
|
||||
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
|
||||
}
|
||||
if audioData == nil {
|
||||
return nil, fmt.Errorf("no data chunk found in WAV file")
|
||||
}
|
||||
|
||||
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
|
||||
if sampleRate != audioSampleRate {
|
||||
samples = resampleLinear(samples, sampleRate, audioSampleRate)
|
||||
}
|
||||
return samples, nil
|
||||
}
|
||||
|
||||
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
|
||||
bytesPerSample := bits / 8
|
||||
totalSamples := len(data) / (bytesPerSample * channels)
|
||||
mono := make([]float32, totalSamples)
|
||||
|
||||
for i := range totalSamples {
|
||||
var sum float64
|
||||
for ch := range channels {
|
||||
off := (i*channels + ch) * bytesPerSample
|
||||
if off+bytesPerSample > len(data) {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case format == 1 && bits == 16:
|
||||
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
|
||||
sum += float64(v) / 32768.0
|
||||
case format == 1 && bits == 32:
|
||||
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v) / 2147483648.0
|
||||
case format == 1 && bits == 24:
|
||||
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
|
||||
if v&0x800000 != 0 {
|
||||
v |= ^0xFFFFFF
|
||||
}
|
||||
sum += float64(v) / 8388608.0
|
||||
case format == 3 && bits == 32:
|
||||
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v)
|
||||
case format == 1 && bits == 8:
|
||||
sum += (float64(data[off]) - 128.0) / 128.0
|
||||
}
|
||||
}
|
||||
mono[i] = float32(sum / float64(channels))
|
||||
}
|
||||
return mono
|
||||
}
|
||||
|
||||
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
|
||||
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
|
||||
idx := int(pos)
|
||||
frac := float32(pos - float64(idx))
|
||||
if idx+1 < len(samples) {
|
||||
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
|
||||
} else {
|
||||
out[i] = samples[idx]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
|
||||
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
|
||||
func computeMelSpectrogram(samples []float32) ([]float32, int) {
|
||||
fftLen := 1
|
||||
for fftLen < frameLength {
|
||||
fftLen <<= 1
|
||||
}
|
||||
fftLen *= 2 // fft_overdrive=True
|
||||
|
||||
// Hanning-nonzero window.
|
||||
window := make([]float64, frameLength)
|
||||
arg := math.Pi * 2.0 / float64(frameLength)
|
||||
for i := range frameLength {
|
||||
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
|
||||
}
|
||||
|
||||
numFreqBins := fftLen/2 + 1
|
||||
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
|
||||
|
||||
frameSizeForUnfold := frameLength + 1
|
||||
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
|
||||
if numFrames <= 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
result := make([]float32, numFrames*melBins)
|
||||
fftInput := make([]complex128, fftLen)
|
||||
|
||||
for f := range numFrames {
|
||||
start := f * hopLength
|
||||
for i := range frameLength {
|
||||
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
|
||||
}
|
||||
for i := frameLength; i < fftLen; i++ {
|
||||
fftInput[i] = 0
|
||||
}
|
||||
|
||||
fft(fftInput)
|
||||
|
||||
for m := range melBins {
|
||||
var melVal float64
|
||||
for k := range numFreqBins {
|
||||
mag := cmplx.Abs(fftInput[k])
|
||||
melVal += mag * float64(melFilters[k*melBins+m])
|
||||
}
|
||||
if melVal < melFloor {
|
||||
melVal = melFloor
|
||||
}
|
||||
result[f*melBins+m] = float32(math.Log(melVal))
|
||||
}
|
||||
}
|
||||
|
||||
return result, numFrames
|
||||
}
|
||||
|
||||
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
|
||||
hzToMel := func(f float64) float64 {
|
||||
return 2595.0 * math.Log10(1.0+f/700.0)
|
||||
}
|
||||
melToHz := func(m float64) float64 {
|
||||
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
|
||||
}
|
||||
|
||||
melMin := hzToMel(fMin)
|
||||
melMax := hzToMel(fMax)
|
||||
|
||||
melPts := make([]float64, numMels+2)
|
||||
for i := range melPts {
|
||||
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
|
||||
}
|
||||
filterFreqs := make([]float64, numMels+2)
|
||||
for i, m := range melPts {
|
||||
filterFreqs[i] = melToHz(m)
|
||||
}
|
||||
|
||||
fftFreqs := make([]float64, numFreqBins)
|
||||
for i := range fftFreqs {
|
||||
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
|
||||
}
|
||||
|
||||
filters := make([]float32, numFreqBins*numMels)
|
||||
for m := range numMels {
|
||||
fLeft := filterFreqs[m]
|
||||
fCenter := filterFreqs[m+1]
|
||||
fRight := filterFreqs[m+2]
|
||||
for k := range numFreqBins {
|
||||
f := fftFreqs[k]
|
||||
var v float64
|
||||
if f >= fLeft && f <= fCenter && fCenter > fLeft {
|
||||
v = (f - fLeft) / (fCenter - fLeft)
|
||||
} else if f > fCenter && f <= fRight && fRight > fCenter {
|
||||
v = (fRight - f) / (fRight - fCenter)
|
||||
}
|
||||
if v > 0 {
|
||||
filters[k*numMels+m] = float32(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
// fft performs an in-place Cooley-Tukey radix-2 FFT.
|
||||
func fft(x []complex128) {
|
||||
n := len(x)
|
||||
if n <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
j := 0
|
||||
for i := 1; i < n; i++ {
|
||||
bit := n >> 1
|
||||
for j&bit != 0 {
|
||||
j ^= bit
|
||||
bit >>= 1
|
||||
}
|
||||
j ^= bit
|
||||
if i < j {
|
||||
x[i], x[j] = x[j], x[i]
|
||||
}
|
||||
}
|
||||
|
||||
for size := 2; size <= n; size <<= 1 {
|
||||
halfSize := size / 2
|
||||
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
|
||||
for start := 0; start < n; start += size {
|
||||
wn := complex(1, 0)
|
||||
for k := range halfSize {
|
||||
t := wn * x[start+k+halfSize]
|
||||
x[start+k+halfSize] = x[start+k] - t
|
||||
x[start+k] = x[start+k] + t
|
||||
wn *= w
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitAudioChunks splits PCM samples into chunks of at most maxChunkSamples,
|
||||
// preferring to split at low-energy (silence) regions for natural boundaries.
|
||||
func splitAudioChunks(samples []float32) [][]float32 {
|
||||
if len(samples) <= maxChunkSamples {
|
||||
return [][]float32{samples}
|
||||
}
|
||||
|
||||
var chunks [][]float32
|
||||
offset := 0
|
||||
for offset < len(samples) {
|
||||
remaining := len(samples) - offset
|
||||
if remaining <= maxChunkSamples {
|
||||
chunks = append(chunks, samples[offset:])
|
||||
break
|
||||
}
|
||||
|
||||
splitAt := offset + maxChunkSamples
|
||||
bestEnergy := float64(math.MaxFloat64)
|
||||
|
||||
scanStart := offset + maxChunkSamples - silenceWindowSize
|
||||
scanEnd := offset + minChunkSamples
|
||||
for pos := scanStart; pos >= scanEnd; pos -= silenceWindowSize / 2 {
|
||||
end := pos + silenceWindowSize
|
||||
if end > len(samples) {
|
||||
end = len(samples)
|
||||
}
|
||||
var sumSq float64
|
||||
for _, s := range samples[pos:end] {
|
||||
sumSq += float64(s) * float64(s)
|
||||
}
|
||||
rms := math.Sqrt(sumSq / float64(end-pos))
|
||||
if rms < bestEnergy {
|
||||
bestEnergy = rms
|
||||
splitAt = pos + silenceWindowSize/2
|
||||
}
|
||||
}
|
||||
|
||||
chunks = append(chunks, samples[offset:splitAt])
|
||||
offset = splitAt
|
||||
}
|
||||
|
||||
slog.Debug("Audio chunked", "chunks", len(chunks), "total_samples", len(samples))
|
||||
return chunks
|
||||
}
|
||||
|
||||
// isAudioData checks if the data starts with WAV magic bytes.
|
||||
func isAudioData(data []byte) bool {
|
||||
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
|
||||
}
|
||||
103
model/models/gemma4/process_image.go
Normal file
103
model/models/gemma4/process_image.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
patchSize int
|
||||
numChannels int
|
||||
nMerge int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 16))
|
||||
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
|
||||
// Token limits from reference: min=40, max=280 output tokens after pooling.
|
||||
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
|
||||
minTokens := 40
|
||||
maxTokens := 280
|
||||
patchArea := patchSize * patchSize * nMerge * nMerge
|
||||
minPixels := minTokens * patchArea
|
||||
maxPixels := maxTokens * patchArea
|
||||
|
||||
return ImageProcessor{
|
||||
patchSize: patchSize,
|
||||
numChannels: numChannels,
|
||||
nMerge: nMerge,
|
||||
minPixels: minPixels,
|
||||
maxPixels: maxPixels,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
|
||||
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
|
||||
// Returns the float32 pixel data and the actual output dimensions.
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
|
||||
// Compute target size preserving aspect ratio
|
||||
alignSize := p.patchSize * p.nMerge
|
||||
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
|
||||
|
||||
// Resize directly without alpha compositing, matching MLX reference.
|
||||
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
|
||||
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
|
||||
data := p.pack(dst)
|
||||
return data, targetW, targetH, nil
|
||||
}
|
||||
|
||||
// smartResize computes target dimensions that preserve aspect ratio and
|
||||
// align to alignSize boundaries. It scales the image to fill the maximum
|
||||
// patch budget (maxPixels), matching the MLX reference.
|
||||
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
|
||||
totalPx := origW * origH
|
||||
|
||||
var targetW, targetH int
|
||||
if p.maxPixels > 0 && totalPx > 0 {
|
||||
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
|
||||
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
|
||||
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
|
||||
} else {
|
||||
targetH = max(alignSize, (origH/alignSize)*alignSize)
|
||||
targetW = max(alignSize, (origW/alignSize)*alignSize)
|
||||
}
|
||||
|
||||
return targetW, targetH
|
||||
}
|
||||
|
||||
// pack extracts RGB values from an image and normalizes to [-1, 1].
|
||||
// Returns channel-first layout: [R..., G..., B...].
|
||||
func (p *ImageProcessor) pack(img image.Image) []float32 {
|
||||
bounds := img.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
size := w * h
|
||||
|
||||
pixelVals := make([]float32, 3*size)
|
||||
rOff, gOff, bOff := 0, size, 2*size
|
||||
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
|
||||
|
||||
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
|
||||
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
|
||||
func TestTokenizerMatchesHF(t *testing.T) {
|
||||
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
|
||||
}
|
||||
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load model: %v", err)
|
||||
}
|
||||
defer m.Backend().Close()
|
||||
|
||||
tok := m.(tokenizer.Tokenizer)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []int32
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
input: "Hello, world!",
|
||||
expected: []int32{9259, 236764, 1902, 236888},
|
||||
},
|
||||
{
|
||||
name: "special_tokens",
|
||||
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
|
||||
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
|
||||
},
|
||||
{
|
||||
name: "tool_declaration",
|
||||
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
|
||||
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
|
||||
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
|
||||
},
|
||||
{
|
||||
name: "thinking",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
|
||||
},
|
||||
{
|
||||
name: "code",
|
||||
input: "func main() { fmt.Println(\"hello\") }",
|
||||
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "The answer is 42, not 43.5 or -1",
|
||||
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
|
||||
},
|
||||
{
|
||||
name: "mixed_chat_with_tools",
|
||||
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tok.Encode(tt.input, false) // no BOS
|
||||
if err != nil {
|
||||
t.Fatalf("encode error: %v", err)
|
||||
}
|
||||
|
||||
if len(tokens) != len(tt.expected) {
|
||||
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
|
||||
t.Logf("got: %v", tokens)
|
||||
t.Logf("want: %v", tt.expected)
|
||||
return
|
||||
}
|
||||
|
||||
mismatches := 0
|
||||
for i := range tokens {
|
||||
if tokens[i] != tt.expected[i] {
|
||||
mismatches++
|
||||
if mismatches <= 5 {
|
||||
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if mismatches > 5 {
|
||||
t.Errorf("... and %d more mismatches", mismatches-5)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/gemma4"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
|
||||
412
model/parsers/gemma4.go
Normal file
412
model/parsers/gemma4.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type Gemma4ParserState int
|
||||
|
||||
const (
|
||||
Gemma4CollectingContent Gemma4ParserState = iota
|
||||
Gemma4CollectingThinking
|
||||
Gemma4CollectingToolCall
|
||||
)
|
||||
|
||||
const (
|
||||
gemma4ThinkingOpenTag = "<|channel>"
|
||||
gemma4ThinkingCloseTag = "<channel|>"
|
||||
gemma4ToolCallOpenTag = "<|tool_call>"
|
||||
gemma4ToolCallCloseTag = "<tool_call|>"
|
||||
)
|
||||
|
||||
type Gemma4Parser struct {
|
||||
state Gemma4ParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
thinkingEnabled bool // true when both model supports and user requested thinking
|
||||
needsChannelNameStrip bool // true when we just entered thinking and need to strip "thought\n"
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
p.thinkingEnabled = p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
if !p.thinkingEnabled {
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
// When thinking is enabled, start in content mode but we'll switch to
|
||||
// thinking when we see <|channel>. The model typically starts with
|
||||
// <|channel> immediately when thinking is enabled.
|
||||
p.state = Gemma4CollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type gemma4Event interface {
|
||||
isGemma4Event()
|
||||
}
|
||||
|
||||
type gemma4EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type gemma4EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type gemma4EventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (gemma4EventThinkingContent) isGemma4Event() {}
|
||||
func (gemma4EventContent) isGemma4Event() {}
|
||||
func (gemma4EventToolCall) isGemma4Event() {}
|
||||
|
||||
func (p *Gemma4Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents(done)
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case gemma4EventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case gemma4EventThinkingContent:
|
||||
if p.thinkingEnabled {
|
||||
thinkingSb.WriteString(event.content)
|
||||
}
|
||||
// When thinking is disabled, silently discard channel content
|
||||
case gemma4EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) parseEvents(done bool) []gemma4Event {
|
||||
var all []gemma4Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []gemma4Event
|
||||
events, keepLooping = p.eat(done)
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// longestOverlap returns the longest overlap between the suffix of bufStr and
|
||||
// a prefix of any of the given tags.
|
||||
func longestOverlap(bufStr string, tags ...string) int {
|
||||
maxOverlap := 0
|
||||
for _, tag := range tags {
|
||||
if o := overlap(bufStr, tag); o > maxOverlap {
|
||||
maxOverlap = o
|
||||
}
|
||||
}
|
||||
return maxOverlap
|
||||
}
|
||||
|
||||
func (p *Gemma4Parser) eat(done bool) ([]gemma4Event, bool) {
|
||||
var events []gemma4Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Gemma4CollectingContent:
|
||||
// Check for thinking open tag
|
||||
if idx := strings.Index(bufStr, gemma4ThinkingOpenTag); idx != -1 {
|
||||
contentBefore := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ThinkingOpenTag):]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingThinking
|
||||
p.needsChannelNameStrip = true
|
||||
|
||||
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
|
||||
events = append(events, gemma4EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for tool call open tag
|
||||
if idx := strings.Index(bufStr, gemma4ToolCallOpenTag); idx != -1 {
|
||||
contentBefore := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ToolCallOpenTag):]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingToolCall
|
||||
|
||||
if contentBefore = strings.TrimRightFunc(contentBefore, unicode.IsSpace); len(contentBefore) > 0 {
|
||||
events = append(events, gemma4EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial tag overlap
|
||||
if !done {
|
||||
if overlapLen := longestOverlap(bufStr, gemma4ThinkingOpenTag, gemma4ToolCallOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
|
||||
// No tags found, emit all content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, gemma4EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case Gemma4CollectingThinking:
|
||||
// Strip channel name (e.g., "thought\n") after <|channel>.
|
||||
// Gemma 4 format: <|channel>thought\n...content...<channel|>
|
||||
// In streaming mode, "thought" and "\n" may arrive in separate chunks.
|
||||
if p.needsChannelNameStrip {
|
||||
if strings.HasPrefix(bufStr, "thought\n") {
|
||||
bufStr = bufStr[len("thought\n"):]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
p.needsChannelNameStrip = false
|
||||
} else if !done && (bufStr == "thought" || strings.HasPrefix("thought\n", bufStr)) {
|
||||
// Partial match — wait for more data.
|
||||
return events, false
|
||||
} else {
|
||||
// No match (different channel name or no newline) — don't strip.
|
||||
p.needsChannelNameStrip = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, gemma4ThinkingCloseTag) {
|
||||
split := strings.SplitN(bufStr, gemma4ThinkingCloseTag, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingContent
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// Check for partial close tag
|
||||
if !done {
|
||||
if overlapLen := overlap(bufStr, gemma4ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
|
||||
// No close tag, emit thinking content (hold back trailing whitespace)
|
||||
if !done {
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, gemma4EventThinkingContent{content: bufStr})
|
||||
}
|
||||
}
|
||||
return events, false
|
||||
|
||||
case Gemma4CollectingToolCall:
|
||||
if idx := strings.Index(bufStr, gemma4ToolCallCloseTag); idx != -1 {
|
||||
toolCallContent := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(gemma4ToolCallCloseTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = Gemma4CollectingContent
|
||||
|
||||
if toolCall, err := parseGemma4ToolCall(toolCallContent); err == nil {
|
||||
events = append(events, gemma4EventToolCall{toolCall: toolCall})
|
||||
} else {
|
||||
slog.Warn("gemma4 tool call parsing failed", "error", err, "content", toolCallContent)
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
// If done, flush any accumulated tool call content even without closing tag.
|
||||
// The model may hit a stop token before emitting <tool_call|>.
|
||||
if done && len(bufStr) > 0 {
|
||||
p.buffer.Reset()
|
||||
p.state = Gemma4CollectingContent
|
||||
if toolCall, err := parseGemma4ToolCall(bufStr); err == nil {
|
||||
events = append(events, gemma4EventToolCall{toolCall: toolCall})
|
||||
} else {
|
||||
slog.Warn("gemma4 tool call flush on done failed", "error", err, "content", bufStr)
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
// Wait for closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseGemma4ToolCall parses a tool call in Gemma 4 format:
|
||||
// call:NAME{key:value,key:value}
|
||||
func parseGemma4ToolCall(content string) (api.ToolCall, error) {
|
||||
// Expected format: call:NAME{args}
|
||||
if !strings.HasPrefix(content, "call:") {
|
||||
return api.ToolCall{}, errors.New("expected 'call:' prefix")
|
||||
}
|
||||
content = content[len("call:"):]
|
||||
|
||||
// Find the opening brace for args
|
||||
braceIdx := strings.Index(content, "{")
|
||||
if braceIdx == -1 {
|
||||
return api.ToolCall{}, errors.New("expected '{' in tool call")
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(content[:braceIdx])
|
||||
argsStr := content[braceIdx:]
|
||||
|
||||
// Convert Gemma 4 argument format to JSON
|
||||
jsonStr := gemma4ArgsToJSON(argsStr)
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(jsonStr), &args); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// gemma4ArgsToJSON converts Gemma 4's custom argument format to valid JSON.
|
||||
func gemma4ArgsToJSON(s string) string {
|
||||
s = strings.ReplaceAll(s, `<|"|>`, `"`)
|
||||
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(s) + 32)
|
||||
inString := false
|
||||
hex := "0123456789abcdef"
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
buf.WriteByte('"')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
switch ch {
|
||||
case '\\':
|
||||
buf.WriteString(`\\`)
|
||||
case '\n':
|
||||
buf.WriteString(`\n`)
|
||||
case '\r':
|
||||
buf.WriteString(`\r`)
|
||||
case '\t':
|
||||
buf.WriteString(`\t`)
|
||||
case '\b':
|
||||
buf.WriteString(`\b`)
|
||||
case '\f':
|
||||
buf.WriteString(`\f`)
|
||||
default:
|
||||
if ch < 0x20 {
|
||||
buf.WriteString(`\u00`)
|
||||
buf.WriteByte(hex[ch>>4])
|
||||
buf.WriteByte(hex[ch&0x0f])
|
||||
} else {
|
||||
buf.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString && isIdentStart(ch) {
|
||||
j := i + 1
|
||||
for j < len(s) && isIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
word := s[i:j]
|
||||
if j < len(s) && s[j] == ':' {
|
||||
buf.WriteByte('"')
|
||||
buf.WriteString(word)
|
||||
buf.WriteByte('"')
|
||||
} else {
|
||||
buf.WriteString(word)
|
||||
}
|
||||
i = j
|
||||
} else {
|
||||
buf.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
463
model/parsers/gemma4_test.go
Normal file
463
model/parsers/gemma4_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGemma4Parser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedToolCalls []api.ToolCall
|
||||
thinkingEnabled bool
|
||||
lastMessage *api.Message
|
||||
}{
|
||||
{
|
||||
name: "simple_content",
|
||||
input: "This is a simple response.",
|
||||
expectedContent: "This is a simple response.",
|
||||
},
|
||||
{
|
||||
name: "thinking_then_content",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expectedContent: "The answer is 42.",
|
||||
expectedThinking: "Let me think about this...",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "multiple_thinking_blocks",
|
||||
input: "<|channel>first thought<channel|><|channel>second thought<channel|>Final answer.",
|
||||
expectedContent: "Final answer.",
|
||||
expectedThinking: "first thoughtsecond thought",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_only_no_content",
|
||||
input: "<|channel>just thinking<channel|>",
|
||||
expectedContent: "",
|
||||
expectedThinking: "just thinking",
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_multiple_args",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
"units": "metric",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_number_arg",
|
||||
input: `<|tool_call>call:set_temp{value:42}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"value": 42.0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_boolean_arg",
|
||||
input: `<|tool_call>call:toggle{enabled:true}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"enabled": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_nested_object",
|
||||
input: `<|tool_call>call:process{config:{enabled:true,name:<|"|>test<|"|>}}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"config": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_array",
|
||||
input: `<|tool_call>call:process{items:[<|"|>a<|"|>,<|"|>b<|"|>]}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"a", "b"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_multiline_string_arg",
|
||||
input: `<|tool_call>call:bash{command:<|"|>date
|
||||
<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "date\n",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: `<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|><|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_then_tool_call",
|
||||
input: "<|channel>thought\nI need to check the weather<channel|><|tool_call>call:get_weather{location:<|\"|>Paris<|\"|>}<tool_call|>",
|
||||
expectedThinking: "I need to check the weather",
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "content_then_tool_call",
|
||||
input: `Let me check that for you.<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}<tool_call|>`,
|
||||
expectedContent: "Let me check that for you.",
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled_channel_tags_as_content",
|
||||
input: "<|channel>this is not thinking<channel|>actual content",
|
||||
expectedContent: "actual content",
|
||||
thinkingEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "prefill_content_only",
|
||||
input: "Continuing content.",
|
||||
expectedContent: "Continuing content.",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Previous content",
|
||||
},
|
||||
thinkingEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, tt.lastMessage, &api.ThinkValue{Value: tt.thinkingEnabled})
|
||||
|
||||
content, thinking, toolCalls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_Streaming(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
chunks := []string{
|
||||
"<|channel>thought",
|
||||
"\nLet me think",
|
||||
"...<channel|>The answer",
|
||||
" is 42.",
|
||||
}
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalContent.String() != "The answer is 42." {
|
||||
t.Errorf("expected content %q, got %q", "The answer is 42.", finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != "Let me think..." {
|
||||
t.Errorf("expected thinking %q, got %q", "Let me think...", finalThinking.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_StreamingToolCall(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
chunks := []string{
|
||||
`<|tool_call>call:get_`,
|
||||
`weather{location:<|"|>Par`,
|
||||
`is<|"|>}<tool_call|>`,
|
||||
}
|
||||
|
||||
var finalContent strings.Builder
|
||||
var finalToolCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, _, toolCalls, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalToolCalls = append(finalToolCalls, toolCalls...)
|
||||
}
|
||||
|
||||
if finalContent.String() != "" {
|
||||
t.Errorf("expected no content, got %q", finalContent.String())
|
||||
}
|
||||
|
||||
expectedToolCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_StreamingSplitThinkingTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
}{
|
||||
{
|
||||
name: "split_channel_open_tag",
|
||||
chunks: []string{
|
||||
"<|chan",
|
||||
"nel>thinking here<channel|>content",
|
||||
},
|
||||
expectedContent: "content",
|
||||
expectedThinking: "thinking here",
|
||||
},
|
||||
{
|
||||
name: "split_channel_close_tag",
|
||||
chunks: []string{
|
||||
"<|channel>thinking here<chan",
|
||||
"nel|>content",
|
||||
},
|
||||
expectedContent: "content",
|
||||
expectedThinking: "thinking here",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalContent.String() != tt.expectedContent {
|
||||
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
|
||||
}
|
||||
if finalThinking.String() != tt.expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4ArgsToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple_string",
|
||||
input: `{location:<|"|>Paris<|"|>}`,
|
||||
expected: `{"location":"Paris"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple_args",
|
||||
input: `{location:<|"|>Paris<|"|>,units:<|"|>metric<|"|>}`,
|
||||
expected: `{"location":"Paris","units":"metric"}`,
|
||||
},
|
||||
{
|
||||
name: "number_value",
|
||||
input: `{value:42}`,
|
||||
expected: `{"value":42}`,
|
||||
},
|
||||
{
|
||||
name: "boolean_value",
|
||||
input: `{enabled:true}`,
|
||||
expected: `{"enabled":true}`,
|
||||
},
|
||||
{
|
||||
name: "nested_object",
|
||||
input: `{config:{enabled:true,name:<|"|>test<|"|>}}`,
|
||||
expected: `{"config":{"enabled":true,"name":"test"}}`,
|
||||
},
|
||||
{
|
||||
name: "array_value",
|
||||
input: `{items:[<|"|>a<|"|>,<|"|>b<|"|>]}`,
|
||||
expected: `{"items":["a","b"]}`,
|
||||
},
|
||||
{
|
||||
name: "empty_object",
|
||||
input: `{}`,
|
||||
expected: `{}`,
|
||||
},
|
||||
{
|
||||
name: "mixed_types",
|
||||
input: `{name:<|"|>test<|"|>,count:5,active:true,tags:[<|"|>a<|"|>]}`,
|
||||
expected: `{"name":"test","count":5,"active":true,"tags":["a"]}`,
|
||||
},
|
||||
{
|
||||
name: "null_value",
|
||||
input: `{value:null}`,
|
||||
expected: `{"value":null}`,
|
||||
},
|
||||
{
|
||||
name: "multiline_string_value",
|
||||
input: `{command:<|"|>date
|
||||
<|"|>}`,
|
||||
expected: `{"command":"date\n"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := gemma4ArgsToJSON(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_HasToolSupport(t *testing.T) {
|
||||
parser := &Gemma4Parser{}
|
||||
if !parser.HasToolSupport() {
|
||||
t.Error("Gemma4Parser should support tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4Parser_HasThinkingSupport(t *testing.T) {
|
||||
parser := &Gemma4Parser{hasThinkingSupport: true}
|
||||
if !parser.HasThinkingSupport() {
|
||||
t.Error("Gemma4Parser with thinking support should report it")
|
||||
}
|
||||
|
||||
parser2 := &Gemma4Parser{hasThinkingSupport: false}
|
||||
if parser2.HasThinkingSupport() {
|
||||
t.Error("Gemma4Parser without thinking support should not report it")
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,10 @@ func ParserForName(name string) Parser {
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "gemma4":
|
||||
return &Gemma4Parser{hasThinkingSupport: true}
|
||||
case "gemma4-no-thinking":
|
||||
return &Gemma4Parser{hasThinkingSupport: false}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrParser{}
|
||||
case "lfm2":
|
||||
|
||||
380
model/renderers/gemma4.go
Normal file
380
model/renderers/gemma4.go
Normal file
@@ -0,0 +1,380 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Gemma4Renderer renders prompts using Gemma 4's chat format with
|
||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||
// <|tool_call>/<|tool_response> tags for function calling.
|
||||
type Gemma4Renderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
const (
|
||||
g4Q = `<|"|>` // Gemma 4 string delimiter
|
||||
)
|
||||
|
||||
func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
imageOffset := 0
|
||||
|
||||
// BOS token — Gemma 4 models have add_bos_token=false in their tokenizer
|
||||
// config, so the tokenizer does not auto-prepend BOS. We must emit it
|
||||
// explicitly in the rendered prompt, matching the HF chat template.
|
||||
sb.WriteString("<bos>")
|
||||
// Extract system message if present.
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
hasSystemRole := len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer")
|
||||
if hasSystemRole {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
// Emit system turn if there's a system/developer role, tools, or thinking.
|
||||
hasThink := thinkValue != nil && thinkValue.Bool()
|
||||
if hasSystemRole || len(tools) > 0 || hasThink {
|
||||
sb.WriteString("<|turn>system\n")
|
||||
if hasThink {
|
||||
sb.WriteString("<|think|>")
|
||||
}
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(strings.TrimSpace(systemMessage))
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
|
||||
// Each message gets its own <|turn>role\n ... <turn|>\n block,
|
||||
// matching the HF chat template exactly.
|
||||
for _, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|turn>user\n")
|
||||
r.renderContent(&sb, message, &imageOffset, true)
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|turn>model\n")
|
||||
// Tool calls come before content (matching HF template order)
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
// Strip thinking from history (matching HF strip_thinking macro)
|
||||
if message.Content != "" {
|
||||
sb.WriteString(stripThinking(message.Content))
|
||||
}
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|turn>tool\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
default:
|
||||
sb.WriteString("<|turn>" + message.Role + "\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Generation prompt
|
||||
sb.WriteString("<|turn>model\n")
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// stripThinking removes <|channel>...<channel|> thinking blocks from content,
|
||||
// matching the HF chat template's strip_thinking macro.
|
||||
func stripThinking(text string) string {
|
||||
var result strings.Builder
|
||||
for {
|
||||
start := strings.Index(text, "<|channel>")
|
||||
if start == -1 {
|
||||
result.WriteString(text)
|
||||
break
|
||||
}
|
||||
result.WriteString(text[:start])
|
||||
end := strings.Index(text[start:], "<channel|>")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
text = text[start+end+len("<channel|>"):]
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// renderContent writes a message's content, interleaving [img-N] tags for images.
|
||||
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
|
||||
// template's | trim filter applied to non-model content).
|
||||
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
|
||||
if len(msg.Images) > 0 && r.useImgTags {
|
||||
for range msg.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
|
||||
*imageOffset++
|
||||
}
|
||||
}
|
||||
content := msg.Content
|
||||
if trim {
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
fn := tool.Function
|
||||
|
||||
sb.WriteString("<|tool>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:" + g4Q + fn.Description + g4Q)
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + req + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(fn.Parameters.Type) + g4Q)
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||
keys := make([]string, 0, props.Len())
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop, _ := props.Get(name)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{")
|
||||
|
||||
hasContent := false
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("description:" + g4Q + prop.Description + g4Q)
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
typeName := strings.ToUpper(prop.Type[0])
|
||||
|
||||
switch typeName {
|
||||
case "STRING":
|
||||
if len(prop.Enum) > 0 {
|
||||
if hasContent {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("enum:[")
|
||||
for j, e := range prop.Enum {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + fmt.Sprintf("%v", e) + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
case "OBJECT":
|
||||
// Render nested properties recursively.
|
||||
// Note: the leading comma is hardcoded (matching the template),
|
||||
// and this does NOT set hasContent — the comma before type:
|
||||
// depends only on whether description was present.
|
||||
sb.WriteString(",properties:{")
|
||||
if prop.Properties != nil && prop.Properties.Len() > 0 {
|
||||
r.writeProperties(sb, prop.Properties)
|
||||
}
|
||||
sb.WriteString("}")
|
||||
if len(prop.Required) > 0 {
|
||||
sb.WriteString(",required:[")
|
||||
for j, req := range prop.Required {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + req + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
}
|
||||
|
||||
case "ARRAY":
|
||||
// Render items specification.
|
||||
// Same as OBJECT: leading comma is hardcoded, does NOT set hasContent.
|
||||
if items, ok := prop.Items.(map[string]any); ok && len(items) > 0 {
|
||||
sb.WriteString(",items:{")
|
||||
r.writeItemsSpec(sb, items)
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + typeName + g4Q)
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
// writeItemsSpec renders the items specification for array-type properties,
|
||||
// matching the Jinja2 template's dictsort iteration over items.
|
||||
func (r *Gemma4Renderer) writeItemsSpec(sb *strings.Builder, items map[string]any) {
|
||||
keys := make([]string, 0, len(items))
|
||||
for k := range items {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := items[key]
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
switch key {
|
||||
case "type":
|
||||
if s, ok := value.(string); ok {
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(s) + g4Q)
|
||||
}
|
||||
default:
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<|tool_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool_call|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return g4Q + v + g4Q
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
1272
model/renderers/gemma4_reference_test.go
Normal file
1272
model/renderers/gemma4_reference_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,6 +81,8 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "gemma4":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
|
||||
263
model/renderers/testdata/gemma4_chat_template.jinja2
vendored
Normal file
263
model/renderers/testdata/gemma4_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,263 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -522,6 +522,20 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
case "input_audio":
|
||||
audioMap, ok := data["input_audio"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format")
|
||||
}
|
||||
b64Data, ok := audioMap["data"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format: missing data")
|
||||
}
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(b64Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input_audio base64 data: %w", err)
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{audioBytes}})
|
||||
default:
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
@@ -824,6 +838,45 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionResponse is the response format for /v1/audio/transcriptions.
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TranscriptionRequest holds parsed fields from the multipart form.
|
||||
type TranscriptionRequest struct {
|
||||
Model string
|
||||
AudioData []byte
|
||||
ResponseFormat string // "json", "text", "verbose_json"
|
||||
Language string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// FromTranscriptionRequest converts a transcription request into a ChatRequest
|
||||
// by wrapping the audio with a system prompt for transcription.
|
||||
func FromTranscriptionRequest(r TranscriptionRequest) (*api.ChatRequest, error) {
|
||||
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
|
||||
if r.Language != "" {
|
||||
systemPrompt += " The audio is in " + r.Language + "."
|
||||
}
|
||||
if r.Prompt != "" {
|
||||
systemPrompt += " Context: " + r.Prompt
|
||||
}
|
||||
|
||||
stream := true
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{r.AudioData}},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
@@ -390,3 +390,48 @@ func (t *Terminal) Read() (rune, error) {
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// SetRawModeOn enables raw terminal mode and keeps it on.
|
||||
// Call SetRawModeOff to restore when done.
|
||||
func (i *Instance) SetRawModeOn() error {
|
||||
if i.Terminal.rawmode {
|
||||
return nil
|
||||
}
|
||||
fd := os.Stdin.Fd()
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.Terminal.rawmode = true
|
||||
i.Terminal.termios = termios
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRawModeOff restores the terminal to its previous mode.
|
||||
func (i *Instance) SetRawModeOff() {
|
||||
if !i.Terminal.rawmode {
|
||||
return
|
||||
}
|
||||
fd := os.Stdin.Fd()
|
||||
//nolint:errcheck
|
||||
UnsetRawMode(fd, i.Terminal.termios)
|
||||
i.Terminal.rawmode = false
|
||||
}
|
||||
|
||||
// ReadRaw reads a single rune. If the terminal is already in raw mode
|
||||
// (via SetRawModeOn), it reads directly. Otherwise it temporarily enters
|
||||
// raw mode for the read.
|
||||
func (i *Instance) ReadRaw() (rune, error) {
|
||||
if !i.Terminal.rawmode {
|
||||
fd := os.Stdin.Fd()
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
//nolint:errcheck
|
||||
UnsetRawMode(fd, termios)
|
||||
}()
|
||||
}
|
||||
return i.Terminal.Read()
|
||||
}
|
||||
|
||||
@@ -1258,6 +1258,12 @@ func (s *Server) loadModel() {
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
}
|
||||
|
||||
if postLoader, ok := s.model.(model.PostLoader); ok {
|
||||
if err := postLoader.PostLoad(); err != nil {
|
||||
panic(fmt.Errorf("failed to finalize model initialization: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "" || len(config.Capabilities) == 0) {
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
@@ -158,6 +158,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
if len(config.Capabilities) == 0 {
|
||||
config.Capabilities = baseConfig.Capabilities
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
@@ -509,6 +512,24 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||
|
||||
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
|
||||
// TODO: abstract this into a registry/lookup table when multiple models
|
||||
// need architecture-based renderer/parser/stop defaults.
|
||||
if config.Renderer == "" || config.Parser == "" {
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
r.Parameters = make(map[string]any)
|
||||
}
|
||||
r.Parameters["stop"] = []string{"<turn|>"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
layers = append(layers, layer.Layer)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ var (
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityAudio = errors.New("audio")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
@@ -93,14 +94,26 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
if f.KeyValue("audio.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityAudio)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
} else if len(m.Config.Capabilities) > 0 {
|
||||
}
|
||||
|
||||
// Also include capabilities from the model config (e.g. vision capability
|
||||
// set during creation for MLX/safetensors models).
|
||||
if len(m.Config.Capabilities) > 0 {
|
||||
for _, c := range m.Config.Capabilities {
|
||||
capabilities = append(capabilities, model.Capability(c))
|
||||
cap := model.Capability(c)
|
||||
if !slices.Contains(capabilities, cap) {
|
||||
capabilities = append(capabilities, cap)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
||||
if len(capabilities) == 0 {
|
||||
slog.Warn("unknown capabilities for model", "model", m.Name)
|
||||
}
|
||||
|
||||
@@ -141,6 +154,14 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||
return c == model.CapabilityVision || c == "audio"
|
||||
})
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@@ -156,6 +177,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityTools: errCapabilityTools,
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityAudio: errCapabilityAudio,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
|
||||
@@ -153,7 +153,16 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
|
||||
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
|
||||
// expert alphabetically, so dense increments the counter and expert uses counter-1.
|
||||
var iLayer int
|
||||
if strings.Contains(name, "_exps") {
|
||||
iLayer = max(0, qs.iFfnDown-1)
|
||||
} else {
|
||||
iLayer = qs.iFfnDown
|
||||
qs.iFfnDown++
|
||||
}
|
||||
n_layer := qs.nFfnDown
|
||||
if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
@@ -162,7 +171,6 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
qs.iFfnDown++
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
@@ -255,8 +263,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision encoder tensors (named with "v." prefix)
|
||||
// don't quantize vision or audio encoder tensors
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
quantize = quantize && !strings.HasPrefix(name, "a.")
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
|
||||
@@ -1718,6 +1718,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
// OpenAI-compatible audio endpoint
|
||||
r.POST("/v1/audio/transcriptions", middleware.TranscriptionMiddleware(), s.ChatHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
CapabilityAudio = Capability("audio")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
73
video/video.go
Normal file
73
video/video.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Package video extracts frames and audio from video files.
|
||||
//
|
||||
// Video files are decomposed into JPEG image frames and a WAV audio track.
|
||||
// The frames and audio can then be fed into a multimodal model's vision
|
||||
// and audio encoders respectively.
|
||||
//
|
||||
// Platform-specific implementations:
|
||||
// - macOS: AVFoundation (system framework, zero external deps)
|
||||
// - Windows: Media Foundation (pure Go via syscall, zero external deps)
|
||||
// - Linux: shells out to ffmpeg (must be installed)
|
||||
package video
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Result holds extracted frames and optional audio from a video file.
|
||||
type Result struct {
|
||||
Frames [][]byte // JPEG-encoded image frames in temporal order
|
||||
Audio []byte // WAV 16kHz mono audio (nil if no audio track)
|
||||
}
|
||||
|
||||
// Options controls frame extraction behavior.
|
||||
type Options struct {
|
||||
MaxFrames int // Max frames to extract (0 = default 16)
|
||||
ExtractAudio bool // Whether to extract the audio track
|
||||
}
|
||||
|
||||
// Extract reads a video file and returns extracted frames and audio.
|
||||
func Extract(path string, opts Options) (*Result, error) {
|
||||
if opts.MaxFrames <= 0 {
|
||||
opts.MaxFrames = 4
|
||||
}
|
||||
if opts.MaxFrames > 64 {
|
||||
opts.MaxFrames = 64
|
||||
}
|
||||
return extract(path, opts)
|
||||
}
|
||||
|
||||
// IsVideo returns true if the content type indicates a video file.
|
||||
func IsVideo(data []byte) bool {
|
||||
if len(data) < 512 {
|
||||
return false
|
||||
}
|
||||
ct := http.DetectContentType(data[:512])
|
||||
return IsVideoContentType(ct)
|
||||
}
|
||||
|
||||
// IsVideoContentType returns true if the MIME type is a video type.
|
||||
func IsVideoContentType(contentType string) bool {
|
||||
return strings.HasPrefix(contentType, "video/")
|
||||
}
|
||||
|
||||
// VideoExtensions lists recognized video file extensions.
|
||||
var VideoExtensions = []string{
|
||||
".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".wmv", ".flv",
|
||||
}
|
||||
|
||||
// IsVideoExtension returns true if the extension (with dot) is a video format.
|
||||
func IsVideoExtension(ext string) bool {
|
||||
ext = strings.ToLower(ext)
|
||||
for _, v := range VideoExtensions {
|
||||
if ext == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ErrFFmpegNotFound is returned on Linux when ffmpeg is not installed.
|
||||
var ErrFFmpegNotFound = fmt.Errorf("video support requires ffmpeg; install it with: sudo apt install ffmpeg (Debian/Ubuntu) or sudo dnf install ffmpeg (Fedora/RHEL)")
|
||||
118
video/video_darwin.go
Normal file
118
video/video_darwin.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package video
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework AVFoundation -framework CoreMedia -framework CoreGraphics -framework CoreVideo -framework Foundation -framework ImageIO -framework UniformTypeIdentifiers
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Extract frames from a video file using AVFoundation.
|
||||
// Returns JPEG data for each frame concatenated, with offsets/sizes in the out arrays.
|
||||
// Audio is extracted as 16kHz mono PCM (int16).
|
||||
int extract_video_frames(
|
||||
const char* path,
|
||||
int max_frames,
|
||||
int extract_audio,
|
||||
// Frame output: caller provides buffers, function fills them
|
||||
uint8_t** frame_data, // out: array of frame JPEG pointers (caller frees each)
|
||||
int* frame_sizes, // out: array of frame JPEG sizes
|
||||
int* num_frames, // out: actual number of frames extracted
|
||||
// Audio output
|
||||
uint8_t** audio_data, // out: PCM int16 data (caller frees)
|
||||
int* audio_size // out: PCM data size in bytes
|
||||
);
|
||||
|
||||
void free_ptr(void* p);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func extract(path string, opts Options) (*Result, error) {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
maxFrames := C.int(opts.MaxFrames)
|
||||
extractAudio := C.int(0)
|
||||
if opts.ExtractAudio {
|
||||
extractAudio = 1
|
||||
}
|
||||
|
||||
// Allocate output arrays
|
||||
frameData := make([]*C.uint8_t, opts.MaxFrames)
|
||||
frameSizes := make([]C.int, opts.MaxFrames)
|
||||
var numFrames C.int
|
||||
|
||||
var audioData *C.uint8_t
|
||||
var audioSize C.int
|
||||
|
||||
rc := C.extract_video_frames(
|
||||
cPath,
|
||||
maxFrames,
|
||||
extractAudio,
|
||||
(**C.uint8_t)(unsafe.Pointer(&frameData[0])),
|
||||
(*C.int)(unsafe.Pointer(&frameSizes[0])),
|
||||
&numFrames,
|
||||
&audioData,
|
||||
&audioSize,
|
||||
)
|
||||
if rc != 0 {
|
||||
return nil, fmt.Errorf("video extraction failed (code %d)", rc)
|
||||
}
|
||||
|
||||
result := &Result{}
|
||||
|
||||
// Copy frame data to Go slices and free C memory
|
||||
for i := 0; i < int(numFrames); i++ {
|
||||
if frameData[i] != nil && frameSizes[i] > 0 {
|
||||
size := int(frameSizes[i])
|
||||
data := C.GoBytes(unsafe.Pointer(frameData[i]), C.int(size))
|
||||
result.Frames = append(result.Frames, data)
|
||||
C.free_ptr(unsafe.Pointer(frameData[i]))
|
||||
}
|
||||
}
|
||||
|
||||
// Copy audio data and wrap in WAV header
|
||||
if audioData != nil && audioSize > 0 {
|
||||
pcm := C.GoBytes(unsafe.Pointer(audioData), audioSize)
|
||||
C.free_ptr(unsafe.Pointer(audioData))
|
||||
result.Audio = wrapPCMAsWAV(pcm, 16000, 1, 16)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// wrapPCMAsWAV wraps raw PCM int16 data in a WAV header.
|
||||
func wrapPCMAsWAV(pcm []byte, sampleRate, channels, bitsPerSample int) []byte {
|
||||
var buf bytes.Buffer
|
||||
dataSize := len(pcm)
|
||||
fileSize := 36 + dataSize
|
||||
|
||||
// RIFF header
|
||||
buf.WriteString("RIFF")
|
||||
binary.Write(&buf, binary.LittleEndian, int32(fileSize))
|
||||
buf.WriteString("WAVE")
|
||||
|
||||
// fmt chunk
|
||||
buf.WriteString("fmt ")
|
||||
binary.Write(&buf, binary.LittleEndian, int32(16)) // chunk size
|
||||
binary.Write(&buf, binary.LittleEndian, int16(1)) // PCM format
|
||||
binary.Write(&buf, binary.LittleEndian, int16(channels))
|
||||
binary.Write(&buf, binary.LittleEndian, int32(sampleRate))
|
||||
byteRate := sampleRate * channels * bitsPerSample / 8
|
||||
binary.Write(&buf, binary.LittleEndian, int32(byteRate))
|
||||
blockAlign := channels * bitsPerSample / 8
|
||||
binary.Write(&buf, binary.LittleEndian, int16(blockAlign))
|
||||
binary.Write(&buf, binary.LittleEndian, int16(bitsPerSample))
|
||||
|
||||
// data chunk
|
||||
buf.WriteString("data")
|
||||
binary.Write(&buf, binary.LittleEndian, int32(dataSize))
|
||||
buf.Write(pcm)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
154
video/video_darwin.m
Normal file
154
video/video_darwin.m
Normal file
@@ -0,0 +1,154 @@
|
||||
#import <AVFoundation/AVFoundation.h>
|
||||
#import <CoreGraphics/CoreGraphics.h>
|
||||
#import <CoreMedia/CoreMedia.h>
|
||||
#import <ImageIO/ImageIO.h>
|
||||
#import <UniformTypeIdentifiers/UniformTypeIdentifiers.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
void free_ptr(void* p) {
|
||||
free(p);
|
||||
}
|
||||
|
||||
// Convert a CGImage to JPEG data.
|
||||
static NSData* cgImageToJPEG(CGImageRef image) {
|
||||
NSMutableData *data = [NSMutableData data];
|
||||
CGImageDestinationRef dest = CGImageDestinationCreateWithData(
|
||||
(__bridge CFMutableDataRef)data, (__bridge CFStringRef)UTTypeJPEG.identifier, 1, NULL);
|
||||
if (!dest) return nil;
|
||||
|
||||
NSDictionary *props = @{(__bridge NSString *)kCGImageDestinationLossyCompressionQuality: @(0.85)};
|
||||
CGImageDestinationAddImage(dest, image, (__bridge CFDictionaryRef)props);
|
||||
CGImageDestinationFinalize(dest);
|
||||
CFRelease(dest);
|
||||
return data;
|
||||
}
|
||||
|
||||
int extract_video_frames(
|
||||
const char* path,
|
||||
int max_frames,
|
||||
int extract_audio,
|
||||
uint8_t** frame_data,
|
||||
int* frame_sizes,
|
||||
int* num_frames,
|
||||
uint8_t** audio_data,
|
||||
int* audio_size)
|
||||
{
|
||||
@autoreleasepool {
|
||||
*num_frames = 0;
|
||||
*audio_size = 0;
|
||||
*audio_data = NULL;
|
||||
|
||||
NSString *filePath = [NSString stringWithUTF8String:path];
|
||||
NSURL *fileURL = [NSURL fileURLWithPath:filePath];
|
||||
AVURLAsset *asset = [AVURLAsset URLAssetWithURL:fileURL options:nil];
|
||||
|
||||
// Get video duration
|
||||
CMTime duration = asset.duration;
|
||||
if (CMTIME_IS_INVALID(duration) || CMTimeGetSeconds(duration) <= 0) {
|
||||
return -1;
|
||||
}
|
||||
Float64 durationSecs = CMTimeGetSeconds(duration);
|
||||
|
||||
// Create image generator
|
||||
AVAssetImageGenerator *generator = [[AVAssetImageGenerator alloc]
|
||||
initWithAsset:asset];
|
||||
generator.appliesPreferredTrackTransform = YES;
|
||||
generator.requestedTimeToleranceBefore = kCMTimeZero;
|
||||
generator.requestedTimeToleranceAfter = kCMTimeZero;
|
||||
|
||||
// Calculate frame times evenly spaced across duration
|
||||
int frameCount = max_frames;
|
||||
if (durationSecs < frameCount) {
|
||||
frameCount = (int)durationSecs;
|
||||
}
|
||||
if (frameCount < 1) frameCount = 1;
|
||||
|
||||
// Extract frames using synchronous API.
|
||||
// Note: copyCGImageAtTime: is deprecated in macOS 15 in favor of the
|
||||
// async generateCGImagesAsynchronouslyForTimes:, but the async API
|
||||
// is incompatible with CGo (callbacks on arbitrary threads). The sync
|
||||
// API remains functional and is the safest approach for CGo callers.
|
||||
int extracted = 0;
|
||||
for (int i = 0; i < frameCount; i++) {
|
||||
Float64 t = (durationSecs * i) / frameCount;
|
||||
CMTime requestTime = CMTimeMakeWithSeconds(t, 600);
|
||||
CMTime actualTime;
|
||||
NSError *error = nil;
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
CGImageRef cgImage = [generator copyCGImageAtTime:requestTime
|
||||
actualTime:&actualTime
|
||||
error:&error];
|
||||
#pragma clang diagnostic pop
|
||||
if (!cgImage) continue;
|
||||
|
||||
NSData *jpegData = cgImageToJPEG(cgImage);
|
||||
CGImageRelease(cgImage);
|
||||
|
||||
if (!jpegData || jpegData.length == 0) continue;
|
||||
|
||||
uint8_t *buf = (uint8_t *)malloc(jpegData.length);
|
||||
memcpy(buf, jpegData.bytes, jpegData.length);
|
||||
frame_data[extracted] = buf;
|
||||
frame_sizes[extracted] = (int)jpegData.length;
|
||||
extracted++;
|
||||
}
|
||||
*num_frames = extracted;
|
||||
|
||||
// Extract audio if requested.
|
||||
if (extract_audio) {
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
NSArray<AVAssetTrack *> *audioTracks =
|
||||
[asset tracksWithMediaType:AVMediaTypeAudio];
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
if (audioTracks.count > 0) {
|
||||
NSError *error = nil;
|
||||
AVAssetReader *reader = [[AVAssetReader alloc]
|
||||
initWithAsset:asset error:&error];
|
||||
if (reader) {
|
||||
NSDictionary *settings = @{
|
||||
AVFormatIDKey: @(kAudioFormatLinearPCM),
|
||||
AVSampleRateKey: @(16000),
|
||||
AVNumberOfChannelsKey: @(1),
|
||||
AVLinearPCMBitDepthKey: @(16),
|
||||
AVLinearPCMIsFloatKey: @(NO),
|
||||
AVLinearPCMIsBigEndianKey: @(NO),
|
||||
};
|
||||
AVAssetReaderTrackOutput *output =
|
||||
[[AVAssetReaderTrackOutput alloc]
|
||||
initWithTrack:audioTracks[0]
|
||||
outputSettings:settings];
|
||||
[reader addOutput:output];
|
||||
|
||||
if ([reader startReading]) {
|
||||
NSMutableData *pcmData = [NSMutableData data];
|
||||
CMSampleBufferRef sampleBuffer;
|
||||
while ((sampleBuffer = [output copyNextSampleBuffer])) {
|
||||
CMBlockBufferRef blockBuffer =
|
||||
CMSampleBufferGetDataBuffer(sampleBuffer);
|
||||
size_t length = CMBlockBufferGetDataLength(blockBuffer);
|
||||
uint8_t *tmp = (uint8_t *)malloc(length);
|
||||
CMBlockBufferCopyDataBytes(blockBuffer, 0, length, tmp);
|
||||
[pcmData appendBytes:tmp length:length];
|
||||
free(tmp);
|
||||
CFRelease(sampleBuffer);
|
||||
}
|
||||
|
||||
if (pcmData.length > 0) {
|
||||
*audio_data = (uint8_t *)malloc(pcmData.length);
|
||||
memcpy(*audio_data, pcmData.bytes, pcmData.length);
|
||||
*audio_size = (int)pcmData.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
48
video/video_darwin_test.go
Normal file
48
video/video_darwin_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractRealVideo(t *testing.T) {
|
||||
// Skip if test video doesn't exist (CI environments)
|
||||
testVideo := ".tmp/test_video.mp4"
|
||||
// Try repo root
|
||||
if _, err := os.Stat(testVideo); err != nil {
|
||||
testVideo = "../.tmp/test_video.mp4"
|
||||
if _, err := os.Stat(testVideo); err != nil {
|
||||
t.Skip("test video not available")
|
||||
}
|
||||
}
|
||||
|
||||
result, err := Extract(testVideo, Options{MaxFrames: 2, ExtractAudio: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Extract failed: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Frames) == 0 {
|
||||
t.Fatal("no frames extracted")
|
||||
}
|
||||
if len(result.Frames) > 2 {
|
||||
t.Errorf("expected at most 2 frames, got %d", len(result.Frames))
|
||||
}
|
||||
|
||||
// Verify frames are valid JPEG
|
||||
for i, frame := range result.Frames {
|
||||
if len(frame) < 2 || frame[0] != 0xFF || frame[1] != 0xD8 {
|
||||
t.Errorf("frame %d is not valid JPEG (first bytes: %x)", i, frame[:min(4, len(frame))])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify audio is valid WAV
|
||||
if result.Audio == nil {
|
||||
t.Error("expected audio but got nil")
|
||||
} else if len(result.Audio) < 44 {
|
||||
t.Error("audio WAV too short")
|
||||
} else if string(result.Audio[:4]) != "RIFF" || string(result.Audio[8:12]) != "WAVE" {
|
||||
t.Error("audio is not valid WAV")
|
||||
}
|
||||
|
||||
t.Logf("Extracted %d frames, audio %d bytes", len(result.Frames), len(result.Audio))
|
||||
}
|
||||
162
video/video_linux.go
Normal file
162
video/video_linux.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func extract(path string, opts Options) (*Result, error) {
|
||||
// Check ffmpeg is available
|
||||
ffmpeg, err := exec.LookPath("ffmpeg")
|
||||
if err != nil {
|
||||
return nil, ErrFFmpegNotFound
|
||||
}
|
||||
|
||||
// Probe video duration
|
||||
duration, err := probeDuration(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to probe video: %w", err)
|
||||
}
|
||||
|
||||
frameCount := opts.MaxFrames
|
||||
if duration < float64(frameCount) {
|
||||
frameCount = int(duration)
|
||||
}
|
||||
if frameCount < 1 {
|
||||
frameCount = 1
|
||||
}
|
||||
|
||||
// Calculate FPS to get evenly spaced frames
|
||||
fps := float64(frameCount) / duration
|
||||
|
||||
result := &Result{}
|
||||
|
||||
// Extract frames as JPEG via pipe
|
||||
args := []string{
|
||||
"-i", path,
|
||||
"-vf", fmt.Sprintf("fps=%.4f", fps),
|
||||
"-frames:v", strconv.Itoa(frameCount),
|
||||
"-f", "image2pipe",
|
||||
"-c:v", "mjpeg",
|
||||
"-q:v", "5",
|
||||
"pipe:1",
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffmpeg, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg frame extraction failed: %s", stderr.String())
|
||||
}
|
||||
|
||||
// Split JPEG frames from the pipe output (each starts with FFD8, ends with FFD9)
|
||||
result.Frames = splitJPEGs(stdout.Bytes())
|
||||
|
||||
// Extract audio if requested
|
||||
if opts.ExtractAudio {
|
||||
audio, err := extractAudio(ffmpeg, path)
|
||||
if err == nil && len(audio) > 44 { // WAV header is 44 bytes
|
||||
result.Audio = audio
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// probeDuration uses ffprobe (or ffmpeg) to get the video duration in seconds.
|
||||
func probeDuration(path string) (float64, error) {
|
||||
ffprobe, err := exec.LookPath("ffprobe")
|
||||
if err != nil {
|
||||
// Fall back to ffmpeg -i which prints duration to stderr
|
||||
ffmpeg, _ := exec.LookPath("ffmpeg")
|
||||
cmd := exec.Command(ffmpeg, "-i", path)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Run() // Ignore error — ffmpeg -i always exits non-zero
|
||||
return parseDurationFromFFmpeg(stderr.String())
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffprobe,
|
||||
"-v", "quiet",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
path,
|
||||
)
|
||||
var stdout bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return strconv.ParseFloat(strings.TrimSpace(stdout.String()), 64)
|
||||
}
|
||||
|
||||
// parseDurationFromFFmpeg extracts duration from ffmpeg -i stderr output.
|
||||
func parseDurationFromFFmpeg(output string) (float64, error) {
|
||||
// Look for "Duration: HH:MM:SS.ss"
|
||||
idx := strings.Index(output, "Duration: ")
|
||||
if idx < 0 {
|
||||
return 0, fmt.Errorf("could not find duration in ffmpeg output")
|
||||
}
|
||||
durStr := output[idx+10:]
|
||||
commaIdx := strings.Index(durStr, ",")
|
||||
if commaIdx > 0 {
|
||||
durStr = durStr[:commaIdx]
|
||||
}
|
||||
durStr = strings.TrimSpace(durStr)
|
||||
|
||||
// Parse HH:MM:SS.ss
|
||||
parts := strings.Split(durStr, ":")
|
||||
if len(parts) != 3 {
|
||||
return 0, fmt.Errorf("unexpected duration format: %s", durStr)
|
||||
}
|
||||
hours, _ := strconv.ParseFloat(parts[0], 64)
|
||||
mins, _ := strconv.ParseFloat(parts[1], 64)
|
||||
secs, _ := strconv.ParseFloat(parts[2], 64)
|
||||
return hours*3600 + mins*60 + secs, nil
|
||||
}
|
||||
|
||||
// extractAudio extracts audio as 16kHz mono WAV.
|
||||
func extractAudio(ffmpeg, path string) ([]byte, error) {
|
||||
args := []string{
|
||||
"-i", path,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-f", "wav",
|
||||
"pipe:1",
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffmpeg, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg audio extraction failed: %s", stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
// splitJPEGs splits concatenated JPEG data into individual images.
|
||||
// Each JPEG starts with FF D8 and ends with FF D9.
|
||||
func splitJPEGs(data []byte) [][]byte {
|
||||
var frames [][]byte
|
||||
start := -1
|
||||
|
||||
for i := 0; i < len(data)-1; i++ {
|
||||
if data[i] == 0xFF && data[i+1] == 0xD8 {
|
||||
start = i
|
||||
} else if data[i] == 0xFF && data[i+1] == 0xD9 && start >= 0 {
|
||||
frames = append(frames, data[start:i+2])
|
||||
start = -1
|
||||
}
|
||||
}
|
||||
|
||||
return frames
|
||||
}
|
||||
131
video/video_test.go
Normal file
131
video/video_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsVideo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "mp4 header",
|
||||
data: make([]byte, 512), // will be filled with mp4 magic
|
||||
want: false, // zeros aren't video
|
||||
},
|
||||
{
|
||||
name: "jpeg data",
|
||||
data: append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, make([]byte, 508)...),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wav data",
|
||||
data: append([]byte("RIFF"), append(make([]byte, 4), []byte("WAVE")...)...),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
data: []byte{0x00, 0x01, 0x02},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsVideo(tt.data)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsVideo() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVideoContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
ct string
|
||||
want bool
|
||||
}{
|
||||
{"video/mp4", true},
|
||||
{"video/webm", true},
|
||||
{"video/quicktime", true},
|
||||
{"video/x-msvideo", true},
|
||||
{"image/jpeg", false},
|
||||
{"audio/wave", false},
|
||||
{"application/octet-stream", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ct, func(t *testing.T) {
|
||||
if got := IsVideoContentType(tt.ct); got != tt.want {
|
||||
t.Errorf("IsVideoContentType(%q) = %v, want %v", tt.ct, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVideoExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
ext string
|
||||
want bool
|
||||
}{
|
||||
{".mp4", true},
|
||||
{".MP4", true},
|
||||
{".webm", true},
|
||||
{".mov", true},
|
||||
{".avi", true},
|
||||
{".mkv", true},
|
||||
{".m4v", true},
|
||||
{".jpg", false},
|
||||
{".png", false},
|
||||
{".wav", false},
|
||||
{".txt", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ext, func(t *testing.T) {
|
||||
if got := IsVideoExtension(tt.ext); got != tt.want {
|
||||
t.Errorf("IsVideoExtension(%q) = %v, want %v", tt.ext, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDefaults(t *testing.T) {
|
||||
// Test that Options defaults are applied correctly
|
||||
opts := Options{}
|
||||
if opts.MaxFrames != 0 {
|
||||
t.Errorf("zero value MaxFrames should be 0, got %d", opts.MaxFrames)
|
||||
}
|
||||
|
||||
// Extract with a non-existent file should error
|
||||
_, err := Extract("/nonexistent/video.mp4", Options{})
|
||||
if err == nil {
|
||||
t.Error("expected error for non-existent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVideoWithRealContentType(t *testing.T) {
|
||||
// Verify that http.DetectContentType correctly identifies video types
|
||||
// by testing with real file magic bytes
|
||||
|
||||
// MP4 ftyp box: 4 bytes size + "ftyp" + brand
|
||||
mp4Data := make([]byte, 512)
|
||||
copy(mp4Data[0:], []byte{0x00, 0x00, 0x00, 0x20}) // size=32
|
||||
copy(mp4Data[4:], "ftypisom") // ftyp + brand
|
||||
copy(mp4Data[12:], []byte{0x00, 0x00, 0x02, 0x00}) // minor version
|
||||
|
||||
ct := http.DetectContentType(mp4Data)
|
||||
t.Logf("MP4 content type: %s", ct)
|
||||
// Note: Go's http.DetectContentType may or may not detect MP4.
|
||||
// The ftyp box detection depends on the Go version.
|
||||
}
|
||||
155
video/video_windows.go
Normal file
155
video/video_windows.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Windows implementation: shell out to ffmpeg, same as Linux.
|
||||
// Media Foundation via syscall is planned but ffmpeg is simpler for v1
|
||||
// and commonly available via winget/chocolatey/scoop.
|
||||
//
|
||||
// TODO: implement Media Foundation via golang.org/x/sys/windows for
|
||||
// zero-dependency video extraction on Windows.
|
||||
|
||||
func extract(path string, opts Options) (*Result, error) {
|
||||
ffmpeg, err := exec.LookPath("ffmpeg")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("video support requires ffmpeg; install it with: winget install ffmpeg (or scoop install ffmpeg)")
|
||||
}
|
||||
|
||||
duration, err := probeDuration(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to probe video: %w", err)
|
||||
}
|
||||
|
||||
frameCount := opts.MaxFrames
|
||||
if duration < float64(frameCount) {
|
||||
frameCount = int(duration)
|
||||
}
|
||||
if frameCount < 1 {
|
||||
frameCount = 1
|
||||
}
|
||||
|
||||
fps := float64(frameCount) / duration
|
||||
|
||||
result := &Result{}
|
||||
|
||||
args := []string{
|
||||
"-i", path,
|
||||
"-vf", fmt.Sprintf("fps=%.4f", fps),
|
||||
"-frames:v", strconv.Itoa(frameCount),
|
||||
"-f", "image2pipe",
|
||||
"-c:v", "mjpeg",
|
||||
"-q:v", "5",
|
||||
"pipe:1",
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffmpeg, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg frame extraction failed: %s", stderr.String())
|
||||
}
|
||||
|
||||
result.Frames = splitJPEGs(stdout.Bytes())
|
||||
|
||||
if opts.ExtractAudio {
|
||||
audio, err := extractAudio(ffmpeg, path)
|
||||
if err == nil && len(audio) > 44 {
|
||||
result.Audio = audio
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func probeDuration(path string) (float64, error) {
|
||||
ffprobe, err := exec.LookPath("ffprobe")
|
||||
if err != nil {
|
||||
ffmpeg, _ := exec.LookPath("ffmpeg")
|
||||
cmd := exec.Command(ffmpeg, "-i", path)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Run()
|
||||
return parseDurationFromFFmpeg(stderr.String())
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffprobe,
|
||||
"-v", "quiet",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
path,
|
||||
)
|
||||
var stdout bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return strconv.ParseFloat(strings.TrimSpace(stdout.String()), 64)
|
||||
}
|
||||
|
||||
func parseDurationFromFFmpeg(output string) (float64, error) {
|
||||
idx := strings.Index(output, "Duration: ")
|
||||
if idx < 0 {
|
||||
return 0, fmt.Errorf("could not find duration in ffmpeg output")
|
||||
}
|
||||
durStr := output[idx+10:]
|
||||
commaIdx := strings.Index(durStr, ",")
|
||||
if commaIdx > 0 {
|
||||
durStr = durStr[:commaIdx]
|
||||
}
|
||||
durStr = strings.TrimSpace(durStr)
|
||||
|
||||
parts := strings.Split(durStr, ":")
|
||||
if len(parts) != 3 {
|
||||
return 0, fmt.Errorf("unexpected duration format: %s", durStr)
|
||||
}
|
||||
hours, _ := strconv.ParseFloat(parts[0], 64)
|
||||
mins, _ := strconv.ParseFloat(parts[1], 64)
|
||||
secs, _ := strconv.ParseFloat(parts[2], 64)
|
||||
return hours*3600 + mins*60 + secs, nil
|
||||
}
|
||||
|
||||
func extractAudio(ffmpeg, path string) ([]byte, error) {
|
||||
args := []string{
|
||||
"-i", path,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-f", "wav",
|
||||
"pipe:1",
|
||||
}
|
||||
|
||||
cmd := exec.Command(ffmpeg, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg audio extraction failed: %s", stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
func splitJPEGs(data []byte) [][]byte {
|
||||
var frames [][]byte
|
||||
start := -1
|
||||
|
||||
for i := 0; i < len(data)-1; i++ {
|
||||
if data[i] == 0xFF && data[i+1] == 0xD8 {
|
||||
start = i
|
||||
} else if data[i] == 0xFF && data[i+1] == 0xD9 && start >= 0 {
|
||||
frames = append(frames, data[start:i+2])
|
||||
start = -1
|
||||
}
|
||||
}
|
||||
|
||||
return frames
|
||||
}
|
||||
Reference in New Issue
Block a user