Reapply "don't require pulling stubs for cloud models" again (#14608)

* Revert "Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)"

This reverts commit 39982a954e.

* fix test + do cloud lookup only when seeing cloud models

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
This commit is contained in:
Jeffrey Morgan
2026-03-06 14:27:47 -08:00
committed by GitHub
parent 1af850e6e3
commit 4eab60c1e2
25 changed files with 2862 additions and 146 deletions

View File

@@ -41,6 +41,7 @@ import (
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
@@ -418,12 +419,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
return err
} else if info.RemoteHost != "" {
} else if info.RemoteHost != "" || requestedCloud {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
@@ -434,10 +437,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
if opts.ShowConnect {
p.StopAndClear()
remoteModel := info.RemoteModel
if remoteModel == "" {
remoteModel = opts.Model
}
if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
}
}
@@ -509,6 +516,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
return nil
}
// TODO(parthsareen): consolidate with TUI signin flow
func handleCloudAuthorizationError(err error) bool {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if authErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, authErr.SigninURL)
}
return true
}
return false
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -605,12 +626,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
requestedCloud := modelref.HasExplicitCloudSource(name)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
@@ -619,6 +644,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
@@ -713,7 +741,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
if err := generate(cmd, opts); err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
return nil
}
func SigninHandler(cmd *cobra.Command, args []string) error {

View File

@@ -18,6 +18,7 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/types/model"
)
@@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
}
}
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if generateCalled {
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
@@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
model string
remoteHost string
remoteModel string
whoamiStatus int
whoamiResp any
expectedError string
}{
{
name: "ollama.com cloud model - user signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "ollama.com cloud model - user not signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
@@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
},
{
name: "non-ollama.com remote - no auth check",
model: "test-cloud-model",
remoteHost: "https://other-remote.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
{
name: "explicit :cloud model - auth check without remote metadata",
model: "kimi-k2.5:cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "explicit -cloud model - auth check without remote metadata",
model: "kimi-k2.5:latest-cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "dash cloud-like name without explicit source does not require auth",
model: "test-cloud-model",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
@@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
RemoteModel: tt.remoteModel,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
@@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
case "/api/generate":
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
@@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
Model: tt.model,
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}

View File

@@ -107,15 +107,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
if isCloudModelName(aliases["primary"]) {
aliases["fast"] = aliases["primary"]
return aliases, false, nil
}
delete(aliases, "fast")
return aliases, false, nil
}
items, existingModels, cloudModels, client, err := listModels(ctx)
@@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
if isCloudModelName(aliases["primary"]) {
aliases["fast"] = aliases["primary"]
} else {
delete(aliases, "fast")
}

View File

@@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
if isCloudModelName(name) {
return true
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false

View File

@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}

View File

@@ -1276,35 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// Cloud suffix normalization must also work since integrations may see either
// :cloud or -cloud model names.
// value that would be used for maxOutputTokens when the selected model uses
// the explicit :cloud source tag.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
l, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
// Also verify -cloud suffix lookup
dashCloudName := name + "-cloud"
l3, ok := lookupCloudModelLimit(dashCloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", dashCloudName)
}
if l3.Output != expected.Output {
t.Errorf("-cloud output = %d, want %d", l3.Output, expected.Output)
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
})
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -326,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
// If the selected model isn't installed, pull it first
if !existingModels[selected] {
if cloudModels[selected] {
// Cloud models only pull a small manifest; no confirmation needed
if err := pullModel(ctx, client, selected); err != nil {
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
}
} else {
if !isCloudModelName(selected) {
msg := fmt.Sprintf("Download %s?", selected)
if ok, err := confirmPrompt(msg); err != nil {
return "", err
@@ -526,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
var toPull []string
for _, m := range selected {
if !existingModels[m] {
if !existingModels[m] && !isCloudModelName(m) {
toPull = append(toPull, m)
}
}
@@ -552,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
return selected, nil
}
// TODO(parthsareen): consolidate pull logic from call sites
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
if isCloudModelName(model) || existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return confirmAndPull(ctx, client, model)
}
// TODO(parthsareen): pull this out to tui package
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if isCloudModelName(model) {
return nil
}
return confirmAndPull(ctx, client, model)
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
@@ -569,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
return nil
}
// TODO(parthsareen): pull this out to tui package
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
// Cloud models only pull a small manifest; skip the download confirmation
// TODO(parthsareen): consolidate with cloud config changes
if strings.HasSuffix(model, "cloud") {
return pullModel(ctx, client, model)
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -733,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
if isCloudModelName(model) {
aliases["fast"] = model
} else {
delete(aliases, "fast")
}
@@ -1022,7 +1012,7 @@ Examples:
existingAliases = aliases
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if isCloudModelName(model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
@@ -1211,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
// When user has no models, preserve recommended order.
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
@@ -1305,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool {
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
// TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit
return modelref.HasExplicitCloudSource(name)
}
func filterCloudModels(existing []modelInfo) []modelInfo {

View File

@@ -426,8 +426,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
}
for _, item := range items {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
if strings.HasSuffix(item.Name, ":cloud") {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
} else {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
}
}
}
}
@@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
case "qwen3:8b":
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
}
}
}
@@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
}
for _, item := range items {
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
isCloud := strings.HasSuffix(item.Name, ":cloud")
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
if isInstalled || isCloud {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
} else {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
@@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
}
}
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
// Confirm prompt should NOT be called for cloud models
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
// Confirm prompt should NOT be called for explicit cloud models
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
@@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
if err != nil {
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
}
if !pullCalled {
t.Error("expected pull to be called for cloud model without confirmation")
if pullCalled {
t.Error("expected pull not to be called for cloud model")
}
}
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
// Confirm prompt should NOT be called for explicit cloud models
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
return false, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
if err != nil {
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
}
if pullCalled {
t.Error("expected pull not to be called for cloud model")
}
}
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
return false, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
if err != nil {
t.Fatalf("expected no error for cloud model, got %v", err)
}
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
if err != nil {
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
}
}
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
case "/api/tags":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"models":[]}`)
case "/api/me":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"test-user"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
single := func(title string, items []ModelItem, current string) (string, error) {
for _, item := range items {
if item.Name == "glm-5:cloud" {
return item.Name, nil
}
}
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
return "", nil
}
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return nil, fmt.Errorf("multi selector should not be called")
}
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
if err != nil {
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
}
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
t.Fatalf("unexpected selected models: %v", selected)
}
if pullCalled {
t.Fatal("expected cloud selection to skip pull")
}
}

View File

@@ -12,8 +12,8 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
)
// OpenCode implements Runner and Editor for OpenCode integration
@@ -26,14 +26,13 @@ type cloudModelLimit struct {
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes common cloud suffixes before checking the shared limit map.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
// TODO(parthsareen): migrate to using cloud check instead.
for _, suffix := range []string{"-cloud", ":cloud"} {
name = strings.TrimSuffix(name, suffix)
}
if l, ok := cloudModelLimits[name]; ok {
return l, true
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
@@ -150,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -161,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
@@ -175,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,

View File

@@ -714,16 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) {
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"glm-5:cloud", true, 202_752, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},

View File

@@ -11,6 +11,7 @@ import (
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/version"
)
@@ -147,7 +148,13 @@ type signInCheckMsg struct {
type clearStatusMsg struct{}
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
if name == "" {
return false
}
if modelref.HasExplicitCloudSource(name) {
return true
}
if m.availableModels == nil {
return false
}
if m.availableModels[name] {
@@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) {
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
return modelref.HasExplicitCloudSource(name)
}
func cloudStatusDisabled(client *api.Client) bool {

View File

@@ -0,0 +1,115 @@
package modelref
import (
"errors"
"fmt"
"strings"
)
type ModelSource uint8
const (
ModelSourceUnspecified ModelSource = iota
ModelSourceLocal
ModelSourceCloud
)
var (
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
ErrModelRequired = errors.New("model is required")
)
type ParsedRef struct {
Original string
Base string
Source ModelSource
}
func ParseRef(raw string) (ParsedRef, error) {
var zero ParsedRef
raw = strings.TrimSpace(raw)
if raw == "" {
return zero, ErrModelRequired
}
base, source, explicit := parseSourceSuffix(raw)
if explicit {
if _, _, nested := parseSourceSuffix(base); nested {
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
}
}
return ParsedRef{
Original: raw,
Base: base,
Source: source,
}, nil
}
func HasExplicitCloudSource(raw string) bool {
parsedRef, err := ParseRef(raw)
return err == nil && parsedRef.Source == ModelSourceCloud
}
func HasExplicitLocalSource(raw string) bool {
parsedRef, err := ParseRef(raw)
return err == nil && parsedRef.Source == ModelSourceLocal
}
func StripCloudSourceTag(raw string) (string, bool) {
parsedRef, err := ParseRef(raw)
if err != nil || parsedRef.Source != ModelSourceCloud {
return strings.TrimSpace(raw), false
}
return parsedRef.Base, true
}
func NormalizePullName(raw string) (string, bool, error) {
parsedRef, err := ParseRef(raw)
if err != nil {
return "", false, err
}
if parsedRef.Source != ModelSourceCloud {
return parsedRef.Base, false, nil
}
return toLegacyCloudPullName(parsedRef.Base), true, nil
}
func toLegacyCloudPullName(base string) string {
if hasExplicitTag(base) {
return base + "-cloud"
}
return base + ":cloud"
}
func hasExplicitTag(name string) bool {
lastSlash := strings.LastIndex(name, "/")
lastColon := strings.LastIndex(name, ":")
return lastColon > lastSlash
}
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
idx := strings.LastIndex(raw, ":")
if idx >= 0 {
suffixRaw := strings.TrimSpace(raw[idx+1:])
suffix := strings.ToLower(suffixRaw)
switch suffix {
case "cloud":
return raw[:idx], ModelSourceCloud, true
case "local":
return raw[:idx], ModelSourceLocal, true
}
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
}
}
return raw, ModelSourceUnspecified, false
}

View File

@@ -0,0 +1,268 @@
package modelref
import (
"errors"
"testing"
)
func TestParseRef(t *testing.T) {
tests := []struct {
name string
input string
wantBase string
wantSource ModelSource
wantErr error
wantCloud bool
wantLocal bool
wantStripped string
wantStripOK bool
}{
{
name: "cloud suffix",
input: "gpt-oss:20b:cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantCloud: true,
wantStripped: "gpt-oss:20b",
wantStripOK: true,
},
{
name: "legacy cloud suffix",
input: "gpt-oss:20b-cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantCloud: true,
wantStripped: "gpt-oss:20b",
wantStripOK: true,
},
{
name: "local suffix",
input: "qwen3:8b:local",
wantBase: "qwen3:8b",
wantSource: ModelSourceLocal,
wantLocal: true,
wantStripped: "qwen3:8b:local",
},
{
name: "no source suffix",
input: "llama3.2",
wantBase: "llama3.2",
wantSource: ModelSourceUnspecified,
wantStripped: "llama3.2",
},
{
name: "bare cloud name is not explicit cloud",
input: "my-cloud-model",
wantBase: "my-cloud-model",
wantSource: ModelSourceUnspecified,
wantStripped: "my-cloud-model",
},
{
name: "slash in suffix blocks legacy cloud parsing",
input: "foo:bar-cloud/baz",
wantBase: "foo:bar-cloud/baz",
wantSource: ModelSourceUnspecified,
wantStripped: "foo:bar-cloud/baz",
},
{
name: "conflicting source suffixes",
input: "foo:cloud:local",
wantErr: ErrConflictingSourceSuffix,
wantSource: ModelSourceUnspecified,
},
{
name: "empty input",
input: " ",
wantErr: ErrModelRequired,
wantSource: ModelSourceUnspecified,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseRef(tt.input)
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
}
if got.Base != tt.wantBase {
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
}
if got.Source != tt.wantSource {
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
}
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
}
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
}
stripped, ok := StripCloudSourceTag(tt.input)
if ok != tt.wantStripOK {
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
}
if stripped != tt.wantStripped {
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
}
})
}
}
func TestNormalizePullName(t *testing.T) {
tests := []struct {
name string
input string
wantName string
wantCloud bool
wantErr error
}{
{
name: "explicit local strips source",
input: "gpt-oss:20b:local",
wantName: "gpt-oss:20b",
},
{
name: "explicit cloud with size maps to legacy dash cloud tag",
input: "gpt-oss:20b:cloud",
wantName: "gpt-oss:20b-cloud",
wantCloud: true,
},
{
name: "legacy cloud with size remains stable",
input: "gpt-oss:20b-cloud",
wantName: "gpt-oss:20b-cloud",
wantCloud: true,
},
{
name: "explicit cloud without tag maps to cloud tag",
input: "qwen3:cloud",
wantName: "qwen3:cloud",
wantCloud: true,
},
{
name: "host port without tag keeps host port and appends cloud tag",
input: "localhost:11434/library/foo:cloud",
wantName: "localhost:11434/library/foo:cloud",
wantCloud: true,
},
{
name: "conflicting source suffixes fail",
input: "foo:cloud:local",
wantErr: ErrConflictingSourceSuffix,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotName, gotCloud, err := NormalizePullName(tt.input)
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
}
if gotName != tt.wantName {
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
}
if gotCloud != tt.wantCloud {
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
}
})
}
}
func TestParseSourceSuffix(t *testing.T) {
tests := []struct {
name string
input string
wantBase string
wantSource ModelSource
wantExplicit bool
}{
{
name: "explicit cloud suffix",
input: "gpt-oss:20b:cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "explicit local suffix",
input: "qwen3:8b:local",
wantBase: "qwen3:8b",
wantSource: ModelSourceLocal,
wantExplicit: true,
},
{
name: "legacy cloud suffix on tag",
input: "gpt-oss:20b-cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "legacy cloud suffix does not match model segment",
input: "my-cloud-model",
wantBase: "my-cloud-model",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "legacy cloud suffix blocked when suffix includes slash",
input: "foo:bar-cloud/baz",
wantBase: "foo:bar-cloud/baz",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "unknown suffix is not explicit source",
input: "gpt-oss:clod",
wantBase: "gpt-oss:clod",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "uppercase suffix is accepted",
input: "gpt-oss:20b:CLOUD",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "no suffix",
input: "llama3.2",
wantBase: "llama3.2",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
if gotBase != tt.wantBase {
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
}
if gotSource != tt.wantSource {
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
}
if gotExplicit != tt.wantExplicit {
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
}
})
}
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/logutil"
)
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
return modelref.HasExplicitCloudSource(name)
}
// extractQueryFromToolCall extracts the search query from a web_search tool call

460
server/cloud_proxy.go Normal file
View File

@@ -0,0 +1,460 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
const (
defaultCloudProxyBaseURL = "https://ollama.com:443"
defaultCloudProxySigningHost = "ollama.com"
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
)
var (
cloudProxyBaseURL = defaultCloudProxyBaseURL
cloudProxySigningHost = defaultCloudProxySigningHost
cloudProxySignRequest = signCloudProxyRequest
cloudProxySigninURL = signinURL
)
var hopByHopHeaders = map[string]struct{}{
"connection": {},
"content-length": {},
"proxy-connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
}
func init() {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
if err != nil {
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
return
}
cloudProxyBaseURL = baseURL
cloudProxySigningHost = signingHost
if overridden {
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
}
}
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method != http.MethodPost {
c.Next()
return
}
// TODO(drifkin): Avoid full-body buffering here for model detection.
// A future optimization can parse just enough JSON to read "model" (and
// optionally short-circuit cloud-disabled explicit-cloud requests) while
// preserving raw passthrough semantics.
body, err := readRequestBody(c.Request)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
model, ok := extractModelField(body)
if !ok {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(model)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
if c.Request.URL.Path == "/v1/messages" {
if hasAnthropicWebSearchTool(body) {
c.Set(legacyCloudAnthropicKey, true)
c.Next()
return
}
}
proxyCloudRequest(c, normalizedBody, disabledOperation)
c.Abort()
}
}
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(modelName)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
proxyPath := "/v1/models/" + modelRef.Base
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
c.Abort()
}
}
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
// TEMP(drifkin): we currently split out this `WithPath` method because we are
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
// stop doing this, we can inline this method.
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
}
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
body, err := json.Marshal(payload)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
proxyCloudRequestWithPath(c, body, path, disabledOperation)
}
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
}
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
return
}
baseURL, err := url.Parse(cloudProxyBaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetURL := baseURL.ResolveReference(&url.URL{
Path: path,
RawQuery: c.Request.URL.RawQuery,
})
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
outReq.Header.Set("Content-Type", "application/json")
}
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
slog.Warn("cloud proxy signing failed", "error", err)
writeCloudUnauthorized(c)
return
}
// TODO(drifkin): Add phase-specific proxy timeouts.
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
// we should not enforce a short total timeout for long-lived responses.
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
c.Status(resp.StatusCode)
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
c.Error(err) //nolint:errcheck
}
}
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modelJSON, err := json.Marshal(model)
if err != nil {
return nil, err
}
payload["model"] = modelJSON
return json.Marshal(payload)
}
func readRequestBody(r *http.Request) ([]byte, error) {
if r.Body == nil {
return nil, nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body = io.NopCloser(bytes.NewReader(body))
return body, nil
}
func extractModelField(body []byte) (string, bool) {
if len(body) == 0 {
return "", false
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return "", false
}
raw, ok := payload["model"]
if !ok {
return "", false
}
var model string
if err := json.Unmarshal(raw, &model); err != nil {
return "", false
}
model = strings.TrimSpace(model)
return model, model != ""
}
func hasAnthropicWebSearchTool(body []byte) bool {
if len(body) == 0 {
return false
}
var payload struct {
Tools []struct {
Type string `json:"type"`
} `json:"tools"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return false
}
for _, tool := range payload.Tools {
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
return true
}
}
return false
}
func writeCloudUnauthorized(c *gin.Context) {
signinURL, err := cloudProxySigninURL()
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
}
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
return nil
}
ts := strconv.FormatInt(time.Now().Unix(), 10)
challenge := buildCloudSignatureChallenge(req, ts)
signature, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
query := req.URL.Query()
query.Set("ts", ts)
req.URL.RawQuery = query.Encode()
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
}
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
baseURL = defaultCloudProxyBaseURL
signingHost = defaultCloudProxySigningHost
rawOverride = strings.TrimSpace(rawOverride)
if rawOverride == "" {
return baseURL, signingHost, false, nil
}
u, err := url.Parse(rawOverride)
if err != nil {
return "", "", false, fmt.Errorf("invalid URL: %w", err)
}
if u.Scheme == "" || u.Host == "" {
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
}
if u.User != nil {
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
}
if u.Path != "" && u.Path != "/" {
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
}
if u.RawQuery != "" || u.Fragment != "" {
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
}
host := u.Hostname()
if host == "" {
return "", "", false, fmt.Errorf("invalid URL: host is required")
}
loopback := isLoopbackHost(host)
if runMode == gin.ReleaseMode && !loopback {
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
}
if !loopback && !strings.EqualFold(u.Scheme, "https") {
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
}
u.Path = ""
u.RawPath = ""
u.RawQuery = ""
u.Fragment = ""
return u.String(), strings.ToLower(host), true, nil
}
func isLoopbackHost(host string) bool {
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func copyProxyRequestHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
flusher, canFlush := dst.(http.Flusher)
buf := make([]byte, 32*1024)
for {
n, err := src.Read(buf)
if n > 0 {
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
return writeErr
}
if canFlush {
// TODO(drifkin): Consider conditional flushing so non-streaming
// responses don't flush every write and can optimize throughput.
flusher.Flush()
}
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func isHopByHopHeader(name string) bool {
_, ok := hopByHopHeaders[strings.ToLower(name)]
return ok
}
func connectionHeaderTokens(header http.Header) map[string]struct{} {
tokens := map[string]struct{}{}
for _, raw := range header.Values("Connection") {
for _, token := range strings.Split(raw, ",") {
token = strings.TrimSpace(strings.ToLower(token))
if token == "" {
continue
}
tokens[token] = struct{}{}
}
}
return tokens
}
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
if len(tokens) == 0 {
return false
}
_, ok := tokens[strings.ToLower(name)]
return ok
}

154
server/cloud_proxy_test.go Normal file
View File

@@ -0,0 +1,154 @@
package server
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
)
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
src.Add("X-Trace-Hop", "drop-me")
src.Add("X-Alt-Hop", "drop-me-too")
src.Add("Keep-Alive", "timeout=5")
src.Add("X-End-To-End", "keep-me")
dst := http.Header{}
copyProxyRequestHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("Keep-Alive"); got != "" {
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
}
if got := dst.Get("X-Trace-Hop"); got != "" {
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-Alt-Hop"); got != "" {
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-End-To-End"); got != "keep-me" {
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
}
}
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "X-Upstream-Hop")
src.Add("X-Upstream-Hop", "drop-me")
src.Add("Content-Type", "application/json")
src.Add("X-Server-Trace", "keep-me")
dst := http.Header{}
copyProxyResponseHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("X-Upstream-Hop"); got != "" {
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("Content-Type"); got != "application/json" {
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
}
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
}
}
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if overridden {
t.Fatal("expected override=false for empty input")
}
if baseURL != defaultCloudProxyBaseURL {
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
}
if signingHost != defaultCloudProxySigningHost {
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "http://localhost:8080" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "localhost" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
if err == nil {
t.Fatal("expected error for non-loopback override in release mode")
}
}
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "https://example.com:8443" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "example.com" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
if err == nil {
t.Fatal("expected error for non-loopback http override in dev mode")
}
}
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}

View File

@@ -110,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.From != "" {
slog.Debug("create model from model name", "from", r.From)
fromName := model.ParseName(r.From)
if !fromName.IsValid() {
fromRef, err := parseAndValidateModelRef(r.From)
if err != nil {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return
}
if r.RemoteHost != "" {
ru, err := remoteURL(r.RemoteHost)
fromName := fromRef.Name
remoteHost := r.RemoteHost
if fromRef.Source == modelSourceCloud && remoteHost == "" {
remoteHost = cloudProxyBaseURL
}
if remoteHost != "" {
ru, err := remoteURL(remoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
config.RemoteModel = r.From
config.RemoteModel = fromRef.Base
config.RemoteHost = ru
remote = true
} else {

81
server/model_resolver.go Normal file
View File

@@ -0,0 +1,81 @@
package server
import (
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/types/model"
)
type modelSource = modelref.ModelSource
const (
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
modelSourceLocal modelSource = modelref.ModelSourceLocal
modelSourceCloud modelSource = modelref.ModelSourceCloud
)
var (
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
errModelRequired = modelref.ErrModelRequired
)
type parsedModelRef struct {
// Original is the caller-provided model string before source parsing.
// Example: "gpt-oss:20b:cloud".
Original string
// Base is the model string after source suffix normalization.
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
Base string
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
// Example: "registry.ollama.ai/library/gpt-oss:20b".
Name model.Name
// Source captures explicit source intent from the original input.
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
Source modelSource
}
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsed, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
name := model.ParseName(parsed.Base)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsed.Original,
Base: parsed.Base,
Name: name,
Source: parsed.Source,
}, nil
}
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsedRef, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
normalizedName, _, err := modelref.NormalizePullName(raw)
if err != nil {
return zero, err
}
name := model.ParseName(normalizedName)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsedRef.Original,
Base: normalizedName,
Name: name,
Source: parsedRef.Source,
}, nil
}

View File

@@ -0,0 +1,170 @@
package server
import (
"errors"
"strings"
"testing"
)
func TestParseModelSelector(t *testing.T) {
t.Run("cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("legacy cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
got, err := parseAndValidateModelRef("my-cloud-model")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Base != "my-cloud-model" {
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
}
})
t.Run("local suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("qwen3:8b:local")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "qwen3:8b" {
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
}
})
t.Run("conflicting source suffixes fail", func(t *testing.T) {
_, err := parseAndValidateModelRef("foo:cloud:local")
if !errors.Is(err, errConflictingModelSource) {
t.Fatalf("expected errConflictingModelSource, got %v", err)
}
})
t.Run("unspecified source", func(t *testing.T) {
got, err := parseAndValidateModelRef("llama3")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "latest" {
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
}
})
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:clod")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "clod" {
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
}
})
t.Run("empty model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("")
if !errors.Is(err, errModelRequired) {
t.Fatalf("expected errModelRequired, got %v", err)
}
})
t.Run("invalid model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("::cloud")
if err == nil {
t.Fatal("expected error for invalid model")
}
if !strings.Contains(err.Error(), "unqualified") {
t.Fatalf("expected unqualified model error, got %v", err)
}
})
}
func TestParsePullModelRef(t *testing.T) {
t.Run("explicit local is normalized", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "gpt-oss:20b-cloud" {
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
got, err := parseNormalizePullModelRef("qwen3:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "qwen3:cloud" {
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
}

View File

@@ -64,6 +64,17 @@ const (
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
)
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
switch {
case errors.Is(err, errConflictingModelSource):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
default:
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
}
}
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
@@ -196,14 +207,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
// original body (modulo model name normalization) is sent to cloud.
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -237,6 +256,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
@@ -670,6 +694,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
var input []string
switch i := req.Input.(type) {
@@ -692,7 +728,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
name, err := getExistingName(model.ParseName(req.Model))
name, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -839,12 +875,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
@@ -886,12 +930,19 @@ func (s *Server) PullHandler(c *gin.Context) {
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
// stub-files until we integrate cloud models into `/api/tags` (in which case
// this roundabout way of "adding" cloud models won't be needed anymore). So
// right here normalize any `:cloud` models into the legacy-style suffixes
// `:<tag>-cloud` and `:cloud`
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
return
}
name := modelRef.Name
name, err = getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1018,13 +1069,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
if err != nil {
switch {
case errors.Is(err, errConflictingModelSource):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
return
}
n, err := getExistingName(n)
n, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
@@ -1073,6 +1131,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
return
}
req.Model = modelRef.Base
resp, err := GetModelInfo(req)
if err != nil {
var statusErr api.StatusError
@@ -1089,6 +1161,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
return
}
c.JSON(http.StatusOK, resp)
}
@@ -1625,18 +1702,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
// parents on v1 request families while preserving this explicit :cloud passthrough.
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -1995,12 +2074,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
if c.GetBool(legacyCloudAnthropicKey) {
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
return
}
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -2032,6 +2123,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)

File diff suppressed because it is too large Load Diff

View File

@@ -794,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
fmt.Printf("resp = %#v\n", resp)
}
func TestCreateFromCloudSourceSuffix(t *testing.T) {
gin.SetMode(gin.TestMode)
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud-from-suffix",
From: "gpt-oss:20b:cloud",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, got %d", w.Code)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, got %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.RemoteHost != "https://ollama.com:443" {
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
}
if resp.RemoteModel != "gpt-oss:20b" {
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
}
}
func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, nil, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "gpt-oss:20b-cloud",
Files: map[string]string{"test.gguf": digest},
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/model"
@@ -43,7 +44,7 @@ const (
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
return !modelref.HasExplicitCloudSource(modelName)
}
// isLocalServer checks if connecting to a local Ollama server.

View File

@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
expected: false,
},
{
name: "cloud model with :cloud suffix",
modelName: "gpt-oss:cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
modelName: "gpt-oss:20b-cloud",
expected: false,
},
{
name: "cloud model with version and :cloud suffix",
modelName: "gpt-oss:20b:cloud",
expected: false,
},
{
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,