diff --git a/cmd/agent_loop_test.go b/cmd/agent_loop_test.go new file mode 100644 index 000000000..330703cb7 --- /dev/null +++ b/cmd/agent_loop_test.go @@ -0,0 +1,402 @@ +package cmd + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +// TestToolMessage verifies that tool messages are constructed correctly +// with ToolName and ToolCallID preserved from the tool call. +func TestToolMessage(t *testing.T) { + tests := []struct { + name string + call api.ToolCall + content string + expected api.Message + }{ + { + name: "basic tool message with ID", + call: api.ToolCall{ + ID: "call_abc123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + content: "Sunny, 22°C", + expected: api.Message{ + Role: "tool", + Content: "Sunny, 22°C", + ToolName: "get_weather", + ToolCallID: "call_abc123", + }, + }, + { + name: "tool message without ID", + call: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: api.ToolCallFunctionArguments{ + "expression": "2+2", + }, + }, + }, + content: "4", + expected: api.Message{ + Role: "tool", + Content: "4", + ToolName: "calculate", + // ToolCallID should be empty when call.ID is empty + }, + }, + { + name: "MCP tool message", + call: api.ToolCall{ + ID: "call_mcp123", + Function: api.ToolCallFunction{ + Name: "mcp_websearch_search", + Arguments: api.ToolCallFunctionArguments{ + "query": "ollama agents", + }, + }, + }, + content: "Found 10 results", + expected: api.Message{ + Role: "tool", + Content: "Found 10 results", + ToolName: "mcp_websearch_search", + ToolCallID: "call_mcp123", + }, + }, + { + name: "skill tool message", + call: api.ToolCall{ + ID: "call_skill456", + Function: api.ToolCallFunction{ + Name: "run_skill_script", + Arguments: api.ToolCallFunctionArguments{ + "skill": "calculator", + "command": "python scripts/calc.py 2+2", + }, + }, + }, + content: "Result: 4", + expected: api.Message{ + Role: "tool", + Content: "Result: 4", + ToolName: "run_skill_script", + ToolCallID: "call_skill456", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := toolMessage(tt.call, tt.content) + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("toolMessage() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestAssistantMessageWithThinking verifies that assistant messages +// in the tool loop should include thinking content. +func TestAssistantMessageConstruction(t *testing.T) { + tests := []struct { + name string + content string + thinking string + toolCalls []api.ToolCall + expectedMsg api.Message + }{ + { + name: "assistant with thinking and tool calls", + content: "", + thinking: "I need to check the weather for Paris.", + toolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + expectedMsg: api.Message{ + Role: "assistant", + Content: "", + Thinking: "I need to check the weather for Paris.", + ToolCalls: []api.ToolCall{ + { + ID: "call_1", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + }, + { + name: "assistant with content, thinking, and tool calls", + content: "Let me check that for you.", + thinking: "User wants weather info.", + toolCalls: []api.ToolCall{ + { + ID: "call_2", + Function: api.ToolCallFunction{ + Name: "search", + Arguments: api.ToolCallFunctionArguments{"query": "weather"}, + }, + }, + }, + expectedMsg: api.Message{ + Role: "assistant", + Content: "Let me check that for you.", + Thinking: "User wants weather info.", + ToolCalls: []api.ToolCall{ + { + ID: "call_2", + Function: api.ToolCallFunction{ + Name: "search", + Arguments: api.ToolCallFunctionArguments{"query": "weather"}, + }, + }, + }, + }, + }, + { + name: "assistant with multiple tool calls", + content: "", + thinking: "I'll check both cities.", + toolCalls: []api.ToolCall{ + { + ID: "call_a", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + { + ID: "call_b", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "London"}, + }, + }, + }, + expectedMsg: api.Message{ + Role: "assistant", + Content: "", + Thinking: "I'll check both cities.", + ToolCalls: []api.ToolCall{ + { + ID: "call_a", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + { + ID: "call_b", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "London"}, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Simulate the assistant message construction as done in chat() + assistantMsg := api.Message{ + Role: "assistant", + Content: tt.content, + Thinking: tt.thinking, + ToolCalls: tt.toolCalls, + } + + if diff := cmp.Diff(tt.expectedMsg, assistantMsg); diff != "" { + t.Errorf("assistant message mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestMessageStitchingOrder verifies that messages in a tool loop +// are stitched in the correct order: +// 1. User message +// 2. Assistant message with tool calls (and thinking) +// 3. Tool result messages (one per tool call, in order) +// 4. Next assistant response +func TestMessageStitchingOrder(t *testing.T) { + // Simulate a complete tool loop conversation + messages := []api.Message{ + // Initial user message + {Role: "user", Content: "What's the weather in Paris and London?"}, + // Assistant's first response with tool calls + { + Role: "assistant", + Content: "", + Thinking: "I need to check the weather for both cities.", + ToolCalls: []api.ToolCall{ + {ID: "call_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}}, + {ID: "call_2", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}}, + }, + }, + // Tool results (in order matching tool calls) + {Role: "tool", Content: "Sunny, 22°C", ToolName: "get_weather", ToolCallID: "call_1"}, + {Role: "tool", Content: "Rainy, 15°C", ToolName: "get_weather", ToolCallID: "call_2"}, + // Final assistant response + {Role: "assistant", Content: "Paris is sunny at 22°C, and London is rainy at 15°C.", Thinking: "Got the data, now summarizing."}, + } + + // Verify structure + expectedRoles := []string{"user", "assistant", "tool", "tool", "assistant"} + for i, msg := range messages { + if msg.Role != expectedRoles[i] { + t.Errorf("message %d: expected role %q, got %q", i, expectedRoles[i], msg.Role) + } + } + + // Verify tool results match tool calls in order + assistantWithTools := messages[1] + toolResults := []api.Message{messages[2], messages[3]} + + if len(toolResults) != len(assistantWithTools.ToolCalls) { + t.Errorf("expected %d tool results for %d tool calls", len(assistantWithTools.ToolCalls), len(toolResults)) + } + + for i, result := range toolResults { + expectedToolCallID := assistantWithTools.ToolCalls[i].ID + if result.ToolCallID != expectedToolCallID { + t.Errorf("tool result %d: expected ToolCallID %q, got %q", i, expectedToolCallID, result.ToolCallID) + } + expectedToolName := assistantWithTools.ToolCalls[i].Function.Name + if result.ToolName != expectedToolName { + t.Errorf("tool result %d: expected ToolName %q, got %q", i, expectedToolName, result.ToolName) + } + } + + // Verify thinking is present in assistant messages + if messages[1].Thinking == "" { + t.Error("first assistant message should have thinking content") + } + if messages[4].Thinking == "" { + t.Error("final assistant message should have thinking content") + } +} + +// TestMultiTurnToolLoop verifies message stitching across multiple +// tool call iterations. +func TestMultiTurnToolLoop(t *testing.T) { + messages := []api.Message{ + {Role: "user", Content: "What's 2+2 and also what's the weather in Paris?"}, + // First tool call: calculate + { + Role: "assistant", + Thinking: "I'll start with the calculation.", + ToolCalls: []api.ToolCall{ + {ID: "calc_1", Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expr": "2+2"}}}, + }, + }, + {Role: "tool", Content: "4", ToolName: "calculate", ToolCallID: "calc_1"}, + // Second tool call: weather + { + Role: "assistant", + Thinking: "Got the calculation. Now checking weather.", + ToolCalls: []api.ToolCall{ + {ID: "weather_1", Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}}, + }, + }, + {Role: "tool", Content: "Sunny, 20°C", ToolName: "get_weather", ToolCallID: "weather_1"}, + // Final response + {Role: "assistant", Content: "2+2 equals 4, and Paris is sunny at 20°C."}, + } + + // Count message types + roleCounts := map[string]int{} + for _, msg := range messages { + roleCounts[msg.Role]++ + } + + if roleCounts["user"] != 1 { + t.Errorf("expected 1 user message, got %d", roleCounts["user"]) + } + if roleCounts["assistant"] != 3 { + t.Errorf("expected 3 assistant messages, got %d", roleCounts["assistant"]) + } + if roleCounts["tool"] != 2 { + t.Errorf("expected 2 tool messages, got %d", roleCounts["tool"]) + } + + // Verify each tool message follows an assistant with matching tool call + for i, msg := range messages { + if msg.Role == "tool" { + // Find preceding assistant message with tool calls + var precedingAssistant *api.Message + for j := i - 1; j >= 0; j-- { + if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 { + precedingAssistant = &messages[j] + break + } + } + + if precedingAssistant == nil { + t.Errorf("tool message at index %d has no preceding assistant with tool calls", i) + continue + } + + // Verify tool result matches one of the tool calls + found := false + for _, tc := range precedingAssistant.ToolCalls { + if tc.ID == msg.ToolCallID { + found = true + break + } + } + if !found { + t.Errorf("tool message at index %d has ToolCallID %q not found in preceding tool calls", i, msg.ToolCallID) + } + } + } +} + +// TestSkillCatalogRunToolCallPreservesFields tests that skill catalog +// returns tool messages with correct fields. +func TestSkillCatalogToolMessageFields(t *testing.T) { + // Create a minimal test for toolMessage function + call := api.ToolCall{ + ID: "test_id_123", + Function: api.ToolCallFunction{ + Name: "run_skill_script", + Arguments: api.ToolCallFunctionArguments{ + "skill": "test-skill", + "command": "echo hello", + }, + }, + } + + msg := toolMessage(call, "hello") + + if msg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", msg.Role) + } + if msg.Content != "hello" { + t.Errorf("expected content 'hello', got %q", msg.Content) + } + if msg.ToolName != "run_skill_script" { + t.Errorf("expected ToolName 'run_skill_script', got %q", msg.ToolName) + } + if msg.ToolCallID != "test_id_123" { + t.Errorf("expected ToolCallID 'test_id_123', got %q", msg.ToolCallID) + } +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 68fee0323..8565a1d2b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1665,10 +1665,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { // Execute tool calls and continue the conversation fmt.Fprintf(os.Stderr, "\n") - // Add assistant's tool call message to history + // Add assistant's tool call message to history (include thinking for proper rendering) assistantMsg := api.Message{ Role: "assistant", Content: fullResponse.String(), + Thinking: thinkingContent.String(), ToolCalls: pendingToolCalls, } messages = append(messages, assistantMsg) @@ -1727,11 +1728,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content) } - // Add tool result to messages - toolResults = append(toolResults, api.Message{ - Role: "tool", - Content: result.Content, - }) + // Add tool result to messages (preserves ToolName, ToolCallID from result) + toolResults = append(toolResults, result) } // Add tool results to message history diff --git a/server/create.go b/server/create.go index f1ad2eeda..28a73df69 100644 --- a/server/create.go +++ b/server/create.go @@ -125,7 +125,10 @@ func (s *Server) CreateHandler(c *gin.Context) { ch <- gin.H{"error": err.Error()} } - if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") { + // Inherit config from base model (Renderer, Parser, Requires, Capabilities, etc.) + // This is especially important for cloud models which don't have GGUF files + // to detect capabilities from. + if err == nil && !remote { manifest, mErr := ParseNamedManifest(fromName) if mErr == nil && manifest.Config.Digest != "" { configPath, pErr := GetBlobsPath(manifest.Config.Digest) @@ -142,6 +145,29 @@ func (s *Server) CreateHandler(c *gin.Context) { if config.Requires == "" { config.Requires = baseConfig.Requires } + // Inherit capabilities for cloud/remote models + // (local models detect capabilities from GGUF file) + if len(config.Capabilities) == 0 && len(baseConfig.Capabilities) > 0 { + config.Capabilities = baseConfig.Capabilities + } + // Inherit remote host/model if base is a cloud model + if config.RemoteHost == "" && baseConfig.RemoteHost != "" { + config.RemoteHost = baseConfig.RemoteHost + } + if config.RemoteModel == "" && baseConfig.RemoteModel != "" { + config.RemoteModel = baseConfig.RemoteModel + } + // Inherit model family for proper rendering + if config.ModelFamily == "" && baseConfig.ModelFamily != "" { + config.ModelFamily = baseConfig.ModelFamily + } + if len(config.ModelFamilies) == 0 && len(baseConfig.ModelFamilies) > 0 { + config.ModelFamilies = baseConfig.ModelFamilies + } + // Inherit context length for cloud models + if config.ContextLen == 0 && baseConfig.ContextLen > 0 { + config.ContextLen = baseConfig.ContextLen + } } cfgFile.Close() }