diff --git a/cmd/cmd.go b/cmd/cmd.go index 2a1de7f5a..010ea9df5 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/cmd/launch" "github.com/ollama/ollama/cmd/tui" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" @@ -58,36 +59,36 @@ import ( func init() { // Override default selectors to use Bubbletea TUI instead of raw terminal I/O. - config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) { + launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) { tuiItems := tui.ReorderItems(tui.ConvertItems(items)) result, err := tui.SelectSingle(title, tuiItems, current) if errors.Is(err, tui.ErrCancelled) { - return "", config.ErrCancelled + return "", launch.ErrCancelled } return result, err } - config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) { + launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) { tuiItems := tui.ReorderItems(tui.ConvertItems(items)) result, err := tui.SelectMultiple(title, tuiItems, preChecked) if errors.Is(err, tui.ErrCancelled) { - return nil, config.ErrCancelled + return nil, launch.ErrCancelled } return result, err } - config.DefaultSignIn = func(modelName, signInURL string) (string, error) { + launch.DefaultSignIn = func(modelName, signInURL string) (string, error) { userName, err := tui.RunSignIn(modelName, signInURL) if errors.Is(err, tui.ErrCancelled) { - return "", config.ErrCancelled + return "", launch.ErrCancelled } return userName, err } - config.DefaultConfirmPrompt = func(prompt string) (bool, error) { + launch.DefaultConfirmPrompt = func(prompt string) (bool, error) { ok, err := tui.RunConfirm(prompt) if errors.Is(err, tui.ErrCancelled) { - return false, config.ErrCancelled + return false, launch.ErrCancelled } return ok, err } @@ -1912,6 +1913,24 @@ func ensureServerRunning(ctx context.Context) error { } } +func launchInteractiveModel(cmd *cobra.Command, modelName string) error { + opts := runOptions{ + Model: modelName, + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]any{}, + ShowConnect: true, + } + // loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload + // and only validate auth/connectivity before interactive chat starts. + if err := loadOrUnloadModel(cmd, &opts); err != nil { + return fmt.Errorf("error loading model: %w", err) + } + if err := generateInteractive(cmd, opts); err != nil { + return fmt.Errorf("error running model: %w", err) + } + return nil +} + // runInteractiveTUI runs the main interactive TUI menu. func runInteractiveTUI(cmd *cobra.Command) { // Ensure the server is running before showing the TUI @@ -1920,175 +1939,81 @@ func runInteractiveTUI(cmd *cobra.Command) { return } - // Selector adapters for tui - singleSelector := func(title string, items []config.ModelItem, current string) (string, error) { - tuiItems := tui.ReorderItems(tui.ConvertItems(items)) - result, err := tui.SelectSingle(title, tuiItems, current) - if errors.Is(err, tui.ErrCancelled) { - return "", config.ErrCancelled - } - return result, err - } - - multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) { - tuiItems := tui.ReorderItems(tui.ConvertItems(items)) - result, err := tui.SelectMultiple(title, tuiItems, preChecked) - if errors.Is(err, tui.ErrCancelled) { - return nil, config.ErrCancelled - } - return result, err + deps := launcherDeps{ + buildState: launch.BuildLauncherState, + runMenu: tui.RunMenu, + resolveRunModel: launch.ResolveRunModel, + launchIntegration: launch.LaunchIntegration, + runModel: launchInteractiveModel, } for { - result, err := tui.Run() + continueLoop, err := runInteractiveTUIStep(cmd, deps) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } + if !continueLoop { return } + } +} - runModel := func(modelName string) { - client, err := api.ClientFromEnvironment() - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - return - } - if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil { - if errors.Is(err, config.ErrCancelled) { - return - } - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - return - } - _ = config.SetLastModel(modelName) - opts := runOptions{ - Model: modelName, - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]any{}, - ShowConnect: true, - } - if err := loadOrUnloadModel(cmd, &opts); err != nil { - fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err) - return - } - if err := generateInteractive(cmd, opts); err != nil { - fmt.Fprintf(os.Stderr, "Error running model: %v\n", err) - } - } +type launcherDeps struct { + buildState func(context.Context) (*launch.LauncherState, error) + runMenu func(*launch.LauncherState) (tui.TUIAction, error) + resolveRunModel func(context.Context, launch.RunModelRequest) (string, error) + launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error + runModel func(*cobra.Command, string) error +} - launchIntegration := func(name string) bool { - if err := config.EnsureInstalled(name); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - return true - } - // If not configured or model no longer exists, prompt for model selection - configuredModel := config.IntegrationModel(name) - if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) { - err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector) - if errors.Is(err, config.ErrCancelled) { - return false // Return to main menu - } - if err != nil { - fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err) - return true - } - } - if err := config.LaunchIntegration(name); err != nil { - fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err) - } - return true - } +func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) { + state, err := deps.buildState(cmd.Context()) + if err != nil { + return false, fmt.Errorf("build launcher state: %w", err) + } - switch result.Selection { - case tui.SelectionNone: - // User quit - return - case tui.SelectionRunModel: - _ = config.SetLastSelection("run") - if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) { - runModel(modelName) - } else { - modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector) - if errors.Is(err, config.ErrCancelled) { - continue // Return to main menu - } - if err != nil { - fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err) - continue - } - runModel(modelName) - } - case tui.SelectionChangeRunModel: - _ = config.SetLastSelection("run") - // Use model from modal if selected, otherwise show picker - modelName := result.Model - if modelName == "" { - var err error - modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector) - if errors.Is(err, config.ErrCancelled) { - continue // Return to main menu - } - if err != nil { - fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err) - continue - } - } - if config.IsCloudModelDisabled(cmd.Context(), modelName) { - continue // Return to main menu - } - runModel(modelName) - case tui.SelectionIntegration: - _ = config.SetLastSelection(result.Integration) - if !launchIntegration(result.Integration) { - continue // Return to main menu - } - case tui.SelectionChangeIntegration: - _ = config.SetLastSelection(result.Integration) - if len(result.Models) > 0 { - // Filter out cloud-disabled models - var filtered []string - for _, m := range result.Models { - if !config.IsCloudModelDisabled(cmd.Context(), m) { - filtered = append(filtered, m) - } - } - if len(filtered) == 0 { - continue - } - result.Models = filtered - // Multi-select from modal (Editor integrations) - if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil { - fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err) - continue - } - if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil { - fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err) - } - } else if result.Model != "" { - if config.IsCloudModelDisabled(cmd.Context(), result.Model) { - continue - } - // Single-select from modal - save and launch - if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil { - fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err) - continue - } - if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil { - fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err) - } - } else { - err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector) - if errors.Is(err, config.ErrCancelled) { - continue // Return to main menu - } - if err != nil { - fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err) - continue - } - if err := config.LaunchIntegration(result.Integration); err != nil { - fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err) - } - } + action, err := deps.runMenu(state) + if err != nil { + return false, fmt.Errorf("run launcher menu: %w", err) + } + + return runLauncherAction(cmd, action, deps) +} + +func saveLauncherSelection(action tui.TUIAction) { + // Best effort only: this affects menu recall, not launch correctness. + _ = config.SetLastSelection(action.LastSelection()) +} + +func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) { + switch action.Kind { + case tui.TUIActionNone: + return false, nil + case tui.TUIActionRunModel: + saveLauncherSelection(action) + modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest()) + if errors.Is(err, launch.ErrCancelled) { + return true, nil } + if err != nil { + return true, fmt.Errorf("selecting model: %w", err) + } + if err := deps.runModel(cmd, modelName); err != nil { + return true, err + } + return true, nil + case tui.TUIActionLaunchIntegration: + saveLauncherSelection(action) + err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest()) + if errors.Is(err, launch.ErrCancelled) { + return true, nil + } + if err != nil { + return true, fmt.Errorf("launching %s: %w", action.Integration, err) + } + return true, nil + default: + return false, fmt.Errorf("unknown launcher action: %d", action.Kind) } } @@ -2358,7 +2283,7 @@ func NewCLI() *cobra.Command { copyCmd, deleteCmd, runnerCmd, - config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI), + launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI), ) return rootCmd diff --git a/cmd/cmd_launcher_test.go b/cmd/cmd_launcher_test.go new file mode 100644 index 000000000..f9ff2739b --- /dev/null +++ b/cmd/cmd_launcher_test.go @@ -0,0 +1,233 @@ +package cmd + +import ( + "context" + "testing" + + "github.com/spf13/cobra" + + "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/cmd/launch" + "github.com/ollama/ollama/cmd/tui" +) + +func setCmdTestHome(t *testing.T, dir string) { + t.Helper() + t.Setenv("HOME", dir) + t.Setenv("USERPROFILE", dir) +} + +func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) { + t.Helper() + return func(ctx context.Context, req launch.RunModelRequest) (string, error) { + t.Fatalf("did not expect run-model resolution: %+v", req) + return "", nil + } +} + +func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error { + t.Helper() + return func(ctx context.Context, req launch.IntegrationLaunchRequest) error { + t.Fatalf("did not expect integration launch: %+v", req) + return nil + } +} + +func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error { + t.Helper() + return func(cmd *cobra.Command, model string) error { + t.Fatalf("did not expect chat launch: %s", model) + return nil + } +} + +func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) { + tests := []struct { + name string + action tui.TUIAction + wantForce bool + wantModel string + }{ + { + name: "enter uses saved model flow", + action: tui.TUIAction{Kind: tui.TUIActionRunModel}, + wantModel: "qwen3:8b", + }, + { + name: "right forces picker", + action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true}, + wantForce: true, + wantModel: "glm-5:cloud", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setCmdTestHome(t, t.TempDir()) + + var menuCalls int + runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) { + menuCalls++ + if menuCalls == 1 { + return tt.action, nil + } + return tui.TUIAction{Kind: tui.TUIActionNone}, nil + } + + var gotReq launch.RunModelRequest + var launched string + deps := launcherDeps{ + buildState: func(ctx context.Context) (*launch.LauncherState, error) { + return &launch.LauncherState{}, nil + }, + runMenu: runMenu, + resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) { + gotReq = req + return tt.wantModel, nil + }, + launchIntegration: unexpectedIntegrationLaunch(t), + runModel: func(cmd *cobra.Command, model string) error { + launched = model + return nil + }, + } + + cmd := &cobra.Command{} + cmd.SetContext(context.Background()) + for { + continueLoop, err := runInteractiveTUIStep(cmd, deps) + if err != nil { + t.Fatalf("unexpected step error: %v", err) + } + if !continueLoop { + break + } + } + + if gotReq.ForcePicker != tt.wantForce { + t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker) + } + if launched != tt.wantModel { + t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched) + } + if got := config.LastSelection(); got != "run" { + t.Fatalf("expected last selection to be run, got %q", got) + } + }) + } +} + +func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) { + tests := []struct { + name string + action tui.TUIAction + wantForce bool + }{ + { + name: "enter launches integration", + action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, + }, + { + name: "right forces configure", + action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}, + wantForce: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setCmdTestHome(t, t.TempDir()) + + var menuCalls int + runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) { + menuCalls++ + if menuCalls == 1 { + return tt.action, nil + } + return tui.TUIAction{Kind: tui.TUIActionNone}, nil + } + + var gotReq launch.IntegrationLaunchRequest + deps := launcherDeps{ + buildState: func(ctx context.Context) (*launch.LauncherState, error) { + return &launch.LauncherState{}, nil + }, + runMenu: runMenu, + resolveRunModel: unexpectedRunModelResolution(t), + launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error { + gotReq = req + return nil + }, + runModel: unexpectedModelLaunch(t), + } + + cmd := &cobra.Command{} + cmd.SetContext(context.Background()) + for { + continueLoop, err := runInteractiveTUIStep(cmd, deps) + if err != nil { + t.Fatalf("unexpected step error: %v", err) + } + if !continueLoop { + break + } + } + + if gotReq.Name != "claude" { + t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name) + } + if gotReq.ForceConfigure != tt.wantForce { + t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure) + } + if got := config.LastSelection(); got != "claude" { + t.Fatalf("expected last selection to be claude, got %q", got) + } + }) + } +} + +func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) { + setCmdTestHome(t, t.TempDir()) + + cmd := &cobra.Command{} + cmd.SetContext(context.Background()) + + continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{ + buildState: nil, + runMenu: nil, + resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) { + return "", launch.ErrCancelled + }, + launchIntegration: unexpectedIntegrationLaunch(t), + runModel: unexpectedModelLaunch(t), + }) + if err != nil { + t.Fatalf("expected nil error on cancellation, got %v", err) + } + if !continueLoop { + t.Fatal("expected cancellation to continue the menu loop") + } +} + +func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) { + setCmdTestHome(t, t.TempDir()) + + cmd := &cobra.Command{} + cmd.SetContext(context.Background()) + + continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{ + buildState: nil, + runMenu: nil, + resolveRunModel: unexpectedRunModelResolution(t), + launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error { + return launch.ErrCancelled + }, + runModel: unexpectedModelLaunch(t), + }) + if err != nil { + t.Fatalf("expected nil error on cancellation, got %v", err) + } + if !continueLoop { + t.Fatal("expected cancellation to continue the menu loop") + } +} diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 49852c02d..ea4cf5983 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -18,7 +18,6 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/types/model" ) @@ -1797,13 +1796,16 @@ func TestRunOptions_Copy_Independence(t *testing.T) { func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { tests := []struct { - name string - model string - remoteHost string - remoteModel string - whoamiStatus int - whoamiResp any - expectedError string + name string + model string + showStatus int + remoteHost string + remoteModel string + whoamiStatus int + whoamiResp any + expectWhoami bool + expectedError string + expectAuthError bool }{ { name: "ollama.com cloud model - user signed in", @@ -1812,6 +1814,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { remoteModel: "test-model", whoamiStatus: http.StatusOK, whoamiResp: api.UserResponse{Name: "testuser"}, + expectWhoami: true, }, { name: "ollama.com cloud model - user not signed in", @@ -1823,7 +1826,9 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { "error": "unauthorized", "signin_url": "https://ollama.com/signin", }, - expectedError: "unauthorized", + expectWhoami: true, + expectedError: "unauthorized", + expectAuthError: true, }, { name: "non-ollama.com remote - no auth check", @@ -1840,6 +1845,17 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { remoteModel: "", whoamiStatus: http.StatusOK, whoamiResp: api.UserResponse{Name: "testuser"}, + expectWhoami: true, + }, + { + name: "explicit :cloud model without local stub returns not found by default", + model: "minimax-m2.5:cloud", + showStatus: http.StatusNotFound, + whoamiStatus: http.StatusOK, + whoamiResp: api.UserResponse{Name: "testuser"}, + expectedError: "not found", + expectWhoami: false, + expectAuthError: false, }, { name: "explicit -cloud model - auth check without remote metadata", @@ -1848,6 +1864,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { remoteModel: "", whoamiStatus: http.StatusOK, whoamiResp: api.UserResponse{Name: "testuser"}, + expectWhoami: true, }, { name: "dash cloud-like name without explicit source does not require auth", @@ -1865,6 +1882,11 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/api/show": + if tt.showStatus != 0 && tt.showStatus != http.StatusOK { + w.WriteHeader(tt.showStatus) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"}) + return + } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(api.ShowResponse{ RemoteHost: tt.remoteHost, @@ -1901,23 +1923,22 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { err := loadOrUnloadModel(cmd, opts) - if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) { - if !whoamiCalled { - t.Error("expected whoami to be called for ollama.com cloud model") - } - } else { - if whoamiCalled { - t.Error("whoami should not be called for non-ollama.com remote") - } + if whoamiCalled != tt.expectWhoami { + t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami) } if tt.expectedError != "" { if err == nil { t.Errorf("expected error containing %q, got nil", tt.expectedError) } else { - var authErr api.AuthorizationError - if !errors.As(err, &authErr) { - t.Errorf("expected AuthorizationError, got %T: %v", err, err) + if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) { + t.Errorf("expected error containing %q, got %v", tt.expectedError, err) + } + if tt.expectAuthError { + var authErr api.AuthorizationError + if !errors.As(err, &authErr) { + t.Errorf("expected AuthorizationError, got %T: %v", err, err) + } } } } else { diff --git a/cmd/config/claude.go b/cmd/config/claude.go deleted file mode 100644 index 9018d193d..000000000 --- a/cmd/config/claude.go +++ /dev/null @@ -1,187 +0,0 @@ -package config - -import ( - "context" - "fmt" - "os" - "os/exec" - "path/filepath" - "runtime" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" -) - -// Claude implements Runner and AliasConfigurer for Claude Code integration -type Claude struct{} - -// Compile-time check that Claude implements AliasConfigurer -var _ AliasConfigurer = (*Claude)(nil) - -func (c *Claude) String() string { return "Claude Code" } - -func (c *Claude) args(model string, extra []string) []string { - var args []string - if model != "" { - args = append(args, "--model", model) - } - args = append(args, extra...) - return args -} - -func (c *Claude) findPath() (string, error) { - if p, err := exec.LookPath("claude"); err == nil { - return p, nil - } - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - name := "claude" - if runtime.GOOS == "windows" { - name = "claude.exe" - } - fallback := filepath.Join(home, ".claude", "local", name) - if _, err := os.Stat(fallback); err != nil { - return "", err - } - return fallback, nil -} - -func (c *Claude) Run(model string, args []string) error { - claudePath, err := c.findPath() - if err != nil { - return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart") - } - - cmd := exec.Command(claudePath, c.args(model, args)...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - env := append(os.Environ(), - "ANTHROPIC_BASE_URL="+envconfig.Host().String(), - "ANTHROPIC_API_KEY=", - "ANTHROPIC_AUTH_TOKEN=ollama", - ) - - env = append(env, c.modelEnvVars(model)...) - - cmd.Env = env - return cmd.Run() -} - -// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama. -func (c *Claude) modelEnvVars(model string) []string { - primary := model - fast := model - if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil { - if p := cfg.Aliases["primary"]; p != "" { - primary = p - } - if f := cfg.Aliases["fast"]; f != "" { - fast = f - } - } - return []string{ - "ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary, - "ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary, - "ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast, - "CLAUDE_CODE_SUBAGENT_MODEL=" + primary, - } -} - -// ConfigureAliases sets up model aliases for Claude Code. -// model: the model to use (if empty, user will be prompted to select) -// aliases: existing alias configuration to preserve/update -// Cloud-only: subagent routing (fast model) is gated to cloud models only until -// there is a better strategy for prompt caching on local models. -func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) { - aliases := make(map[string]string) - for k, v := range existingAliases { - aliases[k] = v - } - - if model != "" { - aliases["primary"] = model - } - - if !force && aliases["primary"] != "" { - if isCloudModelName(aliases["primary"]) { - aliases["fast"] = aliases["primary"] - return aliases, false, nil - } - delete(aliases, "fast") - return aliases, false, nil - } - - items, existingModels, cloudModels, client, err := listModels(ctx) - if err != nil { - return nil, false, err - } - - fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset) - - if aliases["primary"] == "" || force { - primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"]) - if err != nil { - return nil, false, err - } - if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil { - return nil, false, err - } - if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil { - return nil, false, err - } - aliases["primary"] = primary - } - - if isCloudModelName(aliases["primary"]) { - aliases["fast"] = aliases["primary"] - } else { - delete(aliases, "fast") - } - - return aliases, true, nil -} - -// SetAliases syncs the configured aliases to the Ollama server using prefix matching. -// Cloud-only: for local models (fast is empty), we delete any existing aliases to -// prevent stale routing to a previous cloud model. -func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - prefixes := []string{"claude-sonnet-", "claude-haiku-"} - - if aliases["fast"] == "" { - for _, prefix := range prefixes { - _ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix}) - } - return nil - } - - prefixAliases := map[string]string{ - "claude-sonnet-": aliases["primary"], - "claude-haiku-": aliases["fast"], - } - - var errs []string - for prefix, target := range prefixAliases { - req := &api.AliasRequest{ - Alias: prefix, - Target: target, - PrefixMatching: true, - } - if err := client.SetAliasExperimental(ctx, req); err != nil { - errs = append(errs, prefix) - } - } - - if len(errs) > 0 { - return fmt.Errorf("failed to set aliases: %v", errs) - } - return nil -} diff --git a/cmd/config/config.go b/cmd/config/config.go index 82bfb493d..20caeaae5 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -3,7 +3,6 @@ package config import ( - "context" "encoding/json" "errors" "fmt" @@ -11,7 +10,7 @@ import ( "path/filepath" "strings" - "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/internal/fileutil" ) type integration struct { @@ -20,6 +19,9 @@ type integration struct { Onboarded bool `json:"onboarded,omitempty"` } +// IntegrationConfig is the persisted config for one integration. +type IntegrationConfig = integration + type config struct { Integrations map[string]*integration `json:"integrations"` LastModel string `json:"last_model,omitempty"` @@ -124,7 +126,7 @@ func save(cfg *config) error { return err } - return writeWithBackup(path, data) + return fileutil.WriteWithBackup(path, data) } func SaveIntegration(appName string, models []string) error { @@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error { return save(cfg) } -// integrationOnboarded marks an integration as onboarded in ollama's config. -func integrationOnboarded(appName string) error { +// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config. +func MarkIntegrationOnboarded(appName string) error { cfg, err := load() if err != nil { return err @@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error { // IntegrationModel returns the first configured model for an integration, or empty string if not configured. func IntegrationModel(appName string) string { - integrationConfig, err := loadIntegration(appName) + integrationConfig, err := LoadIntegration(appName) if err != nil || len(integrationConfig.Models) == 0 { return "" } @@ -183,7 +185,7 @@ func IntegrationModel(appName string) string { // IntegrationModels returns all configured models for an integration, or nil. func IntegrationModels(appName string) []string { - integrationConfig, err := loadIntegration(appName) + integrationConfig, err := LoadIntegration(appName) if err != nil || len(integrationConfig.Models) == 0 { return nil } @@ -228,31 +230,8 @@ func SetLastSelection(selection string) error { return save(cfg) } -// ModelExists checks if a model exists on the Ollama server. -func ModelExists(ctx context.Context, name string) bool { - if name == "" { - return false - } - if isCloudModelName(name) { - return true - } - client, err := api.ClientFromEnvironment() - if err != nil { - return false - } - models, err := client.List(ctx) - if err != nil { - return false - } - for _, m := range models.Models { - if m.Name == name || strings.HasPrefix(m.Name, name+":") { - return true - } - } - return false -} - -func loadIntegration(appName string) (*integration, error) { +// LoadIntegration returns the saved config for one integration. +func LoadIntegration(appName string) (*integration, error) { cfg, err := load() if err != nil { return nil, err @@ -266,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) { return integrationConfig, nil } -func saveAliases(appName string, aliases map[string]string) error { +// SaveAliases replaces the saved aliases for one integration. +func SaveAliases(appName string, aliases map[string]string) error { if appName == "" { return errors.New("app name cannot be empty") } diff --git a/cmd/config/config_cloud_test.go b/cmd/config/config_cloud_test.go index 23e7313d9..bd917bde8 100644 --- a/cmd/config/config_cloud_test.go +++ b/cmd/config/config_cloud_test.go @@ -1,7 +1,6 @@ package config import ( - "context" "errors" "os" "path/filepath" @@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) { "primary": "cloud-model", "fast": "cloud-model", } - if err := saveAliases("claude", initial); err != nil { + if err := SaveAliases("claude", initial); err != nil { t.Fatalf("failed to save initial aliases: %v", err) } // Verify both are saved - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) { "primary": "local-model", // fast intentionally missing } - if err := saveAliases("claude", updated); err != nil { + if err := SaveAliases("claude", updated); err != nil { t.Fatalf("failed to save updated aliases: %v", err) } // Verify fast is GONE (not merged/preserved) - loaded, err = loadIntegration("claude") + loaded, err = LoadIntegration("claude") if err != nil { t.Fatalf("failed to load after update: %v", err) } @@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) { // Then update aliases aliases := map[string]string{"primary": "new-model"} - if err := saveAliases("claude", aliases); err != nil { + if err := SaveAliases("claude", aliases); err != nil { t.Fatalf("failed to save aliases: %v", err) } // Verify models are preserved - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) { setTestHome(t, tmpDir) // Save with aliases - if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil { t.Fatalf("failed to save: %v", err) } // Save empty map - if err := saveAliases("claude", map[string]string{}); err != nil { + if err := SaveAliases("claude", map[string]string{}); err != nil { t.Fatalf("failed to save empty: %v", err) } - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) { setTestHome(t, tmpDir) // Save with aliases first - if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil { t.Fatalf("failed to save: %v", err) } // Save nil map - should clear aliases - if err := saveAliases("claude", nil); err != nil { + if err := SaveAliases("claude", nil); err != nil { t.Fatalf("failed to save nil: %v", err) } - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) { // TestSaveAliases_EmptyAppName returns error func TestSaveAliases_EmptyAppName(t *testing.T) { - err := saveAliases("", map[string]string{"primary": "model"}) + err := SaveAliases("", map[string]string{"primary": "model"}) if err == nil { t.Error("expected error for empty app name") } @@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) - if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil { + if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil { t.Fatalf("failed to save: %v", err) } // Load with different case - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) { } // Update with different case - if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil { + if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil { t.Fatalf("failed to update: %v", err) } - loaded, err = loadIntegration("claude") + loaded, err = LoadIntegration("claude") if err != nil { t.Fatalf("failed to load after update: %v", err) } @@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) { setTestHome(t, tmpDir) // Save aliases for non-existent integration - if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil { + if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil { t.Fatalf("failed to save: %v", err) } - loaded, err := loadIntegration("newintegration") + loaded, err := LoadIntegration("newintegration") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) { t.Fatal("server should succeed") } - if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil { t.Fatalf("saveAliases failed: %v", err) } // Verify it was actually saved - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) { os.WriteFile(configPath, []byte(initialConfig), 0o644) // Update aliases - if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil { t.Fatalf("failed to save: %v", err) } @@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool { return false } -func TestClaudeImplementsAliasConfigurer(t *testing.T) { - c := &Claude{} - var _ AliasConfigurer = c // Compile-time check -} - func TestModelNameEdgeCases(t *testing.T) { testCases := []struct { name string @@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) { setTestHome(t, tmpDir) aliases := map[string]string{"primary": tc.model} - if err := saveAliases("claude", aliases); err != nil { + if err := SaveAliases("claude", aliases); err != nil { t.Fatalf("failed to save model %q: %v", tc.model, err) } - loaded, err := loadIntegration("claude") + loaded, err := LoadIntegration("claude") if err != nil { t.Fatalf("failed to load: %v", err) } @@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) { setTestHome(t, tmpDir) // Initial cloud config - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "cloud-model", "fast": "cloud-model", }); err != nil { @@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) { } // Switch to local (no fast) - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "local-model", }); err != nil { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") if loaded.Aliases["fast"] != "" { t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"]) } @@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) { setTestHome(t, tmpDir) // Initial local config - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "local-model", }); err != nil { t.Fatal(err) } // Switch to cloud (with fast) - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "cloud-model", "fast": "cloud-model", }); err != nil { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") if loaded.Aliases["fast"] != "cloud-model" { t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"]) } @@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) { setTestHome(t, tmpDir) // Initial cloud config - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "cloud-model-1", "fast": "cloud-model-1", }); err != nil { @@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) { } // Switch to different cloud - if err := saveAliases("claude", map[string]string{ + if err := SaveAliases("claude", map[string]string{ "primary": "cloud-model-2", "fast": "cloud-model-2", }); err != nil { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") if loaded.Aliases["primary"] != "cloud-model-2" { t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"]) } @@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) { }) } -func TestToolCapabilityFiltering(t *testing.T) { - t.Run("all models checked for tool capability", func(t *testing.T) { - // Both cloud and local models are checked for tool capability via Show API - // Only models with "tools" in capabilities are included - m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true} - if !m.ToolCapable { - t.Error("tool capable model should be marked as such") - } - }) - - t.Run("modelInfo includes ToolCapable field", func(t *testing.T) { - m := modelInfo{Name: "test", Remote: true, ToolCapable: true} - if !m.ToolCapable { - t.Error("ToolCapable field should be accessible") - } - }) -} - -func TestIsCloudModel_RequiresClient(t *testing.T) { - t.Run("nil client always returns false", func(t *testing.T) { - // isCloudModel now only uses Show API, no suffix detection - if isCloudModel(context.Background(), nil, "model:cloud") { - t.Error("nil client should return false regardless of suffix") - } - if isCloudModel(context.Background(), nil, "local-model") { - t.Error("nil client should return false") - } - }) -} - func TestModelsAndAliasesMustStayInSync(t *testing.T) { t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) // Save aliases with one model - if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil { t.Fatal(err) } @@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") if loaded.Aliases["primary"] != loaded.Models[0] { t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0]) } @@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) { if err := SaveIntegration("claude", []string{"old-model"}); err != nil { t.Fatal(err) } - if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") // They should be different (this is the bug state) if loaded.Models[0] == loaded.Aliases["primary"] { @@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) { t.Fatal(err) } - loaded, _ = loadIntegration("claude") + loaded, _ = LoadIntegration("claude") if loaded.Models[0] != loaded.Aliases["primary"] { t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)", loaded.Models[0], loaded.Aliases["primary"]) @@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) { if err := SaveIntegration("claude", []string{"initial-model"}); err != nil { t.Fatal(err) } - if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil { t.Fatal(err) } // Update aliases AND models together newAliases := map[string]string{"primary": "updated-model"} - if err := saveAliases("claude", newAliases); err != nil { + if err := SaveAliases("claude", newAliases); err != nil { t.Fatal(err) } if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil { t.Fatal(err) } - loaded, _ := loadIntegration("claude") + loaded, _ := LoadIntegration("claude") if loaded.Models[0] != "updated-model" { t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0]) } diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index fedde7af8..044285cdb 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -10,17 +10,10 @@ import ( // setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests func setTestHome(t *testing.T, dir string) { t.Setenv("HOME", dir) + t.Setenv("TMPDIR", dir) t.Setenv("USERPROFILE", dir) } -// editorPaths is a test helper that safely calls Paths if the runner implements Editor -func editorPaths(r Runner) []string { - if editor, ok := r.(Editor); ok { - return editor.Paths() - } - return nil -} - func TestIntegrationConfig(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) @@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) { t.Fatal(err) } - config, err := loadIntegration("claude") + config, err := LoadIntegration("claude") if err != nil { t.Fatal(err) } @@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) { "primary": "llama3.2:70b", "fast": "llama3.2:8b", } - if err := saveAliases("claude", aliases); err != nil { + if err := SaveAliases("claude", aliases); err != nil { t.Fatal(err) } - config, err := loadIntegration("claude") + config, err := LoadIntegration("claude") if err != nil { t.Fatal(err) } @@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) { if err := SaveIntegration("claude", []string{"model-a"}); err != nil { t.Fatal(err) } - if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { + if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { t.Fatal(err) } if err := SaveIntegration("claude", []string{"model-b"}); err != nil { t.Fatal(err) } - config, err := loadIntegration("claude") + config, err := LoadIntegration("claude") if err != nil { t.Fatal(err) } @@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) { t.Run("defaultModel returns first model", func(t *testing.T) { SaveIntegration("codex", []string{"model-a", "model-b"}) - config, _ := loadIntegration("codex") + config, _ := LoadIntegration("codex") defaultModel := "" if len(config.Models) > 0 { defaultModel = config.Models[0] @@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) { t.Run("app name is case-insensitive", func(t *testing.T) { SaveIntegration("Claude", []string{"model-x"}) - config, err := loadIntegration("claude") + config, err := LoadIntegration("claude") if err != nil { t.Fatal(err) } @@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) { SaveIntegration("app1", []string{"model-1"}) SaveIntegration("app2", []string{"model-2"}) - config1, _ := loadIntegration("app1") - config2, _ := loadIntegration("app2") + config1, _ := LoadIntegration("app1") + config2, _ := LoadIntegration("app2") defaultModel1 := "" if len(config1.Models) > 0 { @@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) { }) } -func TestEditorPaths(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - t.Run("returns empty for claude (no Editor)", func(t *testing.T) { - r := integrations["claude"] - paths := editorPaths(r) - if len(paths) != 0 { - t.Errorf("expected no paths for claude, got %v", paths) - } - }) - - t.Run("returns empty for codex (no Editor)", func(t *testing.T) { - r := integrations["codex"] - paths := editorPaths(r) - if len(paths) != 0 { - t.Errorf("expected no paths for codex, got %v", paths) - } - }) - - t.Run("returns empty for droid when no config exists", func(t *testing.T) { - r := integrations["droid"] - paths := editorPaths(r) - if len(paths) != 0 { - t.Errorf("expected no paths, got %v", paths) - } - }) - - t.Run("returns path for droid when config exists", func(t *testing.T) { - settingsDir, _ := os.UserHomeDir() - settingsDir = filepath.Join(settingsDir, ".factory") - os.MkdirAll(settingsDir, 0o755) - os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644) - - r := integrations["droid"] - paths := editorPaths(r) - if len(paths) != 1 { - t.Errorf("expected 1 path, got %d", len(paths)) - } - }) - - t.Run("returns paths for opencode when configs exist", func(t *testing.T) { - home, _ := os.UserHomeDir() - configDir := filepath.Join(home, ".config", "opencode") - stateDir := filepath.Join(home, ".local", "state", "opencode") - os.MkdirAll(configDir, 0o755) - os.MkdirAll(stateDir, 0o755) - os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644) - os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644) - - r := integrations["opencode"] - paths := editorPaths(r) - if len(paths) != 2 { - t.Errorf("expected 2 paths, got %d: %v", len(paths), paths) - } - }) -} - func TestLoadIntegration_CorruptedJSON(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) @@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) { os.MkdirAll(dir, 0o755) os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644) - _, err := loadIntegration("test") + _, err := LoadIntegration("test") if err == nil { t.Error("expected error for nonexistent integration in corrupted file") } @@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) { t.Fatalf("saveIntegration with nil models failed: %v", err) } - config, err := loadIntegration("test") + config, err := LoadIntegration("test") if err != nil { t.Fatalf("loadIntegration failed: %v", err) } @@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) - _, err := loadIntegration("nonexistent") + _, err := LoadIntegration("nonexistent") if err == nil { t.Error("expected error for nonexistent integration, got nil") } diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go deleted file mode 100644 index e5522d562..000000000 --- a/cmd/config/integrations.go +++ /dev/null @@ -1,1441 +0,0 @@ -package config - -import ( - "context" - "errors" - "fmt" - "net/http" - "os" - "os/exec" - "runtime" - "slices" - "strings" - "time" - - "github.com/ollama/ollama/api" - internalcloud "github.com/ollama/ollama/internal/cloud" - "github.com/ollama/ollama/internal/modelref" - "github.com/ollama/ollama/progress" - "github.com/spf13/cobra" -) - -// Runners execute the launching of a model with the integration - claude, codex -// Editors can edit config files (supports multi-model selection) - opencode, droid -// They are composable interfaces where in some cases an editor is also a runner - opencode, droid -// Runner can run an integration with a model. - -type Runner interface { - Run(model string, args []string) error - // String returns the human-readable name of the integration - String() string -} - -// Editor can edit config files (supports multi-model selection) -type Editor interface { - // Paths returns the paths to the config files for the integration - Paths() []string - // Edit updates the config files for the integration with the given models - Edit(models []string) error - // Models returns the models currently configured for the integration - // TODO(parthsareen): add error return to Models() - Models() []string -} - -// AliasConfigurer can configure model aliases (e.g., for subagent routing). -// Integrations like Claude and Codex use this to route model requests to local models. -type AliasConfigurer interface { - // ConfigureAliases prompts the user to configure aliases and returns the updated map. - ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) - // SetAliases syncs the configured aliases to the server - SetAliases(ctx context.Context, aliases map[string]string) error -} - -// integrations is the registry of available integrations. -var integrations = map[string]Runner{ - "claude": &Claude{}, - "clawdbot": &Openclaw{}, - "cline": &Cline{}, - "codex": &Codex{}, - "moltbot": &Openclaw{}, - "droid": &Droid{}, - "opencode": &OpenCode{}, - "openclaw": &Openclaw{}, - "pi": &Pi{}, -} - -// recommendedModels are shown when the user has no models or as suggestions. -// Order matters: local models first, then cloud models. -var recommendedModels = []ModelItem{ - {Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true}, - {Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true}, - {Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true}, - {Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true}, - {Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true}, - {Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true}, -} - -// cloudModelLimits maps cloud model base names to their token limits. -// TODO(parthsareen): grab context/output limits from model info instead of hardcoding -var cloudModelLimits = map[string]cloudModelLimit{ - "minimax-m2.5": {Context: 204_800, Output: 128_000}, - "cogito-2.1:671b": {Context: 163_840, Output: 65_536}, - "deepseek-v3.1:671b": {Context: 163_840, Output: 163_840}, - "deepseek-v3.2": {Context: 163_840, Output: 65_536}, - "glm-4.6": {Context: 202_752, Output: 131_072}, - "glm-4.7": {Context: 202_752, Output: 131_072}, - "glm-5": {Context: 202_752, Output: 131_072}, - "gpt-oss:120b": {Context: 131_072, Output: 131_072}, - "gpt-oss:20b": {Context: 131_072, Output: 131_072}, - "kimi-k2:1t": {Context: 262_144, Output: 262_144}, - "kimi-k2.5": {Context: 262_144, Output: 262_144}, - "kimi-k2-thinking": {Context: 262_144, Output: 262_144}, - "nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072}, - "qwen3-coder:480b": {Context: 262_144, Output: 65_536}, - "qwen3-coder-next": {Context: 262_144, Output: 32_768}, - "qwen3-next:80b": {Context: 262_144, Output: 32_768}, - "qwen3.5": {Context: 262_144, Output: 32_768}, -} - -// recommendedVRAM maps local recommended models to their approximate VRAM requirement. -var recommendedVRAM = map[string]string{ - "glm-4.7-flash": "~25GB", - "qwen3.5": "~11GB", -} - -// integrationAliases are hidden from the interactive selector but work as CLI arguments. -var integrationAliases = map[string]bool{ - "clawdbot": true, - "moltbot": true, -} - -// integrationInstallHints maps integration names to install URLs. -var integrationInstallHints = map[string]string{ - "claude": "https://code.claude.com/docs/en/quickstart", - "cline": "https://cline.bot/cli", - "openclaw": "https://docs.openclaw.ai", - "codex": "https://developers.openai.com/codex/cli/", - "droid": "https://docs.factory.ai/cli/getting-started/quickstart", - "opencode": "https://opencode.ai", - "pi": "https://github.com/badlogic/pi-mono", -} - -// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable. -func hyperlink(url, text string) string { - return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text) -} - -// IntegrationInfo contains display information about a registered integration. -type IntegrationInfo struct { - Name string // registry key, e.g. "claude" - DisplayName string // human-readable, e.g. "Claude Code" - Description string // short description, e.g. "Anthropic's agentic coding tool" -} - -// integrationDescriptions maps integration names to short descriptions. -var integrationDescriptions = map[string]string{ - "claude": "Anthropic's coding tool with subagents", - "cline": "Autonomous coding agent with parallel execution", - "codex": "OpenAI's open-source coding agent", - "openclaw": "Personal AI with 100+ skills", - "droid": "Factory's coding agent across terminal and IDEs", - "opencode": "Anomaly's open-source coding agent", - "pi": "Minimal AI agent toolkit with plugin support", -} - -// integrationOrder defines a custom display order for integrations. -// Integrations listed here are placed at the end in the given order; -// all others appear first, sorted alphabetically. -var integrationOrder = []string{"opencode", "droid", "pi", "cline"} - -// ListIntegrationInfos returns all non-alias registered integrations, sorted by name -// with integrationOrder entries placed at the end. -func ListIntegrationInfos() []IntegrationInfo { - var result []IntegrationInfo - for name, r := range integrations { - if integrationAliases[name] { - continue - } - result = append(result, IntegrationInfo{ - Name: name, - DisplayName: r.String(), - Description: integrationDescriptions[name], - }) - } - - orderRank := make(map[string]int, len(integrationOrder)) - for i, name := range integrationOrder { - orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list" - } - - slices.SortFunc(result, func(a, b IntegrationInfo) int { - aRank, bRank := orderRank[a.Name], orderRank[b.Name] - // Both have custom order: sort by their rank - if aRank > 0 && bRank > 0 { - return aRank - bRank - } - // Only one has custom order: it goes last - if aRank > 0 { - return 1 - } - if bRank > 0 { - return -1 - } - // Neither has custom order: alphabetical - return strings.Compare(a.Name, b.Name) - }) - return result -} - -// IntegrationInstallHint returns a user-friendly install hint for the given integration, -// or an empty string if none is available. The URL is wrapped in an OSC 8 hyperlink -// so it is cmd+clickable in supported terminals. -func IntegrationInstallHint(name string) string { - url := integrationInstallHints[name] - if url == "" { - return "" - } - return "Install from " + hyperlink(url, url) -} - -// IsIntegrationInstalled checks if an integration binary is installed. -func IsIntegrationInstalled(name string) bool { - switch name { - case "claude": - c := &Claude{} - _, err := c.findPath() - return err == nil - case "openclaw": - if _, err := exec.LookPath("openclaw"); err == nil { - return true - } - if _, err := exec.LookPath("clawdbot"); err == nil { - return true - } - return false - case "codex": - _, err := exec.LookPath("codex") - return err == nil - case "droid": - _, err := exec.LookPath("droid") - return err == nil - case "cline": - _, err := exec.LookPath("cline") - return err == nil - case "opencode": - _, err := exec.LookPath("opencode") - return err == nil - case "pi": - _, err := exec.LookPath("pi") - return err == nil - default: - return true // Assume installed for unknown integrations - } -} - -// AutoInstallable returns true if the integration can be automatically -// installed when not found (e.g. via npm). -func AutoInstallable(name string) bool { - switch strings.ToLower(name) { - case "openclaw", "clawdbot", "moltbot": - return true - default: - return false - } -} - -// EnsureInstalled checks if an auto-installable integration is present and -// offers to install it if missing. Returns nil for non-auto-installable -// integrations or when the binary is already on PATH. -func EnsureInstalled(name string) error { - if !AutoInstallable(name) { - return nil - } - if IsIntegrationInstalled(name) { - return nil - } - _, err := ensureOpenclawInstalled() - return err -} - -// IsEditorIntegration returns true if the named integration uses multi-model -// selection (implements the Editor interface). -func IsEditorIntegration(name string) bool { - r, ok := integrations[strings.ToLower(name)] - if !ok { - return false - } - _, isEditor := r.(Editor) - return isEditor -} - -// SelectModel lets the user select a model to run. -// ModelItem represents a model for selection. -type ModelItem struct { - Name string - Description string - Recommended bool -} - -// SingleSelector is a function type for single item selection. -// current is the name of the previously selected item to highlight; empty means no pre-selection. -type SingleSelector func(title string, items []ModelItem, current string) (string, error) - -// MultiSelector is a function type for multi item selection. -type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error) - -// SelectModelWithSelector prompts the user to select a model using the provided selector. -func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) { - client, err := api.ClientFromEnvironment() - if err != nil { - return "", err - } - - models, err := client.List(ctx) - if err != nil { - return "", err - } - - var existing []modelInfo - for _, m := range models.Models { - existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""}) - } - - cloudDisabled, _ := cloudStatusDisabled(ctx, client) - if cloudDisabled { - existing = filterCloudModels(existing) - } - - lastModel := LastModel() - var preChecked []string - if lastModel != "" { - preChecked = []string{lastModel} - } - - items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel) - - if cloudDisabled { - items = filterCloudItems(items) - } - - if len(items) == 0 { - return "", fmt.Errorf("no models available, run 'ollama pull ' first") - } - - selected, err := selector("Select model to run:", items, "") - if err != nil { - return "", err - } - - // If the selected model isn't installed, pull it first - if !existingModels[selected] { - if !isCloudModelName(selected) { - msg := fmt.Sprintf("Download %s?", selected) - if ok, err := confirmPrompt(msg); err != nil { - return "", err - } else if !ok { - return "", errCancelled - } - fmt.Fprintf(os.Stderr, "\n") - if err := pullModel(ctx, client, selected); err != nil { - return "", fmt.Errorf("failed to pull %s: %w", selected, err) - } - } - } - - // If it's a cloud model, ensure user is signed in - if cloudModels[selected] { - user, err := client.Whoami(ctx) - if err == nil && user != nil && user.Name != "" { - return selected, nil - } - - var aErr api.AuthorizationError - if !errors.As(err, &aErr) || aErr.SigninURL == "" { - return "", err - } - - yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected)) - if err != nil || !yes { - return "", fmt.Errorf("%s requires sign in", selected) - } - - fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) - - // Auto-open browser (best effort, fail silently) - switch runtime.GOOS { - case "darwin": - _ = exec.Command("open", aErr.SigninURL).Start() - case "linux": - _ = exec.Command("xdg-open", aErr.SigninURL).Start() - case "windows": - _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() - } - - spinnerFrames := []string{"|", "/", "-", "\\"} - frame := 0 - - fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) - - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - fmt.Fprintf(os.Stderr, "\r\033[K") - return "", ctx.Err() - case <-ticker.C: - frame++ - fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) - - // poll every 10th frame (~2 seconds) - if frame%10 == 0 { - u, err := client.Whoami(ctx) - if err == nil && u != nil && u.Name != "" { - fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) - return selected, nil - } - } - } - } - } - - return selected, nil -} - -func SelectModel(ctx context.Context) (string, error) { - return SelectModelWithSelector(ctx, DefaultSingleSelector) -} - -// DefaultSingleSelector is the default single-select implementation. -var DefaultSingleSelector SingleSelector - -// DefaultMultiSelector is the default multi-select implementation. -var DefaultMultiSelector MultiSelector - -// DefaultSignIn provides a TUI-based sign-in flow. -// When set, ensureAuth uses it instead of plain text prompts. -// Returns the signed-in username or an error. -var DefaultSignIn func(modelName, signInURL string) (string, error) - -func selectIntegration() (string, error) { - if DefaultSingleSelector == nil { - return "", fmt.Errorf("no selector configured") - } - if len(integrations) == 0 { - return "", fmt.Errorf("no integrations available") - } - - var items []ModelItem - for name, r := range integrations { - if integrationAliases[name] { - continue - } - description := r.String() - if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 { - description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0]) - } - items = append(items, ModelItem{Name: name, Description: description}) - } - - orderRank := make(map[string]int, len(integrationOrder)) - for i, name := range integrationOrder { - orderRank[name] = i + 1 - } - slices.SortFunc(items, func(a, b ModelItem) int { - aRank, bRank := orderRank[a.Name], orderRank[b.Name] - if aRank > 0 && bRank > 0 { - return aRank - bRank - } - if aRank > 0 { - return 1 - } - if bRank > 0 { - return -1 - } - return strings.Compare(a.Name, b.Name) - }) - - return DefaultSingleSelector("Select integration:", items, "") -} - -// selectModelsWithSelectors lets the user select models for an integration using provided selectors. -func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) { - r, ok := integrations[name] - if !ok { - return nil, fmt.Errorf("unknown integration: %s", name) - } - - client, err := api.ClientFromEnvironment() - if err != nil { - return nil, err - } - - models, err := client.List(ctx) - if err != nil { - return nil, err - } - - var existing []modelInfo - for _, m := range models.Models { - existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""}) - } - - cloudDisabled, _ := cloudStatusDisabled(ctx, client) - if cloudDisabled { - existing = filterCloudModels(existing) - } - - var preChecked []string - if saved, err := loadIntegration(name); err == nil { - preChecked = saved.Models - } else if editor, ok := r.(Editor); ok { - preChecked = editor.Models() - } - - items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current) - - if cloudDisabled { - items = filterCloudItems(items) - } - - if len(items) == 0 { - return nil, fmt.Errorf("no models available") - } - - var selected []string - if _, ok := r.(Editor); ok { - selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked) - if err != nil { - return nil, err - } - } else { - prompt := fmt.Sprintf("Select model for %s:", r) - if _, ok := r.(AliasConfigurer); ok { - prompt = fmt.Sprintf("Select Primary model for %s:", r) - } - model, err := single(prompt, items, current) - if err != nil { - return nil, err - } - selected = []string{model} - } - - var toPull []string - for _, m := range selected { - if !existingModels[m] && !isCloudModelName(m) { - toPull = append(toPull, m) - } - } - if len(toPull) > 0 { - msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", ")) - if ok, err := confirmPrompt(msg); err != nil { - return nil, err - } else if !ok { - return nil, errCancelled - } - for _, m := range toPull { - fmt.Fprintf(os.Stderr, "\n") - if err := pullModel(ctx, client, m); err != nil { - return nil, fmt.Errorf("failed to pull %s: %w", m, err) - } - } - } - - if err := ensureAuth(ctx, client, cloudModels, selected); err != nil { - return nil, err - } - - return selected, nil -} - -// TODO(parthsareen): consolidate pull logic from call sites -func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { - if isCloudModelName(model) || existingModels[model] { - return nil - } - return confirmAndPull(ctx, client, model) -} - -// TODO(parthsareen): pull this out to tui package -// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found. -func ShowOrPull(ctx context.Context, client *api.Client, model string) error { - if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { - return nil - } - if isCloudModelName(model) { - return nil - } - return confirmAndPull(ctx, client, model) -} - -func confirmAndPull(ctx context.Context, client *api.Client, model string) error { - if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { - return err - } else if !ok { - return errCancelled - } - fmt.Fprintf(os.Stderr, "\n") - if err := pullModel(ctx, client, model); err != nil { - return fmt.Errorf("failed to pull %s: %w", model, err) - } - return nil -} - -func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) { - client, err := api.ClientFromEnvironment() - if err != nil { - return nil, nil, nil, nil, err - } - - models, err := client.List(ctx) - if err != nil { - return nil, nil, nil, nil, err - } - - var existing []modelInfo - for _, m := range models.Models { - existing = append(existing, modelInfo{ - Name: m.Name, - Remote: m.RemoteModel != "", - }) - } - - cloudDisabled, _ := cloudStatusDisabled(ctx, client) - if cloudDisabled { - existing = filterCloudModels(existing) - } - - items, _, existingModels, cloudModels := buildModelList(existing, nil, "") - - if cloudDisabled { - items = filterCloudItems(items) - } - - if len(items) == 0 { - return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull ' first") - } - - return items, existingModels, cloudModels, client, nil -} - -func OpenBrowser(url string) { - switch runtime.GOOS { - case "darwin": - _ = exec.Command("open", url).Start() - case "linux": - _ = exec.Command("xdg-open", url).Start() - case "windows": - _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - } -} - -func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error { - var selectedCloudModels []string - for _, m := range selected { - if cloudModels[m] { - selectedCloudModels = append(selectedCloudModels, m) - } - } - if len(selectedCloudModels) == 0 { - return nil - } - if disabled, known := cloudStatusDisabled(ctx, client); known && disabled { - return errors.New(internalcloud.DisabledError("remote inference is unavailable")) - } - - user, err := client.Whoami(ctx) - if err == nil && user != nil && user.Name != "" { - return nil - } - - var aErr api.AuthorizationError - if !errors.As(err, &aErr) || aErr.SigninURL == "" { - return err - } - - modelList := strings.Join(selectedCloudModels, ", ") - - if DefaultSignIn != nil { - _, err := DefaultSignIn(modelList, aErr.SigninURL) - if err != nil { - return fmt.Errorf("%s requires sign in", modelList) - } - return nil - } - - // Fallback: plain text sign-in flow - yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) - if err != nil || !yes { - return fmt.Errorf("%s requires sign in", modelList) - } - - fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) - - OpenBrowser(aErr.SigninURL) - - spinnerFrames := []string{"|", "/", "-", "\\"} - frame := 0 - - fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) - - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - fmt.Fprintf(os.Stderr, "\r\033[K") - return ctx.Err() - case <-ticker.C: - frame++ - fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) - - // poll every 10th frame (~2 seconds) - if frame%10 == 0 { - u, err := client.Whoami(ctx) - if err == nil && u != nil && u.Name != "" { - fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) - return nil - } - } - } - } -} - -// selectModels lets the user select models for an integration using default selectors. -func selectModels(ctx context.Context, name, current string) ([]string, error) { - return selectModelsWithSelectors(ctx, name, current, DefaultSingleSelector, DefaultMultiSelector) -} - -func runIntegration(name, modelName string, args []string) error { - r, ok := integrations[name] - if !ok { - return fmt.Errorf("unknown integration: %s", name) - } - - fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName) - return r.Run(modelName, args) -} - -// syncAliases syncs aliases to server and saves locally for an AliasConfigurer. -func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error { - aliases := make(map[string]string) - for k, v := range existing { - aliases[k] = v - } - aliases["primary"] = model - - if isCloudModelName(model) { - aliases["fast"] = model - } else { - delete(aliases, "fast") - } - - if err := ac.SetAliases(ctx, aliases); err != nil { - return err - } - return saveAliases(name, aliases) -} - -// LaunchIntegration launches the named integration using saved config or prompts for setup. -func LaunchIntegration(name string) error { - r, ok := integrations[name] - if !ok { - return fmt.Errorf("unknown integration: %s", name) - } - - // Try to use saved config - if ic, err := loadIntegration(name); err == nil && len(ic.Models) > 0 { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - if err := ShowOrPull(context.Background(), client, ic.Models[0]); err != nil { - return err - } - return runIntegration(name, ic.Models[0], nil) - } - - // No saved config - prompt user to run setup - return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name) -} - -// LaunchIntegrationWithModel launches the named integration with the specified model. -func LaunchIntegrationWithModel(name, modelName string) error { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - if err := ShowOrPull(context.Background(), client, modelName); err != nil { - return err - } - return runIntegration(name, modelName, nil) -} - -// SaveAndEditIntegration saves the models for an Editor integration and runs its Edit method -// to write the integration's config files. -func SaveAndEditIntegration(name string, models []string) error { - r, ok := integrations[strings.ToLower(name)] - if !ok { - return fmt.Errorf("unknown integration: %s", name) - } - if err := SaveIntegration(name, models); err != nil { - return fmt.Errorf("failed to save: %w", err) - } - if editor, isEditor := r.(Editor); isEditor { - if err := editor.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - } - return nil -} - -// resolveEditorModels filters out cloud-disabled models before editor launch. -// If no models remain, it invokes picker to collect a valid replacement list. -func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) { - filtered := filterDisabledCloudModels(models) - if len(filtered) != len(models) { - if err := SaveIntegration(name, filtered); err != nil { - return nil, fmt.Errorf("failed to save: %w", err) - } - } - if len(filtered) > 0 { - return filtered, nil - } - - selected, err := picker() - if err != nil { - return nil, err - } - if err := SaveIntegration(name, selected); err != nil { - return nil, fmt.Errorf("failed to save: %w", err) - } - return selected, nil -} - -// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors. -func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error { - r, ok := integrations[name] - if !ok { - return fmt.Errorf("unknown integration: %s", name) - } - - models, err := selectModelsWithSelectors(ctx, name, "", single, multi) - if errors.Is(err, errCancelled) { - return errCancelled - } - if err != nil { - return err - } - - if editor, isEditor := r.(Editor); isEditor { - paths := editor.Paths() - if len(paths) > 0 { - fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r) - for _, p := range paths { - fmt.Fprintf(os.Stderr, " %s\n", p) - } - fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir()) - - if ok, _ := confirmPrompt("Proceed?"); !ok { - return nil - } - } - - if err := editor.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - } - - if err := SaveIntegration(name, models); err != nil { - return fmt.Errorf("failed to save: %w", err) - } - - if len(models) == 1 { - fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0]) - } else { - fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0]) - } - - return nil -} - -// ConfigureIntegration allows the user to select/change the model for an integration. -func ConfigureIntegration(ctx context.Context, name string) error { - return ConfigureIntegrationWithSelectors(ctx, name, DefaultSingleSelector, DefaultMultiSelector) -} - -// LaunchCmd returns the cobra command for launching integrations. -// The runTUI callback is called when no arguments are provided (alias for main TUI). -func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command { - var modelFlag string - var configFlag bool - - cmd := &cobra.Command{ - Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]", - Short: "Launch the Ollama menu or an integration", - Long: `Launch the Ollama interactive menu, or directly launch a specific integration. - -Without arguments, this is equivalent to running 'ollama' directly. - -Supported integrations: - claude Claude Code - cline Cline - codex Codex - droid Droid - opencode OpenCode - openclaw OpenClaw (aliases: clawdbot, moltbot) - pi Pi - -Examples: - ollama launch - ollama launch claude - ollama launch claude --model - ollama launch droid --config (does not auto-launch) - ollama launch codex -- -p myprofile (pass extra args to integration) - ollama launch codex -- --sandbox workspace-write`, - Args: cobra.ArbitraryArgs, - PreRunE: checkServerHeartbeat, - RunE: func(cmd *cobra.Command, args []string) error { - // No args and no flags - show the full TUI (same as bare 'ollama') - if len(args) == 0 && modelFlag == "" && !configFlag { - runTUI(cmd) - return nil - } - - // Extract integration name and args to pass through using -- separator - var name string - var passArgs []string - dashIdx := cmd.ArgsLenAtDash() - - if dashIdx == -1 { - // No "--" separator: only allow 0 or 1 args (integration name) - if len(args) > 1 { - return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:]) - } - if len(args) == 1 { - name = args[0] - } - } else { - // "--" was used: args before it = integration name, args after = passthrough - if dashIdx > 1 { - return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx) - } - if dashIdx == 1 { - name = args[0] - } - passArgs = args[dashIdx:] - } - - if name == "" { - var err error - name, err = selectIntegration() - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - } - - r, ok := integrations[strings.ToLower(name)] - if !ok { - return fmt.Errorf("unknown integration: %s", name) - } - - if err := EnsureInstalled(name); err != nil { - return err - } - - if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) { - modelFlag = "" - } - - // Handle AliasConfigurer integrations (claude, codex) - if ac, ok := r.(AliasConfigurer); ok { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - // Validate --model flag if provided - if modelFlag != "" { - if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil { - if errors.Is(err, errCancelled) { - return nil - } - return err - } - } - - var model string - var existingAliases map[string]string - - // Load saved config - if cfg, err := loadIntegration(name); err == nil { - existingAliases = cfg.Aliases - if len(cfg.Models) > 0 { - model = cfg.Models[0] - // AliasConfigurer integrations use single model; sanitize if multiple - if len(cfg.Models) > 1 { - _ = SaveIntegration(name, []string{model}) - } - } - } - - // --model flag overrides saved model - if modelFlag != "" { - model = modelFlag - } - - // Validate saved model still exists - if model != "" && modelFlag == "" { - if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) { - model = "" - } else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil { - fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset) - if err := ShowOrPull(cmd.Context(), client, model); err != nil { - model = "" - } - } - } - - // Show picker so user can change model (skip when --model flag provided) - aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "") - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - model = aliases["primary"] - existingAliases = aliases - - // Ensure cloud models are authenticated - if isCloudModelName(model) { - if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil { - return err - } - } - - // Sync aliases and save - if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil { - fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset) - } - if err := SaveIntegration(name, []string{model}); err != nil { - return fmt.Errorf("failed to save: %w", err) - } - - // Launch (unless --config without confirmation) - if configFlag { - if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch { - return runIntegration(name, model, passArgs) - } - return nil - } - return runIntegration(name, model, passArgs) - } - - // Validate --model flag for non-AliasConfigurer integrations - if modelFlag != "" { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil { - if errors.Is(err, errCancelled) { - return nil - } - return err - } - } - - var models []string - if modelFlag != "" { - models = []string{modelFlag} - if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 { - for _, m := range existing.Models { - if m != modelFlag { - models = append(models, m) - } - } - } - models = filterDisabledCloudModels(models) - if len(models) == 0 { - var err error - models, err = selectModels(cmd.Context(), name, "") - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - } - } else { - current := "" - if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 { - current = saved.Models[0] - } - var err error - models, err = selectModels(cmd.Context(), name, current) - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - } - - if editor, isEditor := r.(Editor); isEditor { - paths := editor.Paths() - if len(paths) > 0 { - fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r) - for _, p := range paths { - fmt.Fprintf(os.Stderr, " %s\n", p) - } - fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir()) - - if ok, _ := confirmPrompt("Proceed?"); !ok { - return nil - } - } - } - - if err := SaveIntegration(name, models); err != nil { - return fmt.Errorf("failed to save: %w", err) - } - - if editor, isEditor := r.(Editor); isEditor { - if err := editor.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - } - - if _, isEditor := r.(Editor); isEditor { - if len(models) == 1 { - fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r) - } else { - fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0]) - } - } - - if configFlag { - if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch { - return runIntegration(name, models[0], passArgs) - } - fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0]) - return nil - } - - return runIntegration(name, models[0], passArgs) - }, - } - - cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use") - cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching") - return cmd -} - -type modelInfo struct { - Name string - Remote bool - ToolCapable bool -} - -// buildModelList merges existing models with recommendations, sorts them, and returns -// the ordered items along with maps of existing and cloud model names. -func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) { - existingModels = make(map[string]bool) - cloudModels = make(map[string]bool) - recommended := make(map[string]bool) - var hasLocalModel, hasCloudModel bool - - recDesc := make(map[string]string) - for _, rec := range recommendedModels { - recommended[rec.Name] = true - recDesc[rec.Name] = rec.Description - } - - for _, m := range existing { - existingModels[m.Name] = true - if m.Remote { - cloudModels[m.Name] = true - hasCloudModel = true - } else { - hasLocalModel = true - } - displayName := strings.TrimSuffix(m.Name, ":latest") - existingModels[displayName] = true - item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]} - items = append(items, item) - } - - for _, rec := range recommendedModels { - if existingModels[rec.Name] || existingModels[rec.Name+":latest"] { - continue - } - items = append(items, rec) - if isCloudModelName(rec.Name) { - cloudModels[rec.Name] = true - } - } - - checked := make(map[string]bool, len(preChecked)) - for _, n := range preChecked { - checked[n] = true - } - - // Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest") - for _, item := range items { - if item.Name == current || strings.HasPrefix(item.Name, current+":") { - current = item.Name - break - } - } - - if checked[current] { - preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...) - } - - // Non-existing models get "install?" suffix and are pushed to the bottom. - // When user has no models, preserve recommended order. - notInstalled := make(map[string]bool) - for i := range items { - if !existingModels[items[i].Name] && !cloudModels[items[i].Name] { - notInstalled[items[i].Name] = true - var parts []string - if items[i].Description != "" { - parts = append(parts, items[i].Description) - } - if vram := recommendedVRAM[items[i].Name]; vram != "" { - parts = append(parts, vram) - } - parts = append(parts, "(not downloaded)") - items[i].Description = strings.Join(parts, ", ") - } - } - - // Build a recommended rank map to preserve ordering within tiers. - recRank := make(map[string]int) - for i, rec := range recommendedModels { - recRank[rec.Name] = i + 1 // 1-indexed; 0 means not recommended - } - - onlyLocal := hasLocalModel && !hasCloudModel - - if hasLocalModel || hasCloudModel { - slices.SortStableFunc(items, func(a, b ModelItem) int { - ac, bc := checked[a.Name], checked[b.Name] - aNew, bNew := notInstalled[a.Name], notInstalled[b.Name] - aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0 - aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name] - - // Checked/pre-selected always first - if ac != bc { - if ac { - return -1 - } - return 1 - } - - // Recommended above non-recommended - if aRec != bRec { - if aRec { - return -1 - } - return 1 - } - - // Both recommended - if aRec && bRec { - if aCloud != bCloud { - if onlyLocal { - // Local before cloud when only local installed - if aCloud { - return 1 - } - return -1 - } - // Cloud before local in mixed case - if aCloud { - return -1 - } - return 1 - } - return recRank[a.Name] - recRank[b.Name] - } - - // Both non-recommended: installed before not-installed - if aNew != bNew { - if aNew { - return 1 - } - return -1 - } - - return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) - }) - } - - return items, preChecked, existingModels, cloudModels -} - -// IsCloudModelDisabled reports whether the given model name looks like a cloud -// model and cloud features are currently disabled on the server. -func IsCloudModelDisabled(ctx context.Context, name string) bool { - if !isCloudModelName(name) { - return false - } - client, err := api.ClientFromEnvironment() - if err != nil { - return false - } - disabled, _ := cloudStatusDisabled(ctx, client) - return disabled -} - -func isCloudModelName(name string) bool { - // TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit - return modelref.HasExplicitCloudSource(name) -} - -func filterCloudModels(existing []modelInfo) []modelInfo { - filtered := existing[:0] - for _, m := range existing { - if !m.Remote { - filtered = append(filtered, m) - } - } - return filtered -} - -// filterDisabledCloudModels removes cloud models from a list when cloud is disabled. -func filterDisabledCloudModels(models []string) []string { - var filtered []string - for _, m := range models { - if !IsCloudModelDisabled(context.Background(), m) { - filtered = append(filtered, m) - } - } - return filtered -} - -func filterCloudItems(items []ModelItem) []ModelItem { - filtered := items[:0] - for _, item := range items { - if !isCloudModelName(item.Name) { - filtered = append(filtered, item) - } - } - return filtered -} - -func isCloudModel(ctx context.Context, client *api.Client, name string) bool { - if client == nil { - return false - } - resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) - if err != nil { - return false - } - return resp.RemoteModel != "" -} - -// GetModelItems returns a list of model items including recommendations for the TUI. -// It includes all locally available models plus recommended models that aren't installed. -func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) { - client, err := api.ClientFromEnvironment() - if err != nil { - return nil, nil - } - - models, err := client.List(ctx) - if err != nil { - return nil, nil - } - - var existing []modelInfo - for _, m := range models.Models { - existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""}) - } - - cloudDisabled, _ := cloudStatusDisabled(ctx, client) - if cloudDisabled { - existing = filterCloudModels(existing) - } - - lastModel := LastModel() - var preChecked []string - if lastModel != "" { - preChecked = []string{lastModel} - } - - items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel) - - if cloudDisabled { - items = filterCloudItems(items) - } - - return items, existingModels -} - -func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) { - status, err := client.CloudStatusExperimental(ctx) - if err != nil { - var statusErr api.StatusError - if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound { - return false, false - } - return false, false - } - return status.Cloud.Disabled, true -} - -func pullModel(ctx context.Context, client *api.Client, model string) error { - p := progress.NewProgress(os.Stderr) - defer p.Stop() - - bars := make(map[string]*progress.Bar) - var status string - var spinner *progress.Spinner - - fn := func(resp api.ProgressResponse) error { - if resp.Digest != "" { - if resp.Completed == 0 { - return nil - } - - if spinner != nil { - spinner.Stop() - } - - bar, ok := bars[resp.Digest] - if !ok { - name, isDigest := strings.CutPrefix(resp.Digest, "sha256:") - name = strings.TrimSpace(name) - if isDigest { - name = name[:min(12, len(name))] - } - bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed) - bars[resp.Digest] = bar - p.Add(resp.Digest, bar) - } - - bar.Set(resp.Completed) - } else if status != resp.Status { - if spinner != nil { - spinner.Stop() - } - - status = resp.Status - spinner = progress.NewSpinner(status) - p.Add(status, spinner) - } - - return nil - } - - request := api.PullRequest{Name: model} - return client.Pull(ctx, &request, fn) -} diff --git a/cmd/config/selector.go b/cmd/config/selector.go deleted file mode 100644 index e94f3bffd..000000000 --- a/cmd/config/selector.go +++ /dev/null @@ -1,59 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "os" - - "golang.org/x/term" -) - -// ANSI escape sequences for terminal formatting. -const ( - ansiBold = "\033[1m" - ansiReset = "\033[0m" - ansiGray = "\033[37m" - ansiGreen = "\033[32m" - ansiYellow = "\033[33m" -) - -// ErrCancelled is returned when the user cancels a selection. -var ErrCancelled = errors.New("cancelled") - -// errCancelled is kept as an alias for backward compatibility within the package. -var errCancelled = ErrCancelled - -// DefaultConfirmPrompt provides a TUI-based confirmation prompt. -// When set, confirmPrompt delegates to it instead of using raw terminal I/O. -var DefaultConfirmPrompt func(prompt string) (bool, error) - -func confirmPrompt(prompt string) (bool, error) { - if DefaultConfirmPrompt != nil { - return DefaultConfirmPrompt(prompt) - } - - fd := int(os.Stdin.Fd()) - oldState, err := term.MakeRaw(fd) - if err != nil { - return false, err - } - defer term.Restore(fd, oldState) - - fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt) - - buf := make([]byte, 1) - for { - if _, err := os.Stdin.Read(buf); err != nil { - return false, err - } - - switch buf[0] { - case 'Y', 'y', 13: - fmt.Fprintf(os.Stderr, "yes\r\n") - return true, nil - case 'N', 'n', 27, 3: - fmt.Fprintf(os.Stderr, "no\r\n") - return false, nil - } - } -} diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go deleted file mode 100644 index 3e84d1b5d..000000000 --- a/cmd/config/selector_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package config - -import ( - "testing" -) - -func TestErrCancelled(t *testing.T) { - t.Run("NotNil", func(t *testing.T) { - if errCancelled == nil { - t.Error("errCancelled should not be nil") - } - }) - - t.Run("Message", func(t *testing.T) { - if errCancelled.Error() != "cancelled" { - t.Errorf("expected 'cancelled', got %q", errCancelled.Error()) - } - }) -} diff --git a/cmd/config/files.go b/cmd/internal/fileutil/files.go similarity index 78% rename from cmd/config/files.go rename to cmd/internal/fileutil/files.go index 545e25c4d..bc57ba05e 100644 --- a/cmd/config/files.go +++ b/cmd/internal/fileutil/files.go @@ -1,4 +1,6 @@ -package config +// Package fileutil provides small shared helpers for reading JSON files +// and writing config files with backup-on-overwrite semantics. +package fileutil import ( "bytes" @@ -9,7 +11,8 @@ import ( "time" ) -func readJSONFile(path string) (map[string]any, error) { +// ReadJSON reads a JSON object file into a generic map. +func ReadJSON(path string) (map[string]any, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -33,12 +36,13 @@ func copyFile(src, dst string) error { return os.WriteFile(dst, data, info.Mode().Perm()) } -func backupDir() string { +// BackupDir returns the shared backup directory used before overwriting files. +func BackupDir() string { return filepath.Join(os.TempDir(), "ollama-backups") } func backupToTmp(srcPath string) (string, error) { - dir := backupDir() + dir := BackupDir() if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } @@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) { return backupPath, nil } -// writeWithBackup writes data to path via temp file + rename, backing up any existing file first -func writeWithBackup(path string, data []byte) error { +// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first. +func WriteWithBackup(path string, data []byte) error { var backupPath string // backup must be created before any writes to the target file if existingContent, err := os.ReadFile(path); err == nil { diff --git a/cmd/config/files_test.go b/cmd/internal/fileutil/files_test.go similarity index 85% rename from cmd/config/files_test.go rename to cmd/internal/fileutil/files_test.go index e0aaea2b5..f00a63b14 100644 --- a/cmd/config/files_test.go +++ b/cmd/internal/fileutil/files_test.go @@ -1,4 +1,4 @@ -package config +package fileutil import ( "encoding/json" @@ -9,6 +9,21 @@ import ( "testing" ) +func TestMain(m *testing.M) { + tmpRoot, err := os.MkdirTemp("", "fileutil-test-*") + if err != nil { + panic(err) + } + + if err := os.Setenv("TMPDIR", tmpRoot); err != nil { + panic(err) + } + + code := m.Run() + _ = os.RemoveAll(tmpRoot) + os.Exit(code) +} + func mustMarshal(t *testing.T, v any) []byte { t.Helper() data, err := json.MarshalIndent(v, "", " ") @@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte { return data } +func isolatedTempDir(t *testing.T) string { + t.Helper() + return t.TempDir() +} + func TestWriteWithBackup(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) t.Run("creates file", func(t *testing.T) { path := filepath.Join(tmpDir, "new.json") data := mustMarshal(t, map[string]string{"key": "value"}) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } @@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) { } }) - t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) { + t.Run("creates backup in the temp backup directory", func(t *testing.T) { path := filepath.Join(tmpDir, "backup.json") os.WriteFile(path, []byte(`{"original": true}`), 0o644) data := mustMarshal(t, map[string]bool{"updated": true}) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } - entries, err := os.ReadDir(backupDir()) + entries, err := os.ReadDir(BackupDir()) if err != nil { t.Fatal("backup directory not created") } @@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) { if filepath.Ext(entry.Name()) != ".json" { name := entry.Name() if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." { - backupPath := filepath.Join(backupDir(), name) + backupPath := filepath.Join(BackupDir(), name) backup, err := os.ReadFile(backupPath) if err == nil { var backupData map[string]bool @@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) { } if !foundBackup { - t.Error("backup file not created in /tmp/ollama-backups") + t.Error("backup file not created in backup directory") } current, _ := os.ReadFile(path) @@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) { path := filepath.Join(tmpDir, "nobak.json") data := mustMarshal(t, map[string]string{"new": "file"}) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } - entries, _ := os.ReadDir(backupDir()) + entries, _ := os.ReadDir(BackupDir()) for _, entry := range entries { if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." { t.Error("backup should not exist for new file") @@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) { data := mustMarshal(t, map[string]string{"key": "value"}) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } - entries1, _ := os.ReadDir(backupDir()) + entries1, _ := os.ReadDir(BackupDir()) countBefore := 0 for _, e := range entries1 { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { @@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) { } } - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } - entries2, _ := os.ReadDir(backupDir()) + entries2, _ := os.ReadDir(BackupDir()) countAfter := 0 for _, e := range entries2 { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { @@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) { os.WriteFile(path, []byte(`{"v": 1}`), 0o644) data := mustMarshal(t, map[string]int{"v": 2}) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatal(err) } - entries, _ := os.ReadDir(backupDir()) + entries, _ := os.ReadDir(BackupDir()) var found bool for _, entry := range entries { name := entry.Name() @@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) { } } found = true - os.Remove(filepath.Join(backupDir(), name)) + os.Remove(filepath.Join(BackupDir(), name)) break } } @@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) { t.Skip("permission tests unreliable on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) path := filepath.Join(tmpDir, "config.json") // Create original file @@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) { os.WriteFile(path, originalContent, 0o644) // Make backup directory read-only to force backup failure - backupDir := backupDir() + backupDir := BackupDir() os.MkdirAll(backupDir, 0o755) os.Chmod(backupDir, 0o444) // Read-only defer os.Chmod(backupDir, 0o755) newContent := []byte(`{"updated": true}`) - err := writeWithBackup(path, newContent) + err := WriteWithBackup(path, newContent) // Should fail because backup couldn't be created if err == nil { @@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) { t.Skip("permission tests unreliable on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) // Create a read-only directory readOnlyDir := filepath.Join(tmpDir, "readonly") @@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) { defer os.Chmod(readOnlyDir, 0o755) path := filepath.Join(readOnlyDir, "config.json") - err := writeWithBackup(path, []byte(`{"test": true}`)) + err := WriteWithBackup(path, []byte(`{"test": true}`)) if err == nil { t.Error("expected permission error, got nil") @@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) { // TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist. // writeWithBackup doesn't create directories - caller is responsible. func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json") - err := writeWithBackup(path, []byte(`{"test": true}`)) + err := WriteWithBackup(path, []byte(`{"test": true}`)) // Should fail because directory doesn't exist if err == nil { @@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) { t.Skip("symlink tests may require admin on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) realFile := filepath.Join(tmpDir, "real.json") symlink := filepath.Join(tmpDir, "link.json") @@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) { os.Symlink(realFile, symlink) // Write through symlink - err := writeWithBackup(symlink, []byte(`{"v": 2}`)) + err := WriteWithBackup(symlink, []byte(`{"v": 2}`)) if err != nil { t.Fatalf("writeWithBackup through symlink failed: %v", err) } @@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) { // TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters. // User may have config files with unusual names. func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) // File with spaces and special chars path := filepath.Join(tmpDir, "my config (backup).json") @@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) { t.Skip("permission preservation tests unreliable on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) src := filepath.Join(tmpDir, "src.json") dst := filepath.Join(tmpDir, "dst.json") @@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) { // TestCopyFile_SourceNotFound verifies clear error when source doesn't exist. func TestCopyFile_SourceNotFound(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) src := filepath.Join(tmpDir, "nonexistent.json") dst := filepath.Join(tmpDir, "dst.json") @@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) { // TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory. func TestWriteWithBackup_TargetIsDirectory(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) dirPath := filepath.Join(tmpDir, "actualdir") os.MkdirAll(dirPath, 0o755) - err := writeWithBackup(dirPath, []byte(`{"test": true}`)) + err := WriteWithBackup(dirPath, []byte(`{"test": true}`)) if err == nil { t.Error("expected error when target is a directory, got nil") } @@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) { // TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly. func TestWriteWithBackup_EmptyData(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) path := filepath.Join(tmpDir, "empty.json") - err := writeWithBackup(path, []byte{}) + err := WriteWithBackup(path, []byte{}) if err != nil { t.Fatalf("writeWithBackup with empty data failed: %v", err) } @@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) { t.Skip("permission tests unreliable on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) path := filepath.Join(tmpDir, "unreadable.json") // Create file and make it unreadable @@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) { defer os.Chmod(path, 0o644) // Should fail because we can't read the file to compare/backup - err := writeWithBackup(path, []byte(`{"updated": true}`)) + err := WriteWithBackup(path, []byte(`{"updated": true}`)) if err == nil { t.Error("expected error when file is unreadable, got nil") } @@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) { // TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes // within the same second (timestamp collision scenario). func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) path := filepath.Join(tmpDir, "rapid.json") // Create initial file @@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { // Rapid successive writes for i := 1; i <= 3; i++ { data := []byte(fmt.Sprintf(`{"v": %d}`, i)) - if err := writeWithBackup(path, data); err != nil { + if err := WriteWithBackup(path, data); err != nil { t.Fatalf("write %d failed: %v", i, err) } } @@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { } // Verify at least one backup exists - entries, _ := os.ReadDir(backupDir()) + entries, _ := os.ReadDir(BackupDir()) var backupCount int for _, e := range entries { if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." { @@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) { t.Skip("test modifies system temp directory") } + tmpDir := isolatedTempDir(t) // Create a file at the backup directory path - backupPath := backupDir() + backupPath := BackupDir() // Clean up any existing directory first os.RemoveAll(backupPath) // Create a file instead of directory @@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) { os.MkdirAll(backupPath, 0o755) }() - tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.json") os.WriteFile(path, []byte(`{"original": true}`), 0o644) - err := writeWithBackup(path, []byte(`{"updated": true}`)) + err := WriteWithBackup(path, []byte(`{"updated": true}`)) if err == nil { t.Error("expected error when backup dir is a file, got nil") } @@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) { t.Skip("permission tests unreliable on Windows") } - tmpDir := t.TempDir() + tmpDir := isolatedTempDir(t) // Count existing temp files countTempFiles := func() int { @@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) { badPath := filepath.Join(tmpDir, "isdir") os.MkdirAll(badPath, 0o755) - _ = writeWithBackup(badPath, []byte(`{"test": true}`)) + _ = WriteWithBackup(badPath, []byte(`{"test": true}`)) after := countTempFiles() if after > before { diff --git a/cmd/launch/claude.go b/cmd/launch/claude.go new file mode 100644 index 000000000..19ef49b01 --- /dev/null +++ b/cmd/launch/claude.go @@ -0,0 +1,77 @@ +package launch + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/ollama/ollama/envconfig" +) + +// Claude implements Runner for Claude Code integration. +type Claude struct{} + +func (c *Claude) String() string { return "Claude Code" } + +func (c *Claude) args(model string, extra []string) []string { + var args []string + if model != "" { + args = append(args, "--model", model) + } + args = append(args, extra...) + return args +} + +func (c *Claude) findPath() (string, error) { + if p, err := exec.LookPath("claude"); err == nil { + return p, nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + name := "claude" + if runtime.GOOS == "windows" { + name = "claude.exe" + } + fallback := filepath.Join(home, ".claude", "local", name) + if _, err := os.Stat(fallback); err != nil { + return "", err + } + return fallback, nil +} + +func (c *Claude) Run(model string, args []string) error { + claudePath, err := c.findPath() + if err != nil { + return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart") + } + + cmd := exec.Command(claudePath, c.args(model, args)...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + env := append(os.Environ(), + "ANTHROPIC_BASE_URL="+envconfig.Host().String(), + "ANTHROPIC_API_KEY=", + "ANTHROPIC_AUTH_TOKEN=ollama", + ) + + env = append(env, c.modelEnvVars(model)...) + + cmd.Env = env + return cmd.Run() +} + +// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama. +func (c *Claude) modelEnvVars(model string) []string { + return []string{ + "ANTHROPIC_DEFAULT_OPUS_MODEL=" + model, + "ANTHROPIC_DEFAULT_SONNET_MODEL=" + model, + "ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model, + "CLAUDE_CODE_SUBAGENT_MODEL=" + model, + } +} diff --git a/cmd/config/claude_test.go b/cmd/launch/claude_test.go similarity index 59% rename from cmd/config/claude_test.go rename to cmd/launch/claude_test.go index e5ad16a20..689415b44 100644 --- a/cmd/config/claude_test.go +++ b/cmd/launch/claude_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "os" @@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) { return m } - t.Run("falls back to model param when no aliases saved", func(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - + t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) { got := envMap(c.modelEnvVars("llama3.2")) if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" { t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) @@ -136,63 +133,19 @@ func TestClaudeModelEnvVars(t *testing.T) { } }) - t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - SaveIntegration("claude", []string{"qwen3:8b"}) - saveAliases("claude", map[string]string{"primary": "qwen3:8b"}) - - got := envMap(c.modelEnvVars("qwen3:8b")) - if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" { - t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) + t.Run("supports empty model", func(t *testing.T) { + got := envMap(c.modelEnvVars("")) + if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" { + t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) } - if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" { - t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"]) + if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" { + t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"]) } - if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" { - t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"]) + if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" { + t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"]) } - if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" { - t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"]) - } - }) - - t.Run("uses fast alias for haiku", func(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - SaveIntegration("claude", []string{"llama3.2:70b"}) - saveAliases("claude", map[string]string{ - "primary": "llama3.2:70b", - "fast": "llama3.2:8b", - }) - - got := envMap(c.modelEnvVars("llama3.2:70b")) - if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" { - t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) - } - if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" { - t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"]) - } - if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" { - t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"]) - } - if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" { - t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"]) - } - }) - - t.Run("alias primary overrides model param", func(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - SaveIntegration("claude", []string{"saved-model"}) - saveAliases("claude", map[string]string{"primary": "saved-model"}) - - got := envMap(c.modelEnvVars("different-model")) - if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" { - t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) + if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" { + t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"]) } }) } diff --git a/cmd/config/cline.go b/cmd/launch/cline.go similarity index 77% rename from cmd/config/cline.go rename to cmd/launch/cline.go index 847d8d431..9ad886809 100644 --- a/cmd/config/cline.go +++ b/cmd/launch/cline.go @@ -1,14 +1,13 @@ -package config +package launch import ( - "context" "encoding/json" - "errors" "fmt" "os" "os/exec" "path/filepath" + "github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/envconfig" ) @@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error { return fmt.Errorf("cline is not installed, install with: npm install -g cline") } - models := []string{model} - if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 { - models = config.Models - } - var err error - models, err = resolveEditorModels("cline", models, func() ([]string, error) { - return selectModels(context.Background(), "cline", "") - }) - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - if err := c.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - cmd := exec.Command("cline", args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error { if err != nil { return err } - return writeWithBackup(configPath, data) + return fileutil.WriteWithBackup(configPath, data) } func (c *Cline) Models() []string { @@ -106,7 +87,7 @@ func (c *Cline) Models() []string { return nil } - config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json")) + config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json")) if err != nil { return nil } diff --git a/cmd/config/cline_test.go b/cmd/launch/cline_test.go similarity index 99% rename from cmd/config/cline_test.go rename to cmd/launch/cline_test.go index 7e9f7f07c..f2440a27f 100644 --- a/cmd/config/cline_test.go +++ b/cmd/launch/cline_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "encoding/json" diff --git a/cmd/config/codex.go b/cmd/launch/codex.go similarity index 99% rename from cmd/config/codex.go rename to cmd/launch/codex.go index ee2c70542..821669563 100644 --- a/cmd/config/codex.go +++ b/cmd/launch/codex.go @@ -1,4 +1,4 @@ -package config +package launch import ( "fmt" diff --git a/cmd/config/codex_test.go b/cmd/launch/codex_test.go similarity index 98% rename from cmd/config/codex_test.go rename to cmd/launch/codex_test.go index e886fc4ef..c547d55fe 100644 --- a/cmd/config/codex_test.go +++ b/cmd/launch/codex_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "slices" diff --git a/cmd/launch/command_test.go b/cmd/launch/command_test.go new file mode 100644 index 000000000..cb90bb3d1 --- /dev/null +++ b/cmd/launch/command_test.go @@ -0,0 +1,494 @@ +package launch + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/cmd/config" + "github.com/spf13/cobra" +) + +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + + oldStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("failed to create stderr pipe: %v", err) + } + os.Stderr = w + defer func() { + os.Stderr = oldStderr + }() + + done := make(chan string, 1) + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + done <- buf.String() + }() + + fn() + + _ = w.Close() + return <-done +} + +func TestLaunchCmd(t *testing.T) { + mockCheck := func(cmd *cobra.Command, args []string) error { + return nil + } + mockTUI := func(cmd *cobra.Command) {} + cmd := LaunchCmd(mockCheck, mockTUI) + + t.Run("command structure", func(t *testing.T) { + if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" { + t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]") + } + if cmd.Short == "" { + t.Error("Short description should not be empty") + } + if cmd.Long == "" { + t.Error("Long description should not be empty") + } + }) + + t.Run("flags exist", func(t *testing.T) { + if cmd.Flags().Lookup("model") == nil { + t.Error("--model flag should exist") + } + if cmd.Flags().Lookup("config") == nil { + t.Error("--config flag should exist") + } + if cmd.Flags().Lookup("yes") == nil { + t.Error("--yes flag should exist") + } + }) + + t.Run("PreRunE is set", func(t *testing.T) { + if cmd.PreRunE == nil { + t.Error("PreRunE should be set to checkServerHeartbeat") + } + }) +} + +func TestLaunchCmdTUICallback(t *testing.T) { + mockCheck := func(cmd *cobra.Command, args []string) error { + return nil + } + + t.Run("no args calls TUI", func(t *testing.T) { + tuiCalled := false + mockTUI := func(cmd *cobra.Command) { + tuiCalled = true + } + + cmd := LaunchCmd(mockCheck, mockTUI) + cmd.SetArgs([]string{}) + _ = cmd.Execute() + + if !tuiCalled { + t.Error("TUI callback should be called when no args provided") + } + }) + + t.Run("integration arg bypasses TUI", func(t *testing.T) { + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + tuiCalled := false + mockTUI := func(cmd *cobra.Command) { + tuiCalled = true + } + + cmd := LaunchCmd(mockCheck, mockTUI) + cmd.SetArgs([]string{"claude"}) + _ = cmd.Execute() + + if tuiCalled { + t.Error("TUI callback should NOT be called when integration arg provided") + } + }) + + t.Run("--model flag without integration returns error", func(t *testing.T) { + tuiCalled := false + mockTUI := func(cmd *cobra.Command) { + tuiCalled = true + } + + cmd := LaunchCmd(mockCheck, mockTUI) + cmd.SetArgs([]string{"--model", "test-model"}) + err := cmd.Execute() + + if err == nil { + t.Fatal("expected --model without an integration to fail") + } + if !strings.Contains(err.Error(), "require an integration name") { + t.Fatalf("expected integration-name guidance, got %v", err) + } + if tuiCalled { + t.Error("TUI callback should NOT be called when --model is provided without an integration") + } + }) + + t.Run("--config flag without integration returns error", func(t *testing.T) { + tuiCalled := false + mockTUI := func(cmd *cobra.Command) { + tuiCalled = true + } + + cmd := LaunchCmd(mockCheck, mockTUI) + cmd.SetArgs([]string{"--config"}) + err := cmd.Execute() + + if err == nil { + t.Fatal("expected --config without an integration to fail") + } + if !strings.Contains(err.Error(), "require an integration name") { + t.Fatalf("expected integration-name guidance, got %v", err) + } + if tuiCalled { + t.Error("TUI callback should NOT be called when --config is provided without an integration") + } + }) + + t.Run("extra args without integration return error", func(t *testing.T) { + tuiCalled := false + mockTUI := func(cmd *cobra.Command) { + tuiCalled = true + } + + cmd := LaunchCmd(mockCheck, mockTUI) + cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"}) + err := cmd.Execute() + + if err == nil { + t.Fatal("expected flags and extra args without an integration to fail") + } + if !strings.Contains(err.Error(), "require an integration name") { + t.Fatalf("expected integration-name guidance, got %v", err) + } + if tuiCalled { + t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration") + } + }) +} + +func TestLaunchCmdNilHeartbeat(t *testing.T) { + cmd := LaunchCmd(nil, nil) + if cmd == nil { + t.Fatal("LaunchCmd returned nil") + } + if cmd.PreRunE != nil { + t.Log("Note: PreRunE is set even when nil is passed (acceptable)") + } +} + +func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil { + t.Fatalf("failed to seed saved config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) + case "/api/show": + fmt.Fprintf(w, `{"model":"llama3.2"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherEditorRunner{} + restore := OverrideIntegration("stubeditor", stub) + defer restore() + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("launch command failed: %v", err) + } + + saved, err := config.LoadIntegration("stubeditor") + if err != nil { + t.Fatalf("failed to reload integration config: %v", err) + } + if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" { + t.Fatalf("saved models mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" { + t.Fatalf("editor models mismatch (-want +got):\n%s", diff) + } + if stub.ranModel != "llama3.2" { + t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel) + } +} + +func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherSingleRunner{} + restore := OverrideIntegration("stubapp", stub) + defer restore() + + oldSelector := DefaultSingleSelector + defer func() { DefaultSingleSelector = oldSelector }() + + var selectorCalls int + var gotCurrent string + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + selectorCalls++ + gotCurrent = current + return "llama3.2", nil + } + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"}) + stderr := captureStderr(t, func() { + if err := cmd.Execute(); err != nil { + t.Fatalf("launch command failed: %v", err) + } + }) + + if selectorCalls != 1 { + t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls) + } + if gotCurrent != "" { + t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent) + } + if stub.ranModel != "llama3.2" { + t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel) + } + if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") { + t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr) + } + + saved, err := config.LoadIntegration("stubapp") + if err != nil { + t.Fatalf("failed to reload integration config: %v", err) + } + if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" { + t.Fatalf("saved models mismatch (-want +got):\n%s", diff) + } +} + +func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}} + restore := OverrideIntegration("stubeditor", stub) + defer restore() + + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatalf("unexpected prompt with --yes: %q", prompt) + return false, nil + } + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("launch command with --yes failed: %v", err) + } + + if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" { + t.Fatalf("editor models mismatch (-want +got):\n%s", diff) + } + if stub.ranModel != "llama3.2" { + t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel) + } +} + +func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherSingleRunner{} + restore := OverrideIntegration("stubapp", stub) + defer restore() + + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt) + return false, nil + } + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("launch command with --yes failed: %v", err) + } + + if !pullCalled { + t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode") + } + if stub.ranModel != "missing-model" { + t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel) + } +} + +func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}} + restore := OverrideIntegration("stubeditor", stub) + defer restore() + + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt) + return false, nil + } + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"}) + err := cmd.Execute() + if err == nil { + t.Fatal("expected launch command to fail without --yes in headless mode") + } + if !strings.Contains(err.Error(), "re-run with --yes") { + t.Fatalf("expected actionable --yes guidance, got %v", err) + } + if len(stub.edited) != 0 { + t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited) + } + if stub.ranModel != "" { + t.Fatalf("expected launch to abort before run, got %q", stub.ranModel) + } +} + +func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil { + t.Fatalf("failed to seed saved config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"qwen3:8b"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + stub := &launcherSingleRunner{} + restore := OverrideIntegration("stubapp", stub) + defer restore() + + oldSelector := DefaultSingleSelector + defer func() { DefaultSingleSelector = oldSelector }() + + var gotCurrent string + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + gotCurrent = current + return "qwen3:8b", nil + } + + cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) + cmd.SetArgs([]string{"stubapp"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("launch command failed: %v", err) + } + + if gotCurrent != "llama3.2" { + t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent) + } + if stub.ranModel != "qwen3:8b" { + t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel) + } + + saved, err := config.LoadIntegration("stubapp") + if err != nil { + t.Fatalf("failed to reload integration config: %v", err) + } + if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" { + t.Fatalf("saved models mismatch (-want +got):\n%s", diff) + } +} diff --git a/cmd/config/droid.go b/cmd/launch/droid.go similarity index 88% rename from cmd/config/droid.go rename to cmd/launch/droid.go index ed88c0177..1612352a5 100644 --- a/cmd/config/droid.go +++ b/cmd/launch/droid.go @@ -1,15 +1,14 @@ -package config +package launch import ( - "context" "encoding/json" - "errors" "fmt" "os" "os/exec" "path/filepath" "slices" + "github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/envconfig" ) @@ -46,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error { return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart") } - // Call Edit() to ensure config is up-to-date before launch - models := []string{model} - if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 { - models = config.Models - } - var err error - models, err = resolveEditorModels("droid", models, func() ([]string, error) { - return selectModels(context.Background(), "droid", "") - }) - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - if err := d.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - cmd := exec.Command("droid", args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -110,6 +90,16 @@ func (d *Droid) Edit(models []string) error { json.Unmarshal(data, &settings) // ignore error, zero values are fine } + settingsMap = updateDroidSettings(settingsMap, settings, models) + + data, err := json.MarshalIndent(settingsMap, "", " ") + if err != nil { + return err + } + return fileutil.WriteWithBackup(settingsPath, data) +} + +func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any { // Keep only non-Ollama models from the raw map (preserves extra fields) // Rebuild Ollama models var nonOllamaModels []any @@ -165,12 +155,7 @@ func (d *Droid) Edit(models []string) error { } settingsMap["sessionDefaultSettings"] = sessionSettings - - data, err := json.MarshalIndent(settingsMap, "", " ") - if err != nil { - return err - } - return writeWithBackup(settingsPath, data) + return settingsMap } func (d *Droid) Models() []string { diff --git a/cmd/config/droid_test.go b/cmd/launch/droid_test.go similarity index 98% rename from cmd/config/droid_test.go rename to cmd/launch/droid_test.go index ac26aef58..f37143060 100644 --- a/cmd/config/droid_test.go +++ b/cmd/launch/droid_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "encoding/json" @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/ollama/ollama/cmd/internal/fileutil" ) func TestDroidIntegration(t *testing.T) { @@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) { t.Fatalf("Edit with duplicates failed: %v", err) } - settings, err := readJSONFile(settingsPath) + settings, err := fileutil.ReadJSON(settingsPath) if err != nil { t.Fatalf("readJSONFile failed: %v", err) } @@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) { } // Malformed entries (non-object) are dropped - only valid model objects are preserved - settings, _ := readJSONFile(settingsPath) + settings, _ := fileutil.ReadJSON(settingsPath) customModels, _ := settings["customModels"].([]any) // Should have: 1 new Ollama model only (malformed entries dropped) @@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) { } // Should create proper sessionDefaultSettings - settings, _ := readJSONFile(settingsPath) + settings, _ := fileutil.ReadJSON(settingsPath) session, ok := settings["sessionDefaultSettings"].(map[string]any) if !ok { t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"]) @@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) { } func TestDroidEdit_MissingCustomModelsKey(t *testing.T) { - d := &Droid{} - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - settingsDir := filepath.Join(tmpDir, ".factory") - settingsPath := filepath.Join(settingsDir, "settings.json") - - os.MkdirAll(settingsDir, 0o755) - // No customModels key at all original := `{ "diffMode": "github", "sessionDefaultSettings": {"autonomyMode": "auto-high"} }` - os.WriteFile(settingsPath, []byte(original), 0o644) - if err := d.Edit([]string{"model-a"}); err != nil { + var settingsStruct droidSettings + var settings map[string]any + if err := json.Unmarshal([]byte(original), &settings); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil { t.Fatal(err) } - data, _ := os.ReadFile(settingsPath) - var settings map[string]any - json.Unmarshal(data, &settings) + settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"}) // Original fields preserved if settings["diffMode"] != "github" { t.Error("diffMode not preserved") } + session, ok := settings["sessionDefaultSettings"].(map[string]any) + if !ok { + t.Fatal("sessionDefaultSettings not preserved") + } + if session["autonomyMode"] != "auto-high" { + t.Error("sessionDefaultSettings.autonomyMode not preserved") + } // customModels created models, ok := settings["customModels"].([]any) diff --git a/cmd/config/integrations_test.go b/cmd/launch/integrations_test.go similarity index 75% rename from cmd/config/integrations_test.go rename to cmd/launch/integrations_test.go index 8f4c262df..1219f4e55 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/launch/integrations_test.go @@ -1,8 +1,9 @@ -package config +package launch import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -13,12 +14,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" - "github.com/spf13/cobra" ) type stubEditorRunner struct { edited [][]string ranModel string + editErr error } func (s *stubEditorRunner) Run(model string, args []string) error { @@ -31,6 +32,9 @@ func (s *stubEditorRunner) String() string { return "StubEditor" } func (s *stubEditorRunner) Paths() []string { return nil } func (s *stubEditorRunner) Edit(models []string) error { + if s.editErr != nil { + return s.editErr + } cloned := append([]string(nil), models...) s.edited = append(s.edited, cloned) return nil @@ -111,120 +115,8 @@ func TestHasLocalModel(t *testing.T) { } } -func TestLaunchCmd(t *testing.T) { - // Mock checkServerHeartbeat that always succeeds - mockCheck := func(cmd *cobra.Command, args []string) error { - return nil - } - mockTUI := func(cmd *cobra.Command) {} - cmd := LaunchCmd(mockCheck, mockTUI) - - t.Run("command structure", func(t *testing.T) { - if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" { - t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]") - } - if cmd.Short == "" { - t.Error("Short description should not be empty") - } - if cmd.Long == "" { - t.Error("Long description should not be empty") - } - }) - - t.Run("flags exist", func(t *testing.T) { - modelFlag := cmd.Flags().Lookup("model") - if modelFlag == nil { - t.Error("--model flag should exist") - } - - configFlag := cmd.Flags().Lookup("config") - if configFlag == nil { - t.Error("--config flag should exist") - } - }) - - t.Run("PreRunE is set", func(t *testing.T) { - if cmd.PreRunE == nil { - t.Error("PreRunE should be set to checkServerHeartbeat") - } - }) -} - -func TestLaunchCmd_TUICallback(t *testing.T) { - mockCheck := func(cmd *cobra.Command, args []string) error { - return nil - } - - t.Run("no args calls TUI", func(t *testing.T) { - tuiCalled := false - mockTUI := func(cmd *cobra.Command) { - tuiCalled = true - } - - cmd := LaunchCmd(mockCheck, mockTUI) - cmd.SetArgs([]string{}) - _ = cmd.Execute() - - if !tuiCalled { - t.Error("TUI callback should be called when no args provided") - } - }) - - t.Run("integration arg bypasses TUI", func(t *testing.T) { - srv := httptest.NewServer(http.NotFoundHandler()) - defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - - tuiCalled := false - mockTUI := func(cmd *cobra.Command) { - tuiCalled = true - } - - cmd := LaunchCmd(mockCheck, mockTUI) - cmd.SetArgs([]string{"claude"}) - // Will error because claude isn't configured, but that's OK - _ = cmd.Execute() - - if tuiCalled { - t.Error("TUI callback should NOT be called when integration arg provided") - } - }) - - t.Run("--model flag bypasses TUI", func(t *testing.T) { - tuiCalled := false - mockTUI := func(cmd *cobra.Command) { - tuiCalled = true - } - - cmd := LaunchCmd(mockCheck, mockTUI) - cmd.SetArgs([]string{"--model", "test-model"}) - // Will error because no integration specified, but that's OK - _ = cmd.Execute() - - if tuiCalled { - t.Error("TUI callback should NOT be called when --model flag provided") - } - }) - - t.Run("--config flag bypasses TUI", func(t *testing.T) { - tuiCalled := false - mockTUI := func(cmd *cobra.Command) { - tuiCalled = true - } - - cmd := LaunchCmd(mockCheck, mockTUI) - cmd.SetArgs([]string{"--config"}) - // Will error because no integration specified, but that's OK - _ = cmd.Execute() - - if tuiCalled { - t.Error("TUI callback should NOT be called when --config flag provided") - } - }) -} - -func TestRunIntegration_UnknownIntegration(t *testing.T) { - err := runIntegration("unknown-integration", "model", nil) +func TestLookupIntegration_UnknownIntegration(t *testing.T) { + _, _, err := LookupIntegration("unknown-integration") if err == nil { t.Error("expected error for unknown integration, got nil") } @@ -233,6 +125,17 @@ func TestRunIntegration_UnknownIntegration(t *testing.T) { } } +func TestIsIntegrationInstalled_UnknownIntegrationReturnsFalse(t *testing.T) { + stderr := captureStderr(t, func() { + if IsIntegrationInstalled("unknown-integration") { + t.Fatal("expected unknown integration to report not installed") + } + }) + if !strings.Contains(stderr, `Ollama couldn't find integration "unknown-integration", so it'll show up as not installed.`) { + t.Fatalf("expected unknown-integration warning, got stderr: %q", stderr) + } +} + func TestHasLocalModel_DocumentsHeuristic(t *testing.T) { tests := []struct { name string @@ -261,19 +164,6 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) { } } -func TestLaunchCmd_NilHeartbeat(t *testing.T) { - // This should not panic - cmd creation should work even with nil - cmd := LaunchCmd(nil, nil) - if cmd == nil { - t.Fatal("LaunchCmd returned nil") - } - - // PreRunE should be nil when passed nil - if cmd.PreRunE != nil { - t.Log("Note: PreRunE is set even when nil is passed (acceptable)") - } -} - func TestAllIntegrations_HaveRequiredMethods(t *testing.T) { for name, r := range integrations { t.Run(name, func(t *testing.T) { @@ -730,7 +620,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) { } // Verify loadIntegration returns the saved models - saved, err := loadIntegration("opencode") + saved, err := LoadIntegration("opencode") if err != nil { t.Fatal(err) } @@ -742,153 +632,56 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) { } } -func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - +func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t *testing.T) { + var statusCalls int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/api/status": + statusCalls++ fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) default: w.WriteHeader(http.StatusNotFound) } })) defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - pickerCalled := false - models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) { - pickerCalled = true - return []string{"llama3.2"}, nil - }) - if err != nil { - t.Fatalf("resolveEditorLaunchModels returned error: %v", err) - } - if !pickerCalled { - t.Fatal("expected model picker to be called when all models are filtered") - } - if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" { - t.Fatalf("resolved models mismatch (-want +got):\n%s", diff) + u, _ := url.Parse(srv.URL) + client := &launcherClient{ + apiClient: api.NewClient(u, srv.Client()), } - saved, err := loadIntegration("opencode") - if err != nil { - t.Fatalf("failed to reload integration config: %v", err) + filtered := client.filterDisabledCloudModels(context.Background(), []string{"llama3.2", "glm-5:cloud", "qwen3.5:cloud"}) + if diff := cmp.Diff([]string{"llama3.2"}, filtered); diff != "" { + t.Fatalf("filtered models mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" { - t.Fatalf("saved models mismatch (-want +got):\n%s", diff) + if statusCalls != 1 { + t.Fatalf("expected one cloud status lookup, got %d", statusCalls) } } -func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) { +func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/status": - fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - - pickerCalled := false - models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) { - pickerCalled = true - return []string{"qwen3.5"}, nil - }) - if err != nil { - t.Fatalf("resolveEditorLaunchModels returned error: %v", err) - } - if pickerCalled { - t.Fatal("picker should not be called when a local model remains") - } - if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" { - t.Fatalf("resolved models mismatch (-want +got):\n%s", diff) + if err := SaveIntegration("droid", []string{"existing-model"}); err != nil { + t.Fatalf("failed to seed config: %v", err) } - saved, err := loadIntegration("droid") - if err != nil { - t.Fatalf("failed to reload integration config: %v", err) + editor := &stubEditorRunner{editErr: errors.New("boom")} + err := prepareEditorIntegration("droid", editor, editor, []string{"new-model"}) + if err == nil || !strings.Contains(err.Error(), "setup failed") { + t.Fatalf("expected setup failure, got %v", err) } - if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" { + + saved, err := LoadIntegration("droid") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := cmp.Diff([]string{"existing-model"}, saved.Models); diff != "" { t.Fatalf("saved models mismatch (-want +got):\n%s", diff) } } -func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil { - t.Fatalf("failed to seed saved config: %v", err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/status": - fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) - case "/api/show": - fmt.Fprintf(w, `{"model":"llama3.2"}`) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - - stub := &stubEditorRunner{} - old, existed := integrations["stubeditor"] - integrations["stubeditor"] = stub - defer func() { - if existed { - integrations["stubeditor"] = old - } else { - delete(integrations, "stubeditor") - } - }() - - cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {}) - cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"}) - if err := cmd.Execute(); err != nil { - t.Fatalf("launch command failed: %v", err) - } - - saved, err := loadIntegration("stubeditor") - if err != nil { - t.Fatalf("failed to reload integration config: %v", err) - } - if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" { - t.Fatalf("saved models mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" { - t.Fatalf("editor models mismatch (-want +got):\n%s", diff) - } - if stub.ranModel != "llama3.2" { - t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel) - } -} - -func TestAliasConfigurerInterface(t *testing.T) { - t.Run("claude implements AliasConfigurer", func(t *testing.T) { - claude := &Claude{} - if _, ok := interface{}(claude).(AliasConfigurer); !ok { - t.Error("Claude should implement AliasConfigurer") - } - }) - - t.Run("codex does not implement AliasConfigurer", func(t *testing.T) { - codex := &Codex{} - if _, ok := interface{}(codex).(AliasConfigurer); ok { - t.Error("Codex should not implement AliasConfigurer") - } - }) -} - func TestShowOrPull_ModelExists(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/show" { @@ -903,12 +696,237 @@ func TestShowOrPull_ModelExists(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - err := ShowOrPull(context.Background(), client, "test-model") + err := showOrPullWithPolicy(context.Background(), client, "test-model", missingModelPromptPull, false) if err != nil { t.Errorf("showOrPull should return nil when model exists, got: %v", err) } } +func TestShowOrPullWithPolicy_ModelExists(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"model":"test-model"}`) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "test-model", missingModelFail, false) + if err != nil { + t.Errorf("showOrPullWithPolicy should return nil when model exists, got: %v", err) + } +} + +func TestShowOrPullWithPolicy_ModelNotFound_FailDoesNotPromptOrPull(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatal("confirm prompt should not be called with fail policy") + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelFail, false) + if err == nil { + t.Fatal("expected fail policy to return an error for missing model") + } + if !strings.Contains(err.Error(), "ollama pull missing-model") { + t.Fatalf("expected actionable pull guidance, got: %v", err) + } + if pullCalled { + t.Fatal("expected pull not to be called with fail policy") + } +} + +func TestShowOrPullWithPolicy_ModelNotFound_PromptPolicyPulls(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + if !strings.Contains(prompt, "missing-model") { + t.Fatalf("expected prompt to mention missing model, got %q", prompt) + } + return true, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false) + if err != nil { + t.Fatalf("expected prompt policy to pull and succeed, got %v", err) + } + if !pullCalled { + t.Fatal("expected pull to be called with prompt policy") + } +} + +func TestShowOrPullWithPolicy_ModelNotFound_AutoPullPolicyPullsWithoutPrompt(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatalf("confirm prompt should not be called with auto-pull policy: %q", prompt) + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelAutoPull, false) + if err != nil { + t.Fatalf("expected auto-pull policy to pull and succeed, got %v", err) + } + if !pullCalled { + t.Fatal("expected pull to be called with auto-pull policy") + } +} + +func TestShowOrPullWithPolicy_CloudModelNotFound_FailsEarlyForAllPolicies(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatal("confirm prompt should not be called for explicit cloud models") + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + for _, policy := range []missingModelPolicy{missingModelPromptPull, missingModelAutoPull, missingModelFail} { + t.Run(fmt.Sprintf("policy=%d", policy), func(t *testing.T) { + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", policy, true) + if err == nil { + t.Fatalf("expected cloud model not-found error for policy %d", policy) + } + if !strings.Contains(err.Error(), `model "glm-5:cloud" not found`) { + t.Fatalf("expected not-found error for policy %d, got %v", policy, err) + } + if pullCalled { + t.Fatalf("expected pull not to be called for cloud model with policy %d", policy) + } + }) + } +} + +func TestShowOrPullWithPolicy_CloudModelDisabled_FailsWithCloudDisabledError(t *testing.T) { + oldHook := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatal("confirm prompt should not be called for explicit cloud models") + return false, nil + } + defer func() { DefaultConfirmPrompt = oldHook }() + + for _, policy := range []missingModelPolicy{missingModelPromptPull, missingModelAutoPull, missingModelFail} { + t.Run(fmt.Sprintf("policy=%d", policy), func(t *testing.T) { + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"model not found"}`) + case "/api/status": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"success"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", policy, true) + if err == nil { + t.Fatalf("expected cloud disabled error for policy %d", policy) + } + if !strings.Contains(err.Error(), "remote inference is unavailable") { + t.Fatalf("expected cloud disabled error for policy %d, got %v", policy, err) + } + if pullCalled { + t.Fatalf("expected pull not to be called for cloud model with policy %d", policy) + } + }) + } +} + func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) @@ -920,7 +938,7 @@ func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) { client := api.NewClient(u, srv.Client()) // confirmPrompt will fail in test (no terminal), so showOrPull should return an error - err := ShowOrPull(context.Background(), client, "missing-model") + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false) if err == nil { t.Error("showOrPull should return error when model not found and no terminal available") } @@ -945,7 +963,7 @@ func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - _ = ShowOrPull(context.Background(), client, "qwen3.5") + _ = showOrPullWithPolicy(context.Background(), client, "qwen3.5", missingModelPromptPull, false) if receivedModel != "qwen3.5" { t.Errorf("expected Show to be called with %q, got %q", "qwen3.5", receivedModel) } @@ -981,7 +999,7 @@ func TestShowOrPull_ModelNotFound_ConfirmYes_Pulls(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - err := ShowOrPull(context.Background(), client, "missing-model") + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false) if err != nil { t.Errorf("ShowOrPull should succeed after pull, got: %v", err) } @@ -1013,13 +1031,13 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - err := ShowOrPull(context.Background(), client, "missing-model") + err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false) if err == nil { t.Error("ShowOrPull should return error when user declines") } } -func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) { +func TestShowOrPull_CloudModel_NotFoundDoesNotPull(t *testing.T) { // Confirm prompt should NOT be called for explicit cloud models oldHook := DefaultConfirmPrompt DefaultConfirmPrompt = func(prompt string) (bool, error) { @@ -1047,16 +1065,19 @@ func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - err := ShowOrPull(context.Background(), client, "glm-5:cloud") - if err != nil { - t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) + err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", missingModelPromptPull, true) + if err == nil { + t.Error("ShowOrPull should return not-found error for cloud model") + } + if !strings.Contains(err.Error(), `model "glm-5:cloud" not found`) { + t.Errorf("expected cloud model not-found error, got: %v", err) } if pullCalled { t.Error("expected pull not to be called for cloud model") } } -func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) { +func TestShowOrPull_CloudLegacySuffix_NotFoundDoesNotPull(t *testing.T) { // Confirm prompt should NOT be called for explicit cloud models oldHook := DefaultConfirmPrompt DefaultConfirmPrompt = func(prompt string) (bool, error) { @@ -1084,85 +1105,18 @@ func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) { u, _ := url.Parse(srv.URL) client := api.NewClient(u, srv.Client()) - err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud") - if err != nil { - t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) + err := showOrPullWithPolicy(context.Background(), client, "gpt-oss:20b-cloud", missingModelPromptPull, true) + if err == nil { + t.Error("ShowOrPull should return not-found error for cloud model") + } + if !strings.Contains(err.Error(), `model "gpt-oss:20b-cloud" not found`) { + t.Errorf("expected cloud model not-found error, got: %v", err) } if pullCalled { t.Error("expected pull not to be called for cloud model") } } -func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) { - oldHook := DefaultConfirmPrompt - DefaultConfirmPrompt = func(prompt string) (bool, error) { - t.Error("confirm prompt should not be called for cloud models") - return false, nil - } - defer func() { DefaultConfirmPrompt = oldHook }() - - err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud") - if err != nil { - t.Fatalf("expected no error for cloud model, got %v", err) - } - - err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud") - if err != nil { - t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err) - } -} - -func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) { - var pullCalled bool - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/status": - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, `{"error":"not found"}`) - case "/api/tags": - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"models":[]}`) - case "/api/me": - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"name":"test-user"}`) - case "/api/pull": - pullCalled = true - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"success"}`) - default: - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, `{"error":"not found"}`) - } - })) - defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - - single := func(title string, items []ModelItem, current string) (string, error) { - for _, item := range items { - if item.Name == "glm-5:cloud" { - return item.Name, nil - } - } - t.Fatalf("expected glm-5:cloud in selector items, got %v", items) - return "", nil - } - - multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) { - return nil, fmt.Errorf("multi selector should not be called") - } - - selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi) - if err != nil { - t.Fatalf("selectModelsWithSelectors returned error: %v", err) - } - if !slices.Equal(selected, []string{"glm-5:cloud"}) { - t.Fatalf("unexpected selected models: %v", selected) - } - if pullCalled { - t.Fatal("expected cloud selection to skip pull") - } -} - func TestConfirmPrompt_DelegatesToHook(t *testing.T) { oldHook := DefaultConfirmPrompt var hookCalled bool @@ -1175,7 +1129,7 @@ func TestConfirmPrompt_DelegatesToHook(t *testing.T) { } defer func() { DefaultConfirmPrompt = oldHook }() - ok, err := confirmPrompt("test prompt?") + ok, err := ConfirmPrompt("test prompt?") if err != nil { t.Errorf("unexpected error: %v", err) } @@ -1258,6 +1212,66 @@ func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) { } } +func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) { + oldSignIn := DefaultSignIn + DefaultSignIn = func(modelName, signInURL string) (string, error) { + return "", ErrCancelled + } + defer func() { DefaultSignIn = oldSignIn }() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + case "/api/me": + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := ensureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"}) + if !errors.Is(err, ErrCancelled) { + t.Fatalf("expected ErrCancelled, got %v", err) + } +} + +func TestEnsureAuth_DeclinedFallbackReturnsCancelled(t *testing.T) { + oldConfirm := DefaultConfirmPrompt + DefaultConfirmPrompt = func(prompt string) (bool, error) { + return false, nil + } + defer func() { DefaultConfirmPrompt = oldConfirm }() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + case "/api/me": + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + u, _ := url.Parse(srv.URL) + client := api.NewClient(u, srv.Client()) + + err := ensureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"}) + if !errors.Is(err, ErrCancelled) { + t.Fatalf("expected ErrCancelled, got %v", err) + } +} + func TestHyperlink(t *testing.T) { tests := []struct { name string @@ -1306,7 +1320,7 @@ func TestHyperlink(t *testing.T) { } } -func TestIntegrationInstallHint(t *testing.T) { +func TestIntegration_InstallHint(t *testing.T) { tests := []struct { name string input string @@ -1342,7 +1356,11 @@ func TestIntegrationInstallHint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := IntegrationInstallHint(tt.input) + got := "" + integration, err := integrationFor(tt.input) + if err == nil { + got = integration.installHint + } if tt.wantEmpty { if got != "" { t.Errorf("expected empty hint, got %q", got) @@ -1477,31 +1495,7 @@ func TestBuildModelList_Descriptions(t *testing.T) { }) } -func TestLaunchIntegration_UnknownIntegration(t *testing.T) { - err := LaunchIntegration("nonexistent-integration") - if err == nil { - t.Fatal("expected error for unknown integration") - } - if !strings.Contains(err.Error(), "unknown integration") { - t.Errorf("error should mention 'unknown integration', got: %v", err) - } -} - -func TestLaunchIntegration_NotConfigured(t *testing.T) { - tmpDir := t.TempDir() - setTestHome(t, tmpDir) - - // Claude is a known integration but not configured in temp dir - err := LaunchIntegration("claude") - if err == nil { - t.Fatal("expected error when integration is not configured") - } - if !strings.Contains(err.Error(), "not configured") { - t.Errorf("error should mention 'not configured', got: %v", err) - } -} - -func TestIsEditorIntegration(t *testing.T) { +func TestIntegration_Editor(t *testing.T) { tests := []struct { name string want bool @@ -1515,8 +1509,13 @@ func TestIsEditorIntegration(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := IsEditorIntegration(tt.name); got != tt.want { - t.Errorf("IsEditorIntegration(%q) = %v, want %v", tt.name, got, tt.want) + got := false + integration, err := integrationFor(tt.name) + if err == nil { + got = integration.editor + } + if got != tt.want { + t.Errorf("integrationFor(%q).editor = %v, want %v", tt.name, got, tt.want) } }) } @@ -1543,13 +1542,3 @@ func TestIntegrationModels(t *testing.T) { } }) } - -func TestSaveAndEditIntegration_UnknownIntegration(t *testing.T) { - err := SaveAndEditIntegration("nonexistent", []string{"model"}) - if err == nil { - t.Fatal("expected error for unknown integration") - } - if !strings.Contains(err.Error(), "unknown integration") { - t.Errorf("error should mention 'unknown integration', got: %v", err) - } -} diff --git a/cmd/launch/launch.go b/cmd/launch/launch.go new file mode 100644 index 000000000..5704c4aa0 --- /dev/null +++ b/cmd/launch/launch.go @@ -0,0 +1,833 @@ +package launch + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/config" + "github.com/spf13/cobra" + "golang.org/x/term" +) + +// LauncherState is the launch-owned snapshot used to render the root launcher menu. +type LauncherState struct { + LastSelection string + RunModel string + RunModelUsable bool + Integrations map[string]LauncherIntegrationState +} + +// LauncherIntegrationState is the launch-owned status for one launcher integration. +type LauncherIntegrationState struct { + Name string + DisplayName string + Description string + Installed bool + AutoInstallable bool + Selectable bool + Changeable bool + CurrentModel string + ModelUsable bool + InstallHint string + Editor bool +} + +// RunModelRequest controls how the root launcher resolves the chat model. +type RunModelRequest struct { + ForcePicker bool +} + +// LaunchConfirmMode controls confirmation behavior across launch flows. +type LaunchConfirmMode int + +const ( + // LaunchConfirmPrompt prompts the user for confirmation. + LaunchConfirmPrompt LaunchConfirmMode = iota + // LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted. + LaunchConfirmAutoApprove + // LaunchConfirmRequireYes rejects confirmation requests with a --yes hint. + LaunchConfirmRequireYes +) + +// LaunchMissingModelMode controls local missing-model handling in launch flows. +type LaunchMissingModelMode int + +const ( + // LaunchMissingModelPromptToPull prompts to pull a missing local model. + LaunchMissingModelPromptToPull LaunchMissingModelMode = iota + // LaunchMissingModelAutoPull pulls a missing local model without prompting. + LaunchMissingModelAutoPull + // LaunchMissingModelFail fails immediately when a local model is missing. + LaunchMissingModelFail +) + +// LaunchPolicy controls launch behavior that may vary by caller context. +type LaunchPolicy struct { + Confirm LaunchConfirmMode + MissingModel LaunchMissingModelMode +} + +func defaultLaunchPolicy(interactive bool) LaunchPolicy { + policy := LaunchPolicy{ + Confirm: LaunchConfirmPrompt, + MissingModel: LaunchMissingModelPromptToPull, + } + if !interactive { + policy.Confirm = LaunchConfirmRequireYes + policy.MissingModel = LaunchMissingModelFail + } + return policy +} + +func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy { + switch p.Confirm { + case LaunchConfirmAutoApprove: + return launchConfirmPolicy{yes: true} + case LaunchConfirmRequireYes: + return launchConfirmPolicy{requireYesMessage: true} + default: + return launchConfirmPolicy{} + } +} + +func (p LaunchPolicy) missingModelPolicy() missingModelPolicy { + switch p.MissingModel { + case LaunchMissingModelAutoPull: + return missingModelAutoPull + case LaunchMissingModelFail: + return missingModelFail + default: + return missingModelPromptPull + } +} + +// IntegrationLaunchRequest controls the canonical integration launcher flow. +type IntegrationLaunchRequest struct { + Name string + ModelOverride string + ForceConfigure bool + ConfigureOnly bool + ExtraArgs []string + Policy *LaunchPolicy +} + +var isInteractiveSession = func() bool { + return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) +} + +// Runner executes a model with an integration. +type Runner interface { + Run(model string, args []string) error + String() string +} + +// Editor can edit config files for integrations that support model configuration. +type Editor interface { + Paths() []string + Edit(models []string) error + Models() []string +} + +type modelInfo struct { + Name string + Remote bool + ToolCapable bool +} + +// ModelInfo re-exports launcher model inventory details for callers. +type ModelInfo = modelInfo + +// ModelItem represents a model for selection UIs. +type ModelItem struct { + Name string + Description string + Recommended bool +} + +// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors. +func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error { + oldSingle := DefaultSingleSelector + oldMulti := DefaultMultiSelector + if single != nil { + DefaultSingleSelector = single + } + if multi != nil { + DefaultMultiSelector = multi + } + defer func() { + DefaultSingleSelector = oldSingle + DefaultMultiSelector = oldMulti + }() + + return LaunchIntegration(ctx, IntegrationLaunchRequest{ + Name: name, + ForceConfigure: true, + ConfigureOnly: true, + }) +} + +// ConfigureIntegration allows the user to select/change the model for an integration. +func ConfigureIntegration(ctx context.Context, name string) error { + return LaunchIntegration(ctx, IntegrationLaunchRequest{ + Name: name, + ForceConfigure: true, + ConfigureOnly: true, + }) +} + +// LaunchCmd returns the cobra command for launching integrations. +// The runTUI callback is called when the root launcher UI should be shown. +func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command { + var modelFlag string + var configFlag bool + var yesFlag bool + + cmd := &cobra.Command{ + Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]", + Short: "Launch the Ollama menu or an integration", + Long: `Launch the Ollama interactive menu, or directly launch a specific integration. + +Without arguments, this is equivalent to running 'ollama' directly. +Flags and extra arguments require an integration name. + +Supported integrations: + claude Claude Code + cline Cline + codex Codex + droid Droid + opencode OpenCode + openclaw OpenClaw (aliases: clawdbot, moltbot) + pi Pi + +Examples: + ollama launch + ollama launch claude + ollama launch claude --model + ollama launch droid --config (does not auto-launch) + ollama launch codex -- -p myprofile (pass extra args to integration) + ollama launch codex -- --sandbox workspace-write`, + Args: cobra.ArbitraryArgs, + PreRunE: checkServerHeartbeat, + RunE: func(cmd *cobra.Command, args []string) error { + policy := defaultLaunchPolicy(isInteractiveSession()) + if yesFlag { + policy.Confirm = LaunchConfirmAutoApprove + if policy.MissingModel == LaunchMissingModelFail { + policy.MissingModel = LaunchMissingModelAutoPull + } + } + restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy()) + defer restoreConfirmPolicy() + + var name string + var passArgs []string + dashIdx := cmd.ArgsLenAtDash() + + if dashIdx == -1 { + if len(args) > 1 { + return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:]) + } + if len(args) == 1 { + name = args[0] + } + } else { + if dashIdx > 1 { + return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx) + } + if dashIdx == 1 { + name = args[0] + } + passArgs = args[dashIdx:] + } + + if name == "" { + if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || len(passArgs) > 0 { + return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'") + } + runTUI(cmd) + return nil + } + + if modelFlag != "" && isCloudModelName(modelFlag) { + if client, err := api.ClientFromEnvironment(); err == nil { + if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled { + fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag) + modelFlag = "" + } + } + } + + err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{ + Name: name, + ModelOverride: modelFlag, + ForceConfigure: configFlag || modelFlag == "", + ConfigureOnly: configFlag, + ExtraArgs: passArgs, + Policy: &policy, + }) + if errors.Is(err, ErrCancelled) { + return nil + } + return err + }, + } + + cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use") + cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching") + cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts") + return cmd +} + +type launcherClient struct { + apiClient *api.Client + modelInventory []ModelInfo + inventoryLoaded bool + policy LaunchPolicy +} + +func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) { + apiClient, err := api.ClientFromEnvironment() + if err != nil { + return nil, err + } + + return &launcherClient{ + apiClient: apiClient, + policy: policy, + }, nil +} + +// BuildLauncherState returns the launch-owned root launcher menu snapshot. +func BuildLauncherState(ctx context.Context) (*LauncherState, error) { + launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession())) + if err != nil { + return nil, err + } + return launchClient.buildLauncherState(ctx) +} + +// ResolveRunModel returns the model that should be used for interactive chat. +func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) { + launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession())) + if err != nil { + return "", err + } + return launchClient.resolveRunModel(ctx, req) +} + +// LaunchIntegration runs the canonical launcher flow for one integration. +func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error { + name, runner, err := LookupIntegration(req.Name) + if err != nil { + return err + } + if !req.ConfigureOnly { + if err := EnsureIntegrationInstalled(name, runner); err != nil { + return err + } + } + + policy := defaultLaunchPolicy(isInteractiveSession()) + if req.Policy != nil { + policy = *req.Policy + } + + launchClient, err := newLauncherClient(policy) + if err != nil { + return err + } + saved, _ := loadStoredIntegrationConfig(name) + + if editor, ok := runner.(Editor); ok { + return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req) + } + return launchClient.launchSingleIntegration(ctx, name, runner, saved, req) +} + +func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) { + _ = c.loadModelInventoryOnce(ctx) + + state := &LauncherState{ + LastSelection: config.LastSelection(), + RunModel: config.LastModel(), + Integrations: make(map[string]LauncherIntegrationState), + } + runModelUsable, err := c.savedModelUsable(ctx, state.RunModel) + if err != nil { + runModelUsable = false + } + state.RunModelUsable = runModelUsable + + for _, info := range ListIntegrationInfos() { + integrationState, err := c.buildLauncherIntegrationState(ctx, info) + if err != nil { + return nil, err + } + state.Integrations[info.Name] = integrationState + } + + return state, nil +} + +func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) { + integration, err := integrationFor(info.Name) + if err != nil { + return LauncherIntegrationState{}, err + } + currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor) + if err != nil { + return LauncherIntegrationState{}, err + } + + return LauncherIntegrationState{ + Name: info.Name, + DisplayName: info.DisplayName, + Description: info.Description, + Installed: integration.installed, + AutoInstallable: integration.autoInstallable, + Selectable: integration.installed || integration.autoInstallable, + Changeable: integration.installed || integration.autoInstallable, + CurrentModel: currentModel, + ModelUsable: usable, + InstallHint: integration.installHint, + Editor: integration.editor, + }, nil +} + +func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) { + cfg, loadErr := loadStoredIntegrationConfig(name) + hasModels := loadErr == nil && len(cfg.Models) > 0 + if !hasModels { + return "", false, nil + } + + if isEditor { + filtered := c.filterDisabledCloudModels(ctx, cfg.Models) + if len(filtered) > 0 { + return filtered[0], true, nil + } + return cfg.Models[0], false, nil + } + + model := cfg.Models[0] + usable, usableErr := c.savedModelUsable(ctx, model) + return model, usableErr == nil && usable, nil +} + +func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) { + current := config.LastModel() + if !req.ForcePicker { + usable, err := c.savedModelUsable(ctx, current) + if err != nil { + return "", err + } + if usable { + if err := c.ensureModelsReady(ctx, []string{current}); err != nil { + return "", err + } + if err := config.SetLastModel(current); err != nil { + return "", err + } + return current, nil + } + } + + model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector) + if err != nil { + return "", err + } + if err := config.SetLastModel(model); err != nil { + return "", err + } + return model, nil +} + +func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error { + current := primaryModelFromConfig(saved) + target := req.ModelOverride + needsConfigure := req.ForceConfigure + + if target == "" { + target = current + usable, err := c.savedModelUsable(ctx, target) + if err != nil { + return err + } + if !usable { + needsConfigure = true + } + } + + if needsConfigure { + selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector) + if err != nil { + return err + } + target = selected + } else if err := c.ensureModelsReady(ctx, []string{target}); err != nil { + return err + } + + if target == "" { + return nil + } + + if err := config.SaveIntegration(name, []string{target}); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + + return launchAfterConfiguration(name, runner, target, req) +} + +func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error { + models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req) + + if needsConfigure { + selected, err := c.selectMultiModelsForIntegration(ctx, runner, models) + if err != nil { + return err + } + models = selected + } else if err := c.ensureModelsReady(ctx, models); err != nil { + return err + } + + if len(models) == 0 { + return nil + } + + if needsConfigure || req.ModelOverride != "" { + if err := prepareEditorIntegration(name, runner, editor, models); err != nil { + return err + } + } + + return launchAfterConfiguration(name, runner, models[0], req) +} + +func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) { + if selector == nil { + return "", fmt.Errorf("no selector configured") + } + + items, _, err := c.loadSelectableModels(ctx, singleModelPrechecked(current), current, "no models available, run 'ollama pull ' first") + if err != nil { + return "", err + } + + selected, err := selector(title, items, current) + if err != nil { + return "", err + } + if err := c.ensureModelsReady(ctx, []string{selected}); err != nil { + return "", err + } + return selected, nil +} + +func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) { + if DefaultMultiSelector == nil { + return nil, fmt.Errorf("no selector configured") + } + + items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, firstModel(preChecked), "no models available") + if err != nil { + return nil, err + } + + selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked) + if err != nil { + return nil, err + } + if err := c.ensureModelsReady(ctx, selected); err != nil { + return nil, err + } + return selected, nil +} + +func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) { + if err := c.loadModelInventoryOnce(ctx); err != nil { + return nil, nil, err + } + + cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient) + items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current) + if cloudDisabled { + items = filterCloudItems(items) + orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked) + } + if len(items) == 0 { + return nil, nil, errors.New(emptyMessage) + } + return items, orderedChecked, nil +} + +func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error { + var deduped []string + seen := make(map[string]bool, len(models)) + for _, model := range models { + if model == "" || seen[model] { + continue + } + seen[model] = true + deduped = append(deduped, model) + } + models = deduped + if len(models) == 0 { + return nil + } + + cloudModels := make(map[string]bool, len(models)) + for _, model := range models { + isCloudModel := isCloudModelName(model) + if isCloudModel { + cloudModels[model] = true + } + if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil { + return err + } + } + return ensureAuth(ctx, c.apiClient, cloudModels, models) +} + +func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) { + if req.ForceConfigure { + return editorPreCheckedModels(saved, req.ModelOverride), true + } + + if req.ModelOverride != "" { + models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...) + models = c.filterDisabledCloudModels(ctx, models) + return models, len(models) == 0 + } + + if saved == nil || len(saved.Models) == 0 { + return nil, true + } + + models := c.filterDisabledCloudModels(ctx, saved.Models) + return models, len(models) == 0 +} + +func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string { + // if connection cannot be established or there is a 404, cloud models will continue to be displayed + cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient) + if !cloudDisabled { + return append([]string(nil), models...) + } + + filtered := make([]string, 0, len(models)) + for _, model := range models { + if !isCloudModelName(model) { + filtered = append(filtered, model) + } + } + return filtered +} + +func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) { + if err := c.loadModelInventoryOnce(ctx); err != nil { + return c.showBasedModelUsable(ctx, name) + } + return c.singleModelUsable(ctx, name), nil +} + +func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) { + if name == "" { + return false, nil + } + + info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name}) + if err != nil { + var statusErr api.StatusError + if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound { + return false, nil + } + return false, err + } + + if isCloudModelName(name) || info.RemoteModel != "" { + cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient) + + return !cloudDisabled, nil + } + + return true, nil +} + +func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool { + if name == "" { + return false + } + if isCloudModelName(name) { + cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient) + return !cloudDisabled + } + return c.hasLocalModel(name) +} + +func (c *launcherClient) hasLocalModel(name string) bool { + for _, model := range c.modelInventory { + if model.Remote { + continue + } + if model.Name == name || strings.HasPrefix(model.Name, name+":") { + return true + } + } + return false +} + +func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error { + if c.inventoryLoaded { + return nil + } + + resp, err := c.apiClient.List(ctx) + if err != nil { + return err + } + + c.modelInventory = c.modelInventory[:0] + for _, model := range resp.Models { + c.modelInventory = append(c.modelInventory, ModelInfo{ + Name: model.Name, + Remote: model.RemoteModel != "", + }) + } + + cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient) + if cloudDisabled { + c.modelInventory = filterCloudModels(c.modelInventory) + } + c.inventoryLoaded = true + return nil +} + +func runIntegration(runner Runner, modelName string, args []string) error { + fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName) + return runner.Run(modelName, args) +} + +func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error { + if req.ConfigureOnly { + launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner)) + if err != nil { + return err + } + if !launch { + return nil + } + } + if err := EnsureIntegrationInstalled(name, runner); err != nil { + return err + } + return runIntegration(runner, model, req.ExtraArgs) +} + +func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) { + cfg, err := config.LoadIntegration(name) + if err == nil { + return cfg, nil + } + if !errors.Is(err, os.ErrNotExist) { + return nil, err + } + + spec, specErr := LookupIntegrationSpec(name) + if specErr != nil { + return nil, err + } + + for _, alias := range spec.Aliases { + legacy, legacyErr := config.LoadIntegration(alias) + if legacyErr == nil { + migrateLegacyIntegrationConfig(spec.Name, legacy) + if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil { + return migrated, nil + } + return legacy, nil + } + if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) { + return nil, legacyErr + } + } + + return nil, err +} + +func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) { + if legacy == nil { + return + } + + _ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...)) + if len(legacy.Aliases) > 0 { + _ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases)) + } + if legacy.Onboarded { + _ = config.MarkIntegrationOnboarded(canonical) + } +} + +func primaryModelFromConfig(cfg *config.IntegrationConfig) string { + if cfg == nil || len(cfg.Models) == 0 { + return "" + } + return cfg.Models[0] +} + +func cloneAliases(aliases map[string]string) map[string]string { + if len(aliases) == 0 { + return make(map[string]string) + } + + cloned := make(map[string]string, len(aliases)) + for key, value := range aliases { + cloned[key] = value + } + return cloned +} + +func singleModelPrechecked(current string) []string { + if current == "" { + return nil + } + return []string{current} +} + +func firstModel(models []string) string { + if len(models) == 0 { + return "" + } + return models[0] +} + +func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string { + if override == "" { + if saved == nil { + return nil + } + return append([]string(nil), saved.Models...) + } + return append([]string{override}, additionalSavedModels(saved, override)...) +} + +func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string { + if saved == nil { + return nil + } + + var models []string + for _, model := range saved.Models { + if model != exclude { + models = append(models, model) + } + } + return models +} diff --git a/cmd/launch/launch_test.go b/cmd/launch/launch_test.go new file mode 100644 index 000000000..ca29afcfb --- /dev/null +++ b/cmd/launch/launch_test.go @@ -0,0 +1,1210 @@ +package launch + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "testing" + + "github.com/ollama/ollama/cmd/config" +) + +type launcherEditorRunner struct { + paths []string + edited [][]string + ranModel string +} + +func (r *launcherEditorRunner) Run(model string, args []string) error { + r.ranModel = model + return nil +} + +func (r *launcherEditorRunner) String() string { return "LauncherEditor" } + +func (r *launcherEditorRunner) Paths() []string { return r.paths } + +func (r *launcherEditorRunner) Edit(models []string) error { + r.edited = append(r.edited, append([]string(nil), models...)) + return nil +} + +func (r *launcherEditorRunner) Models() []string { return nil } + +type launcherSingleRunner struct { + ranModel string +} + +func (r *launcherSingleRunner) Run(model string, args []string) error { + r.ranModel = model + return nil +} + +func (r *launcherSingleRunner) String() string { return "StubSingle" } + +func setLaunchTestHome(t *testing.T, dir string) { + t.Helper() + t.Setenv("HOME", dir) + t.Setenv("TMPDIR", dir) + t.Setenv("USERPROFILE", dir) +} + +func writeFakeBinary(t *testing.T, dir, name string) { + t.Helper() + path := filepath.Join(dir, name) + data := []byte("#!/bin/sh\nexit 0\n") + if runtime.GOOS == "windows" { + path += ".cmd" + data = []byte("@echo off\r\nexit /b 0\r\n") + } + if err := os.WriteFile(path, data, 0o755); err != nil { + t.Fatalf("failed to write fake binary: %v", err) + } +} + +func withIntegrationOverride(t *testing.T, name string, runner Runner) { + t.Helper() + restore := OverrideIntegration(name, runner) + t.Cleanup(restore) +} + +func withInteractiveSession(t *testing.T, interactive bool) { + t.Helper() + old := isInteractiveSession + isInteractiveSession = func() bool { return interactive } + t.Cleanup(func() { + isInteractiveSession = old + }) +} + +func withLauncherHooks(t *testing.T) { + t.Helper() + oldSingle := DefaultSingleSelector + oldMulti := DefaultMultiSelector + oldConfirm := DefaultConfirmPrompt + oldSignIn := DefaultSignIn + t.Cleanup(func() { + DefaultSingleSelector = oldSingle + DefaultMultiSelector = oldMulti + DefaultConfirmPrompt = oldConfirm + DefaultSignIn = oldSignIn + }) +} + +func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "opencode") + t.Setenv("PATH", binDir) + + if err := config.SetLastModel("glm-5:cloud"); err != nil { + t.Fatalf("failed to save last model: %v", err) + } + if err := config.SaveIntegration("claude", []string{"glm-5:cloud"}); err != nil { + t.Fatalf("failed to save claude config: %v", err) + } + if err := config.SaveIntegration("opencode", []string{"glm-5:cloud", "llama3.2"}); err != nil { + t.Fatalf("failed to save opencode config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/status": + fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + state, err := BuildLauncherState(context.Background()) + if err != nil { + t.Fatalf("BuildLauncherState returned error: %v", err) + } + + if !state.Integrations["opencode"].Installed { + t.Fatal("expected opencode to be marked installed") + } + if state.Integrations["claude"].Installed { + t.Fatal("expected claude to be marked not installed") + } + if state.RunModelUsable { + t.Fatal("expected saved cloud run model to be unusable when cloud is disabled") + } + if state.Integrations["claude"].ModelUsable { + t.Fatal("expected claude cloud config to be unusable when cloud is disabled") + } + if !state.Integrations["opencode"].ModelUsable { + t.Fatal("expected editor config with a remaining local model to stay usable") + } + if state.Integrations["opencode"].CurrentModel != "llama3.2" { + t.Fatalf("expected editor current model to fall back to remaining local model, got %q", state.Integrations["opencode"].CurrentModel) + } +} + +func TestBuildLauncherState_MigratesLegacyOpenclawAliasConfig(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + if err := config.SaveIntegration("clawdbot", []string{"llama3.2"}); err != nil { + t.Fatalf("failed to seed legacy alias config: %v", err) + } + if err := config.SaveAliases("clawdbot", map[string]string{"primary": "llama3.2"}); err != nil { + t.Fatalf("failed to seed legacy alias map: %v", err) + } + if err := config.MarkIntegrationOnboarded("clawdbot"); err != nil { + t.Fatalf("failed to seed legacy onboarding state: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + state, err := BuildLauncherState(context.Background()) + if err != nil { + t.Fatalf("BuildLauncherState returned error: %v", err) + } + if state.Integrations["openclaw"].CurrentModel != "llama3.2" { + t.Fatalf("expected openclaw state to reuse legacy alias config, got %q", state.Integrations["openclaw"].CurrentModel) + } + + migrated, err := config.LoadIntegration("openclaw") + if err != nil { + t.Fatalf("expected canonical config to be migrated, got %v", err) + } + if !slices.Equal(migrated.Models, []string{"llama3.2"}) { + t.Fatalf("unexpected migrated models: %v", migrated.Models) + } + if migrated.Aliases["primary"] != "llama3.2" { + t.Fatalf("expected aliases to migrate, got %v", migrated.Aliases) + } + if !migrated.Onboarded { + t.Fatal("expected onboarding state to migrate to canonical openclaw key") + } +} + +func TestBuildLauncherState_ToleratesInventoryFailure(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + if err := config.SetLastModel("llama3.2"); err != nil { + t.Fatalf("failed to seed last model: %v", err) + } + if err := config.SaveIntegration("claude", []string{"qwen3:8b"}); err != nil { + t.Fatalf("failed to seed claude config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"error":"temporary failure"}`) + case "/api/show": + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + state, err := BuildLauncherState(context.Background()) + if err != nil { + t.Fatalf("BuildLauncherState should tolerate inventory failure, got %v", err) + } + if !state.RunModelUsable { + t.Fatal("expected saved run model to remain usable via show fallback") + } + if state.Integrations["claude"].CurrentModel != "qwen3:8b" { + t.Fatalf("expected saved integration model to remain visible, got %q", state.Integrations["claude"].CurrentModel) + } + if !state.Integrations["claude"].ModelUsable { + t.Fatal("expected saved integration model to remain usable via show fallback") + } +} + +func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + if err := config.SetLastModel("llama3.2"); err != nil { + t.Fatalf("failed to save last model: %v", err) + } + + selectorCalled := false + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + selectorCalled = true + return "", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if model != "llama3.2" { + t.Fatalf("expected saved model, got %q", model) + } + if selectorCalled { + t.Fatal("selector should not be called when saved model is usable") + } +} + +func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + if err := config.SetLastModel("llama3.2"); err != nil { + t.Fatalf("failed to save last model: %v", err) + } + + var selectorCalls int + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + selectorCalls++ + if current != "llama3.2" { + t.Fatalf("expected current selection to be last model, got %q", current) + } + return "qwen3:8b", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"qwen3:8b"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if selectorCalls != 1 { + t.Fatalf("expected selector to be called once, got %d", selectorCalls) + } + if model != "qwen3:8b" { + t.Fatalf("expected forced selection to win, got %q", model) + } + if got := config.LastModel(); got != "qwen3:8b" { + t.Fatalf("expected last model to be updated, got %q", got) + } +} + +func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + return "glm-5:cloud", nil + } + + signInCalled := false + DefaultSignIn = func(modelName, signInURL string) (string, error) { + signInCalled = true + if modelName != "glm-5:cloud" { + t.Fatalf("unexpected model passed to sign-in: %q", modelName) + } + return "test-user", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + fmt.Fprint(w, `{"remote_model":"glm-5"}`) + case "/api/me": + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + model, err := ResolveRunModel(context.Background(), RunModelRequest{ForcePicker: true}) + if err != nil { + t.Fatalf("ResolveRunModel returned error: %v", err) + } + if model != "glm-5:cloud" { + t.Fatalf("expected selected cloud model, got %q", model) + } + if !signInCalled { + t.Fatal("expected sign-in hook to be used for cloud model") + } +} + +func TestLaunchIntegration_EditorForceConfigure(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}} + withIntegrationOverride(t, "droid", editor) + + var multiCalled bool + DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + multiCalled = true + return []string{"llama3.2", "qwen3:8b"}, nil + } + + var proceedPrompt bool + DefaultConfirmPrompt = func(prompt string) (bool, error) { + if prompt == "Proceed?" { + proceedPrompt = true + } + return true, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`) + case "/api/show": + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ForceConfigure: true, + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + + if !multiCalled { + t.Fatal("expected multi selector to be used for forced editor configure") + } + if !proceedPrompt { + t.Fatal("expected backup warning confirmation before edit") + } + if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2", "qwen3:8b"}}); diff != "" { + t.Fatalf("unexpected edited models (-want +got):\n%s", diff) + } + if editor.ranModel != "llama3.2" { + t.Fatalf("expected launch to use first selected model, got %q", editor.ranModel) + } + saved, err := config.LoadIntegration("droid") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := compareStrings(saved.Models, []string{"llama3.2", "qwen3:8b"}); diff != "" { + t.Fatalf("unexpected saved models (-want +got):\n%s", diff) + } +} + +func TestLaunchIntegration_EditorModelOverridePreservesExtras(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{} + withIntegrationOverride(t, "droid", editor) + + if err := config.SaveIntegration("droid", []string{"llama3.2", "mistral"}); err != nil { + t.Fatalf("failed to seed config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ModelOverride: "qwen3:8b", + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + + want := []string{"qwen3:8b", "llama3.2", "mistral"} + saved, err := config.LoadIntegration("droid") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := compareStrings(saved.Models, want); diff != "" { + t.Fatalf("unexpected saved models (-want +got):\n%s", diff) + } + if diff := compareStringSlices(editor.edited, [][]string{want}); diff != "" { + t.Fatalf("unexpected edited models (-want +got):\n%s", diff) + } + if editor.ranModel != "qwen3:8b" { + t.Fatalf("expected override model to launch first, got %q", editor.ranModel) + } +} + +func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{} + withIntegrationOverride(t, "droid", editor) + + if err := config.SaveIntegration("droid", []string{"glm-5:cloud"}); err != nil { + t.Fatalf("failed to seed config: %v", err) + } + + var multiCalled bool + DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + multiCalled = true + return []string{"llama3.2"}, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/status": + fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`) + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "droid"}); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if !multiCalled { + t.Fatal("expected editor flow to reopen selector when cloud-only config is unusable") + } +} + +func TestLaunchIntegration_ConfiguredEditorLaunchSkipsReconfigure(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}} + withIntegrationOverride(t, "droid", editor) + + if err := config.SaveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil { + t.Fatalf("failed to seed config: %v", err) + } + + DefaultConfirmPrompt = func(prompt string) (bool, error) { + t.Fatalf("did not expect prompt during a normal editor launch: %s", prompt) + return false, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "droid"}); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if len(editor.edited) != 0 { + t.Fatalf("expected normal launch to skip editor rewrites, got %v", editor.edited) + } + if editor.ranModel != "llama3.2" { + t.Fatalf("expected launch to use saved primary model, got %q", editor.ranModel) + } + + saved, err := config.LoadIntegration("droid") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := compareStrings(saved.Models, []string{"llama3.2", "qwen3:8b"}); diff != "" { + t.Fatalf("unexpected saved models (-want +got):\n%s", diff) + } +} + +func TestLaunchIntegration_OpenclawPreservesExistingModelList(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "openclaw") + t.Setenv("PATH", binDir) + + editor := &launcherEditorRunner{} + withIntegrationOverride(t, "openclaw", editor) + + if err := config.SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil { + t.Fatalf("failed to seed config: %v", err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" { + var req apiShowRequest + _ = json.NewDecoder(r.Body).Decode(&req) + fmt.Fprintf(w, `{"model":%q}`, req.Model) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "openclaw"}); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if len(editor.edited) != 0 { + t.Fatalf("expected launch to preserve the existing OpenClaw config, got rewrites %v", editor.edited) + } + if editor.ranModel != "llama3.2" { + t.Fatalf("expected launch to use first saved model, got %q", editor.ranModel) + } + + saved, err := config.LoadIntegration("openclaw") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := compareStrings(saved.Models, []string{"llama3.2", "mistral"}); diff != "" { + t.Fatalf("unexpected saved models (-want +got):\n%s", diff) + } +} + +func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + t.Setenv("PATH", t.TempDir()) + + editor := &launcherEditorRunner{} + withIntegrationOverride(t, "openclaw", editor) + + selectorCalled := false + DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + selectorCalled = true + return []string{"llama3.2"}, nil + } + + err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "openclaw"}) + if err == nil { + t.Fatal("expected launch to fail before configuration when OpenClaw is missing") + } + if !strings.Contains(err.Error(), "npm was not found") { + t.Fatalf("expected install prerequisite error, got %v", err) + } + if selectorCalled { + t.Fatal("expected install check to happen before model selection") + } + if len(editor.edited) != 0 { + t.Fatalf("expected no editor writes before install succeeds, got %v", editor.edited) + } + if _, statErr := os.Stat(filepath.Join(tmpDir, ".openclaw", "openclaw.json")); !os.IsNotExist(statErr) { + t.Fatalf("expected no OpenClaw config file to be created, stat err = %v", statErr) + } +} + +func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + t.Setenv("PATH", t.TempDir()) + + editor := &launcherEditorRunner{paths: []string{"/tmp/settings.json"}} + withIntegrationOverride(t, "droid", editor) + + DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) { + return []string{"llama3.2"}, nil + } + + var prompts []string + DefaultConfirmPrompt = func(prompt string) (bool, error) { + prompts = append(prompts, prompt) + if strings.Contains(prompt, "Launch LauncherEditor now?") { + return false, nil + } + return true, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ForceConfigure: true, + ConfigureOnly: true, + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" { + t.Fatalf("unexpected edited models (-want +got):\n%s", diff) + } + if editor.ranModel != "" { + t.Fatalf("expected configure-only flow to skip launch, got %q", editor.ranModel) + } + if !slices.Contains(prompts, "Proceed?") { + t.Fatalf("expected editor warning prompt, got %v", prompts) + } + if !slices.Contains(prompts, "Launch LauncherEditor now?") { + t.Fatalf("expected configure-only launch prompt, got %v", prompts) + } +} + +func TestLaunchIntegration_ClaudeSavesPrimaryModel(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "claude") + t.Setenv("PATH", binDir) + + var aliasSyncCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[]}`) + case "/api/status": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"not found"}`) + case "/api/show": + fmt.Fprint(w, `{"remote_model":"glm-5"}`) + case "/api/me": + fmt.Fprint(w, `{"name":"test-user"}`) + case "/api/experimental/aliases": + aliasSyncCalled = true + t.Fatalf("did not expect alias sync call after removing Claude alias flow") + default: + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "claude", + ModelOverride: "glm-5:cloud", + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + + saved, err := config.LoadIntegration("claude") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if diff := compareStrings(saved.Models, []string{"glm-5:cloud"}); diff != "" { + t.Fatalf("unexpected saved models (-want +got):\n%s", diff) + } + if aliasSyncCalled { + t.Fatal("expected Claude launch flow not to sync aliases") + } +} + +func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "claude") + t.Setenv("PATH", binDir) + + if err := config.SaveIntegration("claude", []string{"qwen3:8b"}); err != nil { + t.Fatalf("failed to seed config: %v", err) + } + + var selectorCalls int + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + selectorCalls++ + return "glm-5:cloud", nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"qwen3:8b"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"qwen3:8b"}`) + case "/api/me": + fmt.Fprint(w, `{"name":"test-user"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "claude", + ForceConfigure: true, + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if selectorCalls != 1 { + t.Fatalf("expected forced configure to reprompt for model selection, got %d calls", selectorCalls) + } + saved, err := config.LoadIntegration("claude") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if saved.Models[0] != "glm-5:cloud" { + t.Fatalf("expected saved primary to be replaced, got %q", saved.Models[0]) + } +} + +func TestLaunchIntegration_ClaudeModelOverrideSkipsSelector(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, true) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "claude") + t.Setenv("PATH", binDir) + + var selectorCalls int + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + selectorCalls++ + return "", fmt.Errorf("selector should not run when --model override is set") + } + + var confirmCalls int + DefaultConfirmPrompt = func(prompt string) (bool, error) { + confirmCalls++ + if !strings.Contains(prompt, "glm-4") { + t.Fatalf("expected download prompt for override model, got %q", prompt) + } + return true, nil + } + + var pullCalled bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "claude", + ModelOverride: "glm-4", + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + + if selectorCalls != 0 { + t.Fatalf("expected model override to skip selector, got %d calls", selectorCalls) + } + if confirmCalls == 0 { + t.Fatal("expected missing override model to prompt for download in interactive mode") + } + if !pullCalled { + t.Fatal("expected missing override model to be pulled after confirmation") + } + + saved, err := config.LoadIntegration("claude") + if err != nil { + t.Fatalf("failed to reload saved config: %v", err) + } + if saved.Models[0] != "glm-4" { + t.Fatalf("expected saved primary to match override, got %q", saved.Models[0]) + } +} + +func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + + runner := &launcherSingleRunner{} + withIntegrationOverride(t, "stubsingle", runner) + + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + return "llama3.2", nil + } + + var prompts []string + DefaultConfirmPrompt = func(prompt string) (bool, error) { + prompts = append(prompts, prompt) + if strings.Contains(prompt, "Launch StubSingle now?") { + return false, nil + } + return true, nil + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + fmt.Fprint(w, `{"model":"llama3.2"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "stubsingle", + ForceConfigure: true, + ConfigureOnly: true, + }); err != nil { + t.Fatalf("LaunchIntegration returned error: %v", err) + } + if runner.ranModel != "" { + t.Fatalf("expected configure-only flow to skip launch when prompt is declined, got %q", runner.ranModel) + } + if !slices.Contains(prompts, "Launch StubSingle now?") { + t.Fatalf("expected launch confirmation prompt, got %v", prompts) + } +} + +func TestLaunchIntegration_ModelOverrideHeadlessMissingFailsWithoutPrompt(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + runner := &launcherSingleRunner{} + withIntegrationOverride(t, "droid", runner) + + confirmCalled := false + DefaultConfirmPrompt = func(prompt string) (bool, error) { + confirmCalled = true + return true, nil + } + + pullCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ModelOverride: "missing-model", + }) + if err == nil { + t.Fatal("expected missing model to fail in headless mode") + } + if !strings.Contains(err.Error(), "ollama pull missing-model") { + t.Fatalf("expected actionable missing model error, got %v", err) + } + if confirmCalled { + t.Fatal("expected no confirmation prompt in headless mode") + } + if pullCalled { + t.Fatal("expected pull request not to run in headless mode") + } + if runner.ranModel != "" { + t.Fatalf("expected launch to abort before running integration, got %q", runner.ranModel) + } +} + +func TestLaunchIntegration_ModelOverrideHeadlessCanOverrideMissingModelPolicy(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + runner := &launcherSingleRunner{} + withIntegrationOverride(t, "droid", runner) + + confirmCalled := false + DefaultConfirmPrompt = func(prompt string) (bool, error) { + confirmCalled = true + if !strings.Contains(prompt, "missing-model") { + t.Fatalf("expected prompt to mention missing model, got %q", prompt) + } + return true, nil + } + + pullCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + customPolicy := LaunchPolicy{MissingModel: LaunchMissingModelPromptToPull} + if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ModelOverride: "missing-model", + Policy: &customPolicy, + }); err != nil { + t.Fatalf("expected policy override to allow prompt/pull in headless mode, got %v", err) + } + if !confirmCalled { + t.Fatal("expected confirmation prompt when missing-model policy is overridden to prompt/pull") + } + if !pullCalled { + t.Fatal("expected pull request to run when missing-model policy is overridden to prompt/pull") + } + if runner.ranModel != "missing-model" { + t.Fatalf("expected integration to launch after pull, got %q", runner.ranModel) + } +} + +func TestLaunchIntegration_ModelOverrideInteractiveMissingPromptsAndPulls(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, true) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + runner := &launcherSingleRunner{} + withIntegrationOverride(t, "droid", runner) + + confirmCalled := false + DefaultConfirmPrompt = func(prompt string) (bool, error) { + confirmCalled = true + if !strings.Contains(prompt, "missing-model") { + t.Fatalf("expected prompt to mention missing model, got %q", prompt) + } + return true, nil + } + + pullCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ModelOverride: "missing-model", + }) + if err != nil { + t.Fatalf("expected interactive override to prompt/pull and succeed, got %v", err) + } + if !confirmCalled { + t.Fatal("expected interactive flow to prompt before pulling missing model") + } + if !pullCalled { + t.Fatal("expected pull request to run after interactive confirmation") + } + if runner.ranModel != "missing-model" { + t.Fatalf("expected integration to run with pulled model, got %q", runner.ranModel) + } +} + +func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T) { + tmpDir := t.TempDir() + setLaunchTestHome(t, tmpDir) + withLauncherHooks(t) + withInteractiveSession(t, false) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, "droid") + t.Setenv("PATH", binDir) + + runner := &launcherSingleRunner{} + withIntegrationOverride(t, "droid", runner) + + DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) { + return "missing-model", nil + } + + confirmCalled := false + DefaultConfirmPrompt = func(prompt string) (bool, error) { + confirmCalled = true + return true, nil + } + + pullCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/tags": + fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`) + case "/api/show": + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":"model not found"}`) + case "/api/pull": + pullCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"success"}`) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + t.Setenv("OLLAMA_HOST", srv.URL) + + err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{ + Name: "droid", + ForceConfigure: true, + }) + if err == nil { + t.Fatal("expected headless selector flow to fail on missing model") + } + if !strings.Contains(err.Error(), "ollama pull missing-model") { + t.Fatalf("expected actionable missing model error, got %v", err) + } + if confirmCalled { + t.Fatal("expected no confirmation prompt in headless selector flow") + } + if pullCalled { + t.Fatal("expected no pull request in headless selector flow") + } + if runner.ranModel != "" { + t.Fatalf("expected flow to abort before launch, got %q", runner.ranModel) + } +} + +type apiShowRequest struct { + Model string `json:"model"` +} + +func compareStrings(got, want []string) string { + if slices.Equal(got, want) { + return "" + } + return fmt.Sprintf("want %v got %v", want, got) +} + +func compareStringSlices(got, want [][]string) string { + if len(got) != len(want) { + return fmt.Sprintf("want %v got %v", want, got) + } + for i := range got { + if !slices.Equal(got[i], want[i]) { + return fmt.Sprintf("want %v got %v", want, got) + } + } + return "" +} diff --git a/cmd/launch/models.go b/cmd/launch/models.go new file mode 100644 index 000000000..40cb9e414 --- /dev/null +++ b/cmd/launch/models.go @@ -0,0 +1,477 @@ +package launch + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "runtime" + "slices" + "strings" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/cmd/internal/fileutil" + internalcloud "github.com/ollama/ollama/internal/cloud" + "github.com/ollama/ollama/internal/modelref" + "github.com/ollama/ollama/progress" +) + +var recommendedModels = []ModelItem{ + {Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true}, + {Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true}, + {Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true}, + {Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true}, + {Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true}, + {Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true}, +} + +var recommendedVRAM = map[string]string{ + "glm-4.7-flash": "~25GB", + "qwen3.5": "~11GB", +} + +// cloudModelLimit holds context and output token limits for a cloud model. +type cloudModelLimit struct { + Context int + Output int +} + +// cloudModelLimits maps cloud model base names to their token limits. +// TODO(parthsareen): grab context/output limits from model info instead of hardcoding +var cloudModelLimits = map[string]cloudModelLimit{ + "minimax-m2.5": {Context: 204_800, Output: 128_000}, + "cogito-2.1:671b": {Context: 163_840, Output: 65_536}, + "deepseek-v3.1:671b": {Context: 163_840, Output: 163_840}, + "deepseek-v3.2": {Context: 163_840, Output: 65_536}, + "glm-4.6": {Context: 202_752, Output: 131_072}, + "glm-4.7": {Context: 202_752, Output: 131_072}, + "glm-5": {Context: 202_752, Output: 131_072}, + "gpt-oss:120b": {Context: 131_072, Output: 131_072}, + "gpt-oss:20b": {Context: 131_072, Output: 131_072}, + "kimi-k2:1t": {Context: 262_144, Output: 262_144}, + "kimi-k2.5": {Context: 262_144, Output: 262_144}, + "kimi-k2-thinking": {Context: 262_144, Output: 262_144}, + "nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072}, + "qwen3-coder:480b": {Context: 262_144, Output: 65_536}, + "qwen3-coder-next": {Context: 262_144, Output: 32_768}, + "qwen3-next:80b": {Context: 262_144, Output: 32_768}, + "qwen3.5": {Context: 262_144, Output: 32_768}, +} + +// lookupCloudModelLimit returns the token limits for a cloud model. +// It normalizes explicit cloud source suffixes before checking the shared limit map. +func lookupCloudModelLimit(name string) (cloudModelLimit, bool) { + base, stripped := modelref.StripCloudSourceTag(name) + if stripped { + if l, ok := cloudModelLimits[base]; ok { + return l, true + } + } + return cloudModelLimit{}, false +} + +// missingModelPolicy controls how model-not-found errors should be handled. +type missingModelPolicy int + +const ( + // missingModelPromptPull prompts the user to download missing local models. + missingModelPromptPull missingModelPolicy = iota + // missingModelAutoPull downloads missing local models without prompting. + missingModelAutoPull + // missingModelFail returns an error for missing local models without prompting. + missingModelFail +) + +// OpenBrowser opens the URL in the user's browser. +func OpenBrowser(url string) { + switch runtime.GOOS { + case "darwin": + _ = exec.Command("open", url).Start() + case "linux": + _ = exec.Command("xdg-open", url).Start() + case "windows": + _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + } +} + +// ensureAuth ensures the user is signed in before cloud-backed models run. +func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error { + var selectedCloudModels []string + for _, m := range selected { + if cloudModels[m] { + selectedCloudModels = append(selectedCloudModels, m) + } + } + if len(selectedCloudModels) == 0 { + return nil + } + if disabled, known := cloudStatusDisabled(ctx, client); known && disabled { + return errors.New(internalcloud.DisabledError("remote inference is unavailable")) + } + + user, err := client.Whoami(ctx) + if err == nil && user != nil && user.Name != "" { + return nil + } + + var aErr api.AuthorizationError + if !errors.As(err, &aErr) || aErr.SigninURL == "" { + return err + } + + modelList := strings.Join(selectedCloudModels, ", ") + + if DefaultSignIn != nil { + _, err := DefaultSignIn(modelList, aErr.SigninURL) + if errors.Is(err, ErrCancelled) { + return ErrCancelled + } + if err != nil { + return fmt.Errorf("%s requires sign in", modelList) + } + return nil + } + + yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) + if errors.Is(err, ErrCancelled) { + return ErrCancelled + } + if err != nil { + return err + } + if !yes { + return ErrCancelled + } + + fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) + OpenBrowser(aErr.SigninURL) + + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 + fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + + if frame%10 == 0 { + u, err := client.Whoami(ctx) + if err == nil && u != nil && u.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) + return nil + } + } + } + } +} + +// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy. +func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error { + if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { + return nil + } else { + var statusErr api.StatusError + if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound { + return err + } + } + + if isCloudModel { + if disabled, known := cloudStatusDisabled(ctx, client); known && disabled { + return errors.New(internalcloud.DisabledError("remote inference is unavailable")) + } + return fmt.Errorf("model %q not found", model) + } + + switch policy { + case missingModelAutoPull: + return pullMissingModel(ctx, client, model) + case missingModelFail: + return fmt.Errorf("model %q not found; run 'ollama pull %s' first", model, model) + default: + return confirmAndPull(ctx, client, model) + } +} + +func confirmAndPull(ctx context.Context, client *api.Client, model string) error { + if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { + return err + } else if !ok { + return errCancelled + } + fmt.Fprintf(os.Stderr, "\n") + return pullMissingModel(ctx, client, model) +} + +func pullMissingModel(ctx context.Context, client *api.Client, model string) error { + if err := pullModel(ctx, client, model, false); err != nil { + return fmt.Errorf("failed to pull %s: %w", model, err) + } + return nil +} + +// prepareEditorIntegration persists models and applies editor-managed config files. +func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error { + if ok, err := confirmEditorEdit(runner, editor); err != nil { + return err + } else if !ok { + return errCancelled + } + if err := editor.Edit(models); err != nil { + return fmt.Errorf("setup failed: %w", err) + } + if err := config.SaveIntegration(name, models); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + return nil +} + +func confirmEditorEdit(runner Runner, editor Editor) (bool, error) { + paths := editor.Paths() + if len(paths) == 0 { + return true, nil + } + + fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner) + for _, path := range paths { + fmt.Fprintf(os.Stderr, " %s\n", path) + } + fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir()) + + return ConfirmPrompt("Proceed?") +} + +// buildModelList merges existing models with recommendations for selection UIs. +func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) { + existingModels = make(map[string]bool) + cloudModels = make(map[string]bool) + recommended := make(map[string]bool) + var hasLocalModel, hasCloudModel bool + + recDesc := make(map[string]string) + for _, rec := range recommendedModels { + recommended[rec.Name] = true + recDesc[rec.Name] = rec.Description + } + + for _, m := range existing { + existingModels[m.Name] = true + if m.Remote { + cloudModels[m.Name] = true + hasCloudModel = true + } else { + hasLocalModel = true + } + displayName := strings.TrimSuffix(m.Name, ":latest") + existingModels[displayName] = true + item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]} + items = append(items, item) + } + + for _, rec := range recommendedModels { + if existingModels[rec.Name] || existingModels[rec.Name+":latest"] { + continue + } + items = append(items, rec) + if isCloudModelName(rec.Name) { + cloudModels[rec.Name] = true + } + } + + checked := make(map[string]bool, len(preChecked)) + for _, n := range preChecked { + checked[n] = true + } + + for _, item := range items { + if item.Name == current || strings.HasPrefix(item.Name, current+":") { + current = item.Name + break + } + } + if checked[current] { + preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...) + } + + notInstalled := make(map[string]bool) + for i := range items { + if !existingModels[items[i].Name] && !cloudModels[items[i].Name] { + notInstalled[items[i].Name] = true + var parts []string + if items[i].Description != "" { + parts = append(parts, items[i].Description) + } + if vram := recommendedVRAM[items[i].Name]; vram != "" { + parts = append(parts, vram) + } + parts = append(parts, "(not downloaded)") + items[i].Description = strings.Join(parts, ", ") + } + } + + recRank := make(map[string]int) + for i, rec := range recommendedModels { + recRank[rec.Name] = i + 1 + } + + onlyLocal := hasLocalModel && !hasCloudModel + + if hasLocalModel || hasCloudModel { + slices.SortStableFunc(items, func(a, b ModelItem) int { + ac, bc := checked[a.Name], checked[b.Name] + aNew, bNew := notInstalled[a.Name], notInstalled[b.Name] + aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0 + aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name] + + if ac != bc { + if ac { + return -1 + } + return 1 + } + if aRec != bRec { + if aRec { + return -1 + } + return 1 + } + if aRec && bRec { + if aCloud != bCloud { + if onlyLocal { + if aCloud { + return 1 + } + return -1 + } + if aCloud { + return -1 + } + return 1 + } + return recRank[a.Name] - recRank[b.Name] + } + if aNew != bNew { + if aNew { + return 1 + } + return -1 + } + return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) + }) + } + + return items, preChecked, existingModels, cloudModels +} + +// isCloudModelName reports whether the model name has an explicit cloud source. +func isCloudModelName(name string) bool { + return modelref.HasExplicitCloudSource(name) +} + +// filterCloudModels drops remote-only models from the given inventory. +func filterCloudModels(existing []modelInfo) []modelInfo { + filtered := existing[:0] + for _, m := range existing { + if !m.Remote { + filtered = append(filtered, m) + } + } + return filtered +} + +// filterCloudItems removes cloud models from selection items. +func filterCloudItems(items []ModelItem) []ModelItem { + filtered := items[:0] + for _, item := range items { + if !isCloudModelName(item.Name) { + filtered = append(filtered, item) + } + } + return filtered +} + +func isCloudModel(ctx context.Context, client *api.Client, name string) bool { + if client == nil { + return false + } + resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) + if err != nil { + return false + } + return resp.RemoteModel != "" +} + +// cloudStatusDisabled returns whether cloud usage is currently disabled. +func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) { + status, err := client.CloudStatusExperimental(ctx) + if err != nil { + var statusErr api.StatusError + if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound { + return false, false + } + return false, false + } + return status.Cloud.Disabled, true +} + +// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler. +// Move the shared pull rendering to a small utility once the package boundary settles. +func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error { + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + bars := make(map[string]*progress.Bar) + var status string + var spinner *progress.Spinner + + fn := func(resp api.ProgressResponse) error { + if resp.Digest != "" { + if resp.Completed == 0 { + return nil + } + + if spinner != nil { + spinner.Stop() + } + + bar, ok := bars[resp.Digest] + if !ok { + name, isDigest := strings.CutPrefix(resp.Digest, "sha256:") + name = strings.TrimSpace(name) + if isDigest { + name = name[:min(12, len(name))] + } + bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed) + bars[resp.Digest] = bar + p.Add(resp.Digest, bar) + } + + bar.Set(resp.Completed) + } else if status != resp.Status { + if spinner != nil { + spinner.Stop() + } + + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + + return nil + } + + request := api.PullRequest{Name: model, Insecure: insecure} + return client.Pull(ctx, &request, fn) +} diff --git a/cmd/config/openclaw.go b/cmd/launch/openclaw.go similarity index 96% rename from cmd/config/openclaw.go rename to cmd/launch/openclaw.go index c4788f9c5..df79f36e5 100644 --- a/cmd/config/openclaw.go +++ b/cmd/launch/openclaw.go @@ -1,4 +1,4 @@ -package config +package launch import ( "context" @@ -15,6 +15,8 @@ import ( "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -35,7 +37,7 @@ func (c *Openclaw) Run(model string, args []string) error { } firstLaunch := true - if integrationConfig, err := loadIntegration("openclaw"); err == nil { + if integrationConfig, err := loadStoredIntegrationConfig("openclaw"); err == nil { firstLaunch = !integrationConfig.Onboarded } @@ -45,7 +47,7 @@ func (c *Openclaw) Run(model string, args []string) error { fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n") fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset) - ok, err := confirmPrompt("I understand the risks. Continue?") + ok, err := ConfirmPrompt("I understand the risks. Continue?") if err != nil { return err } @@ -107,7 +109,7 @@ func (c *Openclaw) Run(model string, args []string) error { return windowsHint(err) } if firstLaunch { - if err := integrationOnboarded("openclaw"); err != nil { + if err := config.MarkIntegrationOnboarded("openclaw"); err != nil { return fmt.Errorf("failed to save onboarding state: %w", err) } } @@ -166,7 +168,7 @@ func (c *Openclaw) Run(model string, args []string) error { } if firstLaunch { - if err := integrationOnboarded("openclaw"); err != nil { + if err := config.MarkIntegrationOnboarded("openclaw"); err != nil { return fmt.Errorf("failed to save onboarding state: %w", err) } } @@ -426,7 +428,7 @@ func ensureOpenclawInstalled() (string, error) { "and select OpenClaw") } - ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?") + ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?") if err != nil { return "", err } @@ -561,7 +563,7 @@ func (c *Openclaw) Edit(models []string) error { if err != nil { return err } - if err := writeWithBackup(configPath, data); err != nil { + if err := fileutil.WriteWithBackup(configPath, data); err != nil { return err } @@ -776,9 +778,9 @@ func (c *Openclaw) Models() []string { return nil } - config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json")) + config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json")) if err != nil { - config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json")) + config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json")) if err != nil { return nil } diff --git a/cmd/config/openclaw_test.go b/cmd/launch/openclaw_test.go similarity index 99% rename from cmd/config/openclaw_test.go rename to cmd/launch/openclaw_test.go index 1fcdf0050..14601fe87 100644 --- a/cmd/config/openclaw_test.go +++ b/cmd/launch/openclaw_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "bytes" @@ -116,9 +116,9 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) { if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil { t.Fatalf("Run() error = %v", err) } - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err != nil { - t.Fatalf("loadIntegration() error = %v", err) + t.Fatalf("LoadIntegration() error = %v", err) } if !integrationConfig.Onboarded { t.Fatal("expected onboarding flag to be persisted after successful run") @@ -147,7 +147,7 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) { if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil { t.Fatal("expected run failure") } - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err == nil && integrationConfig.Onboarded { t.Fatal("expected onboarding flag to remain unset after failed run") } @@ -1528,7 +1528,7 @@ func TestIntegrationOnboarded(t *testing.T) { tmpDir := t.TempDir() setTestHome(t, tmpDir) - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err == nil && integrationConfig.Onboarded { t.Error("expected false for fresh config") } @@ -1542,7 +1542,7 @@ func TestIntegrationOnboarded(t *testing.T) { if err := integrationOnboarded("openclaw"); err != nil { t.Fatal(err) } - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err != nil || !integrationConfig.Onboarded { t.Error("expected true after integrationOnboarded") } @@ -1556,7 +1556,7 @@ func TestIntegrationOnboarded(t *testing.T) { if err := integrationOnboarded("OpenClaw"); err != nil { t.Fatal(err) } - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err != nil || !integrationConfig.Onboarded { t.Error("expected true when set with different case") } @@ -1575,7 +1575,7 @@ func TestIntegrationOnboarded(t *testing.T) { } // Verify onboarded is set - integrationConfig, err := loadIntegration("openclaw") + integrationConfig, err := LoadIntegration("openclaw") if err != nil || !integrationConfig.Onboarded { t.Error("expected true after integrationOnboarded") } diff --git a/cmd/config/opencode.go b/cmd/launch/opencode.go similarity index 80% rename from cmd/config/opencode.go rename to cmd/launch/opencode.go index 52a1426b9..3a0b2ae8e 100644 --- a/cmd/config/opencode.go +++ b/cmd/launch/opencode.go @@ -1,9 +1,7 @@ -package config +package launch import ( - "context" "encoding/json" - "errors" "fmt" "maps" "os" @@ -12,31 +10,13 @@ import ( "slices" "strings" + "github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/internal/modelref" ) // OpenCode implements Runner and Editor for OpenCode integration type OpenCode struct{} -// cloudModelLimit holds context and output token limits for a cloud model. -type cloudModelLimit struct { - Context int - Output int -} - -// lookupCloudModelLimit returns the token limits for a cloud model. -// It normalizes explicit cloud source suffixes before checking the shared limit map. -func lookupCloudModelLimit(name string) (cloudModelLimit, bool) { - base, stripped := modelref.StripCloudSourceTag(name) - if stripped { - if l, ok := cloudModelLimits[base]; ok { - return l, true - } - } - return cloudModelLimit{}, false -} - func (o *OpenCode) String() string { return "OpenCode" } func (o *OpenCode) Run(model string, args []string) error { @@ -44,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error { return fmt.Errorf("opencode is not installed, install from https://opencode.ai") } - // Call Edit() to ensure config is up-to-date before launch - models := []string{model} - if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 { - models = config.Models - } - var err error - models, err = resolveEditorModels("opencode", models, func() ([]string, error) { - return selectModels(context.Background(), "opencode", "") - }) - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - if err := o.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - cmd := exec.Command("opencode", args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -191,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error { if err != nil { return err } - if err := writeWithBackup(configPath, configData); err != nil { + if err := fileutil.WriteWithBackup(configPath, configData); err != nil { return err } @@ -243,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error { if err != nil { return err } - return writeWithBackup(statePath, stateData) + return fileutil.WriteWithBackup(statePath, stateData) } func (o *OpenCode) Models() []string { @@ -251,7 +212,7 @@ func (o *OpenCode) Models() []string { if err != nil { return nil } - config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json")) + config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json")) if err != nil { return nil } diff --git a/cmd/config/opencode_test.go b/cmd/launch/opencode_test.go similarity index 99% rename from cmd/config/opencode_test.go rename to cmd/launch/opencode_test.go index bd02bbbf0..c1833430d 100644 --- a/cmd/config/opencode_test.go +++ b/cmd/launch/opencode_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "encoding/json" diff --git a/cmd/config/pi.go b/cmd/launch/pi.go similarity index 92% rename from cmd/config/pi.go rename to cmd/launch/pi.go index f2ae0c5f7..e123c955c 100644 --- a/cmd/config/pi.go +++ b/cmd/launch/pi.go @@ -1,4 +1,4 @@ -package config +package launch import ( "context" @@ -12,6 +12,7 @@ import ( "strings" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -26,15 +27,6 @@ func (p *Pi) Run(model string, args []string) error { return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent") } - // Call Edit() to ensure config is up-to-date before launch - models := []string{model} - if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 { - models = config.Models - } - if err := p.Edit(models); err != nil { - return fmt.Errorf("setup failed: %w", err) - } - cmd := exec.Command("pi", args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -149,7 +141,7 @@ func (p *Pi) Edit(models []string) error { if err != nil { return err } - if err := writeWithBackup(configPath, configData); err != nil { + if err := fileutil.WriteWithBackup(configPath, configData); err != nil { return err } @@ -167,7 +159,7 @@ func (p *Pi) Edit(models []string) error { if err != nil { return err } - return writeWithBackup(settingsPath, settingsData) + return fileutil.WriteWithBackup(settingsPath, settingsData) } func (p *Pi) Models() []string { @@ -177,7 +169,7 @@ func (p *Pi) Models() []string { } configPath := filepath.Join(home, ".pi", "agent", "models.json") - config, err := readJSONFile(configPath) + config, err := fileutil.ReadJSON(configPath) if err != nil { return nil } @@ -229,8 +221,15 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s cfg["contextWindow"] = l.Context } + applyCloudContextFallback := func() { + if l, ok := lookupCloudModelLimit(modelID); ok { + cfg["contextWindow"] = l.Context + } + } + resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID}) if err != nil { + applyCloudContextFallback() return cfg } @@ -248,14 +247,19 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s // Extract context window from ModelInfo. For known cloud models, the // pre-filled shared limit remains unless the server provides a positive value. + hasContextWindow := false for key, val := range resp.ModelInfo { if strings.HasSuffix(key, ".context_length") { if ctxLen, ok := val.(float64); ok && ctxLen > 0 { cfg["contextWindow"] = int(ctxLen) + hasContextWindow = true } break } } + if !hasContextWindow { + applyCloudContextFallback() + } return cfg } diff --git a/cmd/config/pi_test.go b/cmd/launch/pi_test.go similarity index 98% rename from cmd/config/pi_test.go rename to cmd/launch/pi_test.go index 1ca572e81..fb869fdb8 100644 --- a/cmd/config/pi_test.go +++ b/cmd/launch/pi_test.go @@ -1,4 +1,4 @@ -package config +package launch import ( "context" @@ -840,7 +840,7 @@ func TestCreateConfig(t *testing.T) { } }) - t.Run("falls back to cloud context when show fails", func(t *testing.T) { + t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"model not found"}`) @@ -857,7 +857,7 @@ func TestCreateConfig(t *testing.T) { } }) - t.Run("falls back to cloud context when model info is empty", func(t *testing.T) { + t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/show" { fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`) @@ -877,7 +877,7 @@ func TestCreateConfig(t *testing.T) { } }) - t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) { + t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"model not found"}`) @@ -893,7 +893,6 @@ func TestCreateConfig(t *testing.T) { t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"]) } }) - t.Run("skips zero context length", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/show" { diff --git a/cmd/launch/registry.go b/cmd/launch/registry.go new file mode 100644 index 000000000..ebafe40b6 --- /dev/null +++ b/cmd/launch/registry.go @@ -0,0 +1,355 @@ +package launch + +import ( + "fmt" + "os" + "os/exec" + "slices" + "strings" +) + +// IntegrationInstallSpec describes how launcher should detect and guide installation. +type IntegrationInstallSpec struct { + CheckInstalled func() bool + EnsureInstalled func() error + URL string + Command []string +} + +// IntegrationSpec is the canonical registry entry for one integration. +type IntegrationSpec struct { + Name string + Runner Runner + Aliases []string + Hidden bool + Description string + Install IntegrationInstallSpec +} + +// IntegrationInfo contains display information about a registered integration. +type IntegrationInfo struct { + Name string + DisplayName string + Description string +} + +var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"} + +var integrationSpecs = []*IntegrationSpec{ + { + Name: "claude", + Runner: &Claude{}, + Description: "Anthropic's coding tool with subagents", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := (&Claude{}).findPath() + return err == nil + }, + URL: "https://code.claude.com/docs/en/quickstart", + }, + }, + { + Name: "cline", + Runner: &Cline{}, + Description: "Autonomous coding agent with parallel execution", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("cline") + return err == nil + }, + Command: []string{"npm", "install", "-g", "cline"}, + }, + }, + { + Name: "codex", + Runner: &Codex{}, + Description: "OpenAI's open-source coding agent", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("codex") + return err == nil + }, + URL: "https://developers.openai.com/codex/cli/", + Command: []string{"npm", "install", "-g", "@openai/codex"}, + }, + }, + { + Name: "droid", + Runner: &Droid{}, + Description: "Factory's coding agent across terminal and IDEs", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("droid") + return err == nil + }, + URL: "https://docs.factory.ai/cli/getting-started/quickstart", + }, + }, + { + Name: "opencode", + Runner: &OpenCode{}, + Description: "Anomaly's open-source coding agent", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("opencode") + return err == nil + }, + URL: "https://opencode.ai", + }, + }, + { + Name: "openclaw", + Runner: &Openclaw{}, + Aliases: []string{"clawdbot", "moltbot"}, + Description: "Personal AI with 100+ skills", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + if _, err := exec.LookPath("openclaw"); err == nil { + return true + } + if _, err := exec.LookPath("clawdbot"); err == nil { + return true + } + return false + }, + EnsureInstalled: func() error { + _, err := ensureOpenclawInstalled() + return err + }, + URL: "https://docs.openclaw.ai", + }, + }, + { + Name: "pi", + Runner: &Pi{}, + Description: "Minimal AI agent toolkit with plugin support", + Install: IntegrationInstallSpec{ + CheckInstalled: func() bool { + _, err := exec.LookPath("pi") + return err == nil + }, + Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"}, + }, + }, +} + +var integrationSpecsByName map[string]*IntegrationSpec + +func init() { + rebuildIntegrationSpecIndexes() +} + +func hyperlink(url, text string) string { + return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text) +} + +func rebuildIntegrationSpecIndexes() { + integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs)) + + canonical := make(map[string]bool, len(integrationSpecs)) + for _, spec := range integrationSpecs { + key := strings.ToLower(spec.Name) + if key == "" { + panic("launch: integration spec missing name") + } + if canonical[key] { + panic(fmt.Sprintf("launch: duplicate integration name %q", key)) + } + canonical[key] = true + integrationSpecsByName[key] = spec + } + + seenAliases := make(map[string]string) + for _, spec := range integrationSpecs { + for _, alias := range spec.Aliases { + key := strings.ToLower(alias) + if key == "" { + panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name)) + } + if canonical[key] { + panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key)) + } + if owner, exists := seenAliases[key]; exists { + panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name)) + } + seenAliases[key] = spec.Name + integrationSpecsByName[key] = spec + } + } + + orderSeen := make(map[string]bool, len(launcherIntegrationOrder)) + for _, name := range launcherIntegrationOrder { + key := strings.ToLower(name) + if orderSeen[key] { + panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key)) + } + orderSeen[key] = true + + spec, ok := integrationSpecsByName[key] + if !ok { + panic(fmt.Sprintf("launch: unknown launcher order entry %q", key)) + } + if spec.Name != key { + panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key)) + } + if spec.Hidden { + panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key)) + } + } +} + +// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec. +func LookupIntegrationSpec(name string) (*IntegrationSpec, error) { + spec, ok := integrationSpecsByName[strings.ToLower(name)] + if !ok { + return nil, fmt.Errorf("unknown integration: %s", name) + } + return spec, nil +} + +// LookupIntegration resolves a registry name to the canonical key and runner. +func LookupIntegration(name string) (string, Runner, error) { + spec, err := LookupIntegrationSpec(name) + if err != nil { + return "", nil, err + } + return spec.Name, spec.Runner, nil +} + +// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs. +func ListVisibleIntegrationSpecs() []IntegrationSpec { + visible := make([]IntegrationSpec, 0, len(integrationSpecs)) + for _, spec := range integrationSpecs { + if spec.Hidden { + continue + } + visible = append(visible, *spec) + } + + orderRank := make(map[string]int, len(launcherIntegrationOrder)) + for i, name := range launcherIntegrationOrder { + orderRank[name] = i + 1 + } + + slices.SortFunc(visible, func(a, b IntegrationSpec) int { + aRank, bRank := orderRank[a.Name], orderRank[b.Name] + if aRank > 0 && bRank > 0 { + return aRank - bRank + } + if aRank > 0 { + return 1 + } + if bRank > 0 { + return -1 + } + return strings.Compare(a.Name, b.Name) + }) + + return visible +} + +// ListIntegrationInfos returns the registered integrations in launcher display order. +func ListIntegrationInfos() []IntegrationInfo { + visible := ListVisibleIntegrationSpecs() + infos := make([]IntegrationInfo, 0, len(visible)) + for _, spec := range visible { + infos = append(infos, IntegrationInfo{ + Name: spec.Name, + DisplayName: spec.Runner.String(), + Description: spec.Description, + }) + } + return infos +} + +// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs. +func IntegrationSelectionItems() ([]ModelItem, error) { + visible := ListVisibleIntegrationSpecs() + if len(visible) == 0 { + return nil, fmt.Errorf("no integrations available") + } + + items := make([]ModelItem, 0, len(visible)) + for _, spec := range visible { + description := spec.Runner.String() + if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 { + description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0]) + } + items = append(items, ModelItem{Name: spec.Name, Description: description}) + } + return items, nil +} + +// IsIntegrationInstalled checks if an integration binary is installed. +func IsIntegrationInstalled(name string) bool { + integration, err := integrationFor(name) + if err != nil { + fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name) + return false + } + return integration.installed +} + +// integration is resolved registry metadata used by launcher state and install checks. +// It combines immutable registry spec data with computed runtime traits. +type integration struct { + spec *IntegrationSpec + installed bool + autoInstallable bool + editor bool + installHint string +} + +// integrationFor resolves an integration name into the canonical spec plus +// derived launcher/install traits used across registry and launch flows. +func integrationFor(name string) (integration, error) { + spec, err := LookupIntegrationSpec(name) + if err != nil { + return integration{}, err + } + + installed := true + if spec.Install.CheckInstalled != nil { + installed = spec.Install.CheckInstalled() + } + + _, editor := spec.Runner.(Editor) + hint := "" + if spec.Install.URL != "" { + hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL) + } else if len(spec.Install.Command) > 0 { + hint = "Install with: " + strings.Join(spec.Install.Command, " ") + } + + return integration{ + spec: spec, + installed: installed, + autoInstallable: spec.Install.EnsureInstalled != nil, + editor: editor, + installHint: hint, + }, nil +} + +// EnsureIntegrationInstalled installs auto-installable integrations when missing. +func EnsureIntegrationInstalled(name string, runner Runner) error { + integration, err := integrationFor(name) + if err != nil { + return fmt.Errorf("%s is not installed", runner) + } + + if integration.installed { + return nil + } + if integration.autoInstallable { + return integration.spec.Install.EnsureInstalled() + } + + switch { + case integration.spec.Install.URL != "": + return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL) + case len(integration.spec.Install.Command) > 0: + return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " ")) + default: + return fmt.Errorf("%s is not installed", runner) + } +} diff --git a/cmd/launch/registry_test_helpers_test.go b/cmd/launch/registry_test_helpers_test.go new file mode 100644 index 000000000..909b31428 --- /dev/null +++ b/cmd/launch/registry_test_helpers_test.go @@ -0,0 +1,21 @@ +package launch + +import "strings" + +// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function. +func OverrideIntegration(name string, runner Runner) func() { + spec, err := LookupIntegrationSpec(name) + if err != nil { + key := strings.ToLower(name) + integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner} + return func() { + delete(integrationSpecsByName, key) + } + } + + original := spec.Runner + spec.Runner = runner + return func() { + spec.Runner = original + } +} diff --git a/cmd/launch/runner_exec_only_test.go b/cmd/launch/runner_exec_only_test.go new file mode 100644 index 000000000..97dbab9c9 --- /dev/null +++ b/cmd/launch/runner_exec_only_test.go @@ -0,0 +1,68 @@ +package launch + +import ( + "os" + "path/filepath" + "testing" +) + +func TestEditorRunsDoNotRewriteConfig(t *testing.T) { + tests := []struct { + name string + binary string + runner Runner + checkPath func(home string) string + }{ + { + name: "droid", + binary: "droid", + runner: &Droid{}, + checkPath: func(home string) string { + return filepath.Join(home, ".factory", "settings.json") + }, + }, + { + name: "opencode", + binary: "opencode", + runner: &OpenCode{}, + checkPath: func(home string) string { + return filepath.Join(home, ".config", "opencode", "opencode.json") + }, + }, + { + name: "cline", + binary: "cline", + runner: &Cline{}, + checkPath: func(home string) string { + return filepath.Join(home, ".cline", "data", "globalState.json") + }, + }, + { + name: "pi", + binary: "pi", + runner: &Pi{}, + checkPath: func(home string) string { + return filepath.Join(home, ".pi", "agent", "models.json") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + binDir := t.TempDir() + writeFakeBinary(t, binDir, tt.binary) + t.Setenv("PATH", binDir) + + configPath := tt.checkPath(home) + if err := tt.runner.Run("llama3.2", nil); err != nil { + t.Fatalf("Run returned error: %v", err) + } + if _, err := os.Stat(configPath); !os.IsNotExist(err) { + t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err) + } + }) + } +} diff --git a/cmd/launch/selector_hooks.go b/cmd/launch/selector_hooks.go new file mode 100644 index 000000000..1204023bb --- /dev/null +++ b/cmd/launch/selector_hooks.go @@ -0,0 +1,111 @@ +package launch + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/term" +) + +// ANSI escape sequences for terminal formatting. +const ( + ansiBold = "\033[1m" + ansiReset = "\033[0m" + ansiGray = "\033[37m" + ansiGreen = "\033[32m" + ansiYellow = "\033[33m" +) + +// ErrCancelled is returned when the user cancels a selection. +var ErrCancelled = errors.New("cancelled") + +// errCancelled is kept as an internal alias for existing call sites. +var errCancelled = ErrCancelled + +// DefaultConfirmPrompt provides a TUI-based confirmation prompt. +// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O. +var DefaultConfirmPrompt func(prompt string) (bool, error) + +// SingleSelector is a function type for single item selection. +// current is the name of the previously selected item to highlight; empty means no pre-selection. +type SingleSelector func(title string, items []ModelItem, current string) (string, error) + +// MultiSelector is a function type for multi item selection. +type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error) + +// DefaultSingleSelector is the default single-select implementation. +var DefaultSingleSelector SingleSelector + +// DefaultMultiSelector is the default multi-select implementation. +var DefaultMultiSelector MultiSelector + +// DefaultSignIn provides a TUI-based sign-in flow. +// When set, ensureAuth uses it instead of plain text prompts. +// Returns the signed-in username or an error. +var DefaultSignIn func(modelName, signInURL string) (string, error) + +type launchConfirmPolicy struct { + yes bool + requireYesMessage bool +} + +var currentLaunchConfirmPolicy launchConfirmPolicy + +func (p launchConfirmPolicy) chain(next launchConfirmPolicy) launchConfirmPolicy { + chained := launchConfirmPolicy{ + yes: p.yes || next.yes, + requireYesMessage: p.requireYesMessage || next.requireYesMessage, + } + if chained.yes { + chained.requireYesMessage = false + } + return chained +} + +func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() { + old := currentLaunchConfirmPolicy + currentLaunchConfirmPolicy = old.chain(policy) + return func() { + currentLaunchConfirmPolicy = old + } +} + +// ConfirmPrompt asks the user to confirm an action using the configured prompt hook. +func ConfirmPrompt(prompt string) (bool, error) { + if currentLaunchConfirmPolicy.yes { + return true, nil + } + if currentLaunchConfirmPolicy.requireYesMessage { + return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt) + } + + if DefaultConfirmPrompt != nil { + return DefaultConfirmPrompt(prompt) + } + + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + return false, err + } + defer term.Restore(fd, oldState) + + fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt) + + buf := make([]byte, 1) + for { + if _, err := os.Stdin.Read(buf); err != nil { + return false, err + } + + switch buf[0] { + case 'Y', 'y', 13: + fmt.Fprintf(os.Stderr, "yes\r\n") + return true, nil + case 'N', 'n', 27, 3: + fmt.Fprintf(os.Stderr, "no\r\n") + return false, nil + } + } +} diff --git a/cmd/launch/selector_test.go b/cmd/launch/selector_test.go new file mode 100644 index 000000000..92fa28460 --- /dev/null +++ b/cmd/launch/selector_test.go @@ -0,0 +1,76 @@ +package launch + +import ( + "strings" + "testing" +) + +func TestErrCancelled(t *testing.T) { + t.Run("NotNil", func(t *testing.T) { + if errCancelled == nil { + t.Error("errCancelled should not be nil") + } + }) + + t.Run("Message", func(t *testing.T) { + if errCancelled.Error() != "cancelled" { + t.Errorf("expected 'cancelled', got %q", errCancelled.Error()) + } + }) +} + +func TestWithLaunchConfirmPolicy_ChainsAndRestores(t *testing.T) { + oldPolicy := currentLaunchConfirmPolicy + oldHook := DefaultConfirmPrompt + t.Cleanup(func() { + currentLaunchConfirmPolicy = oldPolicy + DefaultConfirmPrompt = oldHook + }) + + currentLaunchConfirmPolicy = launchConfirmPolicy{} + var hookCalls int + DefaultConfirmPrompt = func(prompt string) (bool, error) { + hookCalls++ + return true, nil + } + + restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true}) + restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true}) + + ok, err := ConfirmPrompt("test prompt") + if err != nil { + t.Fatalf("expected --yes policy to allow prompt, got error: %v", err) + } + if !ok { + t.Fatal("expected --yes policy to auto-accept prompt") + } + if hookCalls != 0 { + t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls) + } + + restoreInner() + + _, err = ConfirmPrompt("test prompt") + if err == nil { + t.Fatal("expected requireYesMessage policy to block prompt") + } + if !strings.Contains(err.Error(), "re-run with --yes") { + t.Fatalf("expected actionable --yes error, got: %v", err) + } + if hookCalls != 0 { + t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls) + } + + restoreOuter() + + ok, err = ConfirmPrompt("test prompt") + if err != nil { + t.Fatalf("expected restored default behavior to use hook, got error: %v", err) + } + if !ok { + t.Fatal("expected hook to return true") + } + if hookCalls != 1 { + t.Fatalf("expected one hook call after restore, got %d", hookCalls) + } +} diff --git a/cmd/launch/test_config_helpers_test.go b/cmd/launch/test_config_helpers_test.go new file mode 100644 index 000000000..e73a51932 --- /dev/null +++ b/cmd/launch/test_config_helpers_test.go @@ -0,0 +1,82 @@ +package launch + +import ( + "strings" + "testing" + + "github.com/ollama/ollama/cmd/config" +) + +var ( + integrations map[string]Runner + integrationAliases map[string]bool + integrationOrder = launcherIntegrationOrder +) + +func init() { + integrations = buildTestIntegrations() + integrationAliases = buildTestIntegrationAliases() +} + +func buildTestIntegrations() map[string]Runner { + result := make(map[string]Runner, len(integrationSpecsByName)) + for name, spec := range integrationSpecsByName { + result[strings.ToLower(name)] = spec.Runner + } + return result +} + +func buildTestIntegrationAliases() map[string]bool { + result := make(map[string]bool) + for _, spec := range integrationSpecs { + for _, alias := range spec.Aliases { + result[strings.ToLower(alias)] = true + } + } + return result +} + +func setTestHome(t *testing.T, dir string) { + t.Helper() + setLaunchTestHome(t, dir) +} + +func SaveIntegration(appName string, models []string) error { + return config.SaveIntegration(appName, models) +} + +func LoadIntegration(appName string) (*config.IntegrationConfig, error) { + return config.LoadIntegration(appName) +} + +func SaveAliases(appName string, aliases map[string]string) error { + return config.SaveAliases(appName, aliases) +} + +func LastModel() string { + return config.LastModel() +} + +func SetLastModel(model string) error { + return config.SetLastModel(model) +} + +func LastSelection() string { + return config.LastSelection() +} + +func SetLastSelection(selection string) error { + return config.SetLastSelection(selection) +} + +func IntegrationModel(appName string) string { + return config.IntegrationModel(appName) +} + +func IntegrationModels(appName string) []string { + return config.IntegrationModels(appName) +} + +func integrationOnboarded(appName string) error { + return config.MarkIntegrationOnboarded(appName) +} diff --git a/cmd/tui/selector.go b/cmd/tui/selector.go index 7bf8180be..c863f089b 100644 --- a/cmd/tui/selector.go +++ b/cmd/tui/selector.go @@ -7,7 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/cmd/launch" ) var ( @@ -64,8 +64,8 @@ type SelectItem struct { Recommended bool } -// ConvertItems converts config.ModelItem slice to SelectItem slice. -func ConvertItems(items []config.ModelItem) []SelectItem { +// ConvertItems converts launch.ModelItem slice to SelectItem slice. +func ConvertItems(items []launch.ModelItem) []SelectItem { out := make([]SelectItem, len(items)) for i, item := range items { out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended} @@ -101,6 +101,16 @@ type selectorModel struct { width int } +func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel { + m := selectorModel{ + title: title, + items: items, + cursor: cursorForCurrent(items, current), + } + m.updateScroll(m.otherStart()) + return m +} + func (m selectorModel) filteredItems() []SelectItem { if m.filter == "" { return m.items @@ -382,11 +392,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err return "", fmt.Errorf("no items to select from") } - m := selectorModel{ - title: title, - items: items, - cursor: cursorForCurrent(items, current), - } + m := selectorModelWithCurrent(title, items, current) p := tea.NewProgram(m) finalModel, err := p.Run() diff --git a/cmd/tui/selector_test.go b/cmd/tui/selector_test.go index fa8ff4dc4..33e12a02e 100644 --- a/cmd/tui/selector_test.go +++ b/cmd/tui/selector_test.go @@ -216,6 +216,22 @@ func TestUpdateScroll(t *testing.T) { } } +func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) { + m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10") + + if m.cursor != 11 { + t.Fatalf("cursor = %d, want 11", m.cursor) + } + if m.scrollOffset == 0 { + t.Fatal("scrollOffset should move to reveal current item in More section") + } + + content := m.renderContent() + if !strings.Contains(content, "▸ other-10") { + t.Fatalf("expected current item to be visible and highlighted\n%s", content) + } +} + func TestRenderContent_SectionHeaders(t *testing.T) { m := selectorModel{ title: "Pick:", diff --git a/cmd/tui/signin.go b/cmd/tui/signin.go index 118dbdf1c..3667499b2 100644 --- a/cmd/tui/signin.go +++ b/cmd/tui/signin.go @@ -1,15 +1,24 @@ package tui import ( + "context" "fmt" "strings" "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/ollama/ollama/cmd/config" + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/cmd/launch" ) +type signInTickMsg struct{} + +type signInCheckMsg struct { + signedIn bool + userName string +} + type signInModel struct { modelName string signInURL string @@ -104,9 +113,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string { return lipgloss.NewStyle().PaddingLeft(2).Render(s.String()) } +func checkSignIn() tea.Msg { + client, err := api.ClientFromEnvironment() + if err != nil { + return signInCheckMsg{signedIn: false} + } + user, err := client.Whoami(context.Background()) + if err == nil && user != nil && user.Name != "" { + return signInCheckMsg{signedIn: true, userName: user.Name} + } + return signInCheckMsg{signedIn: false} +} + // RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels. func RunSignIn(modelName, signInURL string) (string, error) { - config.OpenBrowser(signInURL) + launch.OpenBrowser(signInURL) m := signInModel{ modelName: modelName, diff --git a/cmd/tui/tui.go b/cmd/tui/tui.go index 5803d98fa..9f1ecfcb6 100644 --- a/cmd/tui/tui.go +++ b/cmd/tui/tui.go @@ -1,17 +1,11 @@ package tui import ( - "context" - "errors" "fmt" - "strings" - "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/cmd/config" - "github.com/ollama/ollama/internal/modelref" + "github.com/ollama/ollama/cmd/launch" "github.com/ollama/ollama/version" ) @@ -46,7 +40,7 @@ var ( type menuItem struct { title string description string - integration string // integration name for loading model config, empty if not an integration + integration string isRunModel bool isOthers bool } @@ -58,18 +52,12 @@ var mainMenuItems = []menuItem{ isRunModel: true, }, { - title: "Launch Claude Code", - description: "Agentic coding across large codebases", integration: "claude", }, { - title: "Launch Codex", - description: "OpenAI's open-source coding agent", integration: "codex", }, { - title: "Launch OpenClaw", - description: "Personal AI with 100+ skills", integration: "openclaw", }, } @@ -80,283 +68,106 @@ var othersMenuItem = menuItem{ isOthers: true, } -// getOtherIntegrations dynamically builds the "Others" list from the integration -// registry, excluding any integrations already present in the pinned mainMenuItems. -func getOtherIntegrations() []menuItem { - pinned := map[string]bool{ - "run": true, // not an integration but in the pinned list +type model struct { + state *launch.LauncherState + items []menuItem + cursor int + showOthers bool + width int + quitting bool + selected bool + action TUIAction +} + +func newModel(state *launch.LauncherState) model { + m := model{ + state: state, } + m.showOthers = shouldExpandOthers(state) + m.items = buildMenuItems(state, m.showOthers) + m.cursor = initialCursor(state, m.items) + return m +} + +func shouldExpandOthers(state *launch.LauncherState) bool { + if state == nil { + return false + } + for _, item := range otherIntegrationItems(state) { + if item.integration == state.LastSelection { + return true + } + } + return false +} + +func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem { + items := make([]menuItem, 0, len(mainMenuItems)+1) for _, item := range mainMenuItems { - if item.integration != "" { - pinned[item.integration] = true + if item.integration == "" { + items = append(items, item) + continue + } + if integrationState, ok := state.Integrations[item.integration]; ok { + items = append(items, integrationMenuItem(integrationState)) } } - var others []menuItem - for _, info := range config.ListIntegrationInfos() { + if showOthers { + items = append(items, otherIntegrationItems(state)...) + } else { + items = append(items, othersMenuItem) + } + + return items +} + +func integrationMenuItem(state launch.LauncherIntegrationState) menuItem { + description := state.Description + if description == "" { + description = "Open " + state.DisplayName + " integration" + } + return menuItem{ + title: "Launch " + state.DisplayName, + description: description, + integration: state.Name, + } +} + +func otherIntegrationItems(state *launch.LauncherState) []menuItem { + pinned := map[string]bool{ + "claude": true, + "codex": true, + "openclaw": true, + } + + var items []menuItem + for _, info := range launch.ListIntegrationInfos() { if pinned[info.Name] { continue } - desc := info.Description - if desc == "" { - desc = "Open " + info.DisplayName + " integration" - } - others = append(others, menuItem{ - title: "Launch " + info.DisplayName, - description: desc, - integration: info.Name, - }) - } - return others -} - -type model struct { - items []menuItem - cursor int - quitting bool - selected bool - changeModel bool - changeModels []string // multi-select result for Editor integrations - showOthers bool - availableModels map[string]bool - err error - - showingModal bool - modalSelector selectorModel - modalItems []SelectItem - - showingMultiModal bool - multiModalSelector multiSelectorModel - - showingSignIn bool - signInURL string - signInModel string - signInSpinner int - signInFromModal bool // true if sign-in was triggered from modal (not main menu) - - width int // terminal width from WindowSizeMsg - statusMsg string // temporary status message shown near help text -} - -type signInTickMsg struct{} - -type signInCheckMsg struct { - signedIn bool - userName string -} - -type clearStatusMsg struct{} - -func (m *model) modelExists(name string) bool { - if name == "" { - return false - } - if modelref.HasExplicitCloudSource(name) { - return true - } - if m.availableModels == nil { - return false - } - if m.availableModels[name] { - return true - } - // Check for prefix match (e.g., "llama2" matches "llama2:latest") - for modelName := range m.availableModels { - if strings.HasPrefix(modelName, name+":") { - return true - } - } - return false -} - -func (m *model) buildModalItems() []SelectItem { - modelItems, _ := config.GetModelItems(context.Background()) - return ReorderItems(ConvertItems(modelItems)) -} - -func (m *model) openModelModal(currentModel string) { - m.modalItems = m.buildModalItems() - cursor := 0 - if currentModel != "" { - for i, item := range m.modalItems { - if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") { - cursor = i - break - } - } - } - m.modalSelector = selectorModel{ - title: "Select model:", - items: m.modalItems, - cursor: cursor, - helpText: "↑/↓ navigate • enter select • ← back", - } - m.modalSelector.updateScroll(m.modalSelector.otherStart()) - m.showingModal = true -} - -func (m *model) openMultiModelModal(integration string) { - items := m.buildModalItems() - var preChecked []string - if models := config.IntegrationModels(integration); len(models) > 0 { - preChecked = models - } - m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked) - // Set cursor to the first pre-checked (last used) model - if len(preChecked) > 0 { - for i, item := range items { - if item.Name == preChecked[0] { - m.multiModalSelector.cursor = i - m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart()) - break - } - } - } - m.showingMultiModal = true -} - -func isCloudModel(name string) bool { - return modelref.HasExplicitCloudSource(name) -} - -func cloudStatusDisabled(client *api.Client) bool { - status, err := client.CloudStatusExperimental(context.Background()) - if err != nil { - return false - } - return status.Cloud.Disabled -} - -func cloudModelDisabled(name string) bool { - if !isCloudModel(name) { - return false - } - client, err := api.ClientFromEnvironment() - if err != nil { - return false - } - return cloudStatusDisabled(client) -} - -// checkCloudSignIn checks if a cloud model needs sign-in. -// Returns a command to start sign-in if needed, or nil if already signed in. -func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd { - if modelName == "" || !isCloudModel(modelName) { - return nil - } - client, err := api.ClientFromEnvironment() - if err != nil { - return nil - } - if cloudStatusDisabled(client) { - return nil - } - user, err := client.Whoami(context.Background()) - if err == nil && user != nil && user.Name != "" { - return nil - } - var aErr api.AuthorizationError - if errors.As(err, &aErr) && aErr.SigninURL != "" { - return m.startSignIn(modelName, aErr.SigninURL, fromModal) - } - return nil -} - -// startSignIn initiates the sign-in flow for a cloud model. -// fromModal indicates if this was triggered from the model picker modal. -func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd { - m.showingModal = false - m.showingSignIn = true - m.signInURL = signInURL - m.signInModel = modelName - m.signInSpinner = 0 - m.signInFromModal = fromModal - - config.OpenBrowser(signInURL) - - return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg { - return signInTickMsg{} - }) -} - -func checkSignIn() tea.Msg { - client, err := api.ClientFromEnvironment() - if err != nil { - return signInCheckMsg{signedIn: false} - } - user, err := client.Whoami(context.Background()) - if err == nil && user != nil && user.Name != "" { - return signInCheckMsg{signedIn: true, userName: user.Name} - } - return signInCheckMsg{signedIn: false} -} - -func (m *model) loadAvailableModels() { - m.availableModels = make(map[string]bool) - client, err := api.ClientFromEnvironment() - if err != nil { - return - } - models, err := client.List(context.Background()) - if err != nil { - return - } - cloudDisabled := cloudStatusDisabled(client) - for _, mdl := range models.Models { - if cloudDisabled && mdl.RemoteModel != "" { + integrationState, ok := state.Integrations[info.Name] + if !ok { continue } - m.availableModels[mdl.Name] = true + items = append(items, integrationMenuItem(integrationState)) } + return items } -func (m *model) buildItems() { - others := getOtherIntegrations() - m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others)) - m.items = append(m.items, mainMenuItems...) - - if m.showOthers { - m.items = append(m.items, others...) - } else { - m.items = append(m.items, othersMenuItem) +func initialCursor(state *launch.LauncherState, items []menuItem) int { + if state == nil || state.LastSelection == "" { + return 0 } -} - -func isOthersIntegration(name string) bool { - for _, item := range getOtherIntegrations() { - if item.integration == name { - return true + for i, item := range items { + if state.LastSelection == "run" && item.isRunModel { + return i + } + if item.integration == state.LastSelection { + return i } } - return false -} - -func initialModel() model { - m := model{ - cursor: 0, - } - m.loadAvailableModels() - - lastSelection := config.LastSelection() - if isOthersIntegration(lastSelection) { - m.showOthers = true - } - - m.buildItems() - - if lastSelection != "" { - for i, item := range m.items { - if lastSelection == "run" && item.isRunModel { - m.cursor = i - break - } else if item.integration == lastSelection { - m.cursor = i - break - } - } - } - - return m + return 0 } func (m model) Init() tea.Cmd { @@ -364,143 +175,11 @@ func (m model) Init() tea.Cmd { } func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - if wmsg, ok := msg.(tea.WindowSizeMsg); ok { - wasSet := m.width > 0 - m.width = wmsg.Width - if wasSet { - return m, tea.EnterAltScreen - } - return m, nil - } - - if _, ok := msg.(clearStatusMsg); ok { - m.statusMsg = "" - return m, nil - } - - if m.showingSignIn { - switch msg := msg.(type) { - case tea.KeyMsg: - switch msg.Type { - case tea.KeyCtrlC, tea.KeyEsc: - m.showingSignIn = false - if m.signInFromModal { - m.showingModal = true - } - return m, nil - } - - case signInTickMsg: - m.signInSpinner++ - // Check sign-in status every 5th tick (~1 second) - if m.signInSpinner%5 == 0 { - return m, tea.Batch( - tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg { - return signInTickMsg{} - }), - checkSignIn, - ) - } - return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg { - return signInTickMsg{} - }) - - case signInCheckMsg: - if msg.signedIn { - if m.signInFromModal { - m.modalSelector.selected = m.signInModel - m.changeModel = true - } else { - m.selected = true - } - m.quitting = true - return m, tea.Quit - } - } - return m, nil - } - - if m.showingMultiModal { - switch msg := msg.(type) { - case tea.KeyMsg: - if msg.Type == tea.KeyLeft { - m.showingMultiModal = false - return m, nil - } - updated, cmd := m.multiModalSelector.Update(msg) - m.multiModalSelector = updated.(multiSelectorModel) - - if m.multiModalSelector.cancelled { - m.showingMultiModal = false - return m, nil - } - if m.multiModalSelector.confirmed { - var selected []string - if m.multiModalSelector.singleAdd != "" { - // Single-add mode: prepend picked model, keep existing deduped - selected = []string{m.multiModalSelector.singleAdd} - for _, name := range config.IntegrationModels(m.items[m.cursor].integration) { - if name != m.multiModalSelector.singleAdd { - selected = append(selected, name) - } - } - } else { - // Last checked is default (first in result) - co := m.multiModalSelector.checkOrder - last := co[len(co)-1] - selected = []string{m.multiModalSelector.items[last].Name} - for _, idx := range co { - if idx != last { - selected = append(selected, m.multiModalSelector.items[idx].Name) - } - } - } - if len(selected) > 0 { - m.changeModels = selected - m.changeModel = true - m.quitting = true - return m, tea.Quit - } - m.multiModalSelector.confirmed = false - return m, nil - } - return m, cmd - } - return m, nil - } - - if m.showingModal { - switch msg := msg.(type) { - case tea.KeyMsg: - switch msg.Type { - case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft: - m.showingModal = false - return m, nil - - case tea.KeyEnter: - filtered := m.modalSelector.filteredItems() - if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) { - m.modalSelector.selected = filtered[m.modalSelector.cursor].Name - } - if m.modalSelector.selected != "" { - if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil { - return m, cmd - } - m.changeModel = true - m.quitting = true - return m, tea.Quit - } - return m, nil - - default: - // Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel - m.modalSelector.updateNavigation(msg) - } - } - return m, nil - } - switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + return m, nil + case tea.KeyMsg: switch msg.String() { case "ctrl+c", "q", "esc": @@ -511,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.cursor > 0 { m.cursor-- } - // Auto-collapse "Others" when cursor moves back into pinned items if m.showOthers && m.cursor < len(mainMenuItems) { m.showOthers = false - m.buildItems() + m.items = buildMenuItems(m.state, false) + m.cursor = min(m.cursor, len(m.items)-1) } + return m, nil case "down", "j": if m.cursor < len(m.items)-1 { m.cursor++ } - // Auto-expand "Others..." when cursor lands on it if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers { m.showOthers = true - m.buildItems() - // cursor now points at the first "other" integration + m.items = buildMenuItems(m.state, true) } + return m, nil case "enter", " ": - item := m.items[m.cursor] - - if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) { - return m, nil + if m.selectableItem(m.items[m.cursor]) { + m.selected = true + m.action = actionForMenuItem(m.items[m.cursor], false) + m.quitting = true + return m, tea.Quit } - - var configuredModel string - if item.isRunModel { - configuredModel = config.LastModel() - } else if item.integration != "" { - configuredModel = config.IntegrationModel(item.integration) - } - if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil { - return m, cmd - } - - if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) { - if item.integration != "" && config.IsEditorIntegration(item.integration) { - m.openMultiModelModal(item.integration) - } else { - m.openModelModal(configuredModel) - } - return m, nil - } - - m.selected = true - m.quitting = true - return m, tea.Quit + return m, nil case "right", "l": item := m.items[m.cursor] - if item.integration != "" || item.isRunModel { - if item.integration != "" && !config.IsIntegrationInstalled(item.integration) { - if config.AutoInstallable(item.integration) { - // Auto-installable: select to trigger install flow - m.selected = true - m.quitting = true - return m, tea.Quit - } - return m, nil - } - if item.integration != "" && config.IsEditorIntegration(item.integration) { - m.openMultiModelModal(item.integration) - } else { - var currentModel string - if item.isRunModel { - currentModel = config.LastModel() - } else if item.integration != "" { - currentModel = config.IntegrationModel(item.integration) - } - m.openModelModal(currentModel) - } + if item.isRunModel || m.changeableItem(item) { + m.selected = true + m.action = actionForMenuItem(item, true) + m.quitting = true + return m, tea.Quit } + return m, nil } } return m, nil } +func (m model) selectableItem(item menuItem) bool { + if item.isRunModel { + return true + } + if item.integration == "" || item.isOthers { + return false + } + state, ok := m.state.Integrations[item.integration] + return ok && state.Selectable +} + +func (m model) changeableItem(item menuItem) bool { + if item.integration == "" || item.isOthers { + return false + } + state, ok := m.state.Integrations[item.integration] + return ok && state.Changeable +} + func (m model) View() string { if m.quitting { return "" } - if m.showingSignIn { - return m.renderSignInDialog() - } - - if m.showingMultiModal { - return m.multiModalSelector.View() - } - - if m.showingModal { - return m.renderModal() - } - s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n" for i, item := range m.items { - cursor := "" - style := menuItemStyle - isInstalled := true - - if item.integration != "" { - isInstalled = config.IsIntegrationInstalled(item.integration) - } - - if m.cursor == i { - cursor = "▸ " - if isInstalled { - style = menuSelectedItemStyle - } else { - style = greyedSelectedStyle - } - } else if !isInstalled && item.integration != "" { - style = greyedStyle - } - - title := item.title - var modelSuffix string - if item.integration != "" { - if !isInstalled { - if config.AutoInstallable(item.integration) { - title += " " + notInstalledStyle.Render("(install)") - } else { - title += " " + notInstalledStyle.Render("(not installed)") - } - } else if m.cursor == i { - if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) { - modelSuffix = " " + modelStyle.Render("("+mdl+")") - } - } - } else if item.isRunModel && m.cursor == i { - if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) { - modelSuffix = " " + modelStyle.Render("("+mdl+")") - } - } - - s += style.Render(cursor+title) + modelSuffix + "\n" - - desc := item.description - if !isInstalled && item.integration != "" && m.cursor == i { - if config.AutoInstallable(item.integration) { - desc = "Press enter to install" - } else if hint := config.IntegrationInstallHint(item.integration); hint != "" { - desc = hint - } else { - desc = "not installed" - } - } - s += menuDescStyle.Render(desc) + "\n\n" + s += m.renderMenuItem(i, item) } - if m.statusMsg != "" { - s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n" - } - - s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit") + s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit") if m.width > 0 { return lipgloss.NewStyle().MaxWidth(m.width).Render(s) @@ -674,80 +269,125 @@ func (m model) View() string { return s } -func (m model) renderModal() string { - modalStyle := lipgloss.NewStyle(). - PaddingBottom(1). - PaddingRight(2) +func (m model) renderMenuItem(index int, item menuItem) string { + cursor := "" + style := menuItemStyle + title := item.title + description := item.description + modelSuffix := "" - s := modalStyle.Render(m.modalSelector.renderContent()) - if m.width > 0 { - return lipgloss.NewStyle().MaxWidth(m.width).Render(s) - } - return s -} - -func (m model) renderSignInDialog() string { - return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width) -} - -type Selection int - -const ( - SelectionNone Selection = iota - SelectionRunModel - SelectionChangeRunModel - SelectionIntegration // Generic integration selection - SelectionChangeIntegration // Generic change model for integration -) - -type Result struct { - Selection Selection - Integration string // integration name if applicable - Model string // model name if selected from single-select modal - Models []string // models selected from multi-select modal (Editor integrations) -} - -func Run() (Result, error) { - m := initialModel() - p := tea.NewProgram(m) - - finalModel, err := p.Run() - if err != nil { - return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err) - } - - fm := finalModel.(model) - if fm.err != nil { - return Result{Selection: SelectionNone}, fm.err - } - - if !fm.selected && !fm.changeModel { - return Result{Selection: SelectionNone}, nil - } - - item := fm.items[fm.cursor] - - if fm.changeModel { - if item.isRunModel { - return Result{ - Selection: SelectionChangeRunModel, - Model: fm.modalSelector.selected, - }, nil - } - return Result{ - Selection: SelectionChangeIntegration, - Integration: item.integration, - Model: fm.modalSelector.selected, - Models: fm.changeModels, - }, nil + if m.cursor == index { + cursor = "▸ " } if item.isRunModel { - return Result{Selection: SelectionRunModel}, nil + if m.cursor == index && m.state.RunModel != "" { + modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")") + } + if m.cursor == index { + style = menuSelectedItemStyle + } + } else if item.isOthers { + if m.cursor == index { + style = menuSelectedItemStyle + } + } else { + integrationState := m.state.Integrations[item.integration] + if !integrationState.Selectable { + if m.cursor == index { + style = greyedSelectedStyle + } else { + style = greyedStyle + } + } else if m.cursor == index { + style = menuSelectedItemStyle + } + + if m.cursor == index && integrationState.CurrentModel != "" { + modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")") + } + + if !integrationState.Installed { + if integrationState.AutoInstallable { + title += " " + notInstalledStyle.Render("(install)") + } else { + title += " " + notInstalledStyle.Render("(not installed)") + } + if m.cursor == index { + if integrationState.AutoInstallable { + description = "Press enter to install" + } else if integrationState.InstallHint != "" { + description = integrationState.InstallHint + } else { + description = "not installed" + } + } + } } - return Result{ - Selection: SelectionIntegration, - Integration: item.integration, - }, nil + return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n" +} + +type TUIActionKind int + +const ( + TUIActionNone TUIActionKind = iota + TUIActionRunModel + TUIActionLaunchIntegration +) + +type TUIAction struct { + Kind TUIActionKind + Integration string + ForceConfigure bool +} + +func (a TUIAction) LastSelection() string { + switch a.Kind { + case TUIActionRunModel: + return "run" + case TUIActionLaunchIntegration: + return a.Integration + default: + return "" + } +} + +func (a TUIAction) RunModelRequest() launch.RunModelRequest { + return launch.RunModelRequest{ForcePicker: a.ForceConfigure} +} + +func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest { + return launch.IntegrationLaunchRequest{ + Name: a.Integration, + ForceConfigure: a.ForceConfigure, + } +} + +func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction { + switch { + case item.isRunModel: + return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure} + case item.integration != "": + return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure} + default: + return TUIAction{Kind: TUIActionNone} + } +} + +func RunMenu(state *launch.LauncherState) (TUIAction, error) { + menu := newModel(state) + program := tea.NewProgram(menu) + + finalModel, err := program.Run() + if err != nil { + return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err) + } + + finalMenu := finalModel.(model) + if !finalMenu.selected { + return TUIAction{Kind: TUIActionNone}, nil + } + + return finalMenu.action, nil } diff --git a/cmd/tui/tui_test.go b/cmd/tui/tui_test.go new file mode 100644 index 000000000..730325387 --- /dev/null +++ b/cmd/tui/tui_test.go @@ -0,0 +1,178 @@ +package tui + +import ( + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + "github.com/ollama/ollama/cmd/launch" +) + +func launcherTestState() *launch.LauncherState { + return &launch.LauncherState{ + LastSelection: "run", + RunModel: "qwen3:8b", + Integrations: map[string]launch.LauncherIntegrationState{ + "claude": { + Name: "claude", + DisplayName: "Claude Code", + Description: "Anthropic's coding tool with subagents", + Selectable: true, + Changeable: true, + CurrentModel: "glm-5:cloud", + }, + "codex": { + Name: "codex", + DisplayName: "Codex", + Description: "OpenAI's open-source coding agent", + Selectable: true, + Changeable: true, + }, + "openclaw": { + Name: "openclaw", + DisplayName: "OpenClaw", + Description: "Personal AI with 100+ skills", + Selectable: true, + Changeable: true, + AutoInstallable: true, + }, + "droid": { + Name: "droid", + DisplayName: "Droid", + Description: "Factory's coding agent across terminal and IDEs", + Selectable: true, + Changeable: true, + }, + "pi": { + Name: "pi", + DisplayName: "Pi", + Description: "Minimal AI agent toolkit with plugin support", + Selectable: true, + Changeable: true, + }, + }, + } +} + +func TestMenuRendersPinnedItemsAndMore(t *testing.T) { + view := newModel(launcherTestState()).View() + for _, want := range []string{"Run a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} { + if !strings.Contains(view, want) { + t.Fatalf("expected menu view to contain %q\n%s", want, view) + } + } +} + +func TestMenuExpandsOthersFromLastSelection(t *testing.T) { + state := launcherTestState() + state.LastSelection = "pi" + + menu := newModel(state) + if !menu.showOthers { + t.Fatal("expected others section to expand when last selection is in the overflow list") + } + view := menu.View() + if !strings.Contains(view, "Launch Pi") { + t.Fatalf("expected expanded view to contain overflow integration\n%s", view) + } + if strings.Contains(view, "More...") { + t.Fatalf("expected expanded view to replace More... item\n%s", view) + } +} + +func TestMenuEnterOnRunSelectsRun(t *testing.T) { + menu := newModel(launcherTestState()) + updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter}) + got := updated.(model) + want := TUIAction{Kind: TUIActionRunModel} + if !got.selected || got.action != want { + t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action) + } +} + +func TestMenuRightOnRunSelectsChangeRun(t *testing.T) { + menu := newModel(launcherTestState()) + updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight}) + got := updated.(model) + want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true} + if !got.selected || got.action != want { + t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action) + } +} + +func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) { + menu := newModel(launcherTestState()) + menu.cursor = 1 + updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter}) + got := updated.(model) + want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"} + if !got.selected || got.action != want { + t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action) + } +} + +func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) { + menu := newModel(launcherTestState()) + menu.cursor = 1 + updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight}) + got := updated.(model) + want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true} + if !got.selected || got.action != want { + t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action) + } +} + +func TestMenuIgnoresDisabledActions(t *testing.T) { + state := launcherTestState() + claude := state.Integrations["claude"] + claude.Selectable = false + claude.Changeable = false + state.Integrations["claude"] = claude + + menu := newModel(state) + menu.cursor = 1 + + updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if updatedEnter.(model).selected { + t.Fatal("expected non-selectable integration to ignore enter") + } + + updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight}) + if updatedRight.(model).selected { + t.Fatal("expected non-changeable integration to ignore right") + } +} + +func TestMenuShowsCurrentModelSuffixes(t *testing.T) { + menu := newModel(launcherTestState()) + runView := menu.View() + if !strings.Contains(runView, "(qwen3:8b)") { + t.Fatalf("expected run row to show current model suffix\n%s", runView) + } + + menu.cursor = 1 + integrationView := menu.View() + if !strings.Contains(integrationView, "(glm-5:cloud)") { + t.Fatalf("expected integration row to show current model suffix\n%s", integrationView) + } +} + +func TestMenuShowsInstallStatusAndHint(t *testing.T) { + state := launcherTestState() + codex := state.Integrations["codex"] + codex.Installed = false + codex.Selectable = false + codex.Changeable = false + codex.InstallHint = "Install from https://example.com/codex" + state.Integrations["codex"] = codex + + menu := newModel(state) + menu.cursor = 2 + view := menu.View() + if !strings.Contains(view, "(not installed)") { + t.Fatalf("expected not-installed marker\n%s", view) + } + if !strings.Contains(view, codex.InstallHint) { + t.Fatalf("expected install hint in description\n%s", view) + } +}