diff --git a/cmd/cmd.go b/cmd/cmd.go index 3cecbbe2f..f56a1d4b7 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -520,6 +520,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { // Check for experimental flag isExperimental, _ := cmd.Flags().GetBool("experimental") + yoloMode, _ := cmd.Flags().GetBool("yolo") if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { @@ -547,9 +548,9 @@ func RunHandler(cmd *cobra.Command, args []string) error { } } - // Use experimental agent loop with + // Use experimental agent loop with tools if isExperimental { - return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive) + return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode) } return generateInteractive(cmd, opts) @@ -1764,6 +1765,7 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools") + runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)") stopCmd := &cobra.Command{ Use: "stop MODEL", diff --git a/readline/errors.go b/readline/errors.go index bb3fbd473..83c26e86e 100644 --- a/readline/errors.go +++ b/readline/errors.go @@ -6,6 +6,9 @@ import ( var ErrInterrupt = errors.New("Interrupt") +// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output +var ErrExpandOutput = errors.New("ExpandOutput") + type InterruptError struct { Line []rune } diff --git a/readline/readline.go b/readline/readline.go index c12327472..8bbef66fd 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -206,6 +206,9 @@ func (i *Instance) Readline() (string, error) { buf.DeleteBefore() case CharCtrlL: buf.ClearScreen() + case CharCtrlO: + // Ctrl+O - expand tool output + return "", ErrExpandOutput case CharCtrlW: buf.DeleteWord() case CharCtrlZ: diff --git a/readline/types.go b/readline/types.go index f4efa8d92..f1509f63f 100644 --- a/readline/types.go +++ b/readline/types.go @@ -18,6 +18,7 @@ const ( CharCtrlL = 12 CharEnter = 13 CharNext = 14 + CharCtrlO = 15 // Ctrl+O - used for expanding tool output CharPrev = 16 CharBckSearch = 18 CharFwdSearch = 19 diff --git a/x/agent/approval.go b/x/agent/approval.go index e2d429ac6..3a160e4d9 100644 --- a/x/agent/approval.go +++ b/x/agent/approval.go @@ -4,6 +4,7 @@ package agent import ( "fmt" "os" + "path" "path/filepath" "strings" "sync" @@ -179,6 +180,7 @@ func FormatDeniedResult(command string, pattern string) string { // extractBashPrefix extracts a prefix pattern from a bash command. // For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/" // For commands without path args, returns empty string. +// Paths with ".." traversal that escape the base directory return empty string for security. func extractBashPrefix(command string) string { // Split command by pipes and get the first part parts := strings.Split(command, "|") @@ -204,8 +206,8 @@ func extractBashPrefix(command string) string { return "" } - // Find the first path-like argument (must contain / or start with .) - // First pass: look for clear paths (containing / or starting with .) + // Find the first path-like argument (must contain / or \ or start with .) + // First pass: look for clear paths (containing path separators or starting with .) for _, arg := range fields[1:] { // Skip flags if strings.HasPrefix(arg, "-") { @@ -215,19 +217,49 @@ func extractBashPrefix(command string) string { if isNumeric(arg) { continue } - // Only process if it looks like a path (contains / or starts with .) - if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") { + // Only process if it looks like a path (contains / or \ or starts with .) + if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") { continue } - // If arg ends with /, it's a directory - use it directly - if strings.HasSuffix(arg, "/") { - return fmt.Sprintf("%s:%s", baseCmd, arg) + // Normalize to forward slashes for consistent cross-platform matching + arg = strings.ReplaceAll(arg, "\\", "/") + + // Security: reject absolute paths + if path.IsAbs(arg) { + return "" // Absolute path - don't create prefix } - // Get the directory part of a file path - dir := filepath.Dir(arg) + + // Normalize the path using stdlib path.Clean (resolves . and ..) + cleaned := path.Clean(arg) + + // Security: reject if cleaned path escapes to parent directory + if strings.HasPrefix(cleaned, "..") { + return "" // Path escapes - don't create prefix + } + + // Security: if original had "..", verify cleaned path didn't escape to sibling + // e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling) + if strings.Contains(arg, "..") { + origBase := strings.SplitN(arg, "/", 2)[0] + cleanedBase := strings.SplitN(cleaned, "/", 2)[0] + if origBase != cleanedBase { + return "" // Path escaped to sibling directory + } + } + + // Check if arg ends with / (explicit directory) + isDir := strings.HasSuffix(arg, "/") + + // Get the directory part + var dir string + if isDir { + dir = cleaned + } else { + dir = path.Dir(cleaned) + } + if dir == "." { - // Path is just a directory like "tools" or "src" (no trailing /) - return fmt.Sprintf("%s:%s/", baseCmd, arg) + return fmt.Sprintf("%s:./", baseCmd) } return fmt.Sprintf("%s:%s/", baseCmd, dir) } @@ -332,6 +364,8 @@ func AllowlistKey(toolName string, args map[string]any) string { } // IsAllowed checks if a tool/command is allowed (exact match or prefix match). +// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed, +// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions). func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool { a.mu.RLock() defer a.mu.RUnlock() @@ -342,12 +376,20 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool { return true } - // For bash commands, check prefix matches + // For bash commands, check prefix matches with hierarchical path support if toolName == "bash" { if cmd, ok := args["command"].(string); ok { prefix := extractBashPrefix(cmd) - if prefix != "" && a.prefixes[prefix] { - return true + if prefix != "" { + // Check exact prefix match first + if a.prefixes[prefix] { + return true + } + // Check hierarchical match: if any stored prefix is a parent of current prefix + // e.g., stored "cat:tools/" should match current "cat:tools/subdir/" + if a.matchesHierarchicalPrefix(prefix) { + return true + } } } } @@ -360,6 +402,40 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool { return false } +// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically. +// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/". +func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool { + // Split prefix into command and path parts (format: "cmd:path/") + colonIdx := strings.Index(currentPrefix, ":") + if colonIdx == -1 { + return false + } + currentCmd := currentPrefix[:colonIdx] + currentPath := currentPrefix[colonIdx+1:] + + for storedPrefix := range a.prefixes { + storedColonIdx := strings.Index(storedPrefix, ":") + if storedColonIdx == -1 { + continue + } + storedCmd := storedPrefix[:storedColonIdx] + storedPath := storedPrefix[storedColonIdx+1:] + + // Commands must match exactly + if currentCmd != storedCmd { + continue + } + + // Check if current path starts with stored path (hierarchical match) + // e.g., "tools/subdir/" starts with "tools/" + if strings.HasPrefix(currentPath, storedPath) { + return true + } + } + + return false +} + // AddToAllowlist adds a tool/command to the session allowlist. // For bash commands, it adds the prefix pattern instead of exact command. func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) { @@ -443,11 +519,12 @@ func formatToolDisplay(toolName string, args map[string]any) string { } } - // For web search, show query + // For web search, show query and internet notice if toolName == "web_search" { if query, ok := args["query"].(string); ok { sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName)) - sb.WriteString(fmt.Sprintf("Query: %s", query)) + sb.WriteString(fmt.Sprintf("Query: %s\n", query)) + sb.WriteString("Uses internet via ollama.com") return sb.String() } } @@ -951,3 +1028,79 @@ func FormatDenyResult(toolName string, reason string) string { } return fmt.Sprintf("User denied execution of %s.", toolName) } + +// PromptYesNo displays a simple Yes/No prompt and returns the user's choice. +// Returns true for Yes, false for No. +func PromptYesNo(question string) (bool, error) { + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + return false, err + } + defer term.Restore(fd, oldState) + + selected := 0 // 0 = Yes, 1 = No + options := []string{"Yes", "No"} + + // Hide cursor + fmt.Fprint(os.Stderr, "\033[?25l") + defer fmt.Fprint(os.Stderr, "\033[?25h") + + renderYesNo := func() { + // Move to start of line and clear + fmt.Fprintf(os.Stderr, "\r\033[K") + fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question) + for i, opt := range options { + if i == selected { + fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt) + } else { + fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt) + } + } + fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m") + } + + renderYesNo() + + buf := make([]byte, 3) + for { + n, err := os.Stdin.Read(buf) + if err != nil { + return false, err + } + + if n == 1 { + switch buf[0] { + case 'y', 'Y': + selected = 0 + renderYesNo() + case 'n', 'N': + selected = 1 + renderYesNo() + case '\r', '\n': // Enter + fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line + return selected == 0, nil + case 3: // Ctrl+C + fmt.Fprintf(os.Stderr, "\r\033[K") + return false, nil + case 27: // Escape - could be arrow key + // Read more bytes for arrow keys + continue + } + } else if n == 3 && buf[0] == 27 && buf[1] == 91 { + // Arrow keys + switch buf[2] { + case 'D': // Left + if selected > 0 { + selected-- + } + renderYesNo() + case 'C': // Right + if selected < len(options)-1 { + selected++ + } + renderYesNo() + } + } + } +} diff --git a/x/agent/approval_test.go b/x/agent/approval_test.go index 652ca8c3b..a05ea3d42 100644 --- a/x/agent/approval_test.go +++ b/x/agent/approval_test.go @@ -151,6 +151,27 @@ func TestExtractBashPrefix(t *testing.T) { command: "head -n 100", expected: "", }, + // Path traversal security tests + { + name: "path traversal - parent escape", + command: "cat tools/../../etc/passwd", + expected: "", // Should NOT create a prefix - path escapes + }, + { + name: "path traversal - deep escape", + command: "cat tools/a/b/../../../etc/passwd", + expected: "", // Normalizes to "../etc/passwd" - escapes + }, + { + name: "path traversal - absolute path", + command: "cat /etc/passwd", + expected: "", // Absolute paths should not create prefix + }, + { + name: "path with safe dotdot - normalized", + command: "cat tools/subdir/../file.go", + expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix + }, } for _, tt := range tests { @@ -164,6 +185,34 @@ func TestExtractBashPrefix(t *testing.T) { } } +func TestApprovalManager_PathTraversalBlocked(t *testing.T) { + am := NewApprovalManager() + + // Allow "cat tools/file.go" - creates prefix "cat:tools/" + am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"}) + + // Path traversal attack: should NOT be allowed + if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) { + t.Error("SECURITY: path traversal attack should NOT be allowed") + } + + // Another traversal variant + if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) { + t.Error("SECURITY: deep path traversal should NOT be allowed") + } + + // Valid subdirectory access should still work + if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) { + t.Error("expected cat tools/subdir/file.go to be allowed") + } + + // Safe ".." that normalizes to within allowed directory should work + // tools/subdir/../other.go normalizes to tools/other.go which is under tools/ + if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) { + t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)") + } +} + func TestApprovalManager_PrefixAllowlist(t *testing.T) { am := NewApprovalManager() @@ -186,6 +235,119 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) { } } +func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) { + am := NewApprovalManager() + + // Allow "cat tools/file.go" - this creates prefix "cat:tools/" + am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"}) + + // Should allow subdirectories (hierarchical matching) + if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) { + t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix") + } + + // Should allow deeply nested subdirectories + if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) { + t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix") + } + + // Should still allow same directory + if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) { + t.Error("expected cat tools/another.go to be allowed") + } + + // Should NOT allow different base directory + if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) { + t.Error("expected cat src/main.go to NOT be allowed") + } + + // Should NOT allow different command even in subdirectory + if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) { + t.Error("expected ls tools/subdir/ to NOT be allowed (different command)") + } + + // Should NOT allow similar but different directory name + if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) { + t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)") + } +} + +func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) { + am := NewApprovalManager() + + // Allow with forward slashes (Unix-style) + am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"}) + + // Should work with backslashes too (Windows-style) - normalized internally + if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) { + t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)") + } + + // Mixed slashes should also work + if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) { + t.Error("expected mixed slash path to be allowed via hierarchical prefix") + } +} + +func TestMatchesHierarchicalPrefix(t *testing.T) { + am := NewApprovalManager() + + // Add prefix for "cat:tools/" + am.prefixes["cat:tools/"] = true + + tests := []struct { + name string + prefix string + expected bool + }{ + { + name: "exact match", + prefix: "cat:tools/", + expected: true, // exact match also passes HasPrefix - caller handles exact match first + }, + { + name: "subdirectory", + prefix: "cat:tools/subdir/", + expected: true, + }, + { + name: "deeply nested", + prefix: "cat:tools/a/b/c/", + expected: true, + }, + { + name: "different base directory", + prefix: "cat:src/", + expected: false, + }, + { + name: "different command same path", + prefix: "ls:tools/", + expected: false, + }, + { + name: "similar directory name", + prefix: "cat:toolsbin/", + expected: false, + }, + { + name: "invalid prefix format", + prefix: "cattools", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := am.matchesHierarchicalPrefix(tt.prefix) + if result != tt.expected { + t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v", + tt.prefix, result, tt.expected) + } + }) + } +} + func TestFormatApprovalResult(t *testing.T) { tests := []struct { name string diff --git a/x/cmd/run.go b/x/cmd/run.go index 2a76a5592..178659e9f 100644 --- a/x/cmd/run.go +++ b/x/cmd/run.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "io" + "net/url" "os" "os/signal" "strings" "syscall" + "time" "github.com/spf13/cobra" "golang.org/x/term" @@ -22,6 +24,101 @@ import ( "github.com/ollama/ollama/x/tools" ) +// Tool output capping constants +const ( + // localModelTokenLimit is the token limit for local models (smaller context). + localModelTokenLimit = 4000 + + // defaultTokenLimit is the token limit for cloud/remote models. + defaultTokenLimit = 10000 + + // charsPerToken is a rough estimate of characters per token. + // TODO: Estimate tokens more accurately using tokenizer if available + charsPerToken = 4 +) + +// isLocalModel checks if the model is running locally (not a cloud model). +// TODO: Improve local/cloud model identification - could check model metadata +func isLocalModel(modelName string) bool { + return !strings.HasSuffix(modelName, "-cloud") +} + +// isLocalServer checks if connecting to a local Ollama server. +// TODO: Could also check other indicators of local vs cloud server +func isLocalServer() bool { + host := os.Getenv("OLLAMA_HOST") + if host == "" { + return true // Default is localhost:11434 + } + + // Parse the URL to check host + parsed, err := url.Parse(host) + if err != nil { + return true // If can't parse, assume local + } + + hostname := parsed.Hostname() + return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434") +} + +// truncateToolOutput truncates tool output to prevent context overflow. +// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote. +func truncateToolOutput(output, modelName string) string { + var tokenLimit int + if isLocalModel(modelName) && isLocalServer() { + tokenLimit = localModelTokenLimit + } else { + tokenLimit = defaultTokenLimit + } + + maxChars := tokenLimit * charsPerToken + if len(output) > maxChars { + return output[:maxChars] + "\n... (output truncated)" + } + return output +} + +// waitForOllamaSignin shows the signin URL and polls until authentication completes. +func waitForOllamaSignin(ctx context.Context) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + // Get signin URL from initial Whoami call + _, err = client.Whoami(ctx) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.SigninURL != "" { + fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n") + fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL) + fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m") + + // Poll until auth succeeds + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\n") + return ctx.Err() + case <-ticker.C: + user, whoamiErr := client.Whoami(ctx) + if whoamiErr == nil && user != nil && user.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name) + return nil + } + // Still waiting, show dot + fmt.Fprintf(os.Stderr, ".") + } + } + } + return err + } + return nil +} + // RunOptions contains options for running an interactive agent session. type RunOptions struct { Model string @@ -37,6 +134,16 @@ type RunOptions struct { // Agent fields (managed externally for session persistence) Tools *tools.Registry Approval *agent.ApprovalManager + + // YoloMode skips all tool approval prompts + YoloMode bool + + // LastToolOutput stores the full output of the last tool execution + // for Ctrl+O expansion. Updated by Chat(), read by caller. + LastToolOutput *string + + // LastToolOutputTruncated stores the truncated version shown inline + LastToolOutputTruncated *string } // Chat runs an agent chat loop with tool support. @@ -77,6 +184,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { var thinkTagOpened bool = false var thinkTagClosed bool = false var pendingToolCalls []api.ToolCall + var consecutiveErrors int // Track consecutive 500 errors for retry limit role := "assistant" messages := opts.Messages @@ -159,6 +267,58 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { return nil, nil } + // Check for 401 Unauthorized - prompt user to sign in + var authErr api.AuthorizationError + if errors.As(err, &authErr) { + p.StopAndClear() + fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n") + result, promptErr := agent.PromptYesNo("Sign in to Ollama?") + if promptErr == nil && result { + if signinErr := waitForOllamaSignin(ctx); signinErr == nil { + // Retry the chat request + fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n") + continue // Retry the loop + } + } + return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate") + } + + // Check for 500 errors (often tool parsing failures) - inform the model + var statusErr api.StatusError + if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 { + consecutiveErrors++ + p.StopAndClear() + + if consecutiveErrors >= 3 { + fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n") + return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage) + } + + fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage) + + // Include both the model's response and the error so it can learn + assistantContent := fullResponse.String() + if assistantContent == "" { + assistantContent = "(empty response)" + } + errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent) + messages = append(messages, + api.Message{Role: "user", Content: errorMsg}, + ) + + // Reset state and retry + fullResponse.Reset() + thinkingContent.Reset() + thinkTagOpened = false + thinkTagClosed = false + pendingToolCalls = nil + state = &displayResponseState{} + p = progress.NewProgress(os.Stderr) + spinner = progress.NewSpinner("") + p.Add("", spinner) + continue + } + if strings.Contains(err.Error(), "upstream error") { p.StopAndClear() fmt.Println("An error occurred while processing your message. Please try again.") @@ -168,6 +328,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { return nil, err } + // Reset consecutive error counter on success + consecutiveErrors = 0 + // If no tool calls, we're done if len(pendingToolCalls) == 0 || toolRegistry == nil { break @@ -216,7 +379,12 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { } // Check approval (uses prefix matching for bash commands) - if !skipApproval && !approval.IsAllowed(toolName, args) { + // In yolo mode, skip all approval prompts + if opts.YoloMode { + if !skipApproval { + fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args)) + } + } else if !skipApproval && !approval.IsAllowed(toolName, args) { result, err := approval.RequestApproval(toolName, args) if err != nil { fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err) @@ -250,6 +418,23 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { // Execute the tool toolResult, err := toolRegistry.Execute(call) if err != nil { + // Check if web search needs authentication + if errors.Is(err, tools.ErrWebSearchAuthRequired) { + // Prompt user to sign in + fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n") + result, promptErr := agent.PromptYesNo("Sign in to Ollama?") + if promptErr == nil && result { + // Get signin URL and wait for auth completion + if signinErr := waitForOllamaSignin(ctx); signinErr == nil { + // Retry the web search + fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n") + toolResult, err = toolRegistry.Execute(call) + if err == nil { + goto toolSuccess + } + } + } + } fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err) toolResults = append(toolResults, api.Message{ Role: "tool", @@ -258,20 +443,34 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { }) continue } + toolSuccess: // Display tool output (truncated for display) + truncatedOutput := "" if toolResult != "" { output := toolResult if len(output) > 300 { - output = output[:300] + "... (truncated)" + output = output[:300] + "... (truncated, press Ctrl+O to expand)" } + truncatedOutput = output // Show result in grey, indented fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n ")) } + // Store full and truncated output for Ctrl+O toggle + if opts.LastToolOutput != nil { + *opts.LastToolOutput = toolResult + } + if opts.LastToolOutputTruncated != nil { + *opts.LastToolOutputTruncated = truncatedOutput + } + + // Truncate output to prevent context overflow + toolResultForLLM := truncateToolOutput(toolResult, opts.Model) + toolResults = append(toolResults, api.Message{ Role: "tool", - Content: toolResult, + Content: toolResultForLLM, ToolCallID: call.ID, }) } @@ -449,7 +648,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool // GenerateInteractive runs an interactive agent session. // This is called from cmd.go when --experimental flag is set. -func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error { +// If yoloMode is true, all tool approvals are skipped. +func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error { scanner, err := readline.New(readline.Prompt{ Prompt: ">>> ", AltPrompt: "... ", @@ -474,11 +674,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op var toolRegistry *tools.Registry if supportsTools { toolRegistry = tools.DefaultRegistry() - fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", ")) - - // Check for OLLAMA_API_KEY for web search - if os.Getenv("OLLAMA_API_KEY") == "" { - fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n") + if toolRegistry.Count() > 0 { + fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", ")) + } + if yoloMode { + fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n") } } else { fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n") @@ -490,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op var messages []api.Message var sb strings.Builder + // Track last tool output for Ctrl+O toggle + var lastToolOutput string + var lastToolOutputTruncated string + var toolOutputExpanded bool + for { line, err := scanner.Readline() switch { @@ -502,6 +707,20 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op } sb.Reset() continue + case errors.Is(err, readline.ErrExpandOutput): + // Ctrl+O pressed - toggle between expanded and collapsed tool output + if lastToolOutput == "" { + fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n") + } else if toolOutputExpanded { + // Currently expanded, show truncated + fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n ")) + toolOutputExpanded = false + } else { + // Currently collapsed, show full + fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n ")) + toolOutputExpanded = true + } + continue case err != nil: return err } @@ -524,6 +743,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:") + fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output") + fmt.Fprintln(os.Stderr, "") continue case strings.HasPrefix(line, "/"): fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0]) @@ -537,16 +759,21 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op messages = append(messages, newMessage) opts := RunOptions{ - Model: modelName, - Messages: messages, - WordWrap: wordWrap, - Options: options, - Think: think, - HideThinking: hideThinking, - KeepAlive: keepAlive, - Tools: toolRegistry, - Approval: approval, + Model: modelName, + Messages: messages, + WordWrap: wordWrap, + Options: options, + Think: think, + HideThinking: hideThinking, + KeepAlive: keepAlive, + Tools: toolRegistry, + Approval: approval, + YoloMode: yoloMode, + LastToolOutput: &lastToolOutput, + LastToolOutputTruncated: &lastToolOutputTruncated, } + // Reset expanded state for new tool execution + toolOutputExpanded = false assistant, err := Chat(cmd.Context(), opts) if err != nil { diff --git a/x/cmd/run_test.go b/x/cmd/run_test.go new file mode 100644 index 000000000..a65e8cc80 --- /dev/null +++ b/x/cmd/run_test.go @@ -0,0 +1,180 @@ +package cmd + +import ( + "testing" +) + +func TestIsLocalModel(t *testing.T) { + tests := []struct { + name string + modelName string + expected bool + }{ + { + name: "local model without suffix", + modelName: "llama3.2", + expected: true, + }, + { + name: "local model with version", + modelName: "qwen2.5:7b", + expected: true, + }, + { + name: "cloud model", + modelName: "gpt-4-cloud", + expected: false, + }, + { + name: "cloud model with version", + modelName: "claude-3-cloud", + expected: false, + }, + { + name: "empty model name", + modelName: "", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isLocalModel(tt.modelName) + if result != tt.expected { + t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected) + } + }) + } +} + +func TestIsLocalServer(t *testing.T) { + tests := []struct { + name string + host string + expected bool + }{ + { + name: "empty host (default)", + host: "", + expected: true, + }, + { + name: "localhost", + host: "http://localhost:11434", + expected: true, + }, + { + name: "127.0.0.1", + host: "http://127.0.0.1:11434", + expected: true, + }, + { + name: "custom port on localhost", + host: "http://localhost:8080", + expected: true, // localhost is always considered local + }, + { + name: "remote host", + host: "http://ollama.example.com:11434", + expected: true, // has :11434 + }, + { + name: "remote host different port", + host: "http://ollama.example.com:8080", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.host) + result := isLocalServer() + if result != tt.expected { + t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected) + } + }) + } +} + +func TestTruncateToolOutput(t *testing.T) { + // Create outputs of different sizes + localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars) + defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars) + for i := range localLimitOutput { + localLimitOutput[i] = 'a' + } + for i := range defaultLimitOutput { + defaultLimitOutput[i] = 'b' + } + + tests := []struct { + name string + output string + modelName string + host string + shouldTrim bool + expectedLimit int + }{ + { + name: "short output local model", + output: "hello world", + modelName: "llama3.2", + host: "", + shouldTrim: false, + expectedLimit: localModelTokenLimit, + }, + { + name: "long output local model - trimmed at 4k", + output: string(localLimitOutput), + modelName: "llama3.2", + host: "", + shouldTrim: true, + expectedLimit: localModelTokenLimit, + }, + { + name: "long output cloud model - uses 10k limit", + output: string(localLimitOutput), // 20k chars, under 10k token limit + modelName: "gpt-4-cloud", + host: "", + shouldTrim: false, + expectedLimit: defaultTokenLimit, + }, + { + name: "very long output cloud model - trimmed at 10k", + output: string(defaultLimitOutput), + modelName: "gpt-4-cloud", + host: "", + shouldTrim: true, + expectedLimit: defaultTokenLimit, + }, + { + name: "long output remote server - uses 10k limit", + output: string(localLimitOutput), + modelName: "llama3.2", + host: "http://remote.example.com:8080", + shouldTrim: false, + expectedLimit: defaultTokenLimit, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.host) + result := truncateToolOutput(tt.output, tt.modelName) + + if tt.shouldTrim { + maxLen := tt.expectedLimit * charsPerToken + if len(result) > maxLen+50 { // +50 for the truncation message + t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result)) + } + if result == tt.output { + t.Error("expected output to be truncated but it wasn't") + } + } else { + if result != tt.output { + t.Error("expected output to not be truncated") + } + } + }) + } +} diff --git a/x/tools/registry.go b/x/tools/registry.go index f9136c9d7..fba0898b7 100644 --- a/x/tools/registry.go +++ b/x/tools/registry.go @@ -3,6 +3,7 @@ package tools import ( "fmt" + "os" "sort" "github.com/ollama/ollama/api" @@ -88,9 +89,16 @@ func (r *Registry) Count() int { } // DefaultRegistry creates a registry with all built-in tools. +// Tools can be disabled via environment variables: +// - OLLAMA_AGENT_DISABLE_WEBSEARCH=1 disables web_search +// - OLLAMA_AGENT_DISABLE_BASH=1 disables bash func DefaultRegistry() *Registry { r := NewRegistry() - r.Register(&WebSearchTool{}) - r.Register(&BashTool{}) + if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" { + r.Register(&WebSearchTool{}) + } + if os.Getenv("OLLAMA_AGENT_DISABLE_BASH") == "" { + r.Register(&BashTool{}) + } return r } diff --git a/x/tools/registry_test.go b/x/tools/registry_test.go index 59539c721..a37410936 100644 --- a/x/tools/registry_test.go +++ b/x/tools/registry_test.go @@ -108,6 +108,57 @@ func TestDefaultRegistry(t *testing.T) { } } +func TestDefaultRegistry_DisableWebsearch(t *testing.T) { + t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1") + + r := DefaultRegistry() + + if r.Count() != 1 { + t.Errorf("expected 1 tool with websearch disabled, got %d", r.Count()) + } + + _, ok := r.Get("bash") + if !ok { + t.Error("expected bash tool in registry") + } + + _, ok = r.Get("web_search") + if ok { + t.Error("expected web_search to be disabled") + } +} + +func TestDefaultRegistry_DisableBash(t *testing.T) { + t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1") + + r := DefaultRegistry() + + if r.Count() != 1 { + t.Errorf("expected 1 tool with bash disabled, got %d", r.Count()) + } + + _, ok := r.Get("web_search") + if !ok { + t.Error("expected web_search tool in registry") + } + + _, ok = r.Get("bash") + if ok { + t.Error("expected bash to be disabled") + } +} + +func TestDefaultRegistry_DisableBoth(t *testing.T) { + t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1") + t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1") + + r := DefaultRegistry() + + if r.Count() != 0 { + t.Errorf("expected 0 tools with both disabled, got %d", r.Count()) + } +} + func TestBashTool_Schema(t *testing.T) { tool := &BashTool{} diff --git a/x/tools/websearch.go b/x/tools/websearch.go index 04c3578e1..16b0dde2c 100644 --- a/x/tools/websearch.go +++ b/x/tools/websearch.go @@ -2,15 +2,19 @@ package tools import ( "bytes" + "context" "encoding/json" + "errors" "fmt" "io" "net/http" - "os" + "net/url" + "strconv" "strings" "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" ) const ( @@ -18,6 +22,9 @@ const ( webSearchTimeout = 15 * time.Second ) +// ErrWebSearchAuthRequired is returned when web search requires authentication +var ErrWebSearchAuthRequired = errors.New("web search requires authentication") + // WebSearchTool implements web search using Ollama's hosted API. type WebSearchTool struct{} @@ -68,17 +75,13 @@ type webSearchResult struct { } // Execute performs the web search. +// Uses Ollama key signing for authentication - this makes requests via ollama.com API. func (w *WebSearchTool) Execute(args map[string]any) (string, error) { query, ok := args["query"].(string) if !ok || query == "" { return "", fmt.Errorf("query parameter is required") } - apiKey := os.Getenv("OLLAMA_API_KEY") - if apiKey == "" { - return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search") - } - // Prepare request reqBody := webSearchRequest{ Query: query, @@ -90,13 +93,34 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) { return "", fmt.Errorf("marshaling request: %w", err) } - req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody)) + // Parse URL and add timestamp for signing + searchURL, err := url.Parse(webSearchAPI) + if err != nil { + return "", fmt.Errorf("parsing search URL: %w", err) + } + + q := searchURL.Query() + q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) + searchURL.RawQuery = q.Encode() + + // Sign the request using Ollama key (~/.ollama/id_ed25519) + // This authenticates with ollama.com using the local signing key + ctx := context.Background() + data := fmt.Appendf(nil, "%s,%s", http.MethodPost, searchURL.RequestURI()) + signature, err := auth.Sign(ctx, data) + if err != nil { + return "", fmt.Errorf("signing request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, searchURL.String(), bytes.NewBuffer(jsonBody)) if err != nil { return "", fmt.Errorf("creating request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + if signature != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature)) + } // Send request client := &http.Client{Timeout: webSearchTimeout} @@ -111,6 +135,9 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) { return "", fmt.Errorf("reading response: %w", err) } + if resp.StatusCode == http.StatusUnauthorized { + return "", ErrWebSearchAuthRequired + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body)) } diff --git a/x/tools/websearch_test.go b/x/tools/websearch_test.go new file mode 100644 index 000000000..8f5774728 --- /dev/null +++ b/x/tools/websearch_test.go @@ -0,0 +1,58 @@ +package tools + +import ( + "errors" + "testing" +) + +func TestWebSearchTool_Name(t *testing.T) { + tool := &WebSearchTool{} + if tool.Name() != "web_search" { + t.Errorf("expected name 'web_search', got '%s'", tool.Name()) + } +} + +func TestWebSearchTool_Description(t *testing.T) { + tool := &WebSearchTool{} + if tool.Description() == "" { + t.Error("expected non-empty description") + } +} + +func TestWebSearchTool_Execute_MissingQuery(t *testing.T) { + tool := &WebSearchTool{} + + // Test with no query + _, err := tool.Execute(map[string]any{}) + if err == nil { + t.Error("expected error for missing query") + } + + // Test with empty query + _, err = tool.Execute(map[string]any{"query": ""}) + if err == nil { + t.Error("expected error for empty query") + } +} + +func TestErrWebSearchAuthRequired(t *testing.T) { + // Test that the error type exists and can be checked with errors.Is + err := ErrWebSearchAuthRequired + if err == nil { + t.Fatal("ErrWebSearchAuthRequired should not be nil") + } + + if err.Error() != "web search requires authentication" { + t.Errorf("unexpected error message: %s", err.Error()) + } + + // Test that errors.Is works + wrappedErr := errors.New("wrapped: " + err.Error()) + if errors.Is(wrappedErr, ErrWebSearchAuthRequired) { + t.Error("wrapped error should not match with errors.Is") + } + + if !errors.Is(ErrWebSearchAuthRequired, ErrWebSearchAuthRequired) { + t.Error("ErrWebSearchAuthRequired should match itself with errors.Is") + } +}