diff --git a/server/routes.go b/server/routes.go index 383e9b29c..d6c1cbe16 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2508,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo return } - // Set headers for streaming response - c.Header("Content-Type", "application/x-ndjson") + // Check streaming preference + isStreaming := req.Stream == nil || *req.Stream + + contentType := "application/x-ndjson" + if !isStreaming { + contentType = "application/json; charset=utf-8" + } + c.Header("Content-Type", contentType) // Get seed from options if provided var seed int64 @@ -2530,6 +2536,8 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo } var streamStarted bool + var finalResponse api.GenerateResponse + if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: req.Prompt, Width: req.Width, @@ -2560,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart) } + if !isStreaming { + finalResponse = res + return + } + data, _ := json.Marshal(res) c.Writer.Write(append(data, '\n')) c.Writer.Flush() @@ -2569,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo if !streamStarted { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } + return + } + + if !isStreaming { + c.JSON(http.StatusOK, finalResponse) } } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index ca149641a..26d40e3c8 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -19,7 +19,9 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/manifest" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/types/model" ) // testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests) @@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error return } +func (mockRunner) Ping(_ context.Context) error { return nil } + func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) { return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) { return mock, nil @@ -2347,3 +2351,92 @@ func TestGenerateWithImages(t *testing.T) { } }) } + +// TestImageGenerateStreamFalse tests that image generation respects stream=false +// and returns a single JSON response instead of streaming ndjson. +func TestImageGenerateStreamFalse(t *testing.T) { + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + mock := mockRunner{} + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false}) + fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false}) + fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"}) + return nil + } + + opts := api.DefaultOptions() + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: map[string]*runnerRef{ + "": { + llama: &mock, + Options: &opts, + model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}}, + numParallel: 1, + }, + }, + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + }, + } + + go s.sched.Run(t.Context()) + + // Create model manifest with image capability + n := model.ParseName("test-image") + cfg := model.ConfigV2{Capabilities: []string{"image"}} + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(&cfg); err != nil { + t.Fatal(err) + } + configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json") + if err != nil { + t.Fatal(err) + } + if err := manifest.WriteManifest(n, configLayer, nil); err != nil { + t.Fatal(err) + } + + streamFalse := false + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-image", + Prompt: "test prompt", + Stream: &streamFalse, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" { + t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct) + } + + body := w.Body.String() + lines := strings.Split(strings.TrimSpace(body), "\n") + if len(lines) != 1 { + t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body) + } + + var resp api.GenerateResponse + if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + + if resp.Image != "base64image" { + t.Errorf("expected image 'base64image', got %q", resp.Image) + } + + if !resp.Done { + t.Errorf("expected done=true") + } +}