diff --git a/cmd/cmd.go b/cmd/cmd.go index 53caa986d..20da8cd2c 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -41,6 +41,7 @@ import ( "github.com/ollama/ollama/cmd/tui" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" @@ -418,12 +419,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { return err } + requestedCloud := modelref.HasExplicitCloudSource(opts.Model) + if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil { return err - } else if info.RemoteHost != "" { + } else if info.RemoteHost != "" || requestedCloud { // Cloud model, no need to load/unload - isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com") + isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com") // Check if user is signed in for ollama.com cloud models if isCloud { @@ -434,10 +437,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { if opts.ShowConnect { p.StopAndClear() + remoteModel := info.RemoteModel + if remoteModel == "" { + remoteModel = opts.Model + } if isCloud { - fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel) + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel) } else { - fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost) + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost) } } @@ -509,6 +516,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a return nil } +// TODO(parthsareen): consolidate with TUI signin flow +func handleCloudAuthorizationError(err error) bool { + var authErr api.AuthorizationError + if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized { + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + if authErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, authErr.SigninURL) + } + return true + } + + return false +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true @@ -605,12 +626,16 @@ func RunHandler(cmd *cobra.Command, args []string) error { } name := args[0] + requestedCloud := modelref.HasExplicitCloudSource(name) info, err := func() (*api.ShowResponse, error) { showReq := &api.ShowRequest{Name: name} info, err := client.Show(cmd.Context(), showReq) var se api.StatusError if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { + if requestedCloud { + return nil, err + } if err := PullHandler(cmd, []string{name}); err != nil { return nil, err } @@ -619,6 +644,9 @@ func RunHandler(cmd *cobra.Command, args []string) error { return info, err }() if err != nil { + if handleCloudAuthorizationError(err) { + return nil + } return err } @@ -713,7 +741,13 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateInteractive(cmd, opts) } - return generate(cmd, opts) + if err := generate(cmd, opts); err != nil { + if handleCloudAuthorizationError(err) { + return nil + } + return err + } + return nil } func SigninHandler(cmd *cobra.Command, args []string) error { diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 7217c3d13..dfbd63a85 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -18,6 +18,7 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/types/model" ) @@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) { } } +func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) { + var generateCalled bool + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusUnauthorized) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + "signin_url": "https://ollama.com/signin", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + generateCalled = true + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + default: + http.NotFound(w, r) + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + oldStdout := os.Stdout + readOut, writeOut, _ := os.Pipe() + os.Stdout = writeOut + t.Cleanup(func() { os.Stdout = oldStdout }) + + err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) + + _ = writeOut.Close() + var out bytes.Buffer + _, _ = io.Copy(&out, readOut) + + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + if generateCalled { + t.Fatal("expected run to stop before /api/generate after unauthorized /api/show") + } + + if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") { + t.Fatalf("expected sign-in guidance message, got %q", out.String()) + } + + if !strings.Contains(out.String(), "https://ollama.com/signin") { + t.Fatalf("expected signin_url in output, got %q", out.String()) + } +} + +func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityCompletion}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusUnauthorized) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + "signin_url": "https://ollama.com/signin", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + default: + http.NotFound(w, r) + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + oldStdout := os.Stdout + readOut, writeOut, _ := os.Pipe() + os.Stdout = writeOut + t.Cleanup(func() { os.Stdout = oldStdout }) + + err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) + + _ = writeOut.Close() + var out bytes.Buffer + _, _ = io.Copy(&out, readOut) + + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") { + t.Fatalf("expected sign-in guidance message, got %q", out.String()) + } + + if !strings.Contains(out.String(), "https://ollama.com/signin") { + t.Fatalf("expected signin_url in output, got %q", out.String()) + } +} + func TestGetModelfileName(t *testing.T) { tests := []struct { name string @@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) { func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { tests := []struct { name string + model string remoteHost string + remoteModel string whoamiStatus int whoamiResp any expectedError string }{ { name: "ollama.com cloud model - user signed in", + model: "test-cloud-model", remoteHost: "https://ollama.com", + remoteModel: "test-model", whoamiStatus: http.StatusOK, whoamiResp: api.UserResponse{Name: "testuser"}, }, { name: "ollama.com cloud model - user not signed in", + model: "test-cloud-model", remoteHost: "https://ollama.com", + remoteModel: "test-model", whoamiStatus: http.StatusUnauthorized, whoamiResp: map[string]string{ "error": "unauthorized", @@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { }, { name: "non-ollama.com remote - no auth check", + model: "test-cloud-model", remoteHost: "https://other-remote.com", + remoteModel: "test-model", + whoamiStatus: http.StatusUnauthorized, // should not be called + whoamiResp: nil, + }, + { + name: "explicit :cloud model - auth check without remote metadata", + model: "kimi-k2.5:cloud", + remoteHost: "", + remoteModel: "", + whoamiStatus: http.StatusOK, + whoamiResp: api.UserResponse{Name: "testuser"}, + }, + { + name: "explicit -cloud model - auth check without remote metadata", + model: "kimi-k2.5:latest-cloud", + remoteHost: "", + remoteModel: "", + whoamiStatus: http.StatusOK, + whoamiResp: api.UserResponse{Name: "testuser"}, + }, + { + name: "dash cloud-like name without explicit source does not require auth", + model: "test-cloud-model", + remoteHost: "", + remoteModel: "", whoamiStatus: http.StatusUnauthorized, // should not be called whoamiResp: nil, }, @@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(api.ShowResponse{ RemoteHost: tt.remoteHost, - RemoteModel: "test-model", + RemoteModel: tt.remoteModel, }); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { http.Error(w, err.Error(), http.StatusInternalServerError) } } + case "/api/generate": + w.WriteHeader(http.StatusOK) default: http.NotFound(w, r) } @@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { cmd.SetContext(t.Context()) opts := &runOptions{ - Model: "test-cloud-model", + Model: tt.model, ShowConnect: false, } err := loadOrUnloadModel(cmd, opts) - if strings.HasPrefix(tt.remoteHost, "https://ollama.com") { + if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) { if !whoamiCalled { t.Error("expected whoami to be called for ollama.com cloud model") } diff --git a/cmd/config/claude.go b/cmd/config/claude.go index b7ed02af1..9018d193d 100644 --- a/cmd/config/claude.go +++ b/cmd/config/claude.go @@ -107,15 +107,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli } if !force && aliases["primary"] != "" { - client, _ := api.ClientFromEnvironment() - if isCloudModel(ctx, client, aliases["primary"]) { - if isCloudModel(ctx, client, aliases["fast"]) { - return aliases, false, nil - } - } else { - delete(aliases, "fast") + if isCloudModelName(aliases["primary"]) { + aliases["fast"] = aliases["primary"] return aliases, false, nil } + delete(aliases, "fast") + return aliases, false, nil } items, existingModels, cloudModels, client, err := listModels(ctx) @@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli aliases["primary"] = primary } - if isCloudModel(ctx, client, aliases["primary"]) { - if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { - aliases["fast"] = aliases["primary"] - } + if isCloudModelName(aliases["primary"]) { + aliases["fast"] = aliases["primary"] } else { delete(aliases, "fast") } diff --git a/cmd/config/config.go b/cmd/config/config.go index 8eb41f4ae..82bfb493d 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool { if name == "" { return false } + if isCloudModelName(name) { + return true + } client, err := api.ClientFromEnvironment() if err != nil { return false diff --git a/cmd/config/droid.go b/cmd/config/droid.go index d1a9f54dc..ed88c0177 100644 --- a/cmd/config/droid.go +++ b/cmd/config/droid.go @@ -10,7 +10,6 @@ import ( "path/filepath" "slices" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) @@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error { } // Build new Ollama model entries with sequential indices (0, 1, 2, ...) - client, _ := api.ClientFromEnvironment() var newModels []any var defaultModelID string for i, model := range models { maxOutput := 64000 - if isCloudModel(context.Background(), client, model) { + if isCloudModelName(model) { if l, ok := lookupCloudModelLimit(model); ok { maxOutput = l.Output } diff --git a/cmd/config/droid_test.go b/cmd/config/droid_test.go index b8306731e..ac26aef58 100644 --- a/cmd/config/droid_test.go +++ b/cmd/config/droid_test.go @@ -1276,35 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) { func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) { // Verify that every cloud model in cloudModelLimits has a valid output - // value that would be used for maxOutputTokens when isCloudModel returns true. - // Cloud suffix normalization must also work since integrations may see either - // :cloud or -cloud model names. + // value that would be used for maxOutputTokens when the selected model uses + // the explicit :cloud source tag. for name, expected := range cloudModelLimits { t.Run(name, func(t *testing.T) { - l, ok := lookupCloudModelLimit(name) - if !ok { - t.Fatalf("lookupCloudModelLimit(%q) returned false", name) - } - if l.Output != expected.Output { - t.Errorf("output = %d, want %d", l.Output, expected.Output) - } - // Also verify :cloud suffix lookup cloudName := name + ":cloud" - l2, ok := lookupCloudModelLimit(cloudName) + l, ok := lookupCloudModelLimit(cloudName) if !ok { t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName) } - if l2.Output != expected.Output { - t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output) - } - // Also verify -cloud suffix lookup - dashCloudName := name + "-cloud" - l3, ok := lookupCloudModelLimit(dashCloudName) - if !ok { - t.Fatalf("lookupCloudModelLimit(%q) returned false", dashCloudName) - } - if l3.Output != expected.Output { - t.Errorf("-cloud output = %d, want %d", l3.Output, expected.Output) + if l.Output != expected.Output { + t.Errorf("output = %d, want %d", l.Output, expected.Output) } }) } diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index e524c14ee..b9502e9b3 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -14,6 +14,7 @@ import ( "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/progress" "github.com/spf13/cobra" ) @@ -326,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri // If the selected model isn't installed, pull it first if !existingModels[selected] { - if cloudModels[selected] { - // Cloud models only pull a small manifest; no confirmation needed - if err := pullModel(ctx, client, selected); err != nil { - return "", fmt.Errorf("failed to pull %s: %w", selected, err) - } - } else { + if !isCloudModelName(selected) { msg := fmt.Sprintf("Download %s?", selected) if ok, err := confirmPrompt(msg); err != nil { return "", err @@ -526,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single var toPull []string for _, m := range selected { - if !existingModels[m] { + if !existingModels[m] && !isCloudModelName(m) { toPull = append(toPull, m) } } @@ -552,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single return selected, nil } +// TODO(parthsareen): consolidate pull logic from call sites func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { - if existingModels[model] { + if isCloudModelName(model) || existingModels[model] { return nil } - msg := fmt.Sprintf("Download %s?", model) - if ok, err := confirmPrompt(msg); err != nil { + return confirmAndPull(ctx, client, model) +} + +// TODO(parthsareen): pull this out to tui package +// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found. +func ShowOrPull(ctx context.Context, client *api.Client, model string) error { + if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { + return nil + } + if isCloudModelName(model) { + return nil + } + return confirmAndPull(ctx, client, model) +} + +func confirmAndPull(ctx context.Context, client *api.Client, model string) error { + if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { return err } else if !ok { return errCancelled @@ -569,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st return nil } -// TODO(parthsareen): pull this out to tui package -// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found. -func ShowOrPull(ctx context.Context, client *api.Client, model string) error { - if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { - return nil - } - // Cloud models only pull a small manifest; skip the download confirmation - // TODO(parthsareen): consolidate with cloud config changes - if strings.HasSuffix(model, "cloud") { - return pullModel(ctx, client, model) - } - if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { - return err - } else if !ok { - return errCancelled - } - fmt.Fprintf(os.Stderr, "\n") - return pullModel(ctx, client, model) -} - func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) { client, err := api.ClientFromEnvironment() if err != nil { @@ -733,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na } aliases["primary"] = model - if isCloudModel(ctx, client, model) { - if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { - aliases["fast"] = model - } + if isCloudModelName(model) { + aliases["fast"] = model } else { delete(aliases, "fast") } @@ -1022,7 +1012,7 @@ Examples: existingAliases = aliases // Ensure cloud models are authenticated - if isCloudModel(cmd.Context(), client, model) { + if isCloudModelName(model) { if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil { return err } @@ -1211,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) ( // When user has no models, preserve recommended order. notInstalled := make(map[string]bool) for i := range items { - if !existingModels[items[i].Name] { + if !existingModels[items[i].Name] && !cloudModels[items[i].Name] { notInstalled[items[i].Name] = true var parts []string if items[i].Description != "" { @@ -1305,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool { } func isCloudModelName(name string) bool { - return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") + // TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit + return modelref.HasExplicitCloudSource(name) } func filterCloudModels(existing []modelInfo) []modelInfo { diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index 914a8f661..2eca19fd9 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -426,8 +426,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) { } for _, item := range items { - if !strings.HasSuffix(item.Description, "(not downloaded)") { - t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description) + if strings.HasSuffix(item.Name, ":cloud") { + if strings.HasSuffix(item.Description, "(not downloaded)") { + t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) + } + } else { + if !strings.HasSuffix(item.Description, "(not downloaded)") { + t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description) + } } } } @@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) { if strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) } - case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b": + case "qwen3:8b": if !strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description) } + case "minimax-m2.5:cloud", "kimi-k2.5:cloud": + if strings.HasSuffix(item.Description, "(not downloaded)") { + t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) + } } } } @@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes } for _, item := range items { - if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) { + isCloud := strings.HasSuffix(item.Name, ":cloud") + isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) + if isInstalled || isCloud { + if strings.HasSuffix(item.Description, "(not downloaded)") { + t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) + } + } else { if !strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description) } @@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) { } } -func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) { - // Confirm prompt should NOT be called for cloud models +func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) { + // Confirm prompt should NOT be called for explicit cloud models oldHook := DefaultConfirmPrompt DefaultConfirmPrompt = func(prompt string) (bool, error) { t.Error("confirm prompt should not be called for cloud models") @@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) { if err != nil { t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) } - if !pullCalled { - t.Error("expected pull to be called for cloud model without confirmation") + if pullCalled { + t.Error("expected pull not to be called for cloud model") + } +} + +func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) { + // Confirm prompt should NOT be called for explicit cloud models + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Error("confirm prompt should not be called for cloud models") + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud") + if err != nil { + t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) + } + if pullCalled { + t.Error("expected pull not to be called for cloud model") + } +} + +func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Error("confirm prompt should not be called for cloud models") + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud") + if err != nil { + t.Fatalf("expected no error for cloud model, got %v", err) + } + + err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud") + if err != nil { + t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err) + } +} + +func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) { + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + case "/api/tags": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"models":[]}`) + case "/api/me": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"name":"test-user"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + single := func(title string, items []ModelItem, current string) (string, error) { + for _, item := range items { + if item.Name == "glm-5:cloud" { + return item.Name, nil + } + } + t.Fatalf("expected glm-5:cloud in selector items, got %v", items) + return "", nil + } + + multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) { + return nil, fmt.Errorf("multi selector should not be called") + } + + selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi) + if err != nil { + t.Fatalf("selectModelsWithSelectors returned error: %v", err) + } + if !slices.Equal(selected, []string{"glm-5:cloud"}) { + t.Fatalf("unexpected selected models: %v", selected) + } + if pullCalled { + t.Fatal("expected cloud selection to skip pull") } } diff --git a/cmd/config/opencode.go b/cmd/config/opencode.go index 774c3f066..52a1426b9 100644 --- a/cmd/config/opencode.go +++ b/cmd/config/opencode.go @@ -12,8 +12,8 @@ import ( "slices" "strings" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/internal/modelref" ) // OpenCode implements Runner and Editor for OpenCode integration @@ -26,14 +26,13 @@ type cloudModelLimit struct { } // lookupCloudModelLimit returns the token limits for a cloud model. -// It normalizes common cloud suffixes before checking the shared limit map. +// It normalizes explicit cloud source suffixes before checking the shared limit map. func lookupCloudModelLimit(name string) (cloudModelLimit, bool) { - // TODO(parthsareen): migrate to using cloud check instead. - for _, suffix := range []string{"-cloud", ":cloud"} { - name = strings.TrimSuffix(name, suffix) - } - if l, ok := cloudModelLimits[name]; ok { - return l, true + base, stripped := modelref.StripCloudSourceTag(name) + if stripped { + if l, ok := cloudModelLimits[base]; ok { + return l, true + } } return cloudModelLimit{}, false } @@ -150,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error { } } - client, _ := api.ClientFromEnvironment() - for _, model := range modelList { if existing, ok := models[model].(map[string]any); ok { // migrate existing models without _launch marker @@ -161,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error { existing["name"] = strings.TrimSuffix(name, " [Ollama]") } } - if isCloudModel(context.Background(), client, model) { + if isCloudModelName(model) { if l, ok := lookupCloudModelLimit(model); ok { existing["limit"] = map[string]any{ "context": l.Context, @@ -175,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error { "name": model, "_launch": true, } - if isCloudModel(context.Background(), client, model) { + if isCloudModelName(model) { if l, ok := lookupCloudModelLimit(model); ok { entry["limit"] = map[string]any{ "context": l.Context, diff --git a/cmd/config/opencode_test.go b/cmd/config/opencode_test.go index 552435453..9f7744892 100644 --- a/cmd/config/opencode_test.go +++ b/cmd/config/opencode_test.go @@ -714,16 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) { wantContext int wantOutput int }{ - {"glm-4.7", true, 202_752, 131_072}, + {"glm-4.7", false, 0, 0}, {"glm-4.7:cloud", true, 202_752, 131_072}, {"glm-5:cloud", true, 202_752, 131_072}, {"gpt-oss:120b-cloud", true, 131_072, 131_072}, {"gpt-oss:20b-cloud", true, 131_072, 131_072}, - {"kimi-k2.5", true, 262_144, 262_144}, + {"kimi-k2.5", false, 0, 0}, {"kimi-k2.5:cloud", true, 262_144, 262_144}, - {"deepseek-v3.2", true, 163_840, 65_536}, + {"deepseek-v3.2", false, 0, 0}, {"deepseek-v3.2:cloud", true, 163_840, 65_536}, - {"qwen3-coder:480b", true, 262_144, 65_536}, + {"qwen3-coder:480b", false, 0, 0}, + {"qwen3-coder:480b:cloud", true, 262_144, 65_536}, {"qwen3-coder-next:cloud", true, 262_144, 32_768}, {"llama3.2", false, 0, 0}, {"unknown-model:cloud", false, 0, 0}, diff --git a/cmd/tui/tui.go b/cmd/tui/tui.go index 389c875a9..5803d98fa 100644 --- a/cmd/tui/tui.go +++ b/cmd/tui/tui.go @@ -11,6 +11,7 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/version" ) @@ -147,7 +148,13 @@ type signInCheckMsg struct { type clearStatusMsg struct{} func (m *model) modelExists(name string) bool { - if m.availableModels == nil || name == "" { + if name == "" { + return false + } + if modelref.HasExplicitCloudSource(name) { + return true + } + if m.availableModels == nil { return false } if m.availableModels[name] { @@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) { } func isCloudModel(name string) bool { - return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") + return modelref.HasExplicitCloudSource(name) } func cloudStatusDisabled(client *api.Client) bool { diff --git a/internal/modelref/modelref.go b/internal/modelref/modelref.go new file mode 100644 index 000000000..f62757912 --- /dev/null +++ b/internal/modelref/modelref.go @@ -0,0 +1,115 @@ +package modelref + +import ( + "errors" + "fmt" + "strings" +) + +type ModelSource uint8 + +const ( + ModelSourceUnspecified ModelSource = iota + ModelSourceLocal + ModelSourceCloud +) + +var ( + ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both") + ErrModelRequired = errors.New("model is required") +) + +type ParsedRef struct { + Original string + Base string + Source ModelSource +} + +func ParseRef(raw string) (ParsedRef, error) { + var zero ParsedRef + + raw = strings.TrimSpace(raw) + if raw == "" { + return zero, ErrModelRequired + } + + base, source, explicit := parseSourceSuffix(raw) + if explicit { + if _, _, nested := parseSourceSuffix(base); nested { + return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw) + } + } + + return ParsedRef{ + Original: raw, + Base: base, + Source: source, + }, nil +} + +func HasExplicitCloudSource(raw string) bool { + parsedRef, err := ParseRef(raw) + return err == nil && parsedRef.Source == ModelSourceCloud +} + +func HasExplicitLocalSource(raw string) bool { + parsedRef, err := ParseRef(raw) + return err == nil && parsedRef.Source == ModelSourceLocal +} + +func StripCloudSourceTag(raw string) (string, bool) { + parsedRef, err := ParseRef(raw) + if err != nil || parsedRef.Source != ModelSourceCloud { + return strings.TrimSpace(raw), false + } + + return parsedRef.Base, true +} + +func NormalizePullName(raw string) (string, bool, error) { + parsedRef, err := ParseRef(raw) + if err != nil { + return "", false, err + } + + if parsedRef.Source != ModelSourceCloud { + return parsedRef.Base, false, nil + } + + return toLegacyCloudPullName(parsedRef.Base), true, nil +} + +func toLegacyCloudPullName(base string) string { + if hasExplicitTag(base) { + return base + "-cloud" + } + + return base + ":cloud" +} + +func hasExplicitTag(name string) bool { + lastSlash := strings.LastIndex(name, "/") + lastColon := strings.LastIndex(name, ":") + return lastColon > lastSlash +} + +func parseSourceSuffix(raw string) (string, ModelSource, bool) { + idx := strings.LastIndex(raw, ":") + if idx >= 0 { + suffixRaw := strings.TrimSpace(raw[idx+1:]) + suffix := strings.ToLower(suffixRaw) + + switch suffix { + case "cloud": + return raw[:idx], ModelSourceCloud, true + case "local": + return raw[:idx], ModelSourceLocal, true + } + + if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") { + return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true + } + } + + return raw, ModelSourceUnspecified, false +} diff --git a/internal/modelref/modelref_test.go b/internal/modelref/modelref_test.go new file mode 100644 index 000000000..7d1c1bee5 --- /dev/null +++ b/internal/modelref/modelref_test.go @@ -0,0 +1,268 @@ +package modelref + +import ( + "errors" + "testing" +) + +func TestParseRef(t *testing.T) { + tests := []struct { + name string + input string + wantBase string + wantSource ModelSource + wantErr error + wantCloud bool + wantLocal bool + wantStripped string + wantStripOK bool + }{ + { + name: "cloud suffix", + input: "gpt-oss:20b:cloud", + wantBase: "gpt-oss:20b", + wantSource: ModelSourceCloud, + wantCloud: true, + wantStripped: "gpt-oss:20b", + wantStripOK: true, + }, + { + name: "legacy cloud suffix", + input: "gpt-oss:20b-cloud", + wantBase: "gpt-oss:20b", + wantSource: ModelSourceCloud, + wantCloud: true, + wantStripped: "gpt-oss:20b", + wantStripOK: true, + }, + { + name: "local suffix", + input: "qwen3:8b:local", + wantBase: "qwen3:8b", + wantSource: ModelSourceLocal, + wantLocal: true, + wantStripped: "qwen3:8b:local", + }, + { + name: "no source suffix", + input: "llama3.2", + wantBase: "llama3.2", + wantSource: ModelSourceUnspecified, + wantStripped: "llama3.2", + }, + { + name: "bare cloud name is not explicit cloud", + input: "my-cloud-model", + wantBase: "my-cloud-model", + wantSource: ModelSourceUnspecified, + wantStripped: "my-cloud-model", + }, + { + name: "slash in suffix blocks legacy cloud parsing", + input: "foo:bar-cloud/baz", + wantBase: "foo:bar-cloud/baz", + wantSource: ModelSourceUnspecified, + wantStripped: "foo:bar-cloud/baz", + }, + { + name: "conflicting source suffixes", + input: "foo:cloud:local", + wantErr: ErrConflictingSourceSuffix, + wantSource: ModelSourceUnspecified, + }, + { + name: "empty input", + input: " ", + wantErr: ErrModelRequired, + wantSource: ModelSourceUnspecified, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseRef(tt.input) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err) + } + + if got.Base != tt.wantBase { + t.Fatalf("base = %q, want %q", got.Base, tt.wantBase) + } + + if got.Source != tt.wantSource { + t.Fatalf("source = %v, want %v", got.Source, tt.wantSource) + } + + if HasExplicitCloudSource(tt.input) != tt.wantCloud { + t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud) + } + + if HasExplicitLocalSource(tt.input) != tt.wantLocal { + t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal) + } + + stripped, ok := StripCloudSourceTag(tt.input) + if ok != tt.wantStripOK { + t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK) + } + if stripped != tt.wantStripped { + t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped) + } + }) + } +} + +func TestNormalizePullName(t *testing.T) { + tests := []struct { + name string + input string + wantName string + wantCloud bool + wantErr error + }{ + { + name: "explicit local strips source", + input: "gpt-oss:20b:local", + wantName: "gpt-oss:20b", + }, + { + name: "explicit cloud with size maps to legacy dash cloud tag", + input: "gpt-oss:20b:cloud", + wantName: "gpt-oss:20b-cloud", + wantCloud: true, + }, + { + name: "legacy cloud with size remains stable", + input: "gpt-oss:20b-cloud", + wantName: "gpt-oss:20b-cloud", + wantCloud: true, + }, + { + name: "explicit cloud without tag maps to cloud tag", + input: "qwen3:cloud", + wantName: "qwen3:cloud", + wantCloud: true, + }, + { + name: "host port without tag keeps host port and appends cloud tag", + input: "localhost:11434/library/foo:cloud", + wantName: "localhost:11434/library/foo:cloud", + wantCloud: true, + }, + { + name: "conflicting source suffixes fail", + input: "foo:cloud:local", + wantErr: ErrConflictingSourceSuffix, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotName, gotCloud, err := NormalizePullName(tt.input) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err) + } + + if gotName != tt.wantName { + t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName) + } + if gotCloud != tt.wantCloud { + t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud) + } + }) + } +} + +func TestParseSourceSuffix(t *testing.T) { + tests := []struct { + name string + input string + wantBase string + wantSource ModelSource + wantExplicit bool + }{ + { + name: "explicit cloud suffix", + input: "gpt-oss:20b:cloud", + wantBase: "gpt-oss:20b", + wantSource: ModelSourceCloud, + wantExplicit: true, + }, + { + name: "explicit local suffix", + input: "qwen3:8b:local", + wantBase: "qwen3:8b", + wantSource: ModelSourceLocal, + wantExplicit: true, + }, + { + name: "legacy cloud suffix on tag", + input: "gpt-oss:20b-cloud", + wantBase: "gpt-oss:20b", + wantSource: ModelSourceCloud, + wantExplicit: true, + }, + { + name: "legacy cloud suffix does not match model segment", + input: "my-cloud-model", + wantBase: "my-cloud-model", + wantSource: ModelSourceUnspecified, + wantExplicit: false, + }, + { + name: "legacy cloud suffix blocked when suffix includes slash", + input: "foo:bar-cloud/baz", + wantBase: "foo:bar-cloud/baz", + wantSource: ModelSourceUnspecified, + wantExplicit: false, + }, + { + name: "unknown suffix is not explicit source", + input: "gpt-oss:clod", + wantBase: "gpt-oss:clod", + wantSource: ModelSourceUnspecified, + wantExplicit: false, + }, + { + name: "uppercase suffix is accepted", + input: "gpt-oss:20b:CLOUD", + wantBase: "gpt-oss:20b", + wantSource: ModelSourceCloud, + wantExplicit: true, + }, + { + name: "no suffix", + input: "llama3.2", + wantBase: "llama3.2", + wantSource: ModelSourceUnspecified, + wantExplicit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input) + if gotBase != tt.wantBase { + t.Fatalf("base = %q, want %q", gotBase, tt.wantBase) + } + if gotSource != tt.wantSource { + t.Fatalf("source = %v, want %v", gotSource, tt.wantSource) + } + if gotExplicit != tt.wantExplicit { + t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit) + } + }) + } +} diff --git a/middleware/anthropic.go b/middleware/anthropic.go index 85c95e60c..d65edd53f 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -17,6 +17,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/logutil" ) @@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool { } func isCloudModelName(name string) bool { - return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") + return modelref.HasExplicitCloudSource(name) } // extractQueryFromToolCall extracts the search query from a web_search tool call diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go new file mode 100644 index 000000000..bf91d7694 --- /dev/null +++ b/server/cloud_proxy.go @@ -0,0 +1,460 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/auth" + "github.com/ollama/ollama/envconfig" + internalcloud "github.com/ollama/ollama/internal/cloud" +) + +const ( + defaultCloudProxyBaseURL = "https://ollama.com:443" + defaultCloudProxySigningHost = "ollama.com" + cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL" + legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search" +) + +var ( + cloudProxyBaseURL = defaultCloudProxyBaseURL + cloudProxySigningHost = defaultCloudProxySigningHost + cloudProxySignRequest = signCloudProxyRequest + cloudProxySigninURL = signinURL +) + +var hopByHopHeaders = map[string]struct{}{ + "connection": {}, + "content-length": {}, + "proxy-connection": {}, + "keep-alive": {}, + "proxy-authenticate": {}, + "proxy-authorization": {}, + "te": {}, + "trailer": {}, + "transfer-encoding": {}, + "upgrade": {}, +} + +func init() { + baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode) + if err != nil { + slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err) + return + } + + cloudProxyBaseURL = baseURL + cloudProxySigningHost = signingHost + + if overridden { + slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode) + } +} + +func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request.Method != http.MethodPost { + c.Next() + return + } + + // TODO(drifkin): Avoid full-body buffering here for model detection. + // A future optimization can parse just enough JSON to read "model" (and + // optionally short-circuit cloud-disabled explicit-cloud requests) while + // preserving raw passthrough semantics. + body, err := readRequestBody(c.Request) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.Abort() + return + } + + model, ok := extractModelField(body) + if !ok { + c.Next() + return + } + + modelRef, err := parseAndValidateModelRef(model) + if err != nil || modelRef.Source != modelSourceCloud { + c.Next() + return + } + + normalizedBody, err := replaceJSONModelField(body, modelRef.Base) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.Abort() + return + } + + // TEMP(drifkin): keep Anthropic web search requests on the local middleware + // path so WebSearchAnthropicWriter can orchestrate follow-up calls. + if c.Request.URL.Path == "/v1/messages" { + if hasAnthropicWebSearchTool(body) { + c.Set(legacyCloudAnthropicKey, true) + c.Next() + return + } + } + + proxyCloudRequest(c, normalizedBody, disabledOperation) + c.Abort() + } +} + +func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc { + return func(c *gin.Context) { + modelName := strings.TrimSpace(c.Param("model")) + if modelName == "" { + c.Next() + return + } + + modelRef, err := parseAndValidateModelRef(modelName) + if err != nil || modelRef.Source != modelSourceCloud { + c.Next() + return + } + + proxyPath := "/v1/models/" + modelRef.Base + proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation) + c.Abort() + } +} + +func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) { + // TEMP(drifkin): we currently split out this `WithPath` method because we are + // mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we + // stop doing this, we can inline this method. + proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation) +} + +func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) { + body, err := json.Marshal(payload) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + proxyCloudRequestWithPath(c, body, path, disabledOperation) +} + +func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) { + proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation) +} + +func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) { + if disabled, _ := internalcloud.Status(); disabled { + c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)}) + return + } + + baseURL, err := url.Parse(cloudProxyBaseURL) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + targetURL := baseURL.ResolveReference(&url.URL{ + Path: path, + RawQuery: c.Request.URL.RawQuery, + }) + + outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + copyProxyRequestHeaders(outReq.Header, c.Request.Header) + if outReq.Header.Get("Content-Type") == "" && len(body) > 0 { + outReq.Header.Set("Content-Type", "application/json") + } + + if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { + slog.Warn("cloud proxy signing failed", "error", err) + writeCloudUnauthorized(c) + return + } + + // TODO(drifkin): Add phase-specific proxy timeouts. + // Connect/TLS/TTFB should have bounded timeouts, but once streaming starts + // we should not enforce a short total timeout for long-lived responses. + resp, err := http.DefaultClient.Do(outReq) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + defer resp.Body.Close() + + copyProxyResponseHeaders(c.Writer.Header(), resp.Header) + c.Status(resp.StatusCode) + + if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil { + c.Error(err) //nolint:errcheck + } +} + +func replaceJSONModelField(body []byte, model string) ([]byte, error) { + if len(body) == 0 { + return body, nil + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + + modelJSON, err := json.Marshal(model) + if err != nil { + return nil, err + } + payload["model"] = modelJSON + + return json.Marshal(payload) +} + +func readRequestBody(r *http.Request) ([]byte, error) { + if r.Body == nil { + return nil, nil + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + return body, nil +} + +func extractModelField(body []byte) (string, bool) { + if len(body) == 0 { + return "", false + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + return "", false + } + + raw, ok := payload["model"] + if !ok { + return "", false + } + + var model string + if err := json.Unmarshal(raw, &model); err != nil { + return "", false + } + + model = strings.TrimSpace(model) + return model, model != "" +} + +func hasAnthropicWebSearchTool(body []byte) bool { + if len(body) == 0 { + return false + } + + var payload struct { + Tools []struct { + Type string `json:"type"` + } `json:"tools"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return false + } + + for _, tool := range payload.Tools { + if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") { + return true + } + } + + return false +} + +func writeCloudUnauthorized(c *gin.Context) { + signinURL, err := cloudProxySigninURL() + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL}) +} + +func signCloudProxyRequest(ctx context.Context, req *http.Request) error { + if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) { + return nil + } + + ts := strconv.FormatInt(time.Now().Unix(), 10) + challenge := buildCloudSignatureChallenge(req, ts) + signature, err := auth.Sign(ctx, []byte(challenge)) + if err != nil { + return err + } + + req.Header.Set("Authorization", signature) + return nil +} + +func buildCloudSignatureChallenge(req *http.Request, ts string) string { + query := req.URL.Query() + query.Set("ts", ts) + req.URL.RawQuery = query.Encode() + + return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()) +} + +func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) { + baseURL = defaultCloudProxyBaseURL + signingHost = defaultCloudProxySigningHost + + rawOverride = strings.TrimSpace(rawOverride) + if rawOverride == "" { + return baseURL, signingHost, false, nil + } + + u, err := url.Parse(rawOverride) + if err != nil { + return "", "", false, fmt.Errorf("invalid URL: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return "", "", false, fmt.Errorf("invalid URL: scheme and host are required") + } + if u.User != nil { + return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed") + } + if u.Path != "" && u.Path != "/" { + return "", "", false, fmt.Errorf("invalid URL: path is not allowed") + } + if u.RawQuery != "" || u.Fragment != "" { + return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed") + } + + host := u.Hostname() + if host == "" { + return "", "", false, fmt.Errorf("invalid URL: host is required") + } + + loopback := isLoopbackHost(host) + if runMode == gin.ReleaseMode && !loopback { + return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode") + } + if !loopback && !strings.EqualFold(u.Scheme, "https") { + return "", "", false, fmt.Errorf("non-loopback cloud override must use https") + } + + u.Path = "" + u.RawPath = "" + u.RawQuery = "" + u.Fragment = "" + + return u.String(), strings.ToLower(host), true, nil +} + +func isLoopbackHost(host string) bool { + if strings.EqualFold(host, "localhost") { + return true + } + + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + +func copyProxyRequestHeaders(dst, src http.Header) { + connectionTokens := connectionHeaderTokens(src) + for key, values := range src { + if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) { + continue + } + + dst.Del(key) + for _, value := range values { + dst.Add(key, value) + } + } +} + +func copyProxyResponseHeaders(dst, src http.Header) { + connectionTokens := connectionHeaderTokens(src) + for key, values := range src { + if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) { + continue + } + + dst.Del(key) + for _, value := range values { + dst.Add(key, value) + } + } +} + +func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error { + flusher, canFlush := dst.(http.Flusher) + buf := make([]byte, 32*1024) + + for { + n, err := src.Read(buf) + if n > 0 { + if _, writeErr := dst.Write(buf[:n]); writeErr != nil { + return writeErr + } + if canFlush { + // TODO(drifkin): Consider conditional flushing so non-streaming + // responses don't flush every write and can optimize throughput. + flusher.Flush() + } + } + + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +func isHopByHopHeader(name string) bool { + _, ok := hopByHopHeaders[strings.ToLower(name)] + return ok +} + +func connectionHeaderTokens(header http.Header) map[string]struct{} { + tokens := map[string]struct{}{} + for _, raw := range header.Values("Connection") { + for _, token := range strings.Split(raw, ",") { + token = strings.TrimSpace(strings.ToLower(token)) + if token == "" { + continue + } + tokens[token] = struct{}{} + } + } + return tokens +} + +func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool { + if len(tokens) == 0 { + return false + } + _, ok := tokens[strings.ToLower(name)] + return ok +} diff --git a/server/cloud_proxy_test.go b/server/cloud_proxy_test.go new file mode 100644 index 000000000..1a7b27956 --- /dev/null +++ b/server/cloud_proxy_test.go @@ -0,0 +1,154 @@ +package server + +import ( + "net/http" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) { + src := http.Header{} + src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop") + src.Add("X-Trace-Hop", "drop-me") + src.Add("X-Alt-Hop", "drop-me-too") + src.Add("Keep-Alive", "timeout=5") + src.Add("X-End-To-End", "keep-me") + + dst := http.Header{} + copyProxyRequestHeaders(dst, src) + + if got := dst.Get("Connection"); got != "" { + t.Fatalf("expected Connection to be stripped, got %q", got) + } + if got := dst.Get("Keep-Alive"); got != "" { + t.Fatalf("expected Keep-Alive to be stripped, got %q", got) + } + if got := dst.Get("X-Trace-Hop"); got != "" { + t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got) + } + if got := dst.Get("X-Alt-Hop"); got != "" { + t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got) + } + if got := dst.Get("X-End-To-End"); got != "keep-me" { + t.Fatalf("expected X-End-To-End to be forwarded, got %q", got) + } +} + +func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) { + src := http.Header{} + src.Add("Connection", "X-Upstream-Hop") + src.Add("X-Upstream-Hop", "drop-me") + src.Add("Content-Type", "application/json") + src.Add("X-Server-Trace", "keep-me") + + dst := http.Header{} + copyProxyResponseHeaders(dst, src) + + if got := dst.Get("Connection"); got != "" { + t.Fatalf("expected Connection to be stripped, got %q", got) + } + if got := dst.Get("X-Upstream-Hop"); got != "" { + t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got) + } + if got := dst.Get("Content-Type"); got != "application/json" { + t.Fatalf("expected Content-Type to be forwarded, got %q", got) + } + if got := dst.Get("X-Server-Trace"); got != "keep-me" { + t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got) + } +} + +func TestResolveCloudProxyBaseURL_Default(t *testing.T) { + baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if overridden { + t.Fatal("expected override=false for empty input") + } + if baseURL != defaultCloudProxyBaseURL { + t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL) + } + if signingHost != defaultCloudProxySigningHost { + t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost) + } +} + +func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) { + baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !overridden { + t.Fatal("expected override=true") + } + if baseURL != "http://localhost:8080" { + t.Fatalf("unexpected base URL: %q", baseURL) + } + if signingHost != "localhost" { + t.Fatalf("unexpected signing host: %q", signingHost) + } +} + +func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) { + _, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode) + if err == nil { + t.Fatal("expected error for non-loopback override in release mode") + } +} + +func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) { + baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !overridden { + t.Fatal("expected override=true") + } + if baseURL != "https://example.com:8443" { + t.Fatalf("unexpected base URL: %q", baseURL) + } + if signingHost != "example.com" { + t.Fatalf("unexpected signing host: %q", signingHost) + } +} + +func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) { + _, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode) + if err == nil { + t.Fatal("expected error for non-loopback http override in dev mode") + } +} + +func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + got := buildCloudSignatureChallenge(req, "123") + want := "POST,/v1/messages?beta=true&foo=bar&ts=123" + if got != want { + t.Fatalf("challenge mismatch: got %q want %q", got, want) + } + if req.URL.RawQuery != "beta=true&foo=bar&ts=123" { + t.Fatalf("unexpected signed query: %q", req.URL.RawQuery) + } +} + +func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + got := buildCloudSignatureChallenge(req, "123") + want := "POST,/v1/messages?beta=true&ts=123" + if got != want { + t.Fatalf("challenge mismatch: got %q want %q", got, want) + } + if req.URL.RawQuery != "beta=true&ts=123" { + t.Fatalf("unexpected signed query: %q", req.URL.RawQuery) + } +} diff --git a/server/create.go b/server/create.go index c9ade530e..9797384fd 100644 --- a/server/create.go +++ b/server/create.go @@ -110,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) { if r.From != "" { slog.Debug("create model from model name", "from", r.From) - fromName := model.ParseName(r.From) - if !fromName.IsValid() { + fromRef, err := parseAndValidateModelRef(r.From) + if err != nil { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } - if r.RemoteHost != "" { - ru, err := remoteURL(r.RemoteHost) + + fromName := fromRef.Name + remoteHost := r.RemoteHost + if fromRef.Source == modelSourceCloud && remoteHost == "" { + remoteHost = cloudProxyBaseURL + } + + if remoteHost != "" { + ru, err := remoteURL(remoteHost) if err != nil { ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} return } - config.RemoteModel = r.From + config.RemoteModel = fromRef.Base config.RemoteHost = ru remote = true } else { diff --git a/server/model_resolver.go b/server/model_resolver.go new file mode 100644 index 000000000..cbbeffa37 --- /dev/null +++ b/server/model_resolver.go @@ -0,0 +1,81 @@ +package server + +import ( + "github.com/ollama/ollama/internal/modelref" + "github.com/ollama/ollama/types/model" +) + +type modelSource = modelref.ModelSource + +const ( + modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified + modelSourceLocal modelSource = modelref.ModelSourceLocal + modelSourceCloud modelSource = modelref.ModelSourceCloud +) + +var ( + errConflictingModelSource = modelref.ErrConflictingSourceSuffix + errModelRequired = modelref.ErrModelRequired +) + +type parsedModelRef struct { + // Original is the caller-provided model string before source parsing. + // Example: "gpt-oss:20b:cloud". + Original string + // Base is the model string after source suffix normalization. + // Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b". + Base string + // Name is Base parsed as a fully-qualified model.Name with defaults applied. + // Example: "registry.ollama.ai/library/gpt-oss:20b". + Name model.Name + // Source captures explicit source intent from the original input. + // Example: "gpt-oss:20b:cloud" -> modelSourceCloud. + Source modelSource +} + +func parseAndValidateModelRef(raw string) (parsedModelRef, error) { + var zero parsedModelRef + + parsed, err := modelref.ParseRef(raw) + if err != nil { + return zero, err + } + + name := model.ParseName(parsed.Base) + if !name.IsValid() { + return zero, model.Unqualified(name) + } + + return parsedModelRef{ + Original: parsed.Original, + Base: parsed.Base, + Name: name, + Source: parsed.Source, + }, nil +} + +func parseNormalizePullModelRef(raw string) (parsedModelRef, error) { + var zero parsedModelRef + + parsedRef, err := modelref.ParseRef(raw) + if err != nil { + return zero, err + } + + normalizedName, _, err := modelref.NormalizePullName(raw) + if err != nil { + return zero, err + } + + name := model.ParseName(normalizedName) + if !name.IsValid() { + return zero, model.Unqualified(name) + } + + return parsedModelRef{ + Original: parsedRef.Original, + Base: normalizedName, + Name: name, + Source: parsedRef.Source, + }, nil +} diff --git a/server/model_resolver_test.go b/server/model_resolver_test.go new file mode 100644 index 000000000..c0926ec30 --- /dev/null +++ b/server/model_resolver_test.go @@ -0,0 +1,170 @@ +package server + +import ( + "errors" + "strings" + "testing" +) + +func TestParseModelSelector(t *testing.T) { + t.Run("cloud suffix", func(t *testing.T) { + got, err := parseAndValidateModelRef("gpt-oss:20b:cloud") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceCloud { + t.Fatalf("expected source cloud, got %v", got.Source) + } + + if got.Base != "gpt-oss:20b" { + t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) + } + + if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" { + t.Fatalf("unexpected resolved name: %q", got.Name.String()) + } + }) + + t.Run("legacy cloud suffix", func(t *testing.T) { + got, err := parseAndValidateModelRef("gpt-oss:20b-cloud") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceCloud { + t.Fatalf("expected source cloud, got %v", got.Source) + } + + if got.Base != "gpt-oss:20b" { + t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) + } + }) + + t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) { + got, err := parseAndValidateModelRef("my-cloud-model") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceUnspecified { + t.Fatalf("expected source unspecified, got %v", got.Source) + } + + if got.Base != "my-cloud-model" { + t.Fatalf("expected base my-cloud-model, got %q", got.Base) + } + }) + + t.Run("local suffix", func(t *testing.T) { + got, err := parseAndValidateModelRef("qwen3:8b:local") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceLocal { + t.Fatalf("expected source local, got %v", got.Source) + } + + if got.Base != "qwen3:8b" { + t.Fatalf("expected base qwen3:8b, got %q", got.Base) + } + }) + + t.Run("conflicting source suffixes fail", func(t *testing.T) { + _, err := parseAndValidateModelRef("foo:cloud:local") + if !errors.Is(err, errConflictingModelSource) { + t.Fatalf("expected errConflictingModelSource, got %v", err) + } + }) + + t.Run("unspecified source", func(t *testing.T) { + got, err := parseAndValidateModelRef("llama3") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceUnspecified { + t.Fatalf("expected source unspecified, got %v", got.Source) + } + + if got.Name.Tag != "latest" { + t.Fatalf("expected default latest tag, got %q", got.Name.Tag) + } + }) + + t.Run("unknown suffix is treated as tag", func(t *testing.T) { + got, err := parseAndValidateModelRef("gpt-oss:clod") + if err != nil { + t.Fatalf("parseModelSelector returned error: %v", err) + } + + if got.Source != modelSourceUnspecified { + t.Fatalf("expected source unspecified, got %v", got.Source) + } + + if got.Name.Tag != "clod" { + t.Fatalf("expected tag clod, got %q", got.Name.Tag) + } + }) + + t.Run("empty model fails", func(t *testing.T) { + _, err := parseAndValidateModelRef("") + if !errors.Is(err, errModelRequired) { + t.Fatalf("expected errModelRequired, got %v", err) + } + }) + + t.Run("invalid model fails", func(t *testing.T) { + _, err := parseAndValidateModelRef("::cloud") + if err == nil { + t.Fatal("expected error for invalid model") + } + if !strings.Contains(err.Error(), "unqualified") { + t.Fatalf("expected unqualified model error, got %v", err) + } + }) +} + +func TestParsePullModelRef(t *testing.T) { + t.Run("explicit local is normalized", func(t *testing.T) { + got, err := parseNormalizePullModelRef("gpt-oss:20b:local") + if err != nil { + t.Fatalf("parseNormalizePullModelRef returned error: %v", err) + } + + if got.Source != modelSourceLocal { + t.Fatalf("expected source local, got %v", got.Source) + } + + if got.Base != "gpt-oss:20b" { + t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) + } + }) + + t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) { + got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud") + if err != nil { + t.Fatalf("parseNormalizePullModelRef returned error: %v", err) + } + if got.Base != "gpt-oss:20b-cloud" { + t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base) + } + if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" { + t.Fatalf("unexpected resolved name: %q", got.Name.String()) + } + }) + + t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) { + got, err := parseNormalizePullModelRef("qwen3:cloud") + if err != nil { + t.Fatalf("parseNormalizePullModelRef returned error: %v", err) + } + if got.Base != "qwen3:cloud" { + t.Fatalf("expected base qwen3:cloud, got %q", got.Base) + } + if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" { + t.Fatalf("unexpected resolved name: %q", got.Name.String()) + } + }) +} diff --git a/server/routes.go b/server/routes.go index 271b357f9..cafd2b8fb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -64,6 +64,17 @@ const ( cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable" ) +func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) { + switch { + case errors.Is(err, errConflictingModelSource): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + case errors.Is(err, model.ErrUnqualifiedName): + c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) + default: + c.JSON(fallbackStatus, gin.H{"error": fallbackMessage}) + } +} + func shouldUseHarmony(model *Model) bool { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { // heuristic to check whether the template expects to be parsed via harmony: @@ -196,14 +207,22 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - name := model.ParseName(req.Model) - if !name.IsValid() { - // Ideally this is "invalid model name" but we're keeping with - // what the API currently returns until we can change it. - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + modelRef, err := parseAndValidateModelRef(req.Model) + if err != nil { + writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model)) return } + if modelRef.Source == modelSourceCloud { + // TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the + // original body (modulo model name normalization) is sent to cloud. + req.Model = modelRef.Base + proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) + return + } + + name := modelRef.Name + resolvedName, _, err := s.resolveAlias(name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -237,6 +256,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + return + } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { if disabled, _ := internalcloud.Status(); disabled { c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)}) @@ -670,6 +694,18 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + modelRef, err := parseAndValidateModelRef(req.Model) + if err != nil { + writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model)) + return + } + + if modelRef.Source == modelSourceCloud { + req.Model = modelRef.Base + proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) + return + } + var input []string switch i := req.Input.(type) { @@ -692,7 +728,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { } } - name, err := getExistingName(model.ParseName(req.Model)) + name, err := getExistingName(modelRef.Name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return @@ -839,12 +875,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + modelRef, err := parseAndValidateModelRef(req.Model) + if err != nil { + writeModelRefParseError(c, err, http.StatusBadRequest, "model is required") return } + if modelRef.Source == modelSourceCloud { + req.Model = modelRef.Base + proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) + return + } + + name := modelRef.Name + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) @@ -886,12 +930,19 @@ func (s *Server) PullHandler(c *gin.Context) { return } - name := model.ParseName(cmp.Or(req.Model, req.Name)) - if !name.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) + // TEMP(drifkin): we're temporarily allowing to continue pulling cloud model + // stub-files until we integrate cloud models into `/api/tags` (in which case + // this roundabout way of "adding" cloud models won't be needed anymore). So + // right here normalize any `:cloud` models into the legacy-style suffixes + // `:-cloud` and `:cloud` + modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name)) + if err != nil { + writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg) return } + name := modelRef.Name + name, err = getExistingName(name) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -1018,13 +1069,20 @@ func (s *Server) DeleteHandler(c *gin.Context) { return } - n := model.ParseName(cmp.Or(r.Model, r.Name)) - if !n.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) + modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name)) + if err != nil { + switch { + case errors.Is(err, errConflictingModelSource): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + case errors.Is(err, model.ErrUnqualifiedName): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) + default: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } return } - n, err := getExistingName(n) + n, err := getExistingName(modelRef.Name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))}) return @@ -1073,6 +1131,20 @@ func (s *Server) ShowHandler(c *gin.Context) { return } + modelRef, err := parseAndValidateModelRef(req.Model) + if err != nil { + writeModelRefParseError(c, err, http.StatusBadRequest, err.Error()) + return + } + + if modelRef.Source == modelSourceCloud { + req.Model = modelRef.Base + proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable) + return + } + + req.Model = modelRef.Base + resp, err := GetModelInfo(req) if err != nil { var statusErr api.StatusError @@ -1089,6 +1161,11 @@ func (s *Server) ShowHandler(c *gin.Context) { return } + if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)}) + return + } + c.JSON(http.StatusOK, resp) } @@ -1625,18 +1702,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/embeddings", s.EmbeddingsHandler) // Inference (OpenAI compatibility) - r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler) - r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler) - r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) + // TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud + // parents on v1 request families while preserving this explicit :cloud passthrough. + r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) - r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) - r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) + r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler) + r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler) // OpenAI-compatible image generation endpoints - r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler) - r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler) + r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler) + r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler) // Inference (Anthropic compatibility) - r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) + r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler) if rc != nil { // wrap old with new @@ -1995,12 +2074,24 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + modelRef, err := parseAndValidateModelRef(req.Model) + if err != nil { + writeModelRefParseError(c, err, http.StatusBadRequest, "model is required") return } + if modelRef.Source == modelSourceCloud { + req.Model = modelRef.Base + if c.GetBool(legacyCloudAnthropicKey) { + proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable) + return + } + proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) + return + } + + name := modelRef.Name + resolvedName, _, err := s.resolveAlias(name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -2032,6 +2123,11 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + return + } + // expire the runner if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go index b0ee126ea..d6311582c 100644 --- a/server/routes_cloud_test.go +++ b/server/routes_cloud_test.go @@ -1,13 +1,22 @@ package server import ( + "bufio" + "bytes" + "context" "encoding/json" + "errors" + "io" "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/middleware" ) func TestStatusHandler(t *testing.T) { @@ -92,3 +101,982 @@ func TestCloudDisabledBlocksRemoteOperations(t *testing.T) { } }) } + +func TestDeleteHandlerNormalizesExplicitSourceSuffixes(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + s := Server{} + + tests := []string{ + "gpt-oss:20b:local", + "gpt-oss:20b:cloud", + "qwen3:cloud", + } + + for _, modelName := range tests { + t.Run(modelName, func(t *testing.T) { + w := createRequest(t, s.DeleteHandler, api.DeleteRequest{ + Model: modelName, + }) + if w.Code != http.StatusNotFound { + t.Fatalf("expected status 404, got %d (%s)", w.Code, w.Body.String()) + } + + var resp map[string]string + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + want := "model '" + modelName + "' not found" + if resp["error"] != want { + t.Fatalf("unexpected error: got %q, want %q", resp["error"], want) + } + }) + } +} + +func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + type upstreamCapture struct { + path string + body string + header http.Header + } + + newUpstream := func(t *testing.T, responseBody string) (*httptest.Server, *upstreamCapture) { + t.Helper() + capture := &upstreamCapture{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + payload, _ := io.ReadAll(r.Body) + capture.path = r.URL.Path + capture.body = string(payload) + capture.header = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(responseBody)) + })) + + return srv, capture + } + + t.Run("api generate", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"ok":"api"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Test-Header", "api-header") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/generate" { + t.Fatalf("expected upstream path /api/generate, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if got := capture.header.Get("X-Test-Header"); got != "api-header" { + t.Fatalf("expected forwarded X-Test-Header=api-header, got %q", got) + } + }) + + t.Run("api chat", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"message":{"role":"assistant","content":"ok"},"done":true}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hello"}],"stream":false}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/chat", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/chat" { + t.Fatalf("expected upstream path /api/chat, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + }) + + t.Run("api embed", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"model":"kimi-k2.5:cloud","embeddings":[[0.1,0.2]]}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","input":"hello"}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/embed", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/embed" { + t.Fatalf("expected upstream path /api/embed, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + }) + + t.Run("api embeddings", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"embedding":[0.1,0.2]}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello"}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/embeddings", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/embeddings" { + t.Fatalf("expected upstream path /api/embeddings, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + }) + + t.Run("api show", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"details":{"format":"gguf"}}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud"}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/show" { + t.Fatalf("expected upstream path /api/show, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + }) + + t.Run("v1 chat completions bypasses conversion", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"chatcmpl_test","object":"chat.completion"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"gpt-oss:120b:cloud","messages":[{"role":"user","content":"hi"}],"max_tokens":7}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Test-Header", "v1-header") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/chat/completions" { + t.Fatalf("expected upstream path /v1/chat/completions, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"max_tokens":7`) { + t.Fatalf("expected original OpenAI request body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if strings.Contains(capture.body, `"options"`) { + t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) + } + + if got := capture.header.Get("X-Test-Header"); got != "v1-header" { + t.Fatalf("expected forwarded X-Test-Header=v1-header, got %q", got) + } + }) + + t.Run("v1 chat completions bypasses conversion with legacy cloud suffix", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"chatcmpl_test","object":"chat.completion"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"gpt-oss:120b-cloud","messages":[{"role":"user","content":"hi"}],"max_tokens":7}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Test-Header", "v1-legacy-header") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/chat/completions" { + t.Fatalf("expected upstream path /v1/chat/completions, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"max_tokens":7`) { + t.Fatalf("expected original OpenAI request body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if strings.Contains(capture.body, `"options"`) { + t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) + } + + if got := capture.header.Get("X-Test-Header"); got != "v1-legacy-header" { + t.Fatalf("expected forwarded X-Test-Header=v1-legacy-header, got %q", got) + } + }) + + t.Run("v1 messages bypasses conversion", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/messages" { + t.Fatalf("expected upstream path /v1/messages, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"max_tokens":10`) { + t.Fatalf("expected original Anthropic request body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if strings.Contains(capture.body, `"options"`) { + t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) + } + }) + + t.Run("v1 messages bypasses conversion with legacy cloud suffix", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:latest-cloud","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/messages" { + t.Fatalf("expected upstream path /v1/messages, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"max_tokens":10`) { + t.Fatalf("expected original Anthropic request body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"model":"kimi-k2.5:latest"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if strings.Contains(capture.body, `"options"`) { + t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) + } + }) + + t.Run("v1 messages web_search fallback uses legacy cloud /api/chat path", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"hello"},"done":true}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"gpt-oss:120b-cloud", + "max_tokens":10, + "messages":[{"role":"user","content":"search the web"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}], + "stream":false + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/api/chat" { + t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { + t.Fatalf("expected normalized model in upstream body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"num_predict":10`) { + t.Fatalf("expected converted ollama options in upstream body, got %q", capture.body) + } + }) + + t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, local.URL+"/v1/models/kimi-k2.5:cloud", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("X-Test-Header", "v1-model-header") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/models/kimi-k2.5" { + t.Fatalf("expected upstream path /v1/models/kimi-k2.5, got %q", capture.path) + } + + if capture.body != "" { + t.Fatalf("expected empty request body, got %q", capture.body) + } + + if got := capture.header.Get("X-Test-Header"); got != "v1-model-header" { + t.Fatalf("expected forwarded X-Test-Header=v1-model-header, got %q", got) + } + }) + + t.Run("v1 model retrieve normalizes legacy cloud suffix", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:latest","object":"model","created":1,"owned_by":"ollama"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, local.URL+"/v1/models/kimi-k2.5:latest-cloud", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/models/kimi-k2.5:latest" { + t.Fatalf("expected upstream path /v1/models/kimi-k2.5:latest, got %q", capture.path) + } + }) +} + +func TestCloudDisabledBlocksExplicitCloudPassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + t.Setenv("OLLAMA_NO_CLOUD", "1") + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + + local := httptest.NewServer(router) + defer local.Close() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(`{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hi"}]}`)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body)) + } + + var got map[string]string + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("expected json error body, got: %q", string(body)) + } + + if got["error"] != internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable) { + t.Fatalf("unexpected error message: %q", got["error"]) + } +} + +func TestCloudPassthroughStreamsPromptly(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("upstream writer is not a flusher") + } + + _, _ = w.Write([]byte(`{"response":"first"}` + "\n")) + flusher.Flush() + + time.Sleep(700 * time.Millisecond) + + _, _ = w.Write([]byte(`{"response":"second"}` + "\n")) + flusher.Flush() + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hi"}],"stream":true}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/chat", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + reader := bufio.NewReader(resp.Body) + + start := time.Now() + firstLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("failed reading first streamed line: %v", err) + } + if elapsed := time.Since(start); elapsed > 400*time.Millisecond { + t.Fatalf("first streamed line arrived too late (%s), likely not flushing", elapsed) + } + if !strings.Contains(firstLine, `"first"`) { + t.Fatalf("expected first line to contain first chunk, got %q", firstLine) + } + + secondLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("failed reading second streamed line: %v", err) + } + if !strings.Contains(secondLine, `"second"`) { + t.Fatalf("expected second line to contain second chunk, got %q", secondLine) + } +} + +func TestCloudPassthroughSkipsAnthropicWebSearch(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + type upstreamCapture struct { + path string + } + capture := &upstreamCapture{} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capture.path = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`)) + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + router := gin.New() + router.POST( + "/v1/messages", + cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), + middleware.AnthropicMessagesMiddleware(), + func(c *gin.Context) { c.Status(http.StatusTeapot) }, + ) + + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"kimi-k2.5:cloud", + "max_tokens":10, + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusTeapot { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected local middleware path status %d, got %d (%s)", http.StatusTeapot, resp.StatusCode, string(body)) + } + + if capture.path != "" { + t.Fatalf("expected no passthrough for web_search requests, got upstream path %q", capture.path) + } +} + +func TestCloudPassthroughSkipsAnthropicWebSearchLegacySuffix(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + type upstreamCapture struct { + path string + } + capture := &upstreamCapture{} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capture.path = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`)) + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + router := gin.New() + router.POST( + "/v1/messages", + cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), + middleware.AnthropicMessagesMiddleware(), + func(c *gin.Context) { c.Status(http.StatusTeapot) }, + ) + + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"kimi-k2.5:latest-cloud", + "max_tokens":10, + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"web_search_20250305","name":"web_search"}] + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusTeapot { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected local middleware path status %d, got %d (%s)", http.StatusTeapot, resp.StatusCode, string(body)) + } + + if capture.path != "" { + t.Fatalf("expected no passthrough for web_search requests, got upstream path %q", capture.path) + } +} + +func TestCloudPassthroughSigningFailureReturnsUnauthorized(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + origSignRequest := cloudProxySignRequest + origSigninURL := cloudProxySigninURL + cloudProxySignRequest = func(context.Context, *http.Request) error { + return errors.New("ssh: no key found") + } + cloudProxySigninURL = func() (string, error) { + return "https://ollama.com/signin/example", nil + } + t.Cleanup(func() { + cloudProxySignRequest = origSignRequest + cloudProxySigninURL = origSigninURL + }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body)) + } + + var got map[string]any + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("expected json error body, got: %q", string(body)) + } + + if got["error"] != "unauthorized" { + t.Fatalf("unexpected error message: %v", got["error"]) + } + + if got["signin_url"] != "https://ollama.com/signin/example" { + t.Fatalf("unexpected signin_url: %v", got["signin_url"]) + } +} + +func TestCloudPassthroughSigningFailureWithoutSigninURL(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + origSignRequest := cloudProxySignRequest + origSigninURL := cloudProxySigninURL + cloudProxySignRequest = func(context.Context, *http.Request) error { + return errors.New("ssh: no key found") + } + cloudProxySigninURL = func() (string, error) { + return "", errors.New("key missing") + } + t.Cleanup(func() { + cloudProxySignRequest = origSignRequest + cloudProxySigninURL = origSigninURL + }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body)) + } + + var got map[string]any + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("expected json error body, got: %q", string(body)) + } + + if got["error"] != "unauthorized" { + t.Fatalf("unexpected error message: %v", got["error"]) + } + + if _, ok := got["signin_url"]; ok { + t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"]) + } +} diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 0d0ac6dbc..401f98d9d 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -794,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) { fmt.Printf("resp = %#v\n", resp) } +func TestCreateFromCloudSourceSuffix(t *testing.T) { + gin.SetMode(gin.TestMode) + + var s Server + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-cloud-from-suffix", + From: "gpt-oss:20b:cloud", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, got %d", w.Code) + } + + w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, got %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.RemoteHost != "https://ollama.com:443" { + t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost) + } + + if resp.RemoteModel != "gpt-oss:20b" { + t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel) + } +} + func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index a1a5f5424..444c76ed6 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) { checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) } + +func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) { + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + var s Server + + _, digest := createBinFile(t, nil, nil) + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "gpt-oss:20b-cloud", + Files: map[string]string{"test.gguf": digest}, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"), + }) + + w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String()) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) +} diff --git a/x/cmd/run.go b/x/cmd/run.go index e5d7ea25e..e96c8385e 100644 --- a/x/cmd/run.go +++ b/x/cmd/run.go @@ -20,6 +20,7 @@ import ( "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/model" @@ -43,7 +44,7 @@ const ( // isLocalModel checks if the model is running locally (not a cloud model). // TODO: Improve local/cloud model identification - could check model metadata func isLocalModel(modelName string) bool { - return !strings.HasSuffix(modelName, "-cloud") + return !modelref.HasExplicitCloudSource(modelName) } // isLocalServer checks if connecting to a local Ollama server. diff --git a/x/cmd/run_test.go b/x/cmd/run_test.go index a65e8cc80..75429f8ac 100644 --- a/x/cmd/run_test.go +++ b/x/cmd/run_test.go @@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) { }, { name: "cloud model", - modelName: "gpt-4-cloud", + modelName: "gpt-oss:latest-cloud", + expected: false, + }, + { + name: "cloud model with :cloud suffix", + modelName: "gpt-oss:cloud", expected: false, }, { name: "cloud model with version", - modelName: "claude-3-cloud", + modelName: "gpt-oss:20b-cloud", + expected: false, + }, + { + name: "cloud model with version and :cloud suffix", + modelName: "gpt-oss:20b:cloud", expected: false, }, { @@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) { { name: "long output cloud model - uses 10k limit", output: string(localLimitOutput), // 20k chars, under 10k token limit - modelName: "gpt-4-cloud", + modelName: "gpt-oss:latest-cloud", host: "", shouldTrim: false, expectedLimit: defaultTokenLimit, @@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) { { name: "very long output cloud model - trimmed at 10k", output: string(defaultLimitOutput), - modelName: "gpt-4-cloud", + modelName: "gpt-oss:latest-cloud", host: "", shouldTrim: true, expectedLimit: defaultTokenLimit,