anthropic: fix KV cache reuse degraded by tool call argument reordering

Use typed structs for tool call arguments instead of map[string]any to
preserve JSON key order, which Go maps do not guarantee.
This commit is contained in:
Jesse Gross
2026-03-09 16:24:57 -07:00
parent e7ccc129ea
commit ac83ac20c4
5 changed files with 275 additions and 308 deletions

View File

@@ -68,7 +68,7 @@ type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"` Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock)
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
@@ -82,8 +82,27 @@ type MessagesRequest struct {
// MessageParam represents a message in the request // MessageParam represents a message in the request
type MessageParam struct { type MessageParam struct {
Role string `json:"role"` // "user" or "assistant" Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal
}
func (m *MessageParam) UnmarshalJSON(data []byte) error {
var raw struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
m.Role = raw.Role
var s string
if err := json.Unmarshal(raw.Content, &s); err == nil {
m.Content = []ContentBlock{{Type: "text", Text: &s}}
return nil
}
return json.Unmarshal(raw.Content, &m.Content)
} }
// ContentBlock represents a content block in a message. // ContentBlock represents a content block in a message.
@@ -102,9 +121,9 @@ type ContentBlock struct {
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// For tool_use and server_tool_use blocks // For tool_use and server_tool_use blocks
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"` Input api.ToolCallFunctionArguments `json:"input,omitempty"`
// For tool_result and web_search_tool_result blocks // For tool_result and web_search_tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"` ToolUseID string `json:"tool_use_id,omitempty"`
@@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message var messages []api.Message
role := strings.ToLower(msg.Role) role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) { var textContent strings.Builder
case string: var images []api.ImageData
messages = append(messages, api.Message{Role: role, Content: content}) var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
case []any: for _, block := range msg.Content {
var textContent strings.Builder switch block.Type {
var images []api.ImageData case "text":
var toolCalls []api.ToolCall textBlocks++
var thinking string if block.Text != nil {
var toolResults []api.Message textContent.WriteString(*block.Text)
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid content block format", "role", role)
return nil, errors.New("invalid content block format")
} }
blockType, _ := blockMap["type"].(string) case "image":
imageBlocks++
if block.Source == nil {
logutil.Trace("anthropic: invalid image source", "role", role)
return nil, errors.New("invalid image source")
}
switch blockType { if block.Source.Type == "base64" {
case "text": decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
textBlocks++ if err != nil {
if text, ok := blockMap["text"].(string); ok { logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
textContent.WriteString(text) return nil, fmt.Errorf("invalid base64 image data: %w", err)
} }
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
}
case "image": case "tool_use":
imageBlocks++ toolUseBlocks++
source, ok := blockMap["source"].(map[string]any) if block.ID == "" {
if !ok { logutil.Trace("anthropic: tool_use block missing id", "role", role)
logutil.Trace("anthropic: invalid image source", "role", role) return nil, errors.New("tool_use block missing required 'id' field")
return nil, errors.New("invalid image source") }
} if block.Name == "" {
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
sourceType, _ := source["type"].(string) case "tool_result":
if sourceType == "base64" { toolResultBlocks++
data, _ := source["data"].(string) var resultContent string
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use": switch c := block.Content.(type) {
toolUseBlocks++ case string:
id, ok := blockMap["id"].(string) resultContent = c
if !ok { case []any:
logutil.Trace("anthropic: tool_use block missing id", "role", role) for _, cb := range c {
return nil, errors.New("tool_use block missing required 'id' field") if cbMap, ok := cb.(map[string]any); ok {
} if cbMap["type"] == "text" {
name, ok := blockMap["name"].(string) if text, ok := cbMap["text"].(string); ok {
if !ok { resultContent += text
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
} }
} }
} }
} }
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
thinkingBlocks++
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
case "server_tool_use":
serverToolUseBlocks++
id, _ := blockMap["id"].(string)
name, _ := blockMap["name"].(string)
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "web_search_tool_result":
webSearchToolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(blockMap["content"]),
ToolCallID: toolUseID,
})
default:
unknownBlocks++
} }
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { toolResults = append(toolResults, api.Message{
m := api.Message{ Role: "tool",
Role: role, Content: resultContent,
Content: textContent.String(), ToolCallID: block.ToolUseID,
Images: images, })
ToolCalls: toolCalls,
Thinking: thinking, case "thinking":
thinkingBlocks++
if block.Thinking != nil {
thinking = *block.Thinking
} }
messages = append(messages, m)
case "server_tool_use":
serverToolUseBlocks++
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
case "web_search_tool_result":
webSearchToolResultBlocks++
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(block.Content),
ToolCallID: block.ToolUseID,
})
default:
unknownBlocks++
} }
// Add tool results as separate messages
messages = append(messages, toolResults...)
logutil.Trace("anthropic: converted block message",
"role", role,
"blocks", len(content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
} }
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
logutil.Trace("anthropic: converted block message",
"role", role,
"blocks", len(msg.Content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
return messages, nil return messages, nil
} }
@@ -892,7 +878,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
Type: "tool_use", Type: "tool_use",
ID: tc.ID, ID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Input: map[string]any{}, Input: api.ToolCallFunctionArguments{},
}, },
}, },
}) })
@@ -989,15 +975,6 @@ func ptr(s string) *string {
return &s return &s
} }
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request // CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct { type CountTokensRequest struct {
Model string `json:"model"` Model string `json:"model"`
@@ -1030,17 +1007,13 @@ func estimateTokens(req CountTokensRequest) int {
var totalLen int var totalLen int
// Count system prompt // Count system prompt
if req.System != nil { totalLen += countAnyContent(req.System)
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages { for _, msg := range req.Messages {
// Count role (always present) // Count role (always present)
totalLen += len(msg.Role) totalLen += len(msg.Role)
// Count content // Count content
contentLen := countAnyContent(msg.Content) totalLen += countAnyContent(msg.Content)
totalLen += contentLen
} }
for _, tool := range req.Tools { for _, tool := range req.Tools {
@@ -1063,12 +1036,25 @@ func countAnyContent(content any) int {
switch c := content.(type) { switch c := content.(type) {
case string: case string:
return len(c) return len(c)
case []any: case []ContentBlock:
total := 0 total := 0
for _, block := range c { for _, block := range c {
total += countContentBlock(block) total += countContentBlock(block)
} }
return total return total
case []any:
total := 0
for _, item := range c {
data, err := json.Marshal(item)
if err != nil {
continue
}
var block ContentBlock
if err := json.Unmarshal(data, &block); err == nil {
total += countContentBlock(block)
}
}
return total
default: default:
if data, err := json.Marshal(content); err == nil { if data, err := json.Marshal(content); err == nil {
return len(data) return len(data)
@@ -1077,38 +1063,19 @@ func countAnyContent(content any) int {
} }
} }
func countContentBlock(block any) int { func countContentBlock(block ContentBlock) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0 total := 0
blockType, _ := blockMap["type"].(string) if block.Text != nil {
total += len(*block.Text)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
} }
if block.Thinking != nil {
if thinking, ok := blockMap["thinking"].(string); ok { total += len(*block.Thinking)
total += len(thinking)
} }
if block.Type == "tool_use" || block.Type == "tool_result" {
if blockType == "tool_use" { if data, err := json.Marshal(block); err == nil {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data) total += len(data)
} }
} }
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
return total return total
} }

View File

@@ -15,11 +15,16 @@ const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
) )
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) // textContent is a convenience for constructing []ContentBlock with a single text block in tests.
func testArgs(m map[string]any) api.ToolCallFunctionArguments { func textContent(s string) []ContentBlock {
return []ContentBlock{{Type: "text", Text: &s}}
}
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments() args := api.NewToolCallFunctionArguments()
for k, v := range m { for i := 0; i < len(kvs)-1; i += 2 {
args.Set(k, v) args.Set(kvs[i].(string), kvs[i+1])
} }
return args return args
} }
@@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
MaxTokens: 1024, MaxTokens: 1024,
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
map[string]any{"type": "text", "text": " Be concise."}, map[string]any{"type": "text", "text": " Be concise."},
}, },
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 2048, MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Temperature: &temp, Temperature: &temp,
TopP: &topP, TopP: &topP,
TopK: &topK, TopK: &topK,
@@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{"type": "text", "text": "What's in this image?"}, {Type: "text", Text: ptr("What's in this image?")},
map[string]any{ {
"type": "image", Type: "image",
"source": map[string]any{ Source: &ImageSource{
"type": "base64", Type: "base64",
"media_type": "image/png", MediaType: "image/png",
"data": testImage, Data: testImage,
}, },
}, },
}, },
@@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"}, {Role: "user", Content: textContent("What's the weather in Paris?")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"id": "call_123", ID: "call_123",
"name": "get_weather", Name: "get_weather",
"input": map[string]any{"location": "Paris"}, Input: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_result", Type: "tool_result",
"tool_use_id": "call_123", ToolUseID: "call_123",
"content": "The weather in Paris is sunny, 22°C", Content: "The weather in Paris is sunny, 22°C",
}, },
}, },
}, },
@@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "get_weather", Name: "get_weather",
@@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "web_search_20250305", Type: "web_search_20250305",
@@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "custom", Type: "custom",
@@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
} }
@@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "thinking", Type: "thinking",
"thinking": "Let me think about this...", Thinking: ptr("Let me think about this..."),
}, },
}, },
}, },
@@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"name": "get_weather", Name: "get_weather",
}, },
}, },
}, },
@@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"id": "call_123", ID: "call_123",
}, },
}, },
}, },
@@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "bad_tool", Name: "bad_tool",
@@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -843,7 +848,7 @@ func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
ID: "call_abc", ID: "call_abc",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "ask_user", Name: "ask_user",
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}), Arguments: makeArgs("question", "cats or dogs?"),
}, },
}, },
}, },
@@ -965,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
ID: "call_good", ID: "call_good",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "good_function", Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
{ {
@@ -1140,7 +1145,7 @@ func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"}, {Role: "user", Content: textContent("Hello, world!")},
}, },
} }
@@ -1161,7 +1166,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
Model: "test-model", Model: "test-model",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -1177,7 +1182,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"}, {Role: "user", Content: textContent("What's the weather?")},
}, },
Tools: []Tool{ Tools: []Tool{
{ {
@@ -1200,17 +1205,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "thinking", Type: "thinking",
"thinking": "Let me think about this carefully...", Thinking: ptr("Let me think about this carefully..."),
}, },
map[string]any{ {
"type": "text", Type: "text",
"text": "Here is my response.", Text: ptr("Here is my response."),
}, },
}, },
}, },
@@ -1308,12 +1313,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
func TestConvertMessage_ServerToolUse(t *testing.T) { func TestConvertMessage_ServerToolUse(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "server_tool_use", Type: "server_tool_use",
"id": "srvtoolu_123", ID: "srvtoolu_123",
"name": "web_search", Name: "web_search",
"input": map[string]any{"query": "test query"}, Input: makeArgs("query", "test query"),
}, },
}, },
} }
@@ -1344,11 +1349,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
func TestConvertMessage_WebSearchToolResult(t *testing.T) { func TestConvertMessage_WebSearchToolResult(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_123", ToolUseID: "srvtoolu_123",
"content": []any{ Content: []any{
map[string]any{ map[string]any{
"type": "web_search_result", "type": "web_search_result",
"title": "Test Result", "title": "Test Result",
@@ -1385,11 +1390,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_empty", ToolUseID: "srvtoolu_empty",
"content": []any{}, Content: []any{},
}, },
}, },
} }
@@ -1416,11 +1421,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_error", ToolUseID: "srvtoolu_error",
"content": map[string]any{ Content: map[string]any{
"type": "web_search_tool_result_error", "type": "web_search_tool_result_error",
"error_code": "max_uses_exceeded", "error_code": "max_uses_exceeded",
}, },

View File

@@ -283,7 +283,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial
Type: "server_tool_use", Type: "server_tool_use",
ID: toolUseID, ID: toolUseID,
Name: "web_search", Name: "web_search",
Input: map[string]any{"query": query}, Input: queryArgs(query),
}, },
anthropic.ContentBlock{ anthropic.ContentBlock{
Type: "web_search_tool_result", Type: "web_search_tool_result",
@@ -348,7 +348,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial
Type: "server_tool_use", Type: "server_tool_use",
ID: maxLoopToolUseID, ID: maxLoopToolUseID,
Name: "web_search", Name: "web_search",
Input: map[string]any{"query": maxLoopQuery}, Input: queryArgs(maxLoopQuery),
}, },
anthropic.ContentBlock{ anthropic.ContentBlock{
Type: "web_search_tool_result", Type: "web_search_tool_result",
@@ -786,7 +786,7 @@ func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query strin
Type: "server_tool_use", Type: "server_tool_use",
ID: toolUseID, ID: toolUseID,
Name: "web_search", Name: "web_search",
Input: map[string]any{"query": query}, Input: queryArgs(query),
}, },
{ {
Type: "web_search_tool_result", Type: "web_search_tool_result",
@@ -942,6 +942,13 @@ func writeSSE(w http.ResponseWriter, eventType string, data any) error {
return nil return nil
} }
// queryArgs creates a ToolCallFunctionArguments with a single "query" key.
func queryArgs(query string) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
args.Set("query", query)
return args
}
// serverToolUseID derives a server tool use ID from a message ID // serverToolUseID derives a server tool use ID from a message ID
func serverToolUseID(messageID string) string { func serverToolUseID(messageID string) string {
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_") return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")

View File

@@ -1208,7 +1208,7 @@ func TestWebSearchStreamResponse(t *testing.T) {
Type: "server_tool_use", Type: "server_tool_use",
ID: "srvtoolu_test123", ID: "srvtoolu_test123",
Name: "web_search", Name: "web_search",
Input: map[string]any{"query": "test query"}, Input: queryArgs("test query"),
}, },
{ {
Type: "web_search_tool_result", Type: "web_search_tool_result",
@@ -1413,12 +1413,8 @@ func TestWebSearchSendError_NonStreaming(t *testing.T) {
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name) t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
} }
// Verify input contains the query // Verify input contains the query
inputMap, ok := result.Content[0].Input.(map[string]any) if q, ok := result.Content[0].Input.Get("query"); !ok || q != "test query" {
if !ok { t.Errorf("expected query 'test query', got %v", q)
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
}
if inputMap["query"] != "test query" {
t.Errorf("expected query 'test query', got %v", inputMap["query"])
} }
// Block 1: web_search_tool_result with error // Block 1: web_search_tool_result with error
@@ -1561,12 +1557,8 @@ func TestWebSearchSendError_EmptyQuery(t *testing.T) {
} }
// Verify the input has empty query // Verify the input has empty query
inputMap, ok := result.Content[0].Input.(map[string]any) if q, ok := result.Content[0].Input.Get("query"); !ok || q != "" {
if !ok { t.Errorf("expected empty query, got %v", q)
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
}
if inputMap["query"] != "" {
t.Errorf("expected empty query, got %v", inputMap["query"])
} }
} }

View File

@@ -328,8 +328,8 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) { func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var parsed struct { var parsed struct {
Name string `json:"name"` Name string `json:"name"`
Arguments map[string]any `json:"arguments"` Arguments api.ToolCallFunctionArguments `json:"arguments"`
} }
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil { if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
@@ -345,13 +345,9 @@ func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCa
toolCall := api.ToolCall{ toolCall := api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: parsed.Name, Name: parsed.Name,
Arguments: api.NewToolCallFunctionArguments(), Arguments: parsed.Arguments,
}, },
} }
for key, value := range parsed.Arguments {
toolCall.Function.Arguments.Set(key, value)
}
return toolCall, nil return toolCall, nil
} }