diff --git a/cmd/cmd.go b/cmd/cmd.go index 010ea9df5..7efd2b20b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "log/slog" "math" "net" "net/http" @@ -43,6 +44,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/internal/modelref" + "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" @@ -516,6 +518,50 @@ func handleCloudAuthorizationError(err error) bool { return false } +// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we +// best-effort pull cloud stub files for any explicit cloud source models. +// Remove this once `/api/tags` is cloud-aware. +func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) { + if !modelref.HasExplicitCloudSource(modelName) { + return + } + + normalizedName, _, err := modelref.NormalizePullName(modelName) + if err != nil { + slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName) + return + } + + listResp, err := client.List(ctx) + if err != nil { + slog.Warn("failed to list models", "error", err) + return + } + + if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) { + return + } + + logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName) + err = client.Pull(ctx, &api.PullRequest{ + Model: normalizedName, + }, func(api.ProgressResponse) error { + return nil + }) + if err != nil { + slog.Warn("failed to pull cloud stub", "model", modelName, "error", err) + } +} + +func hasListedModelName(models []api.ListModelResponse, name string) bool { + for _, m := range models { + if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) { + return true + } + } + return false +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true @@ -636,6 +682,8 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } + ensureCloudStub(cmd.Context(), client, name) + opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed) if err != nil { return err diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index fe21e400a..9326215bd 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -838,6 +838,214 @@ func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) { } } +func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) { + var pulledModel string + var generateCalled bool + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityCompletion}, + RemoteModel: "gpt-oss:20b", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/tags" && r.Method == http.MethodGet: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/pull" && r.Method == http.MethodPost: + var req api.PullRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + pulledModel = req.Model + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + generateCalled = true + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + default: + http.NotFound(w, r) + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + if pulledModel != "gpt-oss:20b-cloud" { + t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel) + } + + if !generateCalled { + t.Fatal("expected /api/generate to be called") + } +} + +func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) { + var pullCalled bool + var generateCalled bool + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityCompletion}, + RemoteModel: "gpt-oss:20b", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/tags" && r.Method == http.MethodGet: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ListResponse{ + Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/pull" && r.Method == http.MethodPost: + pullCalled = true + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + generateCalled = true + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + default: + http.NotFound(w, r) + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + if pullCalled { + t.Fatal("expected /api/pull not to be called when cloud stub already exists") + } + + if !generateCalled { + t.Fatal("expected /api/generate to be called") + } +} + +func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) { + var generateCalled bool + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityCompletion}, + RemoteModel: "gpt-oss:20b", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/tags" && r.Method == http.MethodGet: + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/pull" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusInternalServerError) + if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + generateCalled = true + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + default: + http.NotFound(w, r) + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("truncate", false, "") + cmd.Flags().Int("dimensions", 0, "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + + err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) + if err != nil { + t.Fatalf("RunHandler returned error: %v", err) + } + + if !generateCalled { + t.Fatal("expected /api/generate to be called despite pull failure") + } +} + func TestGetModelfileName(t *testing.T) { tests := []struct { name string