mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 08:15:42 +02:00
Compare commits
1 Commits
pdevine/ml
...
parth-anth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40f56cf543 |
@@ -372,6 +372,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractBase64Image(blockMap map[string]any) (api.ImageData, error) {
|
||||||
|
source, ok := blockMap["source"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid image source")
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceType, _ := source["type"].(string)
|
||||||
|
if sourceType == "base64" {
|
||||||
|
data, _ := source["data"].(string)
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported", sourceType)
|
||||||
|
}
|
||||||
|
|
||||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
@@ -414,26 +432,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
|
|
||||||
case "image":
|
case "image":
|
||||||
imageBlocks++
|
imageBlocks++
|
||||||
source, ok := blockMap["source"].(map[string]any)
|
decoded, err := extractBase64Image(blockMap)
|
||||||
if !ok {
|
if err != nil {
|
||||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
logutil.Trace("anthropic: failed to extract image", "role", role, "error", err)
|
||||||
return nil, errors.New("invalid image source")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
images = append(images, decoded)
|
||||||
sourceType, _ := source["type"].(string)
|
|
||||||
if sourceType == "base64" {
|
|
||||||
data, _ := source["data"].(string)
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
|
||||||
if err != nil {
|
|
||||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
|
||||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
|
||||||
}
|
|
||||||
images = append(images, decoded)
|
|
||||||
} else {
|
|
||||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
|
||||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
|
||||||
}
|
|
||||||
// URL images would need to be fetched - skip for now
|
|
||||||
|
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
toolUseBlocks++
|
toolUseBlocks++
|
||||||
@@ -462,6 +466,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
toolResultBlocks++
|
toolResultBlocks++
|
||||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||||
var resultContent string
|
var resultContent string
|
||||||
|
var resultImages []api.ImageData
|
||||||
|
|
||||||
switch c := blockMap["content"].(type) {
|
switch c := blockMap["content"].(type) {
|
||||||
case string:
|
case string:
|
||||||
@@ -469,10 +474,18 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
case []any:
|
case []any:
|
||||||
for _, cb := range c {
|
for _, cb := range c {
|
||||||
if cbMap, ok := cb.(map[string]any); ok {
|
if cbMap, ok := cb.(map[string]any); ok {
|
||||||
if cbMap["type"] == "text" {
|
switch cbMap["type"] {
|
||||||
|
case "text":
|
||||||
if text, ok := cbMap["text"].(string); ok {
|
if text, ok := cbMap["text"].(string); ok {
|
||||||
resultContent += text
|
resultContent += text
|
||||||
}
|
}
|
||||||
|
case "image":
|
||||||
|
decoded, err := extractBase64Image(cbMap)
|
||||||
|
if err != nil {
|
||||||
|
logutil.Trace("anthropic: failed to extract image from tool_result", "role", role, "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resultImages = append(resultImages, decoded)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -481,6 +494,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
toolResults = append(toolResults, api.Message{
|
toolResults = append(toolResults, api.Message{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
Content: resultContent,
|
Content: resultContent,
|
||||||
|
Images: resultImages,
|
||||||
ToolCallID: toolUseID,
|
ToolCallID: toolUseID,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -266,6 +266,124 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolResultContainingImage(t *testing.T) {
|
||||||
|
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||||
|
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_456",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "Here is the screenshot:"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]any{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := result.Messages[0]
|
||||||
|
if msg.Role != "tool" {
|
||||||
|
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.ToolCallID != "call_456" {
|
||||||
|
t.Errorf("expected tool_call_id 'call_456', got %q", msg.ToolCallID)
|
||||||
|
}
|
||||||
|
if msg.Content != "Here is the screenshot:" {
|
||||||
|
t.Errorf("expected content 'Here is the screenshot:', got %q", msg.Content)
|
||||||
|
}
|
||||||
|
if len(msg.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image in tool result, got %d", len(msg.Images))
|
||||||
|
}
|
||||||
|
if string(msg.Images[0]) != string(imgData) {
|
||||||
|
t.Error("image data mismatch in tool result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolResultContainingMultipleImages(t *testing.T) {
|
||||||
|
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||||
|
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_789",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]any{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{"type": "text", "text": "First image above, second below:"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]any{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := result.Messages[0]
|
||||||
|
if msg.Role != "tool" {
|
||||||
|
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||||
|
}
|
||||||
|
if len(msg.Images) != 2 {
|
||||||
|
t.Fatalf("expected 2 images in tool result, got %d", len(msg.Images))
|
||||||
|
}
|
||||||
|
for i, img := range msg.Images {
|
||||||
|
if string(img) != string(imgData) {
|
||||||
|
t.Errorf("image %d data mismatch in tool result", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
|
|||||||
Reference in New Issue
Block a user