diff --git a/openai/responses.go b/openai/responses.go index f6f202b9a..4420eec53 100644 --- a/openai/responses.go +++ b/openai/responses.go @@ -4,13 +4,15 @@ import ( "encoding/json" "fmt" "math/rand" + "strings" "time" "github.com/ollama/ollama/api" ) // ResponsesContent is a discriminated union for input content types. -// Concrete types: ResponsesTextContent, ResponsesImageContent +// Concrete types: ResponsesTextContent, ResponsesImageContent, +// ResponsesOutputTextContent, ResponsesFileContent. type ResponsesContent interface { responsesContent() // unexported marker method } @@ -41,6 +43,16 @@ type ResponsesOutputTextContent struct { func (ResponsesOutputTextContent) responsesContent() {} +type ResponsesFileContent struct { + Type string `json:"type"` // always "input_file" + FileData string `json:"file_data,omitempty"` + FileID string `json:"file_id,omitempty"` + FileURL string `json:"file_url,omitempty"` + Filename string `json:"filename,omitempty"` +} + +func (ResponsesFileContent) responsesContent() {} + type ResponsesInputMessage struct { Type string `json:"type"` // always "message" Role string `json:"role"` // one of `user`, `system`, `developer` @@ -82,41 +94,55 @@ func (m *ResponsesInputMessage) UnmarshalJSON(data []byte) error { m.Content = make([]ResponsesContent, 0, len(rawItems)) for i, raw := range rawItems { - // Peek at the type field to determine which concrete type to use - var typeField struct { - Type string `json:"type"` - } - if err := json.Unmarshal(raw, &typeField); err != nil { + content, err := unmarshalResponsesContent(raw) + if err != nil { return fmt.Errorf("content[%d]: %w", i, err) } - - switch typeField.Type { - case "input_text": - var content ResponsesTextContent - if err := json.Unmarshal(raw, &content); err != nil { - return fmt.Errorf("content[%d]: %w", i, err) - } - m.Content = append(m.Content, content) - case "input_image": - var content ResponsesImageContent - if err := json.Unmarshal(raw, &content); err != nil { - return fmt.Errorf("content[%d]: %w", i, err) - } - m.Content = append(m.Content, content) - case "output_text": - var content ResponsesOutputTextContent - if err := json.Unmarshal(raw, &content); err != nil { - return fmt.Errorf("content[%d]: %w", i, err) - } - m.Content = append(m.Content, content) - default: - return fmt.Errorf("content[%d]: unknown content type: %s", i, typeField.Type) - } + m.Content = append(m.Content, content) } return nil } +func unmarshalResponsesContent(data []byte) (ResponsesContent, error) { + // Peek at the type field to determine which concrete type to use + var typeField struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &typeField); err != nil { + return nil, err + } + + switch typeField.Type { + case "input_text": + var content ResponsesTextContent + if err := json.Unmarshal(data, &content); err != nil { + return nil, err + } + return content, nil + case "input_image": + var content ResponsesImageContent + if err := json.Unmarshal(data, &content); err != nil { + return nil, err + } + return content, nil + case "output_text": + var content ResponsesOutputTextContent + if err := json.Unmarshal(data, &content); err != nil { + return nil, err + } + return content, nil + case "input_file": + var content ResponsesFileContent + if err := json.Unmarshal(data, &content); err != nil { + return nil, err + } + return content, nil + default: + return nil, fmt.Errorf("unknown content type: %s", typeField.Type) + } +} + type ResponsesOutputMessage struct{} // ResponsesInputItem is a discriminated union for input items. @@ -143,6 +169,60 @@ type ResponsesFunctionCallOutput struct { Type string `json:"type"` // always "function_call_output" CallID string `json:"call_id"` // links to the original function call Output string `json:"output"` // the function result + + // OutputItems is populated when output is provided as Responses content + // items instead of the string shorthand. + OutputItems []ResponsesContent `json:"-"` +} + +func (o *ResponsesFunctionCallOutput) UnmarshalJSON(data []byte) error { + var aux struct { + Type string `json:"type"` + CallID string `json:"call_id"` + Output json.RawMessage `json:"output"` + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + o.Type = aux.Type + o.CallID = aux.CallID + o.Output = "" + o.OutputItems = nil + + if len(aux.Output) == 0 { + return nil + } + + var output string + if err := json.Unmarshal(aux.Output, &output); err == nil { + o.Output = output + return nil + } + + var rawItems []json.RawMessage + if err := json.Unmarshal(aux.Output, &rawItems); err != nil { + return fmt.Errorf("output must be a string or array: %w", err) + } + + o.OutputItems = make([]ResponsesContent, 0, len(rawItems)) + var outputText strings.Builder + for i, raw := range rawItems { + content, err := unmarshalResponsesContent(raw) + if err != nil { + return fmt.Errorf("output[%d]: %w", i, err) + } + o.OutputItems = append(o.OutputItems, content) + + switch v := content.(type) { + case ResponsesTextContent: + outputText.WriteString(v.Text) + case ResponsesOutputTextContent: + outputText.WriteString(v.Text) + } + } + o.Output = outputText.String() + return nil } func (ResponsesFunctionCallOutput) responsesInputItem() {} @@ -394,9 +474,19 @@ func FromResponsesRequest(r ResponsesRequest) (*api.ChatRequest, error) { messages = append(messages, msg) } case ResponsesFunctionCallOutput: + content := v.Output + var images []api.ImageData + if len(v.OutputItems) > 0 { + var err error + content, images, err = convertResponsesContent(v.OutputItems) + if err != nil { + return nil, err + } + } messages = append(messages, api.Message{ Role: "tool", - Content: v.Output, + Content: content, + Images: images, ToolCallID: v.CallID, }) } @@ -492,10 +582,23 @@ func convertTool(t ResponsesTool) (api.Tool, error) { } func convertInputMessage(m ResponsesInputMessage) (api.Message, error) { + content, images, err := convertResponsesContent(m.Content) + if err != nil { + return api.Message{}, err + } + + return api.Message{ + Role: m.Role, + Content: content, + Images: images, + }, nil +} + +func convertResponsesContent(contents []ResponsesContent) (string, []api.ImageData, error) { var content string var images []api.ImageData - for _, c := range m.Content { + for _, c := range contents { switch v := c.(type) { case ResponsesTextContent: content += v.Text @@ -507,17 +610,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) { } img, err := decodeImageURL(v.ImageURL) if err != nil { - return api.Message{}, err + return "", nil, err } images = append(images, img) + case ResponsesFileContent: + // TODO(drifkin): support inlining text-only file_data when it is safe + // to decode and of a reasonable size + return "", nil, fmt.Errorf("file inputs are not currently supported") } } - return api.Message{ - Role: m.Role, - Content: content, - Images: images, - }, nil + return content, images, nil } // Response types for the Responses API diff --git a/openai/responses_test.go b/openai/responses_test.go index 743821b29..3d08f2aa5 100644 --- a/openai/responses_test.go +++ b/openai/responses_test.go @@ -226,6 +226,28 @@ func TestUnmarshalResponsesInputItem(t *testing.T) { } }) + t.Run("function_call_output item with content array", func(t *testing.T) { + got, err := unmarshalResponsesInputItem([]byte(`{"type": "function_call_output", "call_id": "call_abc123", "output": [{"type": "input_text", "text": "the result"}]}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + output, ok := got.(ResponsesFunctionCallOutput) + if !ok { + t.Fatalf("got type %T, want ResponsesFunctionCallOutput", got) + } + + if output.Type != "function_call_output" { + t.Errorf("Type = %q, want %q", output.Type, "function_call_output") + } + if output.CallID != "call_abc123" { + t.Errorf("CallID = %q, want %q", output.CallID, "call_abc123") + } + if output.Output != "the result" { + t.Errorf("Output = %q, want %q", output.Output, "the result") + } + }) + t.Run("unknown item type", func(t *testing.T) { _, err := unmarshalResponsesInputItem([]byte(`{"type": "unknown_type"}`)) if err == nil { @@ -456,6 +478,90 @@ func TestFromResponsesRequest_FunctionCallOutput(t *testing.T) { } } +func TestFromResponsesRequest_FunctionCallOutputContentArray(t *testing.T) { + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what is the weather?"}]}, + {"type": "function_call", "call_id": "call_abc123", "name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}, + {"type": "function_call_output", "call_id": "call_abc123", "output": [{"type": "input_text", "text": "sunny"}, {"type": "input_text", "text": ", 72F"}]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + toolMsg := chatReq.Messages[2] + if toolMsg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", toolMsg.Role) + } + if toolMsg.Content != "sunny, 72F" { + t.Errorf("expected content 'sunny, 72F', got %q", toolMsg.Content) + } + if toolMsg.ToolCallID != "call_abc123" { + t.Errorf("expected ToolCallID 'call_abc123', got %q", toolMsg.ToolCallID) + } +} + +func TestFromResponsesRequest_FunctionCallOutputContentArrayWithImage(t *testing.T) { + // 1x1 red PNG pixel + pngBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + reqJSON := `{ + "model": "gpt-oss:20b", + "input": [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "inspect the image"}]}, + {"type": "function_call", "call_id": "call_abc123", "name": "inspect_image", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "call_abc123", "output": [ + {"type": "input_text", "text": "attached image"}, + {"type": "input_image", "detail": "auto", "image_url": "data:image/png;base64,` + pngBase64 + `"} + ]} + ] + }` + + var req ResponsesRequest + if err := json.Unmarshal([]byte(reqJSON), &req); err != nil { + t.Fatalf("failed to unmarshal request: %v", err) + } + + chatReq, err := FromResponsesRequest(req) + if err != nil { + t.Fatalf("failed to convert request: %v", err) + } + + if len(chatReq.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(chatReq.Messages)) + } + + toolMsg := chatReq.Messages[2] + if toolMsg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", toolMsg.Role) + } + if toolMsg.Content != "attached image" { + t.Errorf("expected content 'attached image', got %q", toolMsg.Content) + } + if len(toolMsg.Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(toolMsg.Images)) + } + if len(toolMsg.Images[0]) == 0 { + t.Error("expected non-empty image data") + } + if toolMsg.ToolCallID != "call_abc123" { + t.Errorf("expected ToolCallID 'call_abc123', got %q", toolMsg.ToolCallID) + } +} + func TestFromResponsesRequest_FunctionCallMerge(t *testing.T) { t.Run("function call merges with preceding assistant message", func(t *testing.T) { // When assistant message has content followed by function_call,