mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 17:54:03 +02:00
Reapply "don't require pulling stubs for cloud models" again (#14608)
* Revert "Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)"
This reverts commit 39982a954e.
* fix test + do cloud lookup only when seeing cloud models
---------
Co-authored-by: ParthSareen <parth.sareen@ollama.com>
This commit is contained in:
44
cmd/cmd.go
44
cmd/cmd.go
@@ -41,6 +41,7 @@ import (
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
@@ -418,12 +419,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||
|
||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
} else if info.RemoteHost != "" || requestedCloud {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
@@ -434,10 +437,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
remoteModel := info.RemoteModel
|
||||
if remoteModel == "" {
|
||||
remoteModel = opts.Model
|
||||
}
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,6 +516,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate with TUI signin flow
|
||||
func handleCloudAuthorizationError(err error) bool {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
if authErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -605,12 +626,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if requestedCloud {
|
||||
return nil, err
|
||||
}
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -619,6 +644,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -713,7 +741,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
174
cmd/cmd_test.go
174
cmd/cmd_test.go
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if generateCalled {
|
||||
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
@@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://other-remote.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "explicit -cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:latest-cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "dash cloud-like name without explicit source does not require auth",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
@@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
RemoteModel: tt.remoteModel,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
case "/api/generate":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
Model: tt.model,
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
|
||||
@@ -107,15 +107,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
return aliases, false, nil
|
||||
}
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
@@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
@@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
return true
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
|
||||
@@ -1276,35 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// Cloud suffix normalization must also work since integrations may see either
|
||||
// :cloud or -cloud model names.
|
||||
// value that would be used for maxOutputTokens when the selected model uses
|
||||
// the explicit :cloud source tag.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
l, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
}
|
||||
// Also verify -cloud suffix lookup
|
||||
dashCloudName := name + "-cloud"
|
||||
l3, ok := lookupCloudModelLimit(dashCloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", dashCloudName)
|
||||
}
|
||||
if l3.Output != expected.Output {
|
||||
t.Errorf("-cloud output = %d, want %d", l3.Output, expected.Output)
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -326,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
|
||||
// If the selected model isn't installed, pull it first
|
||||
if !existingModels[selected] {
|
||||
if cloudModels[selected] {
|
||||
// Cloud models only pull a small manifest; no confirmation needed
|
||||
if err := pullModel(ctx, client, selected); err != nil {
|
||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
||||
}
|
||||
} else {
|
||||
if !isCloudModelName(selected) {
|
||||
msg := fmt.Sprintf("Download %s?", selected)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return "", err
|
||||
@@ -526,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
if !existingModels[m] && !isCloudModelName(m) {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
@@ -552,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate pull logic from call sites
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
if isCloudModelName(model) || existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): pull this out to tui package
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
return nil
|
||||
}
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
@@ -569,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): pull this out to tui package
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Cloud models only pull a small manifest; skip the download confirmation
|
||||
// TODO(parthsareen): consolidate with cloud config changes
|
||||
if strings.HasSuffix(model, "cloud") {
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -733,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
aliases["fast"] = model
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
@@ -1022,7 +1012,7 @@ Examples:
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1211,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
@@ -1305,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
// TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
|
||||
@@ -426,8 +426,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
if strings.HasSuffix(item.Name, ":cloud") {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
|
||||
case "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
isCloud := strings.HasSuffix(item.Name, ":cloud")
|
||||
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
|
||||
if isInstalled || isCloud {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
@@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for cloud models
|
||||
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
@@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Error("expected pull to be called for cloud model without confirmation")
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model, got %v", err)
|
||||
}
|
||||
|
||||
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/tags":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"models":[]}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"test-user"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
single := func(title string, items []ModelItem, current string) (string, error) {
|
||||
for _, item := range items {
|
||||
if item.Name == "glm-5:cloud" {
|
||||
return item.Name, nil
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
return nil, fmt.Errorf("multi selector should not be called")
|
||||
}
|
||||
|
||||
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
|
||||
if err != nil {
|
||||
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
|
||||
}
|
||||
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
|
||||
t.Fatalf("unexpected selected models: %v", selected)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected cloud selection to skip pull")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -26,14 +26,13 @@ type cloudModelLimit struct {
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes common cloud suffixes before checking the shared limit map.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
// TODO(parthsareen): migrate to using cloud check instead.
|
||||
for _, suffix := range []string{"-cloud", ":cloud"} {
|
||||
name = strings.TrimSuffix(name, suffix)
|
||||
}
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
@@ -150,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -161,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -175,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
|
||||
@@ -714,16 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7", false, 0, 0},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"kimi-k2.5", false, 0, 0},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2", false, 0, 0},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder:480b", false, 0, 0},
|
||||
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -147,7 +148,13 @@ type signInCheckMsg struct {
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if modelref.HasExplicitCloudSource(name) {
|
||||
return true
|
||||
}
|
||||
if m.availableModels == nil {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
@@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) {
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
|
||||
115
internal/modelref/modelref.go
Normal file
115
internal/modelref/modelref.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelSource uint8
|
||||
|
||||
const (
|
||||
ModelSourceUnspecified ModelSource = iota
|
||||
ModelSourceLocal
|
||||
ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
|
||||
ErrModelRequired = errors.New("model is required")
|
||||
)
|
||||
|
||||
type ParsedRef struct {
|
||||
Original string
|
||||
Base string
|
||||
Source ModelSource
|
||||
}
|
||||
|
||||
func ParseRef(raw string) (ParsedRef, error) {
|
||||
var zero ParsedRef
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return zero, ErrModelRequired
|
||||
}
|
||||
|
||||
base, source, explicit := parseSourceSuffix(raw)
|
||||
if explicit {
|
||||
if _, _, nested := parseSourceSuffix(base); nested {
|
||||
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
|
||||
}
|
||||
}
|
||||
|
||||
return ParsedRef{
|
||||
Original: raw,
|
||||
Base: base,
|
||||
Source: source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func HasExplicitCloudSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceCloud
|
||||
}
|
||||
|
||||
func HasExplicitLocalSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceLocal
|
||||
}
|
||||
|
||||
func StripCloudSourceTag(raw string) (string, bool) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil || parsedRef.Source != ModelSourceCloud {
|
||||
return strings.TrimSpace(raw), false
|
||||
}
|
||||
|
||||
return parsedRef.Base, true
|
||||
}
|
||||
|
||||
func NormalizePullName(raw string) (string, bool, error) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if parsedRef.Source != ModelSourceCloud {
|
||||
return parsedRef.Base, false, nil
|
||||
}
|
||||
|
||||
return toLegacyCloudPullName(parsedRef.Base), true, nil
|
||||
}
|
||||
|
||||
func toLegacyCloudPullName(base string) string {
|
||||
if hasExplicitTag(base) {
|
||||
return base + "-cloud"
|
||||
}
|
||||
|
||||
return base + ":cloud"
|
||||
}
|
||||
|
||||
func hasExplicitTag(name string) bool {
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
lastColon := strings.LastIndex(name, ":")
|
||||
return lastColon > lastSlash
|
||||
}
|
||||
|
||||
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
|
||||
idx := strings.LastIndex(raw, ":")
|
||||
if idx >= 0 {
|
||||
suffixRaw := strings.TrimSpace(raw[idx+1:])
|
||||
suffix := strings.ToLower(suffixRaw)
|
||||
|
||||
switch suffix {
|
||||
case "cloud":
|
||||
return raw[:idx], ModelSourceCloud, true
|
||||
case "local":
|
||||
return raw[:idx], ModelSourceLocal, true
|
||||
}
|
||||
|
||||
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
|
||||
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
|
||||
}
|
||||
}
|
||||
|
||||
return raw, ModelSourceUnspecified, false
|
||||
}
|
||||
268
internal/modelref/modelref_test.go
Normal file
268
internal/modelref/modelref_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantErr error
|
||||
wantCloud bool
|
||||
wantLocal bool
|
||||
wantStripped string
|
||||
wantStripOK bool
|
||||
}{
|
||||
{
|
||||
name: "cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantLocal: true,
|
||||
wantStripped: "qwen3:8b:local",
|
||||
},
|
||||
{
|
||||
name: "no source suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "llama3.2",
|
||||
},
|
||||
{
|
||||
name: "bare cloud name is not explicit cloud",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "my-cloud-model",
|
||||
},
|
||||
{
|
||||
name: "slash in suffix blocks legacy cloud parsing",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "foo:bar-cloud/baz",
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: " ",
|
||||
wantErr: ErrModelRequired,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseRef(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if got.Base != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
|
||||
}
|
||||
|
||||
if got.Source != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
|
||||
}
|
||||
|
||||
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
|
||||
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
|
||||
}
|
||||
|
||||
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
|
||||
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
|
||||
}
|
||||
|
||||
stripped, ok := StripCloudSourceTag(tt.input)
|
||||
if ok != tt.wantStripOK {
|
||||
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
|
||||
}
|
||||
if stripped != tt.wantStripped {
|
||||
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantCloud bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "explicit local strips source",
|
||||
input: "gpt-oss:20b:local",
|
||||
wantName: "gpt-oss:20b",
|
||||
},
|
||||
{
|
||||
name: "explicit cloud with size maps to legacy dash cloud tag",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud with size remains stable",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "explicit cloud without tag maps to cloud tag",
|
||||
input: "qwen3:cloud",
|
||||
wantName: "qwen3:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "host port without tag keeps host port and appends cloud tag",
|
||||
input: "localhost:11434/library/foo:cloud",
|
||||
wantName: "localhost:11434/library/foo:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes fail",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotName, gotCloud, err := NormalizePullName(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotCloud != tt.wantCloud {
|
||||
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSourceSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantExplicit bool
|
||||
}{
|
||||
{
|
||||
name: "explicit cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix on tag",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix does not match model segment",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix blocked when suffix includes slash",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "unknown suffix is not explicit source",
|
||||
input: "gpt-oss:clod",
|
||||
wantBase: "gpt-oss:clod",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase suffix is accepted",
|
||||
input: "gpt-oss:20b:CLOUD",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "no suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
|
||||
if gotBase != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
|
||||
}
|
||||
if gotSource != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
|
||||
}
|
||||
if gotExplicit != tt.wantExplicit {
|
||||
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
|
||||
460
server/cloud_proxy.go
Normal file
460
server/cloud_proxy.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||
defaultCloudProxySigningHost = "ollama.com"
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||
cloudProxySignRequest = signCloudProxyRequest
|
||||
cloudProxySigninURL = signinURL
|
||||
)
|
||||
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"connection": {},
|
||||
"content-length": {},
|
||||
"proxy-connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
cloudProxyBaseURL = baseURL
|
||||
cloudProxySigningHost = signingHost
|
||||
|
||||
if overridden {
|
||||
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
// preserving raw passthrough semantics.
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(model)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||
if c.Request.URL.Path == "/v1/messages" {
|
||||
if hasAnthropicWebSearchTool(body) {
|
||||
c.Set(legacyCloudAnthropicKey, true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(modelName)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
proxyPath := "/v1/models/" + modelRef.Base
|
||||
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||
// stop doing this, we can inline this method.
|
||||
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := baseURL.ResolveReference(&url.URL{
|
||||
Path: path,
|
||||
RawQuery: c.Request.URL.RawQuery,
|
||||
})
|
||||
|
||||
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||
outReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||
slog.Warn("cloud proxy signing failed", "error", err)
|
||||
writeCloudUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||
// we should not enforce a short total timeout for long-lived responses.
|
||||
resp, err := http.DefaultClient.Do(outReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||
c.Error(err) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelJSON, err := json.Marshal(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["model"] = modelJSON
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||
if r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func extractModelField(body []byte) (string, bool) {
|
||||
if len(body) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw, ok := payload["model"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var model string
|
||||
if err := json.Unmarshal(raw, &model); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
return model, model != ""
|
||||
}
|
||||
|
||||
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range payload.Tools {
|
||||
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeCloudUnauthorized(c *gin.Context) {
|
||||
signinURL, err := cloudProxySigninURL()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||
}
|
||||
|
||||
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
challenge := buildCloudSignatureChallenge(req, ts)
|
||||
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||
query := req.URL.Query()
|
||||
query.Set("ts", ts)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||
}
|
||||
|
||||
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||
baseURL = defaultCloudProxyBaseURL
|
||||
signingHost = defaultCloudProxySigningHost
|
||||
|
||||
rawOverride = strings.TrimSpace(rawOverride)
|
||||
if rawOverride == "" {
|
||||
return baseURL, signingHost, false, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawOverride)
|
||||
if err != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||
}
|
||||
if u.User != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||
}
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||
}
|
||||
if u.RawQuery != "" || u.Fragment != "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||
}
|
||||
|
||||
loopback := isLoopbackHost(host)
|
||||
if runMode == gin.ReleaseMode && !loopback {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||
}
|
||||
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||
}
|
||||
|
||||
u.Path = ""
|
||||
u.RawPath = ""
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
|
||||
return u.String(), strings.ToLower(host), true, nil
|
||||
}
|
||||
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
flusher, canFlush := dst.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
if canFlush {
|
||||
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||
// responses don't flush every write and can optimize throughput.
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||
tokens := map[string]struct{}{}
|
||||
for _, raw := range header.Values("Connection") {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(strings.ToLower(token))
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := tokens[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
154
server/cloud_proxy_test.go
Normal file
154
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||
src.Add("X-Trace-Hop", "drop-me")
|
||||
src.Add("X-Alt-Hop", "drop-me-too")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-End-To-End", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyRequestHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Keep-Alive"); got != "" {
|
||||
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "X-Upstream-Hop")
|
||||
src.Add("X-Upstream-Hop", "drop-me")
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Server-Trace", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyResponseHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if overridden {
|
||||
t.Fatal("expected override=false for empty input")
|
||||
}
|
||||
if baseURL != defaultCloudProxyBaseURL {
|
||||
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||
}
|
||||
if signingHost != defaultCloudProxySigningHost {
|
||||
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "localhost" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback override in release mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "https://example.com:8443" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "example.com" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
@@ -110,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
fromRef, err := parseAndValidateModelRef(r.From)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
|
||||
fromName := fromRef.Name
|
||||
remoteHost := r.RemoteHost
|
||||
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||
remoteHost = cloudProxyBaseURL
|
||||
}
|
||||
|
||||
if remoteHost != "" {
|
||||
ru, err := remoteURL(remoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteModel = fromRef.Base
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
|
||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelSource = modelref.ModelSource
|
||||
|
||||
const (
|
||||
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||
errModelRequired = modelref.ErrModelRequired
|
||||
)
|
||||
|
||||
type parsedModelRef struct {
|
||||
// Original is the caller-provided model string before source parsing.
|
||||
// Example: "gpt-oss:20b:cloud".
|
||||
Original string
|
||||
// Base is the model string after source suffix normalization.
|
||||
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||
Base string
|
||||
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||
Name model.Name
|
||||
// Source captures explicit source intent from the original input.
|
||||
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||
Source modelSource
|
||||
}
|
||||
|
||||
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsed, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(parsed.Base)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsed.Original,
|
||||
Base: parsed.Base,
|
||||
Name: name,
|
||||
Source: parsed.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsedRef, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(normalizedName)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsedRef.Original,
|
||||
Base: normalizedName,
|
||||
Name: name,
|
||||
Source: parsedRef.Source,
|
||||
}, nil
|
||||
}
|
||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelSelector(t *testing.T) {
|
||||
t.Run("cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "my-cloud-model" {
|
||||
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "qwen3:8b" {
|
||||
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||
if !errors.Is(err, errConflictingModelSource) {
|
||||
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified source", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "latest" {
|
||||
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "clod" {
|
||||
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("")
|
||||
if !errors.Is(err, errModelRequired) {
|
||||
t.Fatalf("expected errModelRequired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("::cloud")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unqualified") {
|
||||
t.Fatalf("expected unqualified model error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePullModelRef(t *testing.T) {
|
||||
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "qwen3:cloud" {
|
||||
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
150
server/routes.go
150
server/routes.go
@@ -64,6 +64,17 @@ const (
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
)
|
||||
|
||||
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
default:
|
||||
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -196,14 +207,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
||||
// original body (modulo model name normalization) is sent to cloud.
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -237,6 +256,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
@@ -670,6 +694,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -692,7 +728,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
name, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -839,12 +875,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@@ -886,12 +930,19 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
||||
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
||||
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
||||
// right here normalize any `:cloud` models into the legacy-style suffixes
|
||||
// `:<tag>-cloud` and `:cloud`
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -1018,13 +1069,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !n.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
n, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
@@ -1073,6 +1131,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = modelRef.Base
|
||||
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
@@ -1089,6 +1161,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -1625,18 +1702,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1995,12 +2074,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
if c.GetBool(legacyCloudAnthropicKey) {
|
||||
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -2032,6 +2123,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -794,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateFromCloudSourceSuffix(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud-from-suffix",
|
||||
From: "gpt-oss:20b:cloud",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.RemoteHost != "https://ollama.com:443" {
|
||||
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
|
||||
}
|
||||
|
||||
if resp.RemoteModel != "gpt-oss:20b" {
|
||||
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "gpt-oss:20b-cloud",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -43,7 +44,7 @@ const (
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
return !modelref.HasExplicitCloudSource(modelName)
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
|
||||
@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with :cloud suffix",
|
||||
modelName: "gpt-oss:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
modelName: "gpt-oss:20b-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version and :cloud suffix",
|
||||
modelName: "gpt-oss:20b:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
|
||||
Reference in New Issue
Block a user