mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 05:54:09 +02:00
writeError in both OpenAI and Anthropic middleware writers would return a raw json.SyntaxError when the error payload wasn't valid JSON (e.g. "invalid character 'e' looking for beginning of value"). Fall back to using the raw bytes as the error message instead. Also use the actual HTTP status code rather than hardcoding 500, so error types map correctly
949 lines
27 KiB
Go
949 lines
27 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/anthropic"
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/envconfig"
|
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
|
"github.com/ollama/ollama/internal/modelref"
|
|
"github.com/ollama/ollama/logutil"
|
|
)
|
|
|
|
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
|
type AnthropicWriter struct {
|
|
BaseWriter
|
|
stream bool
|
|
id string
|
|
converter *anthropic.StreamConverter
|
|
}
|
|
|
|
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
|
var errData struct {
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(data, &errData); err != nil {
|
|
// If the error response isn't valid JSON, use the raw bytes as the
|
|
// error message rather than surfacing a confusing JSON parse error.
|
|
errData.Error = string(data)
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error)); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
|
return writeSSE(w.ResponseWriter, eventType, data)
|
|
}
|
|
|
|
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
|
var chatResponse api.ChatResponse
|
|
err := json.Unmarshal(data, &chatResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if w.stream {
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
events := w.converter.Process(chatResponse)
|
|
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
|
|
for _, event := range events {
|
|
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
return len(data), nil
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
|
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
|
|
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
|
}
|
|
|
|
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
|
|
// executes the search, re-invokes the model with results, and assembles the
|
|
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
|
|
type WebSearchAnthropicWriter struct {
|
|
BaseWriter
|
|
newLoopContext func() (context.Context, context.CancelFunc)
|
|
inner *AnthropicWriter
|
|
req anthropic.MessagesRequest // original Anthropic request
|
|
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
|
|
stream bool
|
|
|
|
estimatedInputTokens int
|
|
|
|
terminalSent bool
|
|
|
|
observedPromptEvalCount int
|
|
observedEvalCount int
|
|
|
|
loopInFlight bool
|
|
loopBaseInputTok int
|
|
loopBaseOutputTok int
|
|
loopResultCh chan webSearchLoopResult
|
|
|
|
streamMessageStarted bool
|
|
streamHasOpenBlock bool
|
|
streamOpenBlockIndex int
|
|
streamNextIndex int
|
|
}
|
|
|
|
const maxWebSearchLoops = 3
|
|
|
|
type webSearchLoopResult struct {
|
|
response anthropic.MessagesResponse
|
|
loopErr *webSearchLoopError
|
|
}
|
|
|
|
type webSearchLoopError struct {
|
|
code string
|
|
query string
|
|
usage anthropic.Usage
|
|
err error
|
|
}
|
|
|
|
func (e *webSearchLoopError) Error() string {
|
|
if e.err == nil {
|
|
return e.code
|
|
}
|
|
return fmt.Sprintf("%s: %v", e.code, e.err)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
|
|
if w.terminalSent {
|
|
return len(data), nil
|
|
}
|
|
|
|
code := w.Status()
|
|
if code != http.StatusOK {
|
|
return w.inner.writeError(data)
|
|
}
|
|
|
|
var chatResponse api.ChatResponse
|
|
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
|
return 0, err
|
|
}
|
|
w.recordObservedUsage(chatResponse.Metrics)
|
|
|
|
if w.stream && w.loopInFlight {
|
|
if !chatResponse.Done {
|
|
return len(data), nil
|
|
}
|
|
if err := w.writeLoopResult(); err != nil {
|
|
return len(data), err
|
|
}
|
|
return len(data), nil
|
|
}
|
|
|
|
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
|
|
logutil.Trace("anthropic middleware: upstream chunk",
|
|
"resp", anthropic.TraceChatResponse(chatResponse),
|
|
"web_search", hasWebSearch,
|
|
"other_tools", hasOtherTools,
|
|
)
|
|
if hasWebSearch && hasOtherTools {
|
|
// Prefer web_search if both server and client tools are present in one chunk.
|
|
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
|
|
}
|
|
|
|
if !hasWebSearch {
|
|
if w.stream {
|
|
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(data), nil
|
|
}
|
|
return w.inner.writeResponse(data)
|
|
}
|
|
|
|
if w.stream {
|
|
// Let the original generation continue to completion while web search runs in parallel.
|
|
logutil.Trace("anthropic middleware: starting async web_search loop",
|
|
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
|
"resp", anthropic.TraceChatResponse(chatResponse),
|
|
)
|
|
w.startLoopWorker(chatResponse, webSearchCall)
|
|
if chatResponse.Done {
|
|
if err := w.writeLoopResult(); err != nil {
|
|
return len(data), err
|
|
}
|
|
}
|
|
return len(data), nil
|
|
}
|
|
|
|
loopCtx, cancel := w.startLoopContext()
|
|
defer cancel()
|
|
|
|
initialUsage := anthropic.Usage{
|
|
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
|
|
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
|
|
}
|
|
logutil.Trace("anthropic middleware: starting sync web_search loop",
|
|
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
|
"resp", anthropic.TraceChatResponse(chatResponse),
|
|
"usage", initialUsage,
|
|
)
|
|
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
|
|
if loopErr != nil {
|
|
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
|
|
}
|
|
|
|
if err := w.writeTerminalResponse(response); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
|
|
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
|
|
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
|
|
|
|
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
|
|
usage := initialUsage
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
|
|
"model", w.req.Model,
|
|
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
|
"messages", len(followUpMessages),
|
|
"tools", len(followUpTools),
|
|
"max_loops", maxWebSearchLoops,
|
|
)
|
|
|
|
currentResponse := initialResponse
|
|
currentToolCall := initialToolCall
|
|
|
|
var serverContent []anthropic.ContentBlock
|
|
|
|
for loop := 1; loop <= maxWebSearchLoops; loop++ {
|
|
query := extractQueryFromToolCall(¤tToolCall)
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
|
|
"loop", loop,
|
|
"query", anthropic.TraceTruncateString(query),
|
|
"messages", len(followUpMessages),
|
|
)
|
|
if query == "" {
|
|
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
|
code: "invalid_request",
|
|
query: "",
|
|
usage: usage,
|
|
}
|
|
}
|
|
|
|
const defaultMaxResults = 5
|
|
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
|
|
if err != nil {
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
|
|
"loop", loop,
|
|
"query", query,
|
|
"error", err,
|
|
)
|
|
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
|
code: "unavailable",
|
|
query: query,
|
|
usage: usage,
|
|
err: err,
|
|
}
|
|
}
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
|
|
"loop", loop,
|
|
"results", len(searchResp.Results),
|
|
)
|
|
|
|
toolUseID := loopServerToolUseID(w.inner.id, loop)
|
|
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
|
|
serverContent = append(serverContent,
|
|
anthropic.ContentBlock{
|
|
Type: "server_tool_use",
|
|
ID: toolUseID,
|
|
Name: "web_search",
|
|
Input: map[string]any{"query": query},
|
|
},
|
|
anthropic.ContentBlock{
|
|
Type: "web_search_tool_result",
|
|
ToolUseID: toolUseID,
|
|
Content: searchResults,
|
|
},
|
|
)
|
|
|
|
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
|
|
toolResultMsg := api.Message{
|
|
Role: "tool",
|
|
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
|
|
ToolCallID: currentToolCall.ID,
|
|
}
|
|
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
|
|
|
|
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
|
|
if err != nil {
|
|
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
|
|
"loop", loop,
|
|
"query", query,
|
|
"error", err,
|
|
)
|
|
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
|
code: "api_error",
|
|
query: query,
|
|
usage: usage,
|
|
err: err,
|
|
}
|
|
}
|
|
logutil.TraceContext(ctx, "anthropic middleware: followup response",
|
|
"loop", loop,
|
|
"resp", anthropic.TraceChatResponse(followUpResponse),
|
|
)
|
|
|
|
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
|
|
usage.OutputTokens += followUpResponse.Metrics.EvalCount
|
|
|
|
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
|
|
if hasWebSearch && hasOtherTools {
|
|
// Prefer web_search if both server and client tools are present in one chunk.
|
|
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
|
|
}
|
|
|
|
if !hasWebSearch {
|
|
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
|
|
"loop", loop,
|
|
"resp", anthropic.TraceMessagesResponse(finalResponse),
|
|
)
|
|
return finalResponse, nil
|
|
}
|
|
|
|
currentResponse = followUpResponse
|
|
currentToolCall = nextToolCall
|
|
}
|
|
|
|
maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
|
|
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
|
|
serverContent = append(serverContent,
|
|
anthropic.ContentBlock{
|
|
Type: "server_tool_use",
|
|
ID: maxLoopToolUseID,
|
|
Name: "web_search",
|
|
Input: map[string]any{"query": maxLoopQuery},
|
|
},
|
|
anthropic.ContentBlock{
|
|
Type: "web_search_tool_result",
|
|
ToolUseID: maxLoopToolUseID,
|
|
Content: anthropic.WebSearchToolResultError{
|
|
Type: "web_search_tool_result_error",
|
|
ErrorCode: "max_uses_exceeded",
|
|
},
|
|
},
|
|
)
|
|
|
|
maxResponse := anthropic.MessagesResponse{
|
|
ID: w.inner.id,
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Model: w.req.Model,
|
|
Content: serverContent,
|
|
StopReason: "end_turn",
|
|
Usage: usage,
|
|
}
|
|
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
|
|
"resp", anthropic.TraceMessagesResponse(maxResponse),
|
|
)
|
|
return maxResponse, nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
|
|
if w.loopInFlight {
|
|
return
|
|
}
|
|
|
|
initialUsage := anthropic.Usage{
|
|
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
|
|
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
|
|
}
|
|
w.loopBaseInputTok = initialUsage.InputTokens
|
|
w.loopBaseOutputTok = initialUsage.OutputTokens
|
|
w.loopResultCh = make(chan webSearchLoopResult, 1)
|
|
w.loopInFlight = true
|
|
logutil.Trace("anthropic middleware: loop worker started",
|
|
"usage", initialUsage,
|
|
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
|
)
|
|
|
|
go func() {
|
|
ctx, cancel := w.startLoopContext()
|
|
defer cancel()
|
|
|
|
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
|
|
w.loopResultCh <- webSearchLoopResult{
|
|
response: response,
|
|
loopErr: loopErr,
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
|
|
if w.loopResultCh == nil {
|
|
return w.sendError("api_error", "", w.currentObservedUsage())
|
|
}
|
|
|
|
result := <-w.loopResultCh
|
|
w.loopResultCh = nil
|
|
w.loopInFlight = false
|
|
if result.loopErr != nil {
|
|
logutil.Trace("anthropic middleware: loop worker returned error",
|
|
"code", result.loopErr.code,
|
|
"query", result.loopErr.query,
|
|
"usage", result.loopErr.usage,
|
|
"error", result.loopErr.err,
|
|
)
|
|
usage := result.loopErr.usage
|
|
w.applyObservedUsageDeltaToUsage(&usage)
|
|
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
|
|
}
|
|
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
|
|
|
|
w.applyObservedUsageDelta(&result.response)
|
|
return w.writeTerminalResponse(result.response)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
|
|
w.applyObservedUsageDeltaToUsage(&response.Usage)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
|
|
if metrics.PromptEvalCount > w.observedPromptEvalCount {
|
|
w.observedPromptEvalCount = metrics.PromptEvalCount
|
|
}
|
|
if metrics.EvalCount > w.observedEvalCount {
|
|
w.observedEvalCount = metrics.EvalCount
|
|
}
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
|
|
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
|
|
usage.InputTokens += deltaIn
|
|
}
|
|
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
|
|
usage.OutputTokens += deltaOut
|
|
}
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
|
|
return anthropic.Usage{
|
|
InputTokens: w.observedPromptEvalCount,
|
|
OutputTokens: w.observedEvalCount,
|
|
}
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
|
|
if w.newLoopContext != nil {
|
|
return w.newLoopContext()
|
|
}
|
|
return context.WithTimeout(context.Background(), 5*time.Minute)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
|
|
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
|
|
|
|
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
|
|
content = append(content, serverContent...)
|
|
content = append(content, converted.Content...)
|
|
|
|
return anthropic.MessagesResponse{
|
|
ID: w.inner.id,
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Model: w.req.Model,
|
|
Content: content,
|
|
StopReason: converted.StopReason,
|
|
StopSequence: converted.StopSequence,
|
|
Usage: usage,
|
|
}
|
|
}
|
|
|
|
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
|
|
assistantMsg := api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{webSearchCall},
|
|
}
|
|
if response.Message.Content != "" {
|
|
assistantMsg.Content = response.Message.Content
|
|
}
|
|
if response.Message.Thinking != "" {
|
|
assistantMsg.Thinking = response.Message.Thinking
|
|
}
|
|
return assistantMsg
|
|
}
|
|
|
|
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
|
|
var resultText strings.Builder
|
|
for _, r := range results {
|
|
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
|
|
if r.Content != "" {
|
|
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
|
|
}
|
|
resultText.WriteString("\n")
|
|
}
|
|
return resultText.String()
|
|
}
|
|
|
|
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
|
|
var webSearchCall api.ToolCall
|
|
hasWebSearch := false
|
|
hasOtherTools := false
|
|
|
|
for _, toolCall := range toolCalls {
|
|
if toolCall.Function.Name == "web_search" {
|
|
if !hasWebSearch {
|
|
webSearchCall = toolCall
|
|
hasWebSearch = true
|
|
}
|
|
continue
|
|
}
|
|
hasOtherTools = true
|
|
}
|
|
|
|
return webSearchCall, hasWebSearch, hasOtherTools
|
|
}
|
|
|
|
func loopServerToolUseID(messageID string, loop int) string {
|
|
base := serverToolUseID(messageID)
|
|
if loop <= 1 {
|
|
return base
|
|
}
|
|
return fmt.Sprintf("%s_%d", base, loop)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
|
|
streaming := false
|
|
followUp := api.ChatRequest{
|
|
Model: w.chatReq.Model,
|
|
Messages: messages,
|
|
Stream: &streaming,
|
|
Tools: tools,
|
|
Options: w.chatReq.Options,
|
|
}
|
|
|
|
body, err := json.Marshal(followUp)
|
|
if err != nil {
|
|
return api.ChatResponse{}, err
|
|
}
|
|
|
|
chatURL := envconfig.Host().String() + "/api/chat"
|
|
logutil.TraceContext(ctx, "anthropic middleware: followup request",
|
|
"url", chatURL,
|
|
"req", anthropic.TraceChatRequest(&followUp),
|
|
)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
return api.ChatResponse{}, err
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(httpReq)
|
|
if err != nil {
|
|
return api.ChatResponse{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
|
|
"status", resp.StatusCode,
|
|
"response", strings.TrimSpace(string(respBody)),
|
|
)
|
|
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
|
}
|
|
|
|
var chatResp api.ChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return api.ChatResponse{}, err
|
|
}
|
|
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
|
|
|
|
return chatResp, nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
|
|
events := w.inner.converter.Process(chatResponse)
|
|
for _, event := range events {
|
|
switch e := event.Data.(type) {
|
|
case anthropic.MessageStartEvent:
|
|
w.streamMessageStarted = true
|
|
case anthropic.ContentBlockStartEvent:
|
|
w.streamHasOpenBlock = true
|
|
w.streamOpenBlockIndex = e.Index
|
|
if e.Index+1 > w.streamNextIndex {
|
|
w.streamNextIndex = e.Index + 1
|
|
}
|
|
case anthropic.ContentBlockStopEvent:
|
|
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
|
|
w.streamHasOpenBlock = false
|
|
}
|
|
if e.Index+1 > w.streamNextIndex {
|
|
w.streamNextIndex = e.Index + 1
|
|
}
|
|
case anthropic.MessageStopEvent:
|
|
w.terminalSent = true
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
|
|
if w.streamMessageStarted {
|
|
return nil
|
|
}
|
|
|
|
inputTokens := usage.InputTokens
|
|
if inputTokens == 0 {
|
|
inputTokens = w.estimatedInputTokens
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
|
|
Type: "message_start",
|
|
Message: anthropic.MessagesResponse{
|
|
ID: w.inner.id,
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Model: w.req.Model,
|
|
Content: []anthropic.ContentBlock{},
|
|
Usage: anthropic.Usage{
|
|
InputTokens: inputTokens,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
w.streamMessageStarted = true
|
|
return nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
|
|
if !w.streamHasOpenBlock {
|
|
return nil
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
|
Type: "content_block_stop",
|
|
Index: w.streamOpenBlockIndex,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
|
|
w.streamNextIndex = w.streamOpenBlockIndex + 1
|
|
}
|
|
w.streamHasOpenBlock = false
|
|
return nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
|
|
for _, block := range content {
|
|
index := w.streamNextIndex
|
|
if block.Type == "text" {
|
|
emptyText := ""
|
|
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
|
Type: "content_block_start",
|
|
Index: index,
|
|
ContentBlock: anthropic.ContentBlock{
|
|
Type: "text",
|
|
Text: &emptyText,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
text := ""
|
|
if block.Text != nil {
|
|
text = *block.Text
|
|
}
|
|
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
|
|
Type: "content_block_delta",
|
|
Index: index,
|
|
Delta: anthropic.Delta{
|
|
Type: "text_delta",
|
|
Text: text,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
|
Type: "content_block_start",
|
|
Index: index,
|
|
ContentBlock: block,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
|
Type: "content_block_stop",
|
|
Index: index,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
w.streamNextIndex++
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
|
|
if w.terminalSent {
|
|
return nil
|
|
}
|
|
|
|
if !w.stream {
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
|
|
return err
|
|
}
|
|
w.terminalSent = true
|
|
return nil
|
|
}
|
|
|
|
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
|
|
return err
|
|
}
|
|
if err := w.closeOpenStreamBlock(); err != nil {
|
|
return err
|
|
}
|
|
if err := w.writeStreamContentBlocks(response.Content); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
|
|
Type: "message_delta",
|
|
Delta: anthropic.MessageDelta{
|
|
StopReason: response.StopReason,
|
|
},
|
|
Usage: anthropic.DeltaUsage{
|
|
InputTokens: response.Usage.InputTokens,
|
|
OutputTokens: response.Usage.OutputTokens,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
|
|
Type: "message_stop",
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
w.terminalSent = true
|
|
return nil
|
|
}
|
|
|
|
// streamResponse emits a complete MessagesResponse as SSE events.
|
|
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
|
|
return w.writeTerminalResponse(response)
|
|
}
|
|
|
|
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
|
|
toolUseID := serverToolUseID(w.inner.id)
|
|
|
|
return anthropic.MessagesResponse{
|
|
ID: w.inner.id,
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Model: w.req.Model,
|
|
Content: []anthropic.ContentBlock{
|
|
{
|
|
Type: "server_tool_use",
|
|
ID: toolUseID,
|
|
Name: "web_search",
|
|
Input: map[string]any{"query": query},
|
|
},
|
|
{
|
|
Type: "web_search_tool_result",
|
|
ToolUseID: toolUseID,
|
|
Content: anthropic.WebSearchToolResultError{
|
|
Type: "web_search_tool_result_error",
|
|
ErrorCode: errorCode,
|
|
},
|
|
},
|
|
},
|
|
StopReason: "end_turn",
|
|
Usage: usage,
|
|
}
|
|
}
|
|
|
|
// sendError sends a web search error response.
|
|
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
|
|
response := w.webSearchErrorResponse(errorCode, query, usage)
|
|
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
|
|
return w.writeTerminalResponse(response)
|
|
}
|
|
|
|
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
|
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
requestCtx := c.Request.Context()
|
|
|
|
var req anthropic.MessagesRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if req.Model == "" {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
|
return
|
|
}
|
|
|
|
if req.MaxTokens <= 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
|
return
|
|
}
|
|
|
|
if len(req.Messages) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
|
return
|
|
}
|
|
|
|
chatReq, err := anthropic.FromMessagesRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
|
c.Set("relax_thinking", true)
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
messageID := anthropic.GenerateMessageID()
|
|
|
|
// Estimate input tokens for streaming (actual count not available until generation completes)
|
|
estimatedTokens := anthropic.EstimateInputTokens(req)
|
|
|
|
innerWriter := &AnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: messageID,
|
|
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
|
}
|
|
|
|
if req.Stream {
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
}
|
|
|
|
if hasWebSearchTool(req.Tools) {
|
|
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
|
|
// for cloud models. Local models may still receive web_search tool definitions;
|
|
// execution is validated when the model actually emits a web_search tool call.
|
|
if isCloudModelName(req.Model) {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
|
|
return
|
|
}
|
|
}
|
|
|
|
c.Writer = &WebSearchAnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
newLoopContext: func() (context.Context, context.CancelFunc) {
|
|
return context.WithTimeout(requestCtx, 5*time.Minute)
|
|
},
|
|
inner: innerWriter,
|
|
req: req,
|
|
chatReq: chatReq,
|
|
stream: req.Stream,
|
|
estimatedInputTokens: estimatedTokens,
|
|
}
|
|
} else {
|
|
c.Writer = innerWriter
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// hasWebSearchTool checks if the request tools include a web_search tool
|
|
func hasWebSearchTool(tools []anthropic.Tool) bool {
|
|
for _, tool := range tools {
|
|
if strings.HasPrefix(tool.Type, "web_search") {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isCloudModelName(name string) bool {
|
|
return modelref.HasExplicitCloudSource(name)
|
|
}
|
|
|
|
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
|
func extractQueryFromToolCall(tc *api.ToolCall) string {
|
|
q, ok := tc.Function.Arguments.Get("query")
|
|
if !ok {
|
|
return ""
|
|
}
|
|
if s, ok := q.(string); ok {
|
|
return s
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// writeSSE writes a Server-Sent Event
|
|
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
|
d, err := json.Marshal(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
|
|
return err
|
|
}
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// serverToolUseID derives a server tool use ID from a message ID
|
|
func serverToolUseID(messageID string) string {
|
|
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
|
}
|