mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
cmd: refactor tui and launch (#14609)
This commit is contained in:
259
cmd/cmd.go
259
cmd/cmd.go
@@ -38,6 +38,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -58,36 +59,36 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
return nil, launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return userName, err
|
||||
}
|
||||
|
||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
ok, err := tui.RunConfirm(prompt)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return false, config.ErrCancelled
|
||||
return false, launch.ErrCancelled
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
@@ -1912,6 +1913,24 @@ func ensureServerRunning(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||
// and only validate auth/connectivity before interactive chat starts.
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return fmt.Errorf("error loading model: %w", err)
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
return fmt.Errorf("error running model: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
@@ -1920,175 +1939,81 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
deps := launcherDeps{
|
||||
buildState: launch.BuildLauncherState,
|
||||
runMenu: tui.RunMenu,
|
||||
resolveRunModel: launch.ResolveRunModel,
|
||||
launchIntegration: launch.LaunchIntegration,
|
||||
runModel: launchInteractiveModel,
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
type launcherDeps struct {
|
||||
buildState func(context.Context) (*launch.LauncherState, error)
|
||||
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||
runModel func(*cobra.Command, string) error
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||
state, err := deps.buildState(cmd.Context())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("build launcher state: %w", err)
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if len(result.Models) > 0 {
|
||||
// Filter out cloud-disabled models
|
||||
var filtered []string
|
||||
for _, m := range result.Models {
|
||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
result.Models = filtered
|
||||
// Multi-select from modal (Editor integrations)
|
||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else if result.Model != "" {
|
||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
||||
continue
|
||||
}
|
||||
// Single-select from modal - save and launch
|
||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
action, err := deps.runMenu(state)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||
}
|
||||
|
||||
return runLauncherAction(cmd, action, deps)
|
||||
}
|
||||
|
||||
func saveLauncherSelection(action tui.TUIAction) {
|
||||
// Best effort only: this affects menu recall, not launch correctness.
|
||||
_ = config.SetLastSelection(action.LastSelection())
|
||||
}
|
||||
|
||||
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||
switch action.Kind {
|
||||
case tui.TUIActionNone:
|
||||
return false, nil
|
||||
case tui.TUIActionRunModel:
|
||||
saveLauncherSelection(action)
|
||||
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("selecting model: %w", err)
|
||||
}
|
||||
if err := deps.runModel(cmd, modelName); err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
case tui.TUIActionLaunchIntegration:
|
||||
saveLauncherSelection(action)
|
||||
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2358,7 +2283,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
233
cmd/cmd_launcher_test.go
Normal file
233
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
)
|
||||
|
||||
func setCmdTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
t.Fatalf("did not expect integration launch: %+v", req)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||
t.Helper()
|
||||
return func(cmd *cobra.Command, model string) error {
|
||||
t.Fatalf("did not expect chat launch: %s", model)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "enter uses saved model flow",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||
wantModel: "qwen3:8b",
|
||||
},
|
||||
{
|
||||
name: "right forces picker",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||
wantForce: true,
|
||||
wantModel: "glm-5:cloud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return tt.wantModel, nil
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.ForcePicker != tt.wantForce {
|
||||
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||
}
|
||||
if launched != tt.wantModel {
|
||||
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||
}
|
||||
if got := config.LastSelection(); got != "run" {
|
||||
t.Fatalf("expected last selection to be run, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
}{
|
||||
{
|
||||
name: "enter launches integration",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||
},
|
||||
{
|
||||
name: "right forces configure",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||
wantForce: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.Name != "claude" {
|
||||
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||
}
|
||||
if gotReq.ForceConfigure != tt.wantForce {
|
||||
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||
}
|
||||
if got := config.LastSelection(); got != "claude" {
|
||||
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
return "", launch.ErrCancelled
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return launch.ErrCancelled
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -1797,13 +1796,16 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
name string
|
||||
model string
|
||||
showStatus int
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectWhoami bool
|
||||
expectedError string
|
||||
expectAuthError bool
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
@@ -1812,6 +1814,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
@@ -1823,7 +1826,9 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
expectWhoami: true,
|
||||
expectedError: "unauthorized",
|
||||
expectAuthError: true,
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
@@ -1840,6 +1845,17 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model without local stub returns not found by default",
|
||||
model: "minimax-m2.5:cloud",
|
||||
showStatus: http.StatusNotFound,
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectedError: "not found",
|
||||
expectWhoami: false,
|
||||
expectAuthError: false,
|
||||
},
|
||||
{
|
||||
name: "explicit -cloud model - auth check without remote metadata",
|
||||
@@ -1848,6 +1864,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "dash cloud-like name without explicit source does not require auth",
|
||||
@@ -1865,6 +1882,11 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||
w.WriteHeader(tt.showStatus)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
@@ -1901,23 +1923,22 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
if whoamiCalled != tt.expectWhoami {
|
||||
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
if tt.expectAuthError {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
return aliases, false, nil
|
||||
}
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,7 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
@@ -20,6 +19,9 @@ type integration struct {
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the persisted config for one integration.
|
||||
type IntegrationConfig = integration
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
@@ -124,7 +126,7 @@ func save(cfg *config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
return fileutil.WriteWithBackup(path, data)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||
func MarkIntegrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -228,31 +230,8 @@ func SetLastSelection(selection string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
return true
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
// LoadIntegration returns the saved config for one integration.
|
||||
func LoadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -266,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
// SaveAliases replaces the saved aliases for one integration.
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
if err := SaveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
if err := SaveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
if err := SaveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
loaded, err := LoadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
loaded, _ = LoadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
if err := SaveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
|
||||
@@ -10,17 +10,10 @@ import (
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("TMPDIR", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
config, _ := LoadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
config1, _ := LoadIntegration("app1")
|
||||
config2, _ := LoadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
_, err := loadIntegration("test")
|
||||
_, err := LoadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
config, err := LoadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
_, err := LoadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,59 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
package config
|
||||
// Package fileutil provides small shared helpers for reading JSON files
|
||||
// and writing config files with backup-on-overwrite semantics.
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -9,7 +11,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
// ReadJSON reads a JSON object file into a generic map.
|
||||
func ReadJSON(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
// BackupDir returns the shared backup directory used before overwriting files.
|
||||
func BackupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
dir := BackupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||
func WriteWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,6 +9,21 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
_ = os.RemoveAll(tmpRoot)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func isolatedTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
return t.TempDir()
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
entries, err := os.ReadDir(BackupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backupPath := filepath.Join(BackupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
t.Error("backup file not created in backup directory")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
entries1, _ := os.ReadDir(BackupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
entries2, _ := os.ReadDir(BackupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
os.Remove(filepath.Join(BackupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
backupDir := BackupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
err := WriteWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
err := WriteWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
backupPath := BackupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
77
cmd/launch/claude.go
Normal file
77
cmd/launch/claude.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration.
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
@@ -136,63 +133,19 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
t.Run("supports empty model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars(""))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
return fileutil.WriteWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"slices"
|
||||
494
cmd/launch/command_test.go
Normal file
494
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func captureStderr(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
defer func() {
|
||||
os.Stderr = oldStderr
|
||||
}()
|
||||
|
||||
done := make(chan string, 1)
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, r)
|
||||
done <- buf.String()
|
||||
}()
|
||||
|
||||
fn()
|
||||
|
||||
_ = w.Close()
|
||||
return <-done
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
if cmd.Flags().Lookup("model") == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("config") == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("yes") == nil {
|
||||
t.Error("--yes flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --model without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --config without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected flags and extra args without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var selectorCalls int
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
gotCurrent = current
|
||||
return "llama3.2", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if selectorCalls != 1 {
|
||||
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||
}
|
||||
if gotCurrent != "" {
|
||||
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||
}
|
||||
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if !pullCalled {
|
||||
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||
}
|
||||
if stub.ranModel != "missing-model" {
|
||||
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||
}
|
||||
if len(stub.edited) != 0 {
|
||||
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
gotCurrent = current
|
||||
return "qwen3:8b", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
if gotCurrent != "llama3.2" {
|
||||
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "qwen3:8b" {
|
||||
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -46,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -110,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
@@ -165,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
return settingsMap
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
func TestDroidIntegration(t *testing.T) {
|
||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||
}
|
||||
|
||||
settings, err := readJSONFile(settingsPath)
|
||||
settings, err := fileutil.ReadJSON(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("readJSONFile failed: %v", err)
|
||||
}
|
||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
|
||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should create proper sessionDefaultSettings
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
|
||||
// No customModels key at all
|
||||
original := `{
|
||||
"diffMode": "github",
|
||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||
}`
|
||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
||||
|
||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
||||
var settingsStruct droidSettings
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||
|
||||
// Original fields preserved
|
||||
if settings["diffMode"] != "github" {
|
||||
t.Error("diffMode not preserved")
|
||||
}
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("sessionDefaultSettings not preserved")
|
||||
}
|
||||
if session["autonomyMode"] != "auto-high" {
|
||||
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||
}
|
||||
|
||||
// customModels created
|
||||
models, ok := settings["customModels"].([]any)
|
||||
@@ -1,8 +1,9 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -13,12 +14,12 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
edited [][]string
|
||||
ranModel string
|
||||
editErr error
|
||||
}
|
||||
|
||||
func (s *stubEditorRunner) Run(model string, args []string) error {
|
||||
@@ -31,6 +32,9 @@ func (s *stubEditorRunner) String() string { return "StubEditor" }
|
||||
func (s *stubEditorRunner) Paths() []string { return nil }
|
||||
|
||||
func (s *stubEditorRunner) Edit(models []string) error {
|
||||
if s.editErr != nil {
|
||||
return s.editErr
|
||||
}
|
||||
cloned := append([]string(nil), models...)
|
||||
s.edited = append(s.edited, cloned)
|
||||
return nil
|
||||
@@ -111,120 +115,8 @@ func TestHasLocalModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
// Mock checkServerHeartbeat that always succeeds
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
modelFlag := cmd.Flags().Lookup("model")
|
||||
if modelFlag == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
|
||||
configFlag := cmd.Flags().Lookup("config")
|
||||
if configFlag == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmd_TUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
// Will error because claude isn't configured, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model flag provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag bypasses TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
// Will error because no integration specified, but that's OK
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config flag provided")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
func TestLookupIntegration_UnknownIntegration(t *testing.T) {
|
||||
_, _, err := LookupIntegration("unknown-integration")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -233,6 +125,17 @@ func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsIntegrationInstalled_UnknownIntegrationReturnsFalse(t *testing.T) {
|
||||
stderr := captureStderr(t, func() {
|
||||
if IsIntegrationInstalled("unknown-integration") {
|
||||
t.Fatal("expected unknown integration to report not installed")
|
||||
}
|
||||
})
|
||||
if !strings.Contains(stderr, `Ollama couldn't find integration "unknown-integration", so it'll show up as not installed.`) {
|
||||
t.Fatalf("expected unknown-integration warning, got stderr: %q", stderr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -261,19 +164,6 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
// This should not panic - cmd creation should work even with nil
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
|
||||
// PreRunE should be nil when passed nil
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
@@ -730,7 +620,7 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify loadIntegration returns the saved models
|
||||
saved, err := loadIntegration("opencode")
|
||||
saved, err := LoadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -742,153 +632,56 @@ func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t *testing.T) {
|
||||
var statusCalls int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
statusCalls++
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"llama3.2"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if !pickerCalled {
|
||||
t.Fatal("expected model picker to be called when all models are filtered")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := &launcherClient{
|
||||
apiClient: api.NewClient(u, srv.Client()),
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("opencode")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
filtered := client.filterDisabledCloudModels(context.Background(), []string{"llama3.2", "glm-5:cloud", "qwen3.5:cloud"})
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, filtered); diff != "" {
|
||||
t.Fatalf("filtered models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
if statusCalls != 1 {
|
||||
t.Fatalf("expected one cloud status lookup, got %d", statusCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) {
|
||||
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pickerCalled := false
|
||||
models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) {
|
||||
pickerCalled = true
|
||||
return []string{"qwen3.5"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
|
||||
}
|
||||
if pickerCalled {
|
||||
t.Fatal("picker should not be called when a local model remains")
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
|
||||
t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
|
||||
if err := SaveIntegration("droid", []string{"existing-model"}); err != nil {
|
||||
t.Fatalf("failed to seed config: %v", err)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
editor := &stubEditorRunner{editErr: errors.New("boom")}
|
||||
err := prepareEditorIntegration("droid", editor, editor, []string{"new-model"})
|
||||
if err == nil || !strings.Contains(err.Error(), "setup failed") {
|
||||
t.Fatalf("expected setup failure, got %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
|
||||
saved, err := LoadIntegration("droid")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload saved config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"existing-model"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &stubEditorRunner{}
|
||||
old, existed := integrations["stubeditor"]
|
||||
integrations["stubeditor"] = stub
|
||||
defer func() {
|
||||
if existed {
|
||||
integrations["stubeditor"] = old
|
||||
} else {
|
||||
delete(integrations, "stubeditor")
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := loadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
@@ -903,12 +696,237 @@ func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "test-model")
|
||||
err := showOrPullWithPolicy(context.Background(), client, "test-model", missingModelPromptPull, false)
|
||||
if err != nil {
|
||||
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"test-model"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "test-model", missingModelFail, false)
|
||||
if err != nil {
|
||||
t.Errorf("showOrPullWithPolicy should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelNotFound_FailDoesNotPromptOrPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatal("confirm prompt should not be called with fail policy")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelFail, false)
|
||||
if err == nil {
|
||||
t.Fatal("expected fail policy to return an error for missing model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ollama pull missing-model") {
|
||||
t.Fatalf("expected actionable pull guidance, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected pull not to be called with fail policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelNotFound_PromptPolicyPulls(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
if !strings.Contains(prompt, "missing-model") {
|
||||
t.Fatalf("expected prompt to mention missing model, got %q", prompt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected prompt policy to pull and succeed, got %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Fatal("expected pull to be called with prompt policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_ModelNotFound_AutoPullPolicyPullsWithoutPrompt(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("confirm prompt should not be called with auto-pull policy: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelAutoPull, false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected auto-pull policy to pull and succeed, got %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Fatal("expected pull to be called with auto-pull policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_CloudModelNotFound_FailsEarlyForAllPolicies(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatal("confirm prompt should not be called for explicit cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
for _, policy := range []missingModelPolicy{missingModelPromptPull, missingModelAutoPull, missingModelFail} {
|
||||
t.Run(fmt.Sprintf("policy=%d", policy), func(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", policy, true)
|
||||
if err == nil {
|
||||
t.Fatalf("expected cloud model not-found error for policy %d", policy)
|
||||
}
|
||||
if !strings.Contains(err.Error(), `model "glm-5:cloud" not found`) {
|
||||
t.Fatalf("expected not-found error for policy %d, got %v", policy, err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatalf("expected pull not to be called for cloud model with policy %d", policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPullWithPolicy_CloudModelDisabled_FailsWithCloudDisabledError(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatal("confirm prompt should not be called for explicit cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
for _, policy := range []missingModelPolicy{missingModelPromptPull, missingModelAutoPull, missingModelFail} {
|
||||
t.Run(fmt.Sprintf("policy=%d", policy), func(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", policy, true)
|
||||
if err == nil {
|
||||
t.Fatalf("expected cloud disabled error for policy %d", policy)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "remote inference is unavailable") {
|
||||
t.Fatalf("expected cloud disabled error for policy %d, got %v", policy, err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatalf("expected pull not to be called for cloud model with policy %d", policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -920,7 +938,7 @@ func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false)
|
||||
if err == nil {
|
||||
t.Error("showOrPull should return error when model not found and no terminal available")
|
||||
}
|
||||
@@ -945,7 +963,7 @@ func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
_ = ShowOrPull(context.Background(), client, "qwen3.5")
|
||||
_ = showOrPullWithPolicy(context.Background(), client, "qwen3.5", missingModelPromptPull, false)
|
||||
if receivedModel != "qwen3.5" {
|
||||
t.Errorf("expected Show to be called with %q, got %q", "qwen3.5", receivedModel)
|
||||
}
|
||||
@@ -981,7 +999,7 @@ func TestShowOrPull_ModelNotFound_ConfirmYes_Pulls(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false)
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed after pull, got: %v", err)
|
||||
}
|
||||
@@ -1013,13 +1031,13 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "missing-model")
|
||||
err := showOrPullWithPolicy(context.Background(), client, "missing-model", missingModelPromptPull, false)
|
||||
if err == nil {
|
||||
t.Error("ShowOrPull should return error when user declines")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||
func TestShowOrPull_CloudModel_NotFoundDoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
@@ -1047,16 +1065,19 @@ func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "glm-5:cloud")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
err := showOrPullWithPolicy(context.Background(), client, "glm-5:cloud", missingModelPromptPull, true)
|
||||
if err == nil {
|
||||
t.Error("ShowOrPull should return not-found error for cloud model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `model "glm-5:cloud" not found`) {
|
||||
t.Errorf("expected cloud model not-found error, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||
func TestShowOrPull_CloudLegacySuffix_NotFoundDoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
@@ -1084,85 +1105,18 @@ func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
err := showOrPullWithPolicy(context.Background(), client, "gpt-oss:20b-cloud", missingModelPromptPull, true)
|
||||
if err == nil {
|
||||
t.Error("ShowOrPull should return not-found error for cloud model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `model "gpt-oss:20b-cloud" not found`) {
|
||||
t.Errorf("expected cloud model not-found error, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model, got %v", err)
|
||||
}
|
||||
|
||||
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/tags":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"models":[]}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"test-user"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
single := func(title string, items []ModelItem, current string) (string, error) {
|
||||
for _, item := range items {
|
||||
if item.Name == "glm-5:cloud" {
|
||||
return item.Name, nil
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
return nil, fmt.Errorf("multi selector should not be called")
|
||||
}
|
||||
|
||||
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
|
||||
if err != nil {
|
||||
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
|
||||
}
|
||||
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
|
||||
t.Fatalf("unexpected selected models: %v", selected)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected cloud selection to skip pull")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
var hookCalled bool
|
||||
@@ -1175,7 +1129,7 @@ func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
ok, err := confirmPrompt("test prompt?")
|
||||
ok, err := ConfirmPrompt("test prompt?")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -1258,6 +1212,66 @@ func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_PreservesCancelledSignInHook(t *testing.T) {
|
||||
oldSignIn := DefaultSignIn
|
||||
DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
defer func() { DefaultSignIn = oldSignIn }()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"})
|
||||
if !errors.Is(err, ErrCancelled) {
|
||||
t.Fatalf("expected ErrCancelled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_DeclinedFallbackReturnsCancelled(t *testing.T) {
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldConfirm }()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprintf(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{"cloud-model:cloud": true}, []string{"cloud-model:cloud"})
|
||||
if !errors.Is(err, ErrCancelled) {
|
||||
t.Fatalf("expected ErrCancelled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHyperlink(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1306,7 +1320,7 @@ func TestHyperlink(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationInstallHint(t *testing.T) {
|
||||
func TestIntegration_InstallHint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -1342,7 +1356,11 @@ func TestIntegrationInstallHint(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IntegrationInstallHint(tt.input)
|
||||
got := ""
|
||||
integration, err := integrationFor(tt.input)
|
||||
if err == nil {
|
||||
got = integration.installHint
|
||||
}
|
||||
if tt.wantEmpty {
|
||||
if got != "" {
|
||||
t.Errorf("expected empty hint, got %q", got)
|
||||
@@ -1477,31 +1495,7 @@ func TestBuildModelList_Descriptions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := LaunchIntegration("nonexistent-integration")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_NotConfigured(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Claude is a known integration but not configured in temp dir
|
||||
err := LaunchIntegration("claude")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when integration is not configured")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not configured") {
|
||||
t.Errorf("error should mention 'not configured', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEditorIntegration(t *testing.T) {
|
||||
func TestIntegration_Editor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
@@ -1515,8 +1509,13 @@ func TestIsEditorIntegration(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsEditorIntegration(tt.name); got != tt.want {
|
||||
t.Errorf("IsEditorIntegration(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
got := false
|
||||
integration, err := integrationFor(tt.name)
|
||||
if err == nil {
|
||||
got = integration.editor
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("integrationFor(%q).editor = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1543,13 +1542,3 @@ func TestIntegrationModels(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSaveAndEditIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := SaveAndEditIntegration("nonexistent", []string{"model"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown integration")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown integration") {
|
||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||
}
|
||||
}
|
||||
833
cmd/launch/launch.go
Normal file
833
cmd/launch/launch.go
Normal file
@@ -0,0 +1,833 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||
type LauncherState struct {
|
||||
LastSelection string
|
||||
RunModel string
|
||||
RunModelUsable bool
|
||||
Integrations map[string]LauncherIntegrationState
|
||||
}
|
||||
|
||||
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||
type LauncherIntegrationState struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
Installed bool
|
||||
AutoInstallable bool
|
||||
Selectable bool
|
||||
Changeable bool
|
||||
CurrentModel string
|
||||
ModelUsable bool
|
||||
InstallHint string
|
||||
Editor bool
|
||||
}
|
||||
|
||||
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||
type RunModelRequest struct {
|
||||
ForcePicker bool
|
||||
}
|
||||
|
||||
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||
type LaunchConfirmMode int
|
||||
|
||||
const (
|
||||
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||
LaunchConfirmAutoApprove
|
||||
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||
LaunchConfirmRequireYes
|
||||
)
|
||||
|
||||
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||
type LaunchMissingModelMode int
|
||||
|
||||
const (
|
||||
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||
LaunchMissingModelAutoPull
|
||||
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||
LaunchMissingModelFail
|
||||
)
|
||||
|
||||
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||
type LaunchPolicy struct {
|
||||
Confirm LaunchConfirmMode
|
||||
MissingModel LaunchMissingModelMode
|
||||
}
|
||||
|
||||
func defaultLaunchPolicy(interactive bool) LaunchPolicy {
|
||||
policy := LaunchPolicy{
|
||||
Confirm: LaunchConfirmPrompt,
|
||||
MissingModel: LaunchMissingModelPromptToPull,
|
||||
}
|
||||
if !interactive {
|
||||
policy.Confirm = LaunchConfirmRequireYes
|
||||
policy.MissingModel = LaunchMissingModelFail
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||
switch p.Confirm {
|
||||
case LaunchConfirmAutoApprove:
|
||||
return launchConfirmPolicy{yes: true}
|
||||
case LaunchConfirmRequireYes:
|
||||
return launchConfirmPolicy{requireYesMessage: true}
|
||||
default:
|
||||
return launchConfirmPolicy{}
|
||||
}
|
||||
}
|
||||
|
||||
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||
switch p.MissingModel {
|
||||
case LaunchMissingModelAutoPull:
|
||||
return missingModelAutoPull
|
||||
case LaunchMissingModelFail:
|
||||
return missingModelFail
|
||||
default:
|
||||
return missingModelPromptPull
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||
type IntegrationLaunchRequest struct {
|
||||
Name string
|
||||
ModelOverride string
|
||||
ForceConfigure bool
|
||||
ConfigureOnly bool
|
||||
ExtraArgs []string
|
||||
Policy *LaunchPolicy
|
||||
}
|
||||
|
||||
var isInteractiveSession = func() bool {
|
||||
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||
}
|
||||
|
||||
// Runner executes a model with an integration.
|
||||
type Runner interface {
|
||||
Run(model string, args []string) error
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files for integrations that support model configuration.
|
||||
type Editor interface {
|
||||
Paths() []string
|
||||
Edit(models []string) error
|
||||
Models() []string
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// ModelInfo re-exports launcher model inventory details for callers.
|
||||
type ModelInfo = modelInfo
|
||||
|
||||
// ModelItem represents a model for selection UIs.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
|
||||
func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
|
||||
oldSingle := DefaultSingleSelector
|
||||
oldMulti := DefaultMultiSelector
|
||||
if single != nil {
|
||||
DefaultSingleSelector = single
|
||||
}
|
||||
if multi != nil {
|
||||
DefaultMultiSelector = multi
|
||||
}
|
||||
defer func() {
|
||||
DefaultSingleSelector = oldSingle
|
||||
DefaultMultiSelector = oldMulti
|
||||
}()
|
||||
|
||||
return LaunchIntegration(ctx, IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ForceConfigure: true,
|
||||
ConfigureOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
// ConfigureIntegration allows the user to select/change the model for an integration.
|
||||
func ConfigureIntegration(ctx context.Context, name string) error {
|
||||
return LaunchIntegration(ctx, IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ForceConfigure: true,
|
||||
ConfigureOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
// The runTUI callback is called when the root launcher UI should be shown.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
var yesFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
Flags and extra arguments require an integration name.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
policy := defaultLaunchPolicy(isInteractiveSession())
|
||||
if yesFlag {
|
||||
policy.Confirm = LaunchConfirmAutoApprove
|
||||
if policy.MissingModel == LaunchMissingModelFail {
|
||||
policy.MissingModel = LaunchMissingModelAutoPull
|
||||
}
|
||||
}
|
||||
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||
defer restoreConfirmPolicy()
|
||||
|
||||
var name string
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || len(passArgs) > 0 {
|
||||
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||
}
|
||||
runTUI(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||
modelFlag = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ModelOverride: modelFlag,
|
||||
ForceConfigure: configFlag || modelFlag == "",
|
||||
ConfigureOnly: configFlag,
|
||||
ExtraArgs: passArgs,
|
||||
Policy: &policy,
|
||||
})
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type launcherClient struct {
|
||||
apiClient *api.Client
|
||||
modelInventory []ModelInfo
|
||||
inventoryLoaded bool
|
||||
policy LaunchPolicy
|
||||
}
|
||||
|
||||
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||
apiClient, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &launcherClient{
|
||||
apiClient: apiClient,
|
||||
policy: policy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return launchClient.buildLauncherState(ctx)
|
||||
}
|
||||
|
||||
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return launchClient.resolveRunModel(ctx, req)
|
||||
}
|
||||
|
||||
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||
name, runner, err := LookupIntegration(req.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !req.ConfigureOnly {
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
policy := defaultLaunchPolicy(isInteractiveSession())
|
||||
if req.Policy != nil {
|
||||
policy = *req.Policy
|
||||
}
|
||||
|
||||
launchClient, err := newLauncherClient(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
saved, _ := loadStoredIntegrationConfig(name)
|
||||
|
||||
if editor, ok := runner.(Editor); ok {
|
||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||
}
|
||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
_ = c.loadModelInventoryOnce(ctx)
|
||||
|
||||
state := &LauncherState{
|
||||
LastSelection: config.LastSelection(),
|
||||
RunModel: config.LastModel(),
|
||||
Integrations: make(map[string]LauncherIntegrationState),
|
||||
}
|
||||
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||
if err != nil {
|
||||
runModelUsable = false
|
||||
}
|
||||
state.RunModelUsable = runModelUsable
|
||||
|
||||
for _, info := range ListIntegrationInfos() {
|
||||
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.Integrations[info.Name] = integrationState
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||
integration, err := integrationFor(info.Name)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
|
||||
return LauncherIntegrationState{
|
||||
Name: info.Name,
|
||||
DisplayName: info.DisplayName,
|
||||
Description: info.Description,
|
||||
Installed: integration.installed,
|
||||
AutoInstallable: integration.autoInstallable,
|
||||
Selectable: integration.installed || integration.autoInstallable,
|
||||
Changeable: integration.installed || integration.autoInstallable,
|
||||
CurrentModel: currentModel,
|
||||
ModelUsable: usable,
|
||||
InstallHint: integration.installHint,
|
||||
Editor: integration.editor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||
if !hasModels {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if isEditor {
|
||||
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||
if len(filtered) > 0 {
|
||||
return filtered[0], true, nil
|
||||
}
|
||||
return cfg.Models[0], false, nil
|
||||
}
|
||||
|
||||
model := cfg.Models[0]
|
||||
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||
return model, usableErr == nil && usable, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
current := config.LastModel()
|
||||
if !req.ForcePicker {
|
||||
usable, err := c.savedModelUsable(ctx, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if usable {
|
||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
current := primaryModelFromConfig(saved)
|
||||
target := req.ModelOverride
|
||||
needsConfigure := req.ForceConfigure
|
||||
|
||||
if target == "" {
|
||||
target = current
|
||||
usable, err := c.savedModelUsable(ctx, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !usable {
|
||||
needsConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = selected
|
||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, target, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
models = selected
|
||||
} else if err := c.ensureModelsReady(ctx, models); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, models[0], req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||
if selector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, _, err := c.loadSelectableModels(ctx, singleModelPrechecked(current), current, "no models available, run 'ollama pull <model>' first")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
selected, err := selector(title, items, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||
if DefaultMultiSelector == nil {
|
||||
return nil, fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, firstModel(preChecked), "no models available")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.ensureModelsReady(ctx, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, nil, errors.New(emptyMessage)
|
||||
}
|
||||
return items, orderedChecked, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||
var deduped []string
|
||||
seen := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
if model == "" || seen[model] {
|
||||
continue
|
||||
}
|
||||
seen[model] = true
|
||||
deduped = append(deduped, model)
|
||||
}
|
||||
models = deduped
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloudModels := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
isCloudModel := isCloudModelName(model)
|
||||
if isCloudModel {
|
||||
cloudModels[model] = true
|
||||
}
|
||||
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||
if req.ForceConfigure {
|
||||
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||
}
|
||||
|
||||
if req.ModelOverride != "" {
|
||||
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||
models = c.filterDisabledCloudModels(ctx, models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
if saved == nil || len(saved.Models) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
if !cloudDisabled {
|
||||
return append([]string(nil), models...)
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
if !isCloudModelName(model) {
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return c.showBasedModelUsable(ctx, name)
|
||||
}
|
||||
return c.singleModelUsable(ctx, name), nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
|
||||
return !cloudDisabled, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
return !cloudDisabled
|
||||
}
|
||||
return c.hasLocalModel(name)
|
||||
}
|
||||
|
||||
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||
for _, model := range c.modelInventory {
|
||||
if model.Remote {
|
||||
continue
|
||||
}
|
||||
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||
if c.inventoryLoaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := c.apiClient.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.modelInventory = c.modelInventory[:0]
|
||||
for _, model := range resp.Models {
|
||||
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||
Name: model.Name,
|
||||
Remote: model.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
if cloudDisabled {
|
||||
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||
}
|
||||
c.inventoryLoaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||
return runner.Run(modelName, args)
|
||||
}
|
||||
|
||||
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||
if req.ConfigureOnly {
|
||||
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !launch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(runner, model, req.ExtraArgs)
|
||||
}
|
||||
|
||||
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||
cfg, err := config.LoadIntegration(name)
|
||||
if err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spec, specErr := LookupIntegrationSpec(name)
|
||||
if specErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, alias := range spec.Aliases {
|
||||
legacy, legacyErr := config.LoadIntegration(alias)
|
||||
if legacyErr == nil {
|
||||
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||
return migrated, nil
|
||||
}
|
||||
return legacy, nil
|
||||
}
|
||||
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||
return nil, legacyErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||
if legacy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||
if len(legacy.Aliases) > 0 {
|
||||
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||
}
|
||||
if legacy.Onboarded {
|
||||
_ = config.MarkIntegrationOnboarded(canonical)
|
||||
}
|
||||
}
|
||||
|
||||
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||
if cfg == nil || len(cfg.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return cfg.Models[0]
|
||||
}
|
||||
|
||||
func cloneAliases(aliases map[string]string) map[string]string {
|
||||
if len(aliases) == 0 {
|
||||
return make(map[string]string)
|
||||
}
|
||||
|
||||
cloned := make(map[string]string, len(aliases))
|
||||
for key, value := range aliases {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func singleModelPrechecked(current string) []string {
|
||||
if current == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{current}
|
||||
}
|
||||
|
||||
func firstModel(models []string) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||
if override == "" {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]string(nil), saved.Models...)
|
||||
}
|
||||
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||
}
|
||||
|
||||
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range saved.Models {
|
||||
if model != exclude {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
1210
cmd/launch/launch_test.go
Normal file
1210
cmd/launch/launch_test.go
Normal file
File diff suppressed because it is too large
Load Diff
477
cmd/launch/models.go
Normal file
477
cmd/launch/models.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
)
|
||||
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||
}
|
||||
|
||||
var recommendedVRAM = map[string]string{
|
||||
"glm-4.7-flash": "~25GB",
|
||||
"qwen3.5": "~11GB",
|
||||
}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.5": {Context: 204_800, Output: 128_000},
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"glm-5": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||
type missingModelPolicy int
|
||||
|
||||
const (
|
||||
// missingModelPromptPull prompts the user to download missing local models.
|
||||
missingModelPromptPull missingModelPolicy = iota
|
||||
// missingModelAutoPull downloads missing local models without prompting.
|
||||
missingModelAutoPull
|
||||
// missingModelFail returns an error for missing local models without prompting.
|
||||
missingModelFail
|
||||
)
|
||||
|
||||
// OpenBrowser opens the URL in the user's browser.
|
||||
func OpenBrowser(url string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
|
||||
if DefaultSignIn != nil {
|
||||
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return ErrCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
OpenBrowser(aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if isCloudModel {
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
return fmt.Errorf("model %q not found", model)
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case missingModelAutoPull:
|
||||
return pullMissingModel(ctx, client, model)
|
||||
case missingModelFail:
|
||||
return fmt.Errorf("model %q not found; run 'ollama pull %s' first", model, model)
|
||||
default:
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullMissingModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||
paths := editor.Paths()
|
||||
if len(paths) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||
for _, path := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||
|
||||
return ConfirmPrompt("Proceed?")
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations for selection UIs.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendedModels {
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
onlyLocal := hasLocalModel && !hasCloudModel
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec != bRec {
|
||||
if aRec {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec && bRec {
|
||||
if aCloud != bCloud {
|
||||
if onlyLocal {
|
||||
if aCloud {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if aCloud {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return recRank[a.Name] - recRank[b.Name]
|
||||
}
|
||||
if aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||
func isCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// filterCloudModels drops remote-only models from the given inventory.
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterCloudItems removes cloud models from selection items.
|
||||
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !isCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -35,7 +37,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
if integrationConfig, err := loadStoredIntegrationConfig("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
@@ -45,7 +47,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -107,7 +109,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
if err := config.MarkIntegrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -166,7 +168,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
if err := config.MarkIntegrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -426,7 +428,7 @@ func ensureOpenclawInstalled() (string, error) {
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -561,7 +563,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -776,9 +778,9 @@ func (c *Openclaw) Models() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -116,9 +116,9 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
t.Fatalf("LoadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
@@ -147,7 +147,7 @@ func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
@@ -1528,7 +1528,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
@@ -1542,7 +1542,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1556,7 +1556,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
@@ -1575,7 +1575,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
@@ -12,31 +10,13 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -44,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -191,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -243,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
return fileutil.WriteWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
@@ -251,7 +212,7 @@ func (o *OpenCode) Models() []string {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -26,15 +27,6 @@ func (p *Pi) Run(model string, args []string) error {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -149,7 +141,7 @@ func (p *Pi) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -167,7 +159,7 @@ func (p *Pi) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
@@ -177,7 +169,7 @@ func (p *Pi) Models() []string {
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
config, err := fileutil.ReadJSON(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -229,8 +221,15 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
|
||||
applyCloudContextFallback := func() {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
applyCloudContextFallback()
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -248,14 +247,19 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
|
||||
// Extract context window from ModelInfo. For known cloud models, the
|
||||
// pre-filled shared limit remains unless the server provides a positive value.
|
||||
hasContextWindow := false
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
hasContextWindow = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContextWindow {
|
||||
applyCloudContextFallback()
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -840,7 +840,7 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
|
||||
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
@@ -857,7 +857,7 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
|
||||
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
@@ -877,7 +877,7 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
|
||||
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
@@ -893,7 +893,6 @@ func TestCreateConfig(t *testing.T) {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
355
cmd/launch/registry.go
Normal file
355
cmd/launch/registry.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||
type IntegrationInstallSpec struct {
|
||||
CheckInstalled func() bool
|
||||
EnsureInstalled func() error
|
||||
URL string
|
||||
Command []string
|
||||
}
|
||||
|
||||
// IntegrationSpec is the canonical registry entry for one integration.
|
||||
type IntegrationSpec struct {
|
||||
Name string
|
||||
Runner Runner
|
||||
Aliases []string
|
||||
Hidden bool
|
||||
Description string
|
||||
Install IntegrationInstallSpec
|
||||
}
|
||||
|
||||
// IntegrationInfo contains display information about a registered integration.
|
||||
type IntegrationInfo struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
Name: "claude",
|
||||
Runner: &Claude{},
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Claude{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://code.claude.com/docs/en/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cline",
|
||||
Runner: &Cline{},
|
||||
Description: "Autonomous coding agent with parallel execution",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "cline"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "codex",
|
||||
Runner: &Codex{},
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://developers.openai.com/codex/cli/",
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "opencode",
|
||||
Runner: &OpenCode{},
|
||||
Description: "Anomaly's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://opencode.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "openclaw",
|
||||
Runner: &Openclaw{},
|
||||
Aliases: []string{"clawdbot", "moltbot"},
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://docs.openclaw.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pi",
|
||||
Runner: &Pi{},
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var integrationSpecsByName map[string]*IntegrationSpec
|
||||
|
||||
func init() {
|
||||
rebuildIntegrationSpecIndexes()
|
||||
}
|
||||
|
||||
func hyperlink(url, text string) string {
|
||||
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||
}
|
||||
|
||||
func rebuildIntegrationSpecIndexes() {
|
||||
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||
|
||||
canonical := make(map[string]bool, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
key := strings.ToLower(spec.Name)
|
||||
if key == "" {
|
||||
panic("launch: integration spec missing name")
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||
}
|
||||
canonical[key] = true
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
|
||||
seenAliases := make(map[string]string)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
key := strings.ToLower(alias)
|
||||
if key == "" {
|
||||
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||
}
|
||||
if owner, exists := seenAliases[key]; exists {
|
||||
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||
}
|
||||
seenAliases[key] = spec.Name
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
}
|
||||
|
||||
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||
for _, name := range launcherIntegrationOrder {
|
||||
key := strings.ToLower(name)
|
||||
if orderSeen[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||
}
|
||||
orderSeen[key] = true
|
||||
|
||||
spec, ok := integrationSpecsByName[key]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||
}
|
||||
if spec.Name != key {
|
||||
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||
}
|
||||
if spec.Hidden {
|
||||
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||
func LookupIntegration(name string) (string, Runner, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return spec.Name, spec.Runner, nil
|
||||
}
|
||||
|
||||
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
if spec.Hidden {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, *spec)
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||
for i, name := range launcherIntegrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
|
||||
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return visible
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
infos := make([]IntegrationInfo, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
infos = append(infos, IntegrationInfo{
|
||||
Name: spec.Name,
|
||||
DisplayName: spec.Runner.String(),
|
||||
Description: spec.Description,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
if len(visible) == 0 {
|
||||
return nil, fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
items := make([]ModelItem, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
description := spec.Runner.String()
|
||||
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||
return false
|
||||
}
|
||||
return integration.installed
|
||||
}
|
||||
|
||||
// integration is resolved registry metadata used by launcher state and install checks.
|
||||
// It combines immutable registry spec data with computed runtime traits.
|
||||
type integration struct {
|
||||
spec *IntegrationSpec
|
||||
installed bool
|
||||
autoInstallable bool
|
||||
editor bool
|
||||
installHint string
|
||||
}
|
||||
|
||||
// integrationFor resolves an integration name into the canonical spec plus
|
||||
// derived launcher/install traits used across registry and launch flows.
|
||||
func integrationFor(name string) (integration, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return integration{}, err
|
||||
}
|
||||
|
||||
installed := true
|
||||
if spec.Install.CheckInstalled != nil {
|
||||
installed = spec.Install.CheckInstalled()
|
||||
}
|
||||
|
||||
_, editor := spec.Runner.(Editor)
|
||||
hint := ""
|
||||
if spec.Install.URL != "" {
|
||||
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||
} else if len(spec.Install.Command) > 0 {
|
||||
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||
}
|
||||
|
||||
return integration{
|
||||
spec: spec,
|
||||
installed: installed,
|
||||
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||
editor: editor,
|
||||
installHint: hint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
|
||||
if integration.installed {
|
||||
return nil
|
||||
}
|
||||
if integration.autoInstallable {
|
||||
return integration.spec.Install.EnsureInstalled()
|
||||
}
|
||||
|
||||
switch {
|
||||
case integration.spec.Install.URL != "":
|
||||
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||
case len(integration.spec.Install.Command) > 0:
|
||||
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||
default:
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
}
|
||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package launch
|
||||
|
||||
import "strings"
|
||||
|
||||
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||
func OverrideIntegration(name string, runner Runner) func() {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
key := strings.ToLower(name)
|
||||
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||
return func() {
|
||||
delete(integrationSpecsByName, key)
|
||||
}
|
||||
}
|
||||
|
||||
original := spec.Runner
|
||||
spec.Runner = runner
|
||||
return func() {
|
||||
spec.Runner = original
|
||||
}
|
||||
}
|
||||
68
cmd/launch/runner_exec_only_test.go
Normal file
68
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
binary string
|
||||
runner Runner
|
||||
checkPath func(home string) string
|
||||
}{
|
||||
{
|
||||
name: "droid",
|
||||
binary: "droid",
|
||||
runner: &Droid{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".factory", "settings.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opencode",
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cline",
|
||||
binary: "cline",
|
||||
runner: &Cline{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pi",
|
||||
binary: "pi",
|
||||
runner: &Pi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
binDir := t.TempDir()
|
||||
writeFakeBinary(t, binDir, tt.binary)
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
configPath := tt.checkPath(home)
|
||||
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
111
cmd/launch/selector_hooks.go
Normal file
111
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an internal alias for existing call sites.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, ensureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
type launchConfirmPolicy struct {
|
||||
yes bool
|
||||
requireYesMessage bool
|
||||
}
|
||||
|
||||
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||
|
||||
func (p launchConfirmPolicy) chain(next launchConfirmPolicy) launchConfirmPolicy {
|
||||
chained := launchConfirmPolicy{
|
||||
yes: p.yes || next.yes,
|
||||
requireYesMessage: p.requireYesMessage || next.requireYesMessage,
|
||||
}
|
||||
if chained.yes {
|
||||
chained.requireYesMessage = false
|
||||
}
|
||||
return chained
|
||||
}
|
||||
|
||||
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||
old := currentLaunchConfirmPolicy
|
||||
currentLaunchConfirmPolicy = old.chain(policy)
|
||||
return func() {
|
||||
currentLaunchConfirmPolicy = old
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmPrompt asks the user to confirm an action using the configured prompt hook.
|
||||
func ConfirmPrompt(prompt string) (bool, error) {
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return true, nil
|
||||
}
|
||||
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||
}
|
||||
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
76
cmd/launch/selector_test.go
Normal file
76
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithLaunchConfirmPolicy_ChainsAndRestores(t *testing.T) {
|
||||
oldPolicy := currentLaunchConfirmPolicy
|
||||
oldHook := DefaultConfirmPrompt
|
||||
t.Cleanup(func() {
|
||||
currentLaunchConfirmPolicy = oldPolicy
|
||||
DefaultConfirmPrompt = oldHook
|
||||
})
|
||||
|
||||
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||
var hookCalls int
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
hookCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
|
||||
ok, err := ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreInner()
|
||||
|
||||
_, err = ConfirmPrompt("test prompt")
|
||||
if err == nil {
|
||||
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreOuter()
|
||||
|
||||
ok, err = ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected hook to return true")
|
||||
}
|
||||
if hookCalls != 1 {
|
||||
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||
}
|
||||
}
|
||||
82
cmd/launch/test_config_helpers_test.go
Normal file
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
var (
|
||||
integrations map[string]Runner
|
||||
integrationAliases map[string]bool
|
||||
integrationOrder = launcherIntegrationOrder
|
||||
)
|
||||
|
||||
func init() {
|
||||
integrations = buildTestIntegrations()
|
||||
integrationAliases = buildTestIntegrationAliases()
|
||||
}
|
||||
|
||||
func buildTestIntegrations() map[string]Runner {
|
||||
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||
for name, spec := range integrationSpecsByName {
|
||||
result[strings.ToLower(name)] = spec.Runner
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildTestIntegrationAliases() map[string]bool {
|
||||
result := make(map[string]bool)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
result[strings.ToLower(alias)] = true
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
setLaunchTestHome(t, dir)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
return config.SaveIntegration(appName, models)
|
||||
}
|
||||
|
||||
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||
return config.LoadIntegration(appName)
|
||||
}
|
||||
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
return config.SaveAliases(appName, aliases)
|
||||
}
|
||||
|
||||
func LastModel() string {
|
||||
return config.LastModel()
|
||||
}
|
||||
|
||||
func SetLastModel(model string) error {
|
||||
return config.SetLastModel(model)
|
||||
}
|
||||
|
||||
func LastSelection() string {
|
||||
return config.LastSelection()
|
||||
}
|
||||
|
||||
func SetLastSelection(selection string) error {
|
||||
return config.SetLastSelection(selection)
|
||||
}
|
||||
|
||||
func IntegrationModel(appName string) string {
|
||||
return config.IntegrationModel(appName)
|
||||
}
|
||||
|
||||
func IntegrationModels(appName string) []string {
|
||||
return config.IntegrationModels(appName)
|
||||
}
|
||||
|
||||
func integrationOnboarded(appName string) error {
|
||||
return config.MarkIntegrationOnboarded(appName)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -64,8 +64,8 @@ type SelectItem struct {
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
||||
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||
out := make([]SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
@@ -101,6 +101,16 @@ type selectorModel struct {
|
||||
width int
|
||||
}
|
||||
|
||||
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
return m
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
@@ -382,11 +392,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m := selectorModelWithCurrent(title, items, current)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
|
||||
@@ -216,6 +216,22 @@ func TestUpdateScroll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||
|
||||
if m.cursor != 11 {
|
||||
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||
}
|
||||
if m.scrollOffset == 0 {
|
||||
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||
}
|
||||
|
||||
content := m.renderContent()
|
||||
if !strings.Contains(content, "▸ other-10") {
|
||||
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
@@ -104,9 +113,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
config.OpenBrowser(signInURL)
|
||||
launch.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
|
||||
846
cmd/tui/tui.go
846
cmd/tui/tui.go
@@ -1,17 +1,11 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -46,7 +40,7 @@ var (
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
integration string
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
@@ -58,18 +52,12 @@ var mainMenuItems = []menuItem{
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Agentic coding across large codebases",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Codex",
|
||||
description: "OpenAI's open-source coding agent",
|
||||
integration: "codex",
|
||||
},
|
||||
{
|
||||
title: "Launch OpenClaw",
|
||||
description: "Personal AI with 100+ skills",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
@@ -80,283 +68,106 @@ var othersMenuItem = menuItem{
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
||||
func getOtherIntegrations() []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"run": true, // not an integration but in the pinned list
|
||||
type model struct {
|
||||
state *launch.LauncherState
|
||||
items []menuItem
|
||||
cursor int
|
||||
showOthers bool
|
||||
width int
|
||||
quitting bool
|
||||
selected bool
|
||||
action TUIAction
|
||||
}
|
||||
|
||||
func newModel(state *launch.LauncherState) model {
|
||||
m := model{
|
||||
state: state,
|
||||
}
|
||||
m.showOthers = shouldExpandOthers(state)
|
||||
m.items = buildMenuItems(state, m.showOthers)
|
||||
m.cursor = initialCursor(state, m.items)
|
||||
return m
|
||||
}
|
||||
|
||||
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||
if state == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range otherIntegrationItems(state) {
|
||||
if item.integration == state.LastSelection {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration != "" {
|
||||
pinned[item.integration] = true
|
||||
if item.integration == "" {
|
||||
items = append(items, item)
|
||||
continue
|
||||
}
|
||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
}
|
||||
|
||||
var others []menuItem
|
||||
for _, info := range config.ListIntegrationInfos() {
|
||||
if showOthers {
|
||||
items = append(items, otherIntegrationItems(state)...)
|
||||
} else {
|
||||
items = append(items, othersMenuItem)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||
description := state.Description
|
||||
if description == "" {
|
||||
description = "Open " + state.DisplayName + " integration"
|
||||
}
|
||||
return menuItem{
|
||||
title: "Launch " + state.DisplayName,
|
||||
description: description,
|
||||
integration: state.Name,
|
||||
}
|
||||
}
|
||||
|
||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"claude": true,
|
||||
"codex": true,
|
||||
"openclaw": true,
|
||||
}
|
||||
|
||||
var items []menuItem
|
||||
for _, info := range launch.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
desc := info.Description
|
||||
if desc == "" {
|
||||
desc = "Open " + info.DisplayName + " integration"
|
||||
}
|
||||
others = append(others, menuItem{
|
||||
title: "Launch " + info.DisplayName,
|
||||
description: desc,
|
||||
integration: info.Name,
|
||||
})
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool
|
||||
changeModel bool
|
||||
changeModels []string // multi-select result for Editor integrations
|
||||
showOthers bool
|
||||
availableModels map[string]bool
|
||||
err error
|
||||
|
||||
showingModal bool
|
||||
modalSelector selectorModel
|
||||
modalItems []SelectItem
|
||||
|
||||
showingMultiModal bool
|
||||
multiModalSelector multiSelectorModel
|
||||
|
||||
showingSignIn bool
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
}
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if modelref.HasExplicitCloudSource(name) {
|
||||
return true
|
||||
}
|
||||
if m.availableModels == nil {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
return ReorderItems(ConvertItems(modelItems))
|
||||
}
|
||||
|
||||
func (m *model) openModelModal(currentModel string) {
|
||||
m.modalItems = m.buildModalItems()
|
||||
cursor := 0
|
||||
if currentModel != "" {
|
||||
for i, item := range m.modalItems {
|
||||
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
|
||||
cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
cursor: cursor,
|
||||
helpText: "↑/↓ navigate • enter select • ← back",
|
||||
}
|
||||
m.modalSelector.updateScroll(m.modalSelector.otherStart())
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
func (m *model) openMultiModelModal(integration string) {
|
||||
items := m.buildModalItems()
|
||||
var preChecked []string
|
||||
if models := config.IntegrationModels(integration); len(models) > 0 {
|
||||
preChecked = models
|
||||
}
|
||||
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
|
||||
// Set cursor to the first pre-checked (last used) model
|
||||
if len(preChecked) > 0 {
|
||||
for i, item := range items {
|
||||
if item.Name == preChecked[0] {
|
||||
m.multiModalSelector.cursor = i
|
||||
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.showingMultiModal = true
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
status, err := client.CloudStatusExperimental(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return status.Cloud.Disabled
|
||||
}
|
||||
|
||||
func cloudModelDisabled(name string) bool {
|
||||
if !isCloudModel(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cloudStatusDisabled(client)
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if cloudStatusDisabled(client) {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cloudDisabled := cloudStatusDisabled(client)
|
||||
for _, mdl := range models.Models {
|
||||
if cloudDisabled && mdl.RemoteModel != "" {
|
||||
integrationState, ok := state.Integrations[info.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m.availableModels[mdl.Name] = true
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||
if state == nil || state.LastSelection == "" {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func isOthersIntegration(name string) bool {
|
||||
for _, item := range getOtherIntegrations() {
|
||||
if item.integration == name {
|
||||
return true
|
||||
for i, item := range items {
|
||||
if state.LastSelection == "run" && item.isRunModel {
|
||||
return i
|
||||
}
|
||||
if item.integration == state.LastSelection {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
@@ -364,143 +175,11 @@ func (m model) Init() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
||||
wasSet := m.width > 0
|
||||
m.width = wmsg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if _, ok := msg.(clearStatusMsg); ok {
|
||||
m.statusMsg = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
if m.signInFromModal {
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyLeft {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
updated, cmd := m.multiModalSelector.Update(msg)
|
||||
m.multiModalSelector = updated.(multiSelectorModel)
|
||||
|
||||
if m.multiModalSelector.cancelled {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.multiModalSelector.confirmed = false
|
||||
return m, nil
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
||||
m.modalSelector.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
@@ -511,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
m.showOthers = false
|
||||
m.buildItems()
|
||||
m.items = buildMenuItems(m.state, false)
|
||||
m.cursor = min(m.cursor, len(m.items)-1)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
// Auto-expand "Others..." when cursor lands on it
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.buildItems()
|
||||
// cursor now points at the first "other" integration
|
||||
m.items = buildMenuItems(m.state, true)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
return m, nil
|
||||
if m.selectableItem(m.items[m.cursor]) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
m.openModelModal(configuredModel)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
return m, nil
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
var currentModel string
|
||||
if item.isRunModel {
|
||||
currentModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
currentModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
m.openModelModal(currentModel)
|
||||
}
|
||||
if item.isRunModel || m.changeableItem(item) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(item, true)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) selectableItem(item menuItem) bool {
|
||||
if item.isRunModel {
|
||||
return true
|
||||
}
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Selectable
|
||||
}
|
||||
|
||||
func (m model) changeableItem(item menuItem) bool {
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Changeable
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
return m.multiModalSelector.View()
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = menuSelectedItemStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
} else if item.isRunModel && m.cursor == i {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
}
|
||||
}
|
||||
s += menuDescStyle.Render(desc) + "\n\n"
|
||||
s += m.renderMenuItem(i, item)
|
||||
}
|
||||
|
||||
if m.statusMsg != "" {
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
@@ -674,80 +269,125 @@ func (m model) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
PaddingBottom(1).
|
||||
PaddingRight(2)
|
||||
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
title := item.title
|
||||
description := item.description
|
||||
modelSuffix := ""
|
||||
|
||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderSignInDialog() string {
|
||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
||||
}
|
||||
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from single-select modal
|
||||
Models []string // models selected from multi-select modal (Editor integrations)
|
||||
}
|
||||
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
Models: fm.changeModels,
|
||||
}, nil
|
||||
if m.cursor == index {
|
||||
cursor = "▸ "
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
if m.cursor == index && m.state.RunModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||
}
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else if item.isOthers {
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else {
|
||||
integrationState := m.state.Integrations[item.integration]
|
||||
if !integrationState.Selectable {
|
||||
if m.cursor == index {
|
||||
style = greyedSelectedStyle
|
||||
} else {
|
||||
style = greyedStyle
|
||||
}
|
||||
} else if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
|
||||
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||
}
|
||||
|
||||
if !integrationState.Installed {
|
||||
if integrationState.AutoInstallable {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
if m.cursor == index {
|
||||
if integrationState.AutoInstallable {
|
||||
description = "Press enter to install"
|
||||
} else if integrationState.InstallHint != "" {
|
||||
description = integrationState.InstallHint
|
||||
} else {
|
||||
description = "not installed"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||
}
|
||||
|
||||
type TUIActionKind int
|
||||
|
||||
const (
|
||||
TUIActionNone TUIActionKind = iota
|
||||
TUIActionRunModel
|
||||
TUIActionLaunchIntegration
|
||||
)
|
||||
|
||||
type TUIAction struct {
|
||||
Kind TUIActionKind
|
||||
Integration string
|
||||
ForceConfigure bool
|
||||
}
|
||||
|
||||
func (a TUIAction) LastSelection() string {
|
||||
switch a.Kind {
|
||||
case TUIActionRunModel:
|
||||
return "run"
|
||||
case TUIActionLaunchIntegration:
|
||||
return a.Integration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||
}
|
||||
|
||||
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||
return launch.IntegrationLaunchRequest{
|
||||
Name: a.Integration,
|
||||
ForceConfigure: a.ForceConfigure,
|
||||
}
|
||||
}
|
||||
|
||||
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||
case item.integration != "":
|
||||
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||
default:
|
||||
return TUIAction{Kind: TUIActionNone}
|
||||
}
|
||||
}
|
||||
|
||||
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||
menu := newModel(state)
|
||||
program := tea.NewProgram(menu)
|
||||
|
||||
finalModel, err := program.Run()
|
||||
if err != nil {
|
||||
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
finalMenu := finalModel.(model)
|
||||
if !finalMenu.selected {
|
||||
return TUIAction{Kind: TUIActionNone}, nil
|
||||
}
|
||||
|
||||
return finalMenu.action, nil
|
||||
}
|
||||
|
||||
178
cmd/tui/tui_test.go
Normal file
178
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func launcherTestState() *launch.LauncherState {
|
||||
return &launch.LauncherState{
|
||||
LastSelection: "run",
|
||||
RunModel: "qwen3:8b",
|
||||
Integrations: map[string]launch.LauncherIntegrationState{
|
||||
"claude": {
|
||||
Name: "claude",
|
||||
DisplayName: "Claude Code",
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
CurrentModel: "glm-5:cloud",
|
||||
},
|
||||
"codex": {
|
||||
Name: "codex",
|
||||
DisplayName: "Codex",
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"openclaw": {
|
||||
Name: "openclaw",
|
||||
DisplayName: "OpenClaw",
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
AutoInstallable: true,
|
||||
},
|
||||
"droid": {
|
||||
Name: "droid",
|
||||
DisplayName: "Droid",
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"pi": {
|
||||
Name: "pi",
|
||||
DisplayName: "Pi",
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
view := newModel(launcherTestState()).View()
|
||||
for _, want := range []string{"Run a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
state.LastSelection = "pi"
|
||||
|
||||
menu := newModel(state)
|
||||
if !menu.showOthers {
|
||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "Launch Pi") {
|
||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "More...") {
|
||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
claude := state.Integrations["claude"]
|
||||
claude.Selectable = false
|
||||
claude.Changeable = false
|
||||
state.Integrations["claude"] = claude
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 1
|
||||
|
||||
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if updatedEnter.(model).selected {
|
||||
t.Fatal("expected non-selectable integration to ignore enter")
|
||||
}
|
||||
|
||||
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
if updatedRight.(model).selected {
|
||||
t.Fatal("expected non-changeable integration to ignore right")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
runView := menu.View()
|
||||
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||
}
|
||||
|
||||
menu.cursor = 1
|
||||
integrationView := menu.View()
|
||||
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
codex := state.Integrations["codex"]
|
||||
codex.Installed = false
|
||||
codex.Selectable = false
|
||||
codex.Changeable = false
|
||||
codex.InstallHint = "Install from https://example.com/codex"
|
||||
state.Integrations["codex"] = codex
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 2
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "(not installed)") {
|
||||
t.Fatalf("expected not-installed marker\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, codex.InstallHint) {
|
||||
t.Fatalf("expected install hint in description\n%s", view)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user