mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
cloud_proxy: for the web_search legacy path, flush on newlines (#14897)
`WebSearchAnthropicWriter` expects a single object per write. The new transparent proxy will instead send it whatever bytes it sees. This cloud-model + local-orchestration + cloud-search is a temporary code path, so instead of making the web search code more robust to this, I put an adapter in the middle that will flush line-by-line to preserve the old behavior.
This commit is contained in:
@@ -226,7 +226,24 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
|
|||||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||||
c.Status(resp.StatusCode)
|
c.Status(resp.StatusCode)
|
||||||
|
|
||||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
var bodyWriter http.ResponseWriter = c.Writer
|
||||||
|
var framedWriter *jsonlFramingResponseWriter
|
||||||
|
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
|
||||||
|
// web_search fallback (which is a path we're removing soon). Local
|
||||||
|
// /v1/messages writes one JSON value per streamResponse callback directly
|
||||||
|
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
|
||||||
|
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
|
||||||
|
// unmarshals one JSON value per Write.
|
||||||
|
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
|
||||||
|
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
|
||||||
|
bodyWriter = framedWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyProxyResponseBody(bodyWriter, resp.Body)
|
||||||
|
if err == nil && framedWriter != nil {
|
||||||
|
err = framedWriter.FlushPending()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
ctxErr := c.Request.Context().Err()
|
ctxErr := c.Request.Context().Err()
|
||||||
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
||||||
slog.Debug(
|
slog.Debug(
|
||||||
@@ -240,6 +257,7 @@ func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disable
|
|||||||
slog.Warn(
|
slog.Warn(
|
||||||
"cloud proxy response copy failed",
|
"cloud proxy response copy failed",
|
||||||
"path", c.Request.URL.Path,
|
"path", c.Request.URL.Path,
|
||||||
|
"upstream_path", path,
|
||||||
"status", resp.StatusCode,
|
"status", resp.StatusCode,
|
||||||
"request_context_canceled", ctxErr != nil,
|
"request_context_canceled", ctxErr != nil,
|
||||||
"request_context_err", ctxErr,
|
"request_context_err", ctxErr,
|
||||||
@@ -473,6 +491,55 @@ func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type jsonlFramingResponseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
pending []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *jsonlFramingResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
w.pending = append(w.pending, p...)
|
||||||
|
if err := w.flushCompleteLines(); err != nil {
|
||||||
|
return len(p), err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *jsonlFramingResponseWriter) FlushPending() error {
|
||||||
|
trailing := bytes.TrimSpace(w.pending)
|
||||||
|
w.pending = nil
|
||||||
|
if len(trailing) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := w.ResponseWriter.Write(trailing)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
|
||||||
|
for {
|
||||||
|
newline := bytes.IndexByte(w.pending, '\n')
|
||||||
|
if newline < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line := bytes.TrimSpace(w.pending[:newline])
|
||||||
|
w.pending = w.pending[newline+1:]
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := w.ResponseWriter.Write(line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isHopByHopHeader(name string) bool {
|
func isHopByHopHeader(name string) bool {
|
||||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||||
return ok
|
return ok
|
||||||
|
|||||||
@@ -248,3 +248,71 @@ func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
|||||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) {
|
||||||
|
rec := &chunkRecorder{header: http.Header{}}
|
||||||
|
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||||
|
|
||||||
|
payload := []byte("{\"a\":1}\n{\"b\":2}\n")
|
||||||
|
if n, err := w.Write(payload); err != nil {
|
||||||
|
t.Fatalf("write failed: %v", err)
|
||||||
|
} else if n != len(payload) {
|
||||||
|
t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.FlushPending(); err != nil {
|
||||||
|
t.Fatalf("FlushPending failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rec.chunks) != 2 {
|
||||||
|
t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks))
|
||||||
|
}
|
||||||
|
if got := string(rec.chunks[0]); got != `{"a":1}` {
|
||||||
|
t.Fatalf("first chunk mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
if got := string(rec.chunks[1]); got != `{"b":2}` {
|
||||||
|
t.Fatalf("second chunk mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) {
|
||||||
|
rec := &chunkRecorder{header: http.Header{}}
|
||||||
|
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||||
|
|
||||||
|
if _, err := w.Write([]byte("{\"a\":1")); err != nil {
|
||||||
|
t.Fatalf("write failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(rec.chunks) != 0 {
|
||||||
|
t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.FlushPending(); err != nil {
|
||||||
|
t.Fatalf("FlushPending failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(rec.chunks) != 1 {
|
||||||
|
t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks))
|
||||||
|
}
|
||||||
|
if got := string(rec.chunks[0]); got != `{"a":1` {
|
||||||
|
t.Fatalf("trailing chunk mismatch: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkRecorder struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
chunks [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *chunkRecorder) Header() http.Header {
|
||||||
|
return r.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *chunkRecorder) WriteHeader(statusCode int) {
|
||||||
|
r.status = statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *chunkRecorder) Write(p []byte) (int, error) {
|
||||||
|
cp := append([]byte(nil), p...)
|
||||||
|
r.chunks = append(r.chunks, cp)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -652,6 +652,67 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("v1 messages web_search fallback frames coalesced jsonl chunks", func(t *testing.T) {
|
||||||
|
type upstreamCapture struct {
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
capture := &upstreamCapture{}
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
capture.path = r.URL.Path
|
||||||
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
combined := strings.Join([]string{
|
||||||
|
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hel"},"done":false}`,
|
||||||
|
`{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"lo"},"done":true}`,
|
||||||
|
}, "\n") + "\n"
|
||||||
|
_, _ = w.Write([]byte(combined))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
original := cloudProxyBaseURL
|
||||||
|
cloudProxyBaseURL = upstream.URL
|
||||||
|
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
reqBody := `{
|
||||||
|
"model":"gpt-oss:120b-cloud",
|
||||||
|
"max_tokens":10,
|
||||||
|
"stream":true,
|
||||||
|
"messages":[{"role":"user","content":"search the web"}],
|
||||||
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
||||||
|
}`
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
if capture.path != "/api/chat" {
|
||||||
|
t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), "event: message_stop") {
|
||||||
|
t.Fatalf("expected anthropic streaming message_stop event, got body %q", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) {
|
t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) {
|
||||||
upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`)
|
upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`)
|
||||||
defer upstream.Close()
|
defer upstream.Close()
|
||||||
|
|||||||
Reference in New Issue
Block a user