gemma4: add audio input support for run command

- /audio toggle in interactive mode for voice chat
- Platform-specific microphone recording (AVFoundation on macOS,
  PulseAudio/ALSA on Linux, WASAPI on Windows)
- Space to start/stop recording, automatic chunking for long audio
This commit is contained in:
Daniel Hiltgen
2026-04-01 08:33:45 -07:00
parent ebd70f73b7
commit 570c53859d
7 changed files with 1269 additions and 9 deletions

216
cmd/audio.go Normal file
View File

@@ -0,0 +1,216 @@
package cmd
import (
"encoding/binary"
"sync"
"time"
)
const (
audioSampleRate = 16000
audioChannels = 1
audioFrameSize = 1024 // samples per callback
)
// AudioRecorder captures audio from the default microphone.
// Platform-specific capture is provided by audioStream (audio_darwin.go, etc.).
type AudioRecorder struct {
stream audioStream
mu sync.Mutex
samples []float32
started time.Time
MaxChunkSeconds int // hard split limit in seconds; 0 means use default
}
// audioStream is the platform-specific audio capture interface.
type audioStream interface {
// Start begins capturing. Samples are delivered via the callback.
Start(callback func(samples []float32)) error
// Stop ends capturing and releases resources.
Stop() error
}
// NewAudioRecorder creates a recorder ready to capture from the default mic.
func NewAudioRecorder() (*AudioRecorder, error) {
stream, err := newAudioStream(audioSampleRate, audioChannels, audioFrameSize)
if err != nil {
return nil, err
}
return &AudioRecorder{stream: stream}, nil
}
// Start begins capturing audio from the microphone.
func (r *AudioRecorder) Start() error {
r.mu.Lock()
defer r.mu.Unlock()
r.samples = make([]float32, 0, audioSampleRate*30) // preallocate ~30s
r.started = time.Now()
return r.stream.Start(func(samples []float32) {
r.mu.Lock()
r.samples = append(r.samples, samples...)
r.mu.Unlock()
})
}
// Stop ends the recording and returns the duration.
func (r *AudioRecorder) Stop() (time.Duration, error) {
r.mu.Lock()
dur := time.Since(r.started)
r.mu.Unlock()
if r.stream != nil {
r.stream.Stop()
}
return dur, nil
}
// Duration returns how long the current recording has been running.
func (r *AudioRecorder) Duration() time.Duration {
r.mu.Lock()
defer r.mu.Unlock()
if r.started.IsZero() {
return 0
}
return time.Since(r.started)
}
// Chunking constants for live transcription.
const (
chunkTargetSamples = 8 * audioSampleRate // 8s — start yielding when silence found
chunkMinSamples = 5 * audioSampleRate // start scanning for silence at 5s
defaultMaxAudioSeconds = 28 // default hard split (just under typical 30s model cap)
silenceWindow = 800 // 50ms RMS window
)
func (r *AudioRecorder) maxChunk() int {
if r.MaxChunkSeconds > 0 {
return r.MaxChunkSeconds * audioSampleRate
}
return defaultMaxAudioSeconds * audioSampleRate
}
// TakeChunk checks if there are enough accumulated samples to yield a chunk.
// If so, it splits at the best silence boundary, removes the consumed samples
// from the buffer, and returns the chunk as WAV bytes. Returns nil if not enough
// audio has accumulated yet.
func (r *AudioRecorder) TakeChunk() []byte {
r.mu.Lock()
n := len(r.samples)
if n < chunkMinSamples {
r.mu.Unlock()
return nil
}
maxSamples := r.maxChunk()
if n < chunkTargetSamples && n < maxSamples {
r.mu.Unlock()
return nil
}
limit := n
if limit > maxSamples {
limit = maxSamples
}
splitAt := limit
bestEnergy := float64(1e30)
scanStart := limit - silenceWindow
scanEnd := chunkMinSamples
for pos := scanStart; pos >= scanEnd; pos -= silenceWindow / 2 {
end := pos + silenceWindow
if end > n {
end = n
}
var sumSq float64
for _, s := range r.samples[pos:end] {
sumSq += float64(s) * float64(s)
}
rms := sumSq / float64(end-pos)
if rms < bestEnergy {
bestEnergy = rms
splitAt = pos + silenceWindow/2
}
}
chunk := make([]float32, splitAt)
copy(chunk, r.samples[:splitAt])
remaining := make([]float32, n-splitAt)
copy(remaining, r.samples[splitAt:])
r.samples = remaining
r.mu.Unlock()
return encodeWAV(chunk, audioSampleRate, audioChannels)
}
// FlushWAV returns any remaining samples as WAV, clearing the buffer.
func (r *AudioRecorder) FlushWAV() []byte {
r.mu.Lock()
samples := r.samples
r.samples = nil
r.mu.Unlock()
if len(samples) == 0 {
return nil
}
return encodeWAV(samples, audioSampleRate, audioChannels)
}
// WAV encodes the captured samples as a WAV file in memory.
func (r *AudioRecorder) WAV() ([]byte, error) {
r.mu.Lock()
samples := make([]float32, len(r.samples))
copy(samples, r.samples)
r.mu.Unlock()
if len(samples) == 0 {
return nil, errNoAudio
}
return encodeWAV(samples, audioSampleRate, audioChannels), nil
}
// encodeWAV produces a 16-bit PCM WAV file from float32 samples.
func encodeWAV(samples []float32, sampleRate, channels int) []byte {
numSamples := len(samples)
bitsPerSample := 16
byteRate := sampleRate * channels * bitsPerSample / 8
blockAlign := channels * bitsPerSample / 8
dataSize := numSamples * blockAlign
buf := make([]byte, 44+dataSize)
copy(buf[0:4], "RIFF")
binary.LittleEndian.PutUint32(buf[4:8], uint32(36+dataSize))
copy(buf[8:12], "WAVE")
copy(buf[12:16], "fmt ")
binary.LittleEndian.PutUint32(buf[16:20], 16)
binary.LittleEndian.PutUint16(buf[20:22], 1)
binary.LittleEndian.PutUint16(buf[22:24], uint16(channels))
binary.LittleEndian.PutUint32(buf[24:28], uint32(sampleRate))
binary.LittleEndian.PutUint32(buf[28:32], uint32(byteRate))
binary.LittleEndian.PutUint16(buf[32:34], uint16(blockAlign))
binary.LittleEndian.PutUint16(buf[34:36], uint16(bitsPerSample))
copy(buf[36:40], "data")
binary.LittleEndian.PutUint32(buf[40:44], uint32(dataSize))
offset := 44
for _, s := range samples {
if s > 1.0 {
s = 1.0
} else if s < -1.0 {
s = -1.0
}
val := int16(s * 32767)
binary.LittleEndian.PutUint16(buf[offset:offset+2], uint16(val))
offset += 2
}
return buf
}

180
cmd/audio_darwin.go Normal file
View File

@@ -0,0 +1,180 @@
package cmd
/*
#cgo LDFLAGS: -framework CoreAudio -framework AudioToolbox
#include <AudioToolbox/AudioQueue.h>
#include <string.h>
// Callback context passed to AudioQueue.
typedef struct {
int ready; // set to 1 when a buffer is filled
} AQContext;
// C callback — re-enqueues the buffer so recording continues.
// Not static — must be visible to the linker for Go's function pointer.
void aqInputCallback(
void *inUserData,
AudioQueueRef inAQ,
AudioQueueBufferRef inBuffer,
const AudioTimeStamp *inStartTime,
UInt32 inNumberPacketDescriptions,
const AudioStreamPacketDescription *inPacketDescs)
{
// Re-enqueue the buffer immediately so recording continues.
AudioQueueEnqueueBuffer(inAQ, inBuffer, 0, NULL);
}
*/
import "C"
import (
"fmt"
"math"
"sync"
"time"
)
var errNoAudio = fmt.Errorf("no audio recorded")
const numAQBuffers = 3
type coreAudioStream struct {
queue C.AudioQueueRef
buffers [numAQBuffers]C.AudioQueueBufferRef
mu sync.Mutex
callback func(samples []float32)
running bool
pollDone chan struct{}
sampleRate int
channels int
frameSize int
}
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
return &coreAudioStream{
sampleRate: sampleRate,
channels: channels,
frameSize: frameSize,
}, nil
}
func (s *coreAudioStream) Start(callback func(samples []float32)) error {
s.mu.Lock()
defer s.mu.Unlock()
s.callback = callback
// Set up audio format: 16-bit signed integer PCM, mono, 16kHz.
var format C.AudioStreamBasicDescription
format.mSampleRate = C.Float64(s.sampleRate)
format.mFormatID = C.kAudioFormatLinearPCM
format.mFormatFlags = C.kLinearPCMFormatFlagIsSignedInteger | C.kLinearPCMFormatFlagIsPacked
format.mBitsPerChannel = 16
format.mChannelsPerFrame = C.UInt32(s.channels)
format.mBytesPerFrame = 2 * C.UInt32(s.channels)
format.mFramesPerPacket = 1
format.mBytesPerPacket = format.mBytesPerFrame
// Create the audio queue.
var status C.OSStatus
status = C.AudioQueueNewInput(
&format,
C.AudioQueueInputCallback(C.aqInputCallback),
nil, // user data
C.CFRunLoopRef(0), // NULL run loop — use internal thread
C.CFStringRef(0), // NULL run loop mode
0, // flags
&s.queue,
)
if status != 0 {
return fmt.Errorf("AudioQueueNewInput failed: %d", status)
}
// Allocate and enqueue buffers.
bufferBytes := C.UInt32(s.frameSize * int(format.mBytesPerFrame))
for i := range s.buffers {
status = C.AudioQueueAllocateBuffer(s.queue, bufferBytes, &s.buffers[i])
if status != 0 {
C.AudioQueueDispose(s.queue, C.true)
return fmt.Errorf("AudioQueueAllocateBuffer failed: %d", status)
}
status = C.AudioQueueEnqueueBuffer(s.queue, s.buffers[i], 0, nil)
if status != 0 {
C.AudioQueueDispose(s.queue, C.true)
return fmt.Errorf("AudioQueueEnqueueBuffer failed: %d", status)
}
}
// Start recording.
status = C.AudioQueueStart(s.queue, nil)
if status != 0 {
C.AudioQueueDispose(s.queue, C.true)
return fmt.Errorf("AudioQueueStart failed: %d", status)
}
s.running = true
s.pollDone = make(chan struct{})
// Poll buffers for data. AudioQueue re-enqueues in the C callback,
// so we read the data out periodically.
go s.pollLoop()
return nil
}
func (s *coreAudioStream) pollLoop() {
defer close(s.pollDone)
// Read at roughly frameSize intervals.
interval := time.Duration(float64(s.frameSize) / float64(s.sampleRate) * float64(time.Second))
if interval < 10*time.Millisecond {
interval = 10 * time.Millisecond
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
if !s.running {
s.mu.Unlock()
return
}
// Read available data from each buffer.
for i := range s.buffers {
buf := s.buffers[i]
if buf.mAudioDataByteSize > 0 {
numSamples := int(buf.mAudioDataByteSize) / 2 // 16-bit samples
if numSamples > 0 {
raw := (*[1 << 28]int16)(buf.mAudioData)[:numSamples:numSamples]
floats := make([]float32, numSamples)
for j, v := range raw {
floats[j] = float32(v) / float32(math.MaxInt16)
}
s.callback(floats)
}
buf.mAudioDataByteSize = 0
}
}
s.mu.Unlock()
}
}
func (s *coreAudioStream) Stop() error {
s.mu.Lock()
s.running = false
queue := s.queue
s.mu.Unlock()
if queue != nil {
C.AudioQueueStop(queue, C.true)
C.AudioQueueDispose(queue, C.true)
}
if s.pollDone != nil {
<-s.pollDone
}
return nil
}

275
cmd/audio_linux.go Normal file
View File

@@ -0,0 +1,275 @@
package cmd
/*
#cgo LDFLAGS: -ldl
#include <dlfcn.h>
#include <stdint.h>
#include <stdlib.h>
// Function pointer types for ALSA functions loaded at runtime.
typedef int (*pcm_open_fn)(void**, const char*, int, int);
typedef int (*pcm_simple_fn)(void*);
typedef long (*pcm_readi_fn)(void*, void*, unsigned long);
typedef int (*hw_malloc_fn)(void**);
typedef void (*hw_free_fn)(void*);
typedef int (*hw_any_fn)(void*, void*);
typedef int (*hw_set_int_fn)(void*, void*, int);
typedef int (*hw_set_uint_fn)(void*, void*, unsigned int);
typedef int (*hw_set_rate_fn)(void*, void*, unsigned int*, int*);
typedef int (*hw_set_period_fn)(void*, void*, unsigned long*, int*);
typedef int (*hw_apply_fn)(void*, void*);
typedef const char* (*strerror_fn)(int);
// Trampoline functions — call dynamically loaded ALSA symbols.
static int alsa_pcm_open(void* fn, void** h, const char* name, int stream, int mode) {
return ((pcm_open_fn)fn)(h, name, stream, mode);
}
static int alsa_pcm_close(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
static int alsa_pcm_prepare(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
static int alsa_pcm_drop(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
static long alsa_pcm_readi(void* fn, void* h, void* buf, unsigned long frames) {
return ((pcm_readi_fn)fn)(h, buf, frames);
}
static int alsa_hw_malloc(void* fn, void** p) { return ((hw_malloc_fn)fn)(p); }
static void alsa_hw_free(void* fn, void* p) { ((hw_free_fn)fn)(p); }
static int alsa_hw_any(void* fn, void* h, void* p) { return ((hw_any_fn)fn)(h, p); }
static int alsa_hw_set_access(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
static int alsa_hw_set_format(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
static int alsa_hw_set_channels(void* fn, void* h, void* p, unsigned int v) { return ((hw_set_uint_fn)fn)(h, p, v); }
static int alsa_hw_set_rate(void* fn, void* h, void* p, unsigned int* v, int* d) { return ((hw_set_rate_fn)fn)(h, p, v, d); }
static int alsa_hw_set_period(void* fn, void* h, void* p, unsigned long* v, int* d) { return ((hw_set_period_fn)fn)(h, p, v, d); }
static int alsa_hw_apply(void* fn, void* h, void* p) { return ((hw_apply_fn)fn)(h, p); }
static const char* alsa_strerror(void* fn, int e) { return ((strerror_fn)fn)(e); }
*/
import "C"
import (
"fmt"
"math"
"sync"
"time"
"unsafe"
)
var errNoAudio = fmt.Errorf("no audio recorded")
const (
sndPCMStreamCapture = 1
sndPCMAccessRWInterleaved = 3
sndPCMFormatS16LE = 2
)
var (
alsaLoadErr error
alsaOnce sync.Once
alsa alsaFuncs
)
type alsaFuncs struct {
pcmOpen, pcmClose, pcmPrepare, pcmDrop, pcmReadi unsafe.Pointer
hwMalloc, hwFree, hwAny unsafe.Pointer
hwSetAccess, hwSetFormat, hwSetChannels unsafe.Pointer
hwSetRate, hwSetPeriod, hwApply unsafe.Pointer
strerror unsafe.Pointer
}
func loadALSA() {
var lib unsafe.Pointer
for _, name := range []string{"libasound.so.2", "libasound.so"} {
cName := C.CString(name)
lib = C.dlopen(cName, C.RTLD_NOW)
C.free(unsafe.Pointer(cName))
if lib != nil {
break
}
}
if lib == nil {
alsaLoadErr = fmt.Errorf("audio capture unavailable: libasound.so not found")
return
}
sym := func(name string) unsafe.Pointer {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
return C.dlsym(lib, cName)
}
syms := []struct {
ptr *unsafe.Pointer
name string
}{
{&alsa.pcmOpen, "snd_pcm_open"},
{&alsa.pcmClose, "snd_pcm_close"},
{&alsa.pcmPrepare, "snd_pcm_prepare"},
{&alsa.pcmDrop, "snd_pcm_drop"},
{&alsa.pcmReadi, "snd_pcm_readi"},
{&alsa.hwMalloc, "snd_pcm_hw_params_malloc"},
{&alsa.hwFree, "snd_pcm_hw_params_free"},
{&alsa.hwAny, "snd_pcm_hw_params_any"},
{&alsa.hwSetAccess, "snd_pcm_hw_params_set_access"},
{&alsa.hwSetFormat, "snd_pcm_hw_params_set_format"},
{&alsa.hwSetChannels, "snd_pcm_hw_params_set_channels"},
{&alsa.hwSetRate, "snd_pcm_hw_params_set_rate_near"},
{&alsa.hwSetPeriod, "snd_pcm_hw_params_set_period_size_near"},
{&alsa.hwApply, "snd_pcm_hw_params"},
{&alsa.strerror, "snd_strerror"},
}
for _, s := range syms {
*s.ptr = sym(s.name)
if *s.ptr == nil {
alsaLoadErr = fmt.Errorf("audio capture unavailable: missing %s in libasound", s.name)
return
}
}
}
func alsaError(code C.int) string {
if alsa.strerror == nil {
return fmt.Sprintf("error %d", code)
}
return C.GoString(C.alsa_strerror(alsa.strerror, code))
}
type alsaStream struct {
handle unsafe.Pointer
mu sync.Mutex
callback func(samples []float32)
running bool
done chan struct{}
sampleRate int
channels int
frameSize int
}
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
alsaOnce.Do(loadALSA)
if alsaLoadErr != nil {
return nil, alsaLoadErr
}
return &alsaStream{
sampleRate: sampleRate,
channels: channels,
frameSize: frameSize,
}, nil
}
func (s *alsaStream) Start(callback func(samples []float32)) error {
s.mu.Lock()
defer s.mu.Unlock()
s.callback = callback
cName := C.CString("default")
defer C.free(unsafe.Pointer(cName))
rc := C.alsa_pcm_open(alsa.pcmOpen, (*unsafe.Pointer)(unsafe.Pointer(&s.handle)), cName, C.int(sndPCMStreamCapture), 0)
if rc < 0 {
return fmt.Errorf("snd_pcm_open: %s", alsaError(rc))
}
var hwParams unsafe.Pointer
C.alsa_hw_malloc(alsa.hwMalloc, (*unsafe.Pointer)(unsafe.Pointer(&hwParams)))
defer C.alsa_hw_free(alsa.hwFree, hwParams)
C.alsa_hw_any(alsa.hwAny, s.handle, hwParams)
if rc = C.alsa_hw_set_access(alsa.hwSetAccess, s.handle, hwParams, C.int(sndPCMAccessRWInterleaved)); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("set access: %s", alsaError(rc))
}
if rc = C.alsa_hw_set_format(alsa.hwSetFormat, s.handle, hwParams, C.int(sndPCMFormatS16LE)); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("set format: %s", alsaError(rc))
}
if rc = C.alsa_hw_set_channels(alsa.hwSetChannels, s.handle, hwParams, C.uint(s.channels)); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("set channels: %s", alsaError(rc))
}
rate := C.uint(s.sampleRate)
if rc = C.alsa_hw_set_rate(alsa.hwSetRate, s.handle, hwParams, &rate, nil); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("set rate: %s", alsaError(rc))
}
periodSize := C.ulong(s.frameSize)
if rc = C.alsa_hw_set_period(alsa.hwSetPeriod, s.handle, hwParams, &periodSize, nil); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("set period: %s", alsaError(rc))
}
if rc = C.alsa_hw_apply(alsa.hwApply, s.handle, hwParams); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("apply hw params: %s", alsaError(rc))
}
if rc = C.alsa_pcm_prepare(alsa.pcmPrepare, s.handle); rc < 0 {
C.alsa_pcm_close(alsa.pcmClose, s.handle)
return fmt.Errorf("prepare: %s", alsaError(rc))
}
s.running = true
s.done = make(chan struct{})
go s.captureLoop(int(periodSize))
return nil
}
func (s *alsaStream) captureLoop(periodSize int) {
defer close(s.done)
buf := make([]int16, periodSize*s.channels)
for {
s.mu.Lock()
if !s.running {
s.mu.Unlock()
return
}
handle := s.handle
s.mu.Unlock()
frames := C.alsa_pcm_readi(alsa.pcmReadi, handle, unsafe.Pointer(&buf[0]), C.ulong(periodSize))
if frames < 0 {
C.alsa_pcm_prepare(alsa.pcmPrepare, handle)
continue
}
if frames == 0 {
time.Sleep(5 * time.Millisecond)
continue
}
numSamples := int(frames) * s.channels
floats := make([]float32, numSamples)
for i := 0; i < numSamples; i++ {
floats[i] = float32(buf[i]) / float32(math.MaxInt16)
}
s.mu.Lock()
if s.callback != nil {
s.callback(floats)
}
s.mu.Unlock()
}
}
func (s *alsaStream) Stop() error {
s.mu.Lock()
s.running = false
handle := s.handle
s.handle = nil
s.mu.Unlock()
if s.done != nil {
<-s.done
}
if handle != nil {
C.alsa_pcm_drop(alsa.pcmDrop, handle)
C.alsa_pcm_close(alsa.pcmClose, handle)
}
return nil
}

288
cmd/audio_windows.go Normal file
View File

@@ -0,0 +1,288 @@
package cmd
import (
"fmt"
"math"
"sync"
"syscall"
"time"
"unsafe"
)
var errNoAudio = fmt.Errorf("no audio recorded")
// WASAPI COM GUIDs
var (
iidIMMDeviceEnumerator = guid{0xA95664D2, 0x9614, 0x4F35, [8]byte{0xA7, 0x46, 0xDE, 0x8D, 0xB6, 0x36, 0x17, 0xE6}}
clsidMMDeviceEnumerator = guid{0xBCDE0395, 0xE52F, 0x467C, [8]byte{0x8E, 0x3D, 0xC4, 0x57, 0x92, 0x91, 0x69, 0x2E}}
iidIAudioClient = guid{0x1CB9AD4C, 0xDBFA, 0x4C32, [8]byte{0xB1, 0x78, 0xC2, 0xF5, 0x68, 0xA7, 0x03, 0xB2}}
iidIAudioCaptureClient = guid{0xC8ADBD64, 0xE71E, 0x48A0, [8]byte{0xA4, 0xDE, 0x18, 0x5C, 0x39, 0x5C, 0xD3, 0x17}}
)
type guid struct {
Data1 uint32
Data2 uint16
Data3 uint16
Data4 [8]byte
}
// WAVEFORMATEX structure
type waveFormatEx struct {
FormatTag uint16
Channels uint16
SamplesPerSec uint32
AvgBytesPerSec uint32
BlockAlign uint16
BitsPerSample uint16
CbSize uint16
}
const (
wavePCM = 1
eCapture = 1
eConsole = 0
audclntSharemode = 0 // AUDCLNT_SHAREMODE_SHARED
audclntStreamflagsEventcallback = 0x00040000
coinitMultithreaded = 0x0
clsctxAll = 0x17
reftimesPerSec = 10000000 // 100ns units per second
reftimesPerMillis = 10000
)
var (
ole32 = syscall.NewLazyDLL("ole32.dll")
coInit = ole32.NewProc("CoInitializeEx")
coCreate = ole32.NewProc("CoCreateInstance")
)
type wasapiStream struct {
mu sync.Mutex
callback func(samples []float32)
running bool
done chan struct{}
sampleRate int
channels int
frameSize int
// COM interfaces (stored as uintptr for syscall)
enumerator uintptr
device uintptr
client uintptr
capture uintptr
}
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
return &wasapiStream{
sampleRate: sampleRate,
channels: channels,
frameSize: frameSize,
}, nil
}
func (s *wasapiStream) Start(callback func(samples []float32)) error {
s.mu.Lock()
defer s.mu.Unlock()
s.callback = callback
// Initialize COM
hr, _, _ := coInit.Call(0, uintptr(coinitMultithreaded))
// S_OK or S_FALSE (already initialized) are both fine
if hr != 0 && hr != 1 {
return fmt.Errorf("CoInitializeEx failed: 0x%08x", hr)
}
// Create device enumerator
hr, _, _ = coCreate.Call(
uintptr(unsafe.Pointer(&clsidMMDeviceEnumerator)),
0,
uintptr(clsctxAll),
uintptr(unsafe.Pointer(&iidIMMDeviceEnumerator)),
uintptr(unsafe.Pointer(&s.enumerator)),
)
if hr != 0 {
return fmt.Errorf("CoCreateInstance(MMDeviceEnumerator) failed: 0x%08x", hr)
}
// Get default capture device
// IMMDeviceEnumerator::GetDefaultAudioEndpoint is vtable index 4
hr = comCall(s.enumerator, 4, uintptr(eCapture), uintptr(eConsole), uintptr(unsafe.Pointer(&s.device)))
if hr != 0 {
return fmt.Errorf("GetDefaultAudioEndpoint failed: 0x%08x", hr)
}
// Activate IAudioClient
// IMMDevice::Activate is vtable index 3
hr = comCall(s.device, 3,
uintptr(unsafe.Pointer(&iidIAudioClient)),
uintptr(clsctxAll),
0,
uintptr(unsafe.Pointer(&s.client)),
)
if hr != 0 {
return fmt.Errorf("IMMDevice::Activate failed: 0x%08x", hr)
}
// Set up format: 16-bit PCM mono 16kHz
format := waveFormatEx{
FormatTag: wavePCM,
Channels: uint16(s.channels),
SamplesPerSec: uint32(s.sampleRate),
BitsPerSample: 16,
BlockAlign: uint16(2 * s.channels),
AvgBytesPerSec: uint32(s.sampleRate * 2 * s.channels),
CbSize: 0,
}
// Initialize audio client
// IAudioClient::Initialize is vtable index 3
bufferDuration := int64(reftimesPerSec) // 1 second buffer
hr = comCall(s.client, 3,
uintptr(audclntSharemode),
0, // stream flags
uintptr(bufferDuration),
0, // periodicity (0 = use default)
uintptr(unsafe.Pointer(&format)),
0, // audio session GUID (NULL = default)
)
if hr != 0 {
return fmt.Errorf("IAudioClient::Initialize failed: 0x%08x", hr)
}
// Get capture client
// IAudioClient::GetService is vtable index 8
hr = comCall(s.client, 8,
uintptr(unsafe.Pointer(&iidIAudioCaptureClient)),
uintptr(unsafe.Pointer(&s.capture)),
)
if hr != 0 {
return fmt.Errorf("IAudioClient::GetService failed: 0x%08x", hr)
}
// Start capture
// IAudioClient::Start is vtable index 6
hr = comCall(s.client, 6)
if hr != 0 {
return fmt.Errorf("IAudioClient::Start failed: 0x%08x", hr)
}
s.running = true
s.done = make(chan struct{})
go s.captureLoop()
return nil
}
func (s *wasapiStream) captureLoop() {
defer close(s.done)
ticker := time.NewTicker(20 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
if !s.running {
s.mu.Unlock()
return
}
// Read available packets
for {
var data uintptr
var numFrames uint32
var flags uint32
// IAudioCaptureClient::GetBuffer is vtable index 3
hr := comCall(s.capture, 3,
uintptr(unsafe.Pointer(&data)),
uintptr(unsafe.Pointer(&numFrames)),
uintptr(unsafe.Pointer(&flags)),
0, // device position (not needed)
0, // QPC position (not needed)
)
if hr != 0 || numFrames == 0 {
break
}
// Convert int16 samples to float32
samples := make([]float32, numFrames*uint32(s.channels))
raw := (*[1 << 28]int16)(unsafe.Pointer(data))[:len(samples):len(samples)]
for i, v := range raw {
samples[i] = float32(v) / float32(math.MaxInt16)
}
s.callback(samples)
// IAudioCaptureClient::ReleaseBuffer is vtable index 4
comCall(s.capture, 4, uintptr(numFrames))
}
s.mu.Unlock()
}
}
func (s *wasapiStream) Stop() error {
s.mu.Lock()
s.running = false
s.mu.Unlock()
if s.done != nil {
<-s.done
}
// IAudioClient::Stop is vtable index 7
if s.client != 0 {
comCall(s.client, 7)
}
// Release COM interfaces (IUnknown::Release is vtable index 2)
if s.capture != 0 {
comCall(s.capture, 2)
}
if s.client != 0 {
comCall(s.client, 2)
}
if s.device != 0 {
comCall(s.device, 2)
}
if s.enumerator != 0 {
comCall(s.enumerator, 2)
}
return nil
}
// comCall invokes a COM method by vtable index.
func comCall(obj uintptr, method uintptr, args ...uintptr) uintptr {
vtable := *(*uintptr)(unsafe.Pointer(obj))
fn := *(*uintptr)(unsafe.Pointer(vtable + method*unsafe.Sizeof(uintptr(0))))
// Build syscall args: first arg is always 'this' pointer
callArgs := make([]uintptr, 1+len(args))
callArgs[0] = obj
copy(callArgs[1:], args)
var hr uintptr
switch len(callArgs) {
case 1:
hr, _, _ = syscall.SyscallN(fn, callArgs[0])
case 2:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1])
case 3:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2])
case 4:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3])
case 5:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4])
case 6:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5])
case 7:
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5], callArgs[6])
default:
hr, _, _ = syscall.SyscallN(fn, callArgs...)
}
return hr
}

View File

@@ -568,6 +568,25 @@ func hasListedModelName(models []api.ListModelResponse, name string) bool {
return false
}
// getMaxAudioSeconds extracts the max audio duration from model info metadata.
// Returns 0 if the model doesn't report audio limits.
func getMaxAudioSeconds(info *api.ShowResponse) int {
if info == nil || info.ModelInfo == nil {
return 0
}
for k, v := range info.ModelInfo {
if strings.HasSuffix(k, ".max_audio_seconds") {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
}
}
}
return 0
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -712,6 +731,19 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
opts.AudioCapable = slices.Contains(info.Capabilities, model.CapabilityAudio)
audioin, _ := cmd.Flags().GetBool("audioin")
if audioin {
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Warning: audio input disabled — %s does not support audio\n", opts.Model)
} else {
opts.AudioInput = true
opts.MultiModal = true // audio uses the multimodal pipeline
opts.MaxAudioSeconds = getMaxAudioSeconds(info)
}
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -1435,8 +1467,12 @@ type runOptions struct {
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
MultiModal bool
AudioInput bool
AudioCapable bool // model supports audio input
MaxAudioSeconds int // from model metadata; 0 = use default
Language string // language hint for transcription
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
@@ -1494,6 +1530,9 @@ type displayResponseState struct {
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if termWidth == 0 {
termWidth = 80
}
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {
@@ -2159,6 +2198,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
runCmd.Flags().Bool("audioin", false, "Enable audio input via microphone (press Space to record)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)

View File

@@ -12,6 +12,7 @@ import (
"regexp"
"slices"
"strings"
"time"
"github.com/spf13/cobra"
@@ -23,6 +24,129 @@ import (
"github.com/ollama/ollama/types/model"
)
// errFallbackToText is returned when the user types a non-space key in audio mode,
// indicating we should fall through to the normal text input.
type errFallbackToText struct {
prefill string
}
func (e errFallbackToText) Error() string { return "fallback to text" }
// doAudioRecording handles the spacebar-driven recording flow.
// Returns WAV bytes on success, nil to retry, or an error.
func doAudioRecording(scanner *readline.Instance, recorder *AudioRecorder) ([]byte, error) {
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
// Wait for spacebar to start.
for {
r, err := scanner.ReadRaw()
if err != nil {
return nil, io.EOF
}
if r == 3 { // Ctrl+C
fmt.Print("\r\033[K")
fmt.Println("Use Ctrl + d or /bye to exit.")
return nil, nil
}
if r == 4 { // Ctrl+D
fmt.Println()
return nil, io.EOF
}
if r == ' ' {
break
}
// User typed a regular character — fall back to text input with this char.
if r == '/' || (r >= 32 && r < 127) {
fmt.Print("\r\033[K") // clear the "Press Space" line
return nil, errFallbackToText{prefill: string(r)}
}
}
// Start recording.
if err := recorder.Start(); err != nil {
fmt.Println()
return nil, fmt.Errorf("start recording: %w", err)
}
// Show recording indicator with elapsed time.
done := make(chan struct{})
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
d := recorder.Duration()
fmt.Printf("\r>>> \033[91m◈ Recording... %.1fs\033[0m ", d.Seconds())
}
}
}()
// Wait for spacebar to stop.
for {
r, err := scanner.ReadRaw()
if err != nil {
close(done)
recorder.Stop()
return nil, io.EOF
}
if r == ' ' || r == 3 { // Space or Ctrl+C
break
}
}
close(done)
dur, _ := recorder.Stop()
fmt.Printf("\r>>> \033[90m◇ Recorded %.1fs\033[0m \n", dur.Seconds())
// Encode to WAV.
wav, err := recorder.WAV()
if err != nil {
return nil, err
}
return wav, nil
}
// tokenCallback is called for each streamed token. Return non-nil error to abort.
type tokenCallback func(token string)
// streamChat sends a chat request and streams tokens to the callback.
// Returns the full accumulated text.
func streamChat(cmd *cobra.Command, model string, messages []api.Message, onToken tokenCallback) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
noThink := &api.ThinkValue{Value: false}
stream := true
req := &api.ChatRequest{
Model: model,
Messages: messages,
Stream: &stream,
Think: noThink,
Options: map[string]any{"temperature": 0},
}
var result strings.Builder
fn := func(response api.ChatResponse) error {
tok := response.Message.Content
result.WriteString(tok)
if onToken != nil {
onToken(tok)
}
return nil
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
return "", err
}
return strings.TrimSpace(result.String()), nil
}
type MultilineState int
const (
@@ -39,6 +163,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
if opts.AudioCapable {
fmt.Fprintln(os.Stderr, " /audio Toggle voice input mode")
} else {
fmt.Fprintln(os.Stderr, " /audio (not supported by current model)")
}
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@@ -136,7 +265,66 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
audioMode := opts.AudioInput
var recorder *AudioRecorder
if audioMode {
var err error
recorder, err = NewAudioRecorder()
if err != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", err)
audioMode = false
} else {
if opts.MaxAudioSeconds > 0 {
recorder.MaxChunkSeconds = opts.MaxAudioSeconds - 2 // 2s headroom
}
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
}
for {
// Audio recording mode: wait for spacebar instead of text input.
if audioMode && recorder != nil {
audioData, err := doAudioRecording(scanner, recorder)
if err != nil {
if err == io.EOF {
fmt.Println()
return nil
}
// User typed a regular key — fall through to normal readline.
if fb, ok := err.(errFallbackToText); ok {
scanner.Prefill = fb.prefill
goto textInput
}
fmt.Fprintf(os.Stderr, "Audio error: %v\n", err)
continue
}
if audioData == nil {
continue
}
// Send audio as the user's input — the model hears and responds.
newMessage := api.Message{
Role: "user",
Images: []api.ImageData{audioData},
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") ||
strings.Contains(err.Error(), "invalid think value") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
continue
}
textInput:
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
@@ -474,6 +662,29 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usage()
}
case line == "/audio":
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Audio input not supported by %s\n", opts.Model)
continue
}
if audioMode {
audioMode = false
fmt.Fprintln(os.Stderr, "Voice input disabled.")
} else {
audioMode = true
if recorder == nil {
var recErr error
recorder, recErr = NewAudioRecorder()
if recErr != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", recErr)
audioMode = false
continue
}
}
opts.MultiModal = true
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
continue
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/"):
@@ -592,7 +803,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav|mp4|webm|mov|avi|mkv|m4v)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
@@ -608,10 +819,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
return "", imgs, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
ext := strings.ToLower(filepath.Ext(nfp))
switch ext {
case ".wav":
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
default:
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
}
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
@@ -685,9 +902,9 @@ func getImageData(filePath string) ([]byte, error) {
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
return nil, fmt.Errorf("invalid file type: %s", contentType)
}
info, err := file.Stat()
@@ -695,8 +912,7 @@ func getImageData(filePath string) ([]byte, error) {
return nil, err
}
// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
var maxSize int64 = 100 * 1024 * 1024 // 100MB
if info.Size() > maxSize {
return nil, errors.New("file size exceeds maximum limit (100MB)")
}

View File

@@ -390,3 +390,48 @@ func (t *Terminal) Read() (rune, error) {
}
return r, nil
}
// SetRawModeOn enables raw terminal mode and keeps it on.
// Call SetRawModeOff to restore when done.
func (i *Instance) SetRawModeOn() error {
if i.Terminal.rawmode {
return nil
}
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return err
}
i.Terminal.rawmode = true
i.Terminal.termios = termios
return nil
}
// SetRawModeOff restores the terminal to its previous mode.
func (i *Instance) SetRawModeOff() {
if !i.Terminal.rawmode {
return
}
fd := os.Stdin.Fd()
//nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}
// ReadRaw reads a single rune. If the terminal is already in raw mode
// (via SetRawModeOn), it reads directly. Otherwise it temporarily enters
// raw mode for the read.
func (i *Instance) ReadRaw() (rune, error) {
if !i.Terminal.rawmode {
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return 0, err
}
defer func() {
//nolint:errcheck
UnsetRawMode(fd, termios)
}()
}
return i.Terminal.Read()
}