mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
anthropic: fix KV cache reuse degraded by tool call argument reordering
Use typed structs for tool call arguments instead of map[string]any to preserve JSON key order, which Go maps do not guarantee.
This commit is contained in:
@@ -68,7 +68,7 @@ type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock)
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
@@ -82,8 +82,27 @@ type MessagesRequest struct {
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal
|
||||
}
|
||||
|
||||
func (m *MessageParam) UnmarshalJSON(data []byte) error {
|
||||
var raw struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
m.Role = raw.Role
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw.Content, &s); err == nil {
|
||||
m.Content = []ContentBlock{{Type: "text", Text: &s}}
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(raw.Content, &m.Content)
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
@@ -102,9 +121,9 @@ type ContentBlock struct {
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input api.ToolCallFunctionArguments `json:"input,omitempty"`
|
||||
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
@@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if block.Text != nil {
|
||||
textContent.WriteString(*block.Text)
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
case "image":
|
||||
imageBlocks++
|
||||
if block.Source == nil {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
if block.Source.Type == "base64" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
if block.ID == "" {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
if block.Name == "" {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: block.Name,
|
||||
Arguments: block.Input,
|
||||
},
|
||||
})
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
var resultContent string
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
switch c := block.Content.(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if block.Thinking != nil {
|
||||
thinking = *block.Thinking
|
||||
}
|
||||
messages = append(messages, m)
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: block.Name,
|
||||
Arguments: block.Input,
|
||||
},
|
||||
})
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(block.Content),
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(msg.Content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
@@ -892,7 +878,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
Input: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -989,15 +975,6 @@ func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
@@ -1030,17 +1007,13 @@ func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
totalLen += countAnyContent(req.System)
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
totalLen += countAnyContent(msg.Content)
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
@@ -1063,12 +1036,25 @@ func countAnyContent(content any) int {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
case []ContentBlock:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
case []any:
|
||||
total := 0
|
||||
for _, item := range c {
|
||||
data, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var block ContentBlock
|
||||
if err := json.Unmarshal(data, &block); err == nil {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
@@ -1077,38 +1063,19 @@ func countAnyContent(content any) int {
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countContentBlock(block ContentBlock) int {
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
if block.Text != nil {
|
||||
total += len(*block.Text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
if block.Thinking != nil {
|
||||
total += len(*block.Thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
if block.Type == "tool_use" || block.Type == "tool_result" {
|
||||
if data, err := json.Marshal(block); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,16 @@ const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
// textContent is a convenience for constructing []ContentBlock with a single text block in tests.
|
||||
func textContent(s string) []ContentBlock {
|
||||
return []ContentBlock{{Type: "text", Text: &s}}
|
||||
}
|
||||
|
||||
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
|
||||
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
for i := 0; i < len(kvs)-1; i += 2 {
|
||||
args.Set(kvs[i].(string), kvs[i+1])
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
@@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
Content: []ContentBlock{
|
||||
{Type: "text", Text: ptr("What's in this image?")},
|
||||
{
|
||||
Type: "image",
|
||||
Source: &ImageSource{
|
||||
Type: "base64",
|
||||
MediaType: "image/png",
|
||||
Data: testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{Role: "user", Content: textContent("What's the weather in Paris?")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call_123",
|
||||
Name: "get_weather",
|
||||
Input: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "call_123",
|
||||
Content: "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
@@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
@@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
@@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
@@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "thinking",
|
||||
Thinking: ptr("Let me think about this..."),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
@@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -843,7 +848,7 @@ func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ask_user",
|
||||
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}),
|
||||
Arguments: makeArgs("question", "cats or dogs?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -965,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1140,7 +1145,7 @@ func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "user", Content: textContent("Hello, world!")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1161,7 +1166,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1177,7 +1182,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "user", Content: textContent("What's the weather?")},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
@@ -1200,17 +1205,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "thinking",
|
||||
Thinking: ptr("Let me think about this carefully..."),
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
{
|
||||
Type: "text",
|
||||
Text: ptr("Here is my response."),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1308,12 +1313,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: "srvtoolu_123",
|
||||
Name: "web_search",
|
||||
Input: makeArgs("query", "test query"),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1344,11 +1349,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_123",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
@@ -1385,11 +1390,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_empty",
|
||||
Content: []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1416,11 +1421,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_error",
|
||||
Content: map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
|
||||
@@ -283,7 +283,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
Input: queryArgs(query),
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
@@ -348,7 +348,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial
|
||||
Type: "server_tool_use",
|
||||
ID: maxLoopToolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": maxLoopQuery},
|
||||
Input: queryArgs(maxLoopQuery),
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
@@ -786,7 +786,7 @@ func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query strin
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
Input: queryArgs(query),
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
@@ -942,6 +942,13 @@ func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryArgs creates a ToolCallFunctionArguments with a single "query" key.
|
||||
func queryArgs(query string) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("query", query)
|
||||
return args
|
||||
}
|
||||
|
||||
// serverToolUseID derives a server tool use ID from a message ID
|
||||
func serverToolUseID(messageID string) string {
|
||||
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
||||
|
||||
@@ -1208,7 +1208,7 @@ func TestWebSearchStreamResponse(t *testing.T) {
|
||||
Type: "server_tool_use",
|
||||
ID: "srvtoolu_test123",
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": "test query"},
|
||||
Input: queryArgs("test query"),
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
@@ -1413,12 +1413,8 @@ func TestWebSearchSendError_NonStreaming(t *testing.T) {
|
||||
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
|
||||
}
|
||||
// Verify input contains the query
|
||||
inputMap, ok := result.Content[0].Input.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
|
||||
}
|
||||
if inputMap["query"] != "test query" {
|
||||
t.Errorf("expected query 'test query', got %v", inputMap["query"])
|
||||
if q, ok := result.Content[0].Input.Get("query"); !ok || q != "test query" {
|
||||
t.Errorf("expected query 'test query', got %v", q)
|
||||
}
|
||||
|
||||
// Block 1: web_search_tool_result with error
|
||||
@@ -1561,12 +1557,8 @@ func TestWebSearchSendError_EmptyQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the input has empty query
|
||||
inputMap, ok := result.Content[0].Input.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
|
||||
}
|
||||
if inputMap["query"] != "" {
|
||||
t.Errorf("expected empty query, got %v", inputMap["query"])
|
||||
if q, ok := result.Content[0].Input.Get("query"); !ok || q != "" {
|
||||
t.Errorf("expected empty query, got %v", q)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -328,8 +328,8 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
|
||||
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
Arguments api.ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||
@@ -345,13 +345,9 @@ func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCa
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
Arguments: parsed.Arguments,
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range parsed.Arguments {
|
||||
toolCall.Function.Arguments.Set(key, value)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user