cmd: simplify audio input to dropped file attachments

This commit is contained in:
jmorganca
2026-04-02 00:40:27 -07:00
parent 1cbe7950d6
commit 9c8bcecdb2
3 changed files with 35 additions and 648 deletions

View File

@@ -568,395 +568,6 @@ func hasListedModelName(models []api.ListModelResponse, name string) bool {
return false
}
// getMaxAudioSeconds extracts the max audio duration from model info metadata.
// Returns 0 if the model doesn't report audio limits.
func getMaxAudioSeconds(info *api.ShowResponse) int {
if info == nil || info.ModelInfo == nil {
return 0
}
// Look for {arch}.max_audio_seconds in ModelInfo.
for k, v := range info.ModelInfo {
if strings.HasSuffix(k, ".max_audio_seconds") {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
}
}
}
return 0
}
// ANSI escape helpers for transcription display.
const (
)
// TranscribeHandler implements `ollama transcribe MODEL`.
//
// Two modes:
// - Interactive (tty on stdin): spacebar start/stop with >>> prompt,
// slash commands (/set, /show, /load, /bye, /?), word-wrapped output.
// - Non-interactive (pipe/redirect): reads audio from stdin or records
// until Ctrl+C, transcribes, writes word-wrapped text to stdout.
func TranscribeHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
modelName := args[0]
interactive := term.IsTerminal(int(os.Stdin.Fd()))
// Pull model if needed and get model info.
showReq := &api.ShowRequest{Name: modelName}
info, err := client.Show(cmd.Context(), showReq)
if err != nil {
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{modelName}); err != nil {
return err
}
info, err = client.Show(cmd.Context(), showReq)
if err != nil {
return err
}
} else {
return err
}
}
language, _ := cmd.Flags().GetString("language")
opts := runOptions{
Model: modelName,
WordWrap: true,
Options: map[string]any{"temperature": 0},
Language: language,
}
transcribeAndDisplay := func(wav []byte) {
state := &displayResponseState{}
_, err := transcribeAudio(cmd, opts, wav, func(tok string) {
displayResponse(tok, opts.WordWrap, state)
})
if err != nil {
fmt.Fprintln(os.Stderr, "Transcription error:", err)
}
fmt.Println()
}
// --- Non-interactive mode ---
if !interactive {
audioData, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("read stdin: %w", err)
}
if len(audioData) > 44 {
// Pipe with data (at least WAV header size): transcribe and output.
transcribeAndDisplay(audioData)
return nil
}
// Empty stdin (< /dev/null or echo "" |): record until Ctrl+C.
recorder, err := NewAudioRecorder()
if err != nil {
return fmt.Errorf("audio input unavailable: %w", err)
}
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
if err := recorder.Start(); err != nil {
return fmt.Errorf("start recording: %w", err)
}
fmt.Fprintln(os.Stderr, "Recording... Press Ctrl+C to stop.")
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
<-sigCh
signal.Stop(sigCh)
recorder.Stop()
fmt.Fprintln(os.Stderr)
if wav := recorder.FlushWAV(); wav != nil {
transcribeAndDisplay(wav)
}
return nil
}
// --- Interactive mode ---
recorder, err := NewAudioRecorder()
if err != nil {
return fmt.Errorf("audio input unavailable: %w", err)
}
if maxSec := getMaxAudioSeconds(info); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Press Space to record (/? for help)",
})
if err != nil {
return err
}
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a different model")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Press Space to start/stop recording.")
fmt.Fprintln(os.Stderr, "")
}
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "")
}
// doTranscribeRecording is like doAudioRecording but polls TakeChunk()
// during recording to stream transcription of long recordings.
doTranscribeRecording := func() ([]byte, error) {
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
for {
r, err := scanner.ReadRaw()
if err != nil {
return nil, io.EOF
}
if r == 3 { // Ctrl+C
fmt.Print("\r\033[K")
fmt.Println("Use Ctrl + d or /bye to exit.")
return nil, nil
}
if r == 4 { // Ctrl+D
fmt.Println()
return nil, io.EOF
}
if r == ' ' {
fmt.Print("\r\033[K") // clear the prompt line
break
}
if r == '/' || (r >= 32 && r < 127) {
fmt.Print("\r\033[K")
return nil, errFallbackToText{prefill: string(r)}
}
}
if err := recorder.Start(); err != nil {
fmt.Println()
return nil, fmt.Errorf("start recording: %w", err)
}
// Poll for chunks in a background goroutine while recording.
chunkDone := make(chan struct{})
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-chunkDone:
return
case <-ticker.C:
if wav := recorder.TakeChunk(); wav != nil {
transcribeAndDisplay(wav)
}
}
}
}()
// Wait for Space to stop, Ctrl+C to discard, Ctrl+D to exit.
for {
r, err := scanner.ReadRaw()
if err != nil {
close(chunkDone)
recorder.Stop()
return nil, io.EOF
}
if r == 4 { // Ctrl+D
close(chunkDone)
recorder.Stop()
fmt.Println()
return nil, io.EOF
}
if r == 3 { // Ctrl+C
close(chunkDone)
recorder.Stop()
return nil, nil
}
if r == ' ' { // Space: stop recording
close(chunkDone)
recorder.Stop()
return recorder.FlushWAV(), nil
}
// Ignore other keys while recording.
}
}
for {
wav, err := doTranscribeRecording()
if err != nil {
var fallback errFallbackToText
if errors.As(err, &fallback) {
// User typed text instead of pressing Space.
line := fallback.prefill
if line == "/" {
// Need the rest of the command — read via readline.
scanner.Prefill = "/"
fullLine, err := scanner.Readline()
if errors.Is(err, io.EOF) {
fmt.Println()
return nil
}
if err != nil {
return err
}
line = fullLine
}
line = strings.TrimSpace(line)
switch {
case line == "/?" || line == "/help":
usage()
case strings.HasPrefix(line, "/? "):
arg := strings.TrimSpace(line[3:])
switch arg {
case "set":
usageSet()
default:
usage()
}
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) == 1 {
usageSet()
continue
}
switch args[1] {
case "wordwrap":
opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
cmd.Flags().Set("verbose", "true")
fmt.Println("Set 'verbose' mode.")
case "quiet":
cmd.Flags().Set("verbose", "false")
fmt.Println("Set 'quiet' mode.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
opts.Options[args[2]] = args[3]
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], args[3])
default:
fmt.Printf("Unknown option: %s\n", args[1])
usageSet()
}
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) == 1 {
args = append(args, "info")
}
showReq := &api.ShowRequest{Name: opts.Model}
resp, err := client.Show(cmd.Context(), showReq)
if err != nil {
fmt.Println("Error:", err)
continue
}
switch args[1] {
case "info":
if err := showInfo(resp, false, os.Stdout); err != nil {
fmt.Println("Error:", err)
}
case "license":
fmt.Println(resp.License)
case "parameters":
fmt.Println(resp.Parameters)
case "system":
fmt.Println(resp.System)
default:
fmt.Printf("Unknown show command: %s\n", args[1])
}
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModel := args[1]
fmt.Printf("Loading model '%s'\n", newModel)
showReq := &api.ShowRequest{Name: newModel}
newInfo, err := client.Show(cmd.Context(), showReq)
if err != nil {
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
fmt.Printf("error: model '%s' not found\n", newModel)
} else {
fmt.Println("Error:", err)
}
continue
}
// Verify audio capability.
hasAudio := false
for _, cap := range newInfo.Capabilities {
if cap == "audio" {
hasAudio = true
break
}
}
if !hasAudio {
fmt.Printf("error: model '%s' does not support audio input\n", newModel)
continue
}
opts.Model = newModel
if maxSec := getMaxAudioSeconds(newInfo); maxSec > 0 {
recorder.MaxChunkSeconds = maxSec - 2
}
case line == "/exit" || line == "/bye":
fmt.Println()
return nil
case line != "":
fmt.Printf("Unknown command: %s (type /? for help)\n", line)
}
continue
}
if errors.Is(err, io.EOF) {
fmt.Println()
return nil
}
fmt.Fprintf(os.Stderr, "Recording error: %v\n", err)
continue
}
if wav == nil {
// Ctrl+C during recording — discard and retry.
continue
}
// Transcribe the recording.
transcribeAndDisplay(wav)
}
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -1084,7 +695,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
@@ -1101,19 +713,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
opts.AudioCapable = slices.Contains(info.Capabilities, model.CapabilityAudio)
audioin, _ := cmd.Flags().GetBool("audioin")
if audioin {
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Warning: audio input disabled — %s does not support audio\n", opts.Model)
} else {
opts.AudioInput = true
opts.MultiModal = true // audio uses the multimodal pipeline
opts.MaxAudioSeconds = getMaxAudioSeconds(info)
}
}
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -1837,12 +1436,8 @@ type runOptions struct {
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
AudioInput bool
AudioCapable bool // model supports audio input
MaxAudioSeconds int // from model metadata; 0 = use default
Language string // language hint for transcription
KeepAlive *api.Duration
MultiModal bool
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
@@ -2568,7 +2163,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
runCmd.Flags().Bool("audioin", false, "Enable audio input via microphone (press Space to record)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
@@ -2576,16 +2170,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
runCmd.Flags().MarkHidden("imagegen")
transcribeCmd := &cobra.Command{
Use: "transcribe MODEL",
Short: "Transcribe audio to text using microphone",
Long: "Record audio via microphone and transcribe to text.\nPress Space to start/stop recording. Ctrl+D to exit.",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: TranscribeHandler,
}
transcribeCmd.Flags().String("language", "", "Language hint (e.g. en, es, fr)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
@@ -2706,7 +2290,6 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
transcribeCmd,
stopCmd,
pullCmd,
pushCmd,
@@ -2750,7 +2333,6 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
transcribeCmd,
stopCmd,
pullCmd,
pushCmd,

View File

@@ -12,7 +12,6 @@ import (
"regexp"
"slices"
"strings"
"time"
"github.com/spf13/cobra"
@@ -24,143 +23,6 @@ import (
"github.com/ollama/ollama/types/model"
)
// errFallbackToText is returned when the user types a non-space key in audio mode,
// indicating we should fall through to the normal text input.
type errFallbackToText struct {
prefill string
}
func (e errFallbackToText) Error() string { return "fallback to text" }
// doAudioRecording handles the spacebar-driven recording flow.
// Returns WAV bytes on success, nil to retry, or an error.
func doAudioRecording(scanner *readline.Instance, recorder *AudioRecorder) ([]byte, error) {
fmt.Print(">>> \033[90m◉ Press Space to record...\033[0m")
// Wait for spacebar to start.
for {
r, err := scanner.ReadRaw()
if err != nil {
return nil, io.EOF
}
if r == 3 { // Ctrl+C
fmt.Print("\r\033[K")
fmt.Println("Use Ctrl + d or /bye to exit.")
return nil, nil
}
if r == 4 { // Ctrl+D
fmt.Println()
return nil, io.EOF
}
if r == ' ' {
break
}
// User typed a regular character — fall back to text input with this char.
if r == '/' || (r >= 32 && r < 127) {
fmt.Print("\r\033[K") // clear the "Press Space" line
return nil, errFallbackToText{prefill: string(r)}
}
}
// Start recording.
if err := recorder.Start(); err != nil {
fmt.Println()
return nil, fmt.Errorf("start recording: %w", err)
}
// Show recording indicator with elapsed time.
done := make(chan struct{})
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
d := recorder.Duration()
fmt.Printf("\r>>> \033[91m◈ Recording... %.1fs\033[0m ", d.Seconds())
}
}
}()
// Wait for spacebar to stop.
for {
r, err := scanner.ReadRaw()
if err != nil {
close(done)
recorder.Stop()
return nil, io.EOF
}
if r == ' ' || r == 3 { // Space or Ctrl+C
break
}
}
close(done)
dur, _ := recorder.Stop()
fmt.Printf("\r>>> \033[90m◇ Recorded %.1fs\033[0m \n", dur.Seconds())
// Encode to WAV.
wav, err := recorder.WAV()
if err != nil {
return nil, err
}
return wav, nil
}
// tokenCallback is called for each streamed token. Return non-nil error to abort.
type tokenCallback func(token string)
// streamChat sends a chat request and streams tokens to the callback.
// Returns the full accumulated text.
func streamChat(cmd *cobra.Command, model string, messages []api.Message, onToken tokenCallback) (string, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return "", err
}
noThink := &api.ThinkValue{Value: false}
stream := true
req := &api.ChatRequest{
Model: model,
Messages: messages,
Stream: &stream,
Think: noThink,
Options: map[string]any{"temperature": 0},
}
var result strings.Builder
fn := func(response api.ChatResponse) error {
tok := response.Message.Content
result.WriteString(tok)
if onToken != nil {
onToken(tok)
}
return nil
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
return "", err
}
return strings.TrimSpace(result.String()), nil
}
// transcribeAudio sends audio to the model for transcription.
// onToken is called for each streamed token (may be nil for silent operation).
func transcribeAudio(cmd *cobra.Command, opts runOptions, audioData []byte, onToken tokenCallback) (string, error) {
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
if opts.Language != "" {
systemPrompt += " The audio is in " + opts.Language + "."
}
return streamChat(cmd, opts.Model, []api.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{audioData}},
}, onToken)
}
type MultilineState int
const (
@@ -177,11 +39,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
if opts.AudioCapable {
fmt.Fprintln(os.Stderr, " /audio Toggle voice input mode")
} else {
fmt.Fprintln(os.Stderr, " /audio (not supported by current model)")
}
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@@ -190,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
@@ -279,66 +136,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
audioMode := opts.AudioInput
var recorder *AudioRecorder
if audioMode {
var err error
recorder, err = NewAudioRecorder()
if err != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", err)
audioMode = false
} else {
if opts.MaxAudioSeconds > 0 {
recorder.MaxChunkSeconds = opts.MaxAudioSeconds - 2 // 2s headroom
}
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
}
for {
// Audio recording mode: wait for spacebar instead of text input.
if audioMode && recorder != nil {
audioData, err := doAudioRecording(scanner, recorder)
if err != nil {
if err == io.EOF {
fmt.Println()
return nil
}
// User typed a regular key — fall through to normal readline.
if fb, ok := err.(errFallbackToText); ok {
scanner.Prefill = fb.prefill
goto textInput
}
fmt.Fprintf(os.Stderr, "Audio error: %v\n", err)
continue
}
if audioData == nil {
continue
}
// Send audio as the user's input — the model hears and responds.
newMessage := api.Message{
Role: "user",
Images: []api.ImageData{audioData},
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") ||
strings.Contains(err.Error(), "invalid think value") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
continue
}
textInput:
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
@@ -676,29 +474,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} else {
usage()
}
case line == "/audio":
if !opts.AudioCapable {
fmt.Fprintf(os.Stderr, "Audio input not supported by %s\n", opts.Model)
continue
}
if audioMode {
audioMode = false
fmt.Fprintln(os.Stderr, "Voice input disabled.")
} else {
audioMode = true
if recorder == nil {
var recErr error
recorder, recErr = NewAudioRecorder()
if recErr != nil {
fmt.Fprintf(os.Stderr, "Audio input unavailable: %v\n", recErr)
audioMode = false
continue
}
}
opts.MultiModal = true
fmt.Fprintln(os.Stderr, "Voice input enabled. Press Space to record, Space again to send.")
}
continue
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/"):

View File

@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}