launch: improve multi-select for already added models (#15113)

This commit is contained in:
Parth Sareen
2026-03-28 13:44:40 -07:00
committed by GitHub
parent 6214103e66
commit 7c8da5679e
2 changed files with 519 additions and 14 deletions

View File

@@ -494,8 +494,10 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
return err
}
models = selected
} else if err := c.ensureModelsReady(ctx, models); err != nil {
return err
} else if len(models) > 0 {
if err := c.ensureModelsReady(ctx, models[:1]); err != nil {
return err
}
}
if len(models) == 0 {
@@ -560,10 +562,14 @@ func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, ru
if err != nil {
return nil, err
}
if err := c.ensureModelsReady(ctx, selected); err != nil {
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
if err != nil {
return nil, err
}
return selected, nil
for _, skip := range skipped {
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
}
return accepted, nil
}
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
@@ -584,16 +590,7 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
}
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
var deduped []string
seen := make(map[string]bool, len(models))
for _, model := range models {
if model == "" || seen[model] {
continue
}
seen[model] = true
deduped = append(deduped, model)
}
models = deduped
models = dedupeModelList(models)
if len(models) == 0 {
return nil
}
@@ -611,6 +608,56 @@ func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string)
return ensureAuth(ctx, c.apiClient, cloudModels, models)
}
func dedupeModelList(models []string) []string {
deduped := make([]string, 0, len(models))
seen := make(map[string]bool, len(models))
for _, model := range models {
if model == "" || seen[model] {
continue
}
seen[model] = true
deduped = append(deduped, model)
}
return deduped
}
type skippedModel struct {
model string
reason string
}
func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected []string) ([]string, []skippedModel, error) {
selected = dedupeModelList(selected)
accepted := make([]string, 0, len(selected))
skipped := make([]skippedModel, 0, len(selected))
for _, model := range selected {
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, err
}
skipped = append(skipped, skippedModel{
model: model,
reason: skippedModelReason(model, err),
})
continue
}
accepted = append(accepted, model)
}
return accepted, skipped, nil
}
func skippedModelReason(model string, err error) string {
if errors.Is(err, ErrCancelled) {
if isCloudModelName(model) {
return "sign in was cancelled"
}
return "download was cancelled"
}
return err.Error()
}
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
if req.ForceConfigure {
return editorPreCheckedModels(saved, req.ModelOverride), true

View File

@@ -834,6 +834,403 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T)
}
}
func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccepted(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return []string{"glm-5:cloud", "missing-local"}, nil
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if prompt == "Proceed?" {
return true, nil
}
if prompt == "Download missing-local?" {
return false, nil
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "glm-5:cloud":
fmt.Fprint(w, `{"remote_model":"glm-5"}`)
case "missing-local":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
default:
http.NotFound(w, r)
}
case "/api/me":
fmt.Fprint(w, `{"name":"test-user"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
var launchErr error
stderr := captureStderr(t, func() {
launchErr = LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "droid",
ForceConfigure: true,
})
})
if launchErr != nil {
t.Fatalf("LaunchIntegration returned error: %v", launchErr)
}
if editor.ranModel != "glm-5:cloud" {
t.Fatalf("expected launch to use cloud primary, got %q", editor.ranModel)
}
saved, err := config.LoadIntegration("droid")
if err != nil {
t.Fatalf("failed to reload saved config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"glm-5:cloud"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
if diff := compareStringSlices(editor.edited, [][]string{{"glm-5:cloud"}}); diff != "" {
t.Fatalf("unexpected edited models (-want +got):\n%s", diff)
}
if !strings.Contains(stderr, "Skipped missing-local:") {
t.Fatalf("expected skip reason in stderr, got %q", stderr)
}
}
func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAccepted(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return []string{"llama3.2", "glm-5:cloud"}, nil
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if prompt == "Proceed?" {
return true, nil
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
DefaultSignIn = func(modelName, signInURL string) (string, error) {
return "", ErrCancelled
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "llama3.2":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "glm-5:cloud":
fmt.Fprint(w, `{"remote_model":"glm-5"}`)
default:
http.NotFound(w, r)
}
case "/api/me":
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
var launchErr error
stderr := captureStderr(t, func() {
launchErr = LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "droid",
ForceConfigure: true,
})
})
if launchErr != nil {
t.Fatalf("LaunchIntegration returned error: %v", launchErr)
}
if editor.ranModel != "llama3.2" {
t.Fatalf("expected launch to use local primary, got %q", editor.ranModel)
}
saved, err := config.LoadIntegration("droid")
if err != nil {
t.Fatalf("failed to reload saved config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"llama3.2"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" {
t.Fatalf("unexpected edited models (-want +got):\n%s", diff)
}
if !strings.Contains(stderr, "Skipped glm-5:cloud: sign in was cancelled") {
t.Fatalf("expected skip reason in stderr, got %q", stderr)
}
}
func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
if err := config.SaveIntegration("droid", []string{"glm-5:cloud", "llama3.2"}); err != nil {
t.Fatalf("failed to seed config: %v", err)
}
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return append([]string(nil), preChecked...), nil
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if prompt == "Proceed?" {
return true, nil
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
DefaultSignIn = func(modelName, signInURL string) (string, error) {
return "", ErrCancelled
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"},{"name":"llama3.2"}]}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
if req.Model == "glm-5:cloud" {
fmt.Fprint(w, `{"remote_model":"glm-5"}`)
return
}
if req.Model == "llama3.2" {
fmt.Fprint(w, `{"model":"llama3.2"}`)
return
}
http.NotFound(w, r)
case "/api/me":
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error":"unauthorized","signin_url":"https://example.com/signin"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
var launchErr error
stderr := captureStderr(t, func() {
launchErr = LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "droid",
ForceConfigure: true,
})
})
if launchErr != nil {
t.Fatalf("LaunchIntegration returned error: %v", launchErr)
}
if editor.ranModel != "llama3.2" {
t.Fatalf("expected launch to use surviving model, got %q", editor.ranModel)
}
if diff := compareStringSlices(editor.edited, [][]string{{"llama3.2"}}); diff != "" {
t.Fatalf("unexpected edited models (-want +got):\n%s", diff)
}
saved, loadErr := config.LoadIntegration("droid")
if loadErr != nil {
t.Fatalf("failed to reload saved config: %v", loadErr)
}
if diff := compareStrings(saved.Models, []string{"llama3.2"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
if !strings.Contains(stderr, "Skipped glm-5:cloud: sign in was cancelled") {
t.Fatalf("expected skip reason in stderr, got %q", stderr)
}
}
func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsLaunch(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
if err := config.SaveIntegration("droid", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed config: %v", err)
}
DefaultMultiSelector = func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return []string{"missing-local-a", "missing-local-b"}, nil
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if prompt == "Download missing-local-a?" || prompt == "Download missing-local-b?" {
return false, nil
}
if prompt == "Proceed?" {
t.Fatal("did not expect proceed prompt when no models are accepted")
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "missing-local-a", "missing-local-b":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
default:
http.NotFound(w, r)
}
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
var launchErr error
stderr := captureStderr(t, func() {
launchErr = LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "droid",
ForceConfigure: true,
})
})
if launchErr != nil {
t.Fatalf("LaunchIntegration returned error: %v", launchErr)
}
if editor.ranModel != "" {
t.Fatalf("expected no launch when all selected models are skipped, got %q", editor.ranModel)
}
if len(editor.edited) != 0 {
t.Fatalf("expected no editor writes when all selections fail, got %v", editor.edited)
}
saved, err := config.LoadIntegration("droid")
if err != nil {
t.Fatalf("failed to reload saved config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"llama3.2"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
if !strings.Contains(stderr, "Skipped missing-local-a:") {
t.Fatalf("expected first skip reason in stderr, got %q", stderr)
}
if !strings.Contains(stderr, "Skipped missing-local-b:") {
t.Fatalf("expected second skip reason in stderr, got %q", stderr)
}
}
func TestLaunchIntegration_ConfiguredEditorLaunchValidatesPrimaryOnly(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "droid")
t.Setenv("PATH", binDir)
editor := &launcherEditorRunner{}
withIntegrationOverride(t, "droid", editor)
if err := config.SaveIntegration("droid", []string{"llama3.2", "missing-local"}); err != nil {
t.Fatalf("failed to seed config: %v", err)
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt during normal configured launch: %q", prompt)
return false, nil
}
var missingShowCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "llama3.2":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "missing-local":
missingShowCalled = true
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "droid"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if missingShowCalled {
t.Fatal("expected configured launch to validate only the primary model")
}
if editor.ranModel != "llama3.2" {
t.Fatalf("expected launch to use saved primary model, got %q", editor.ranModel)
}
if len(editor.edited) != 0 {
t.Fatalf("expected no editor writes during normal launch, got %v", editor.edited)
}
saved, err := config.LoadIntegration("droid")
if err != nil {
t.Fatalf("failed to reload saved config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"llama3.2", "missing-local"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
}
func TestLaunchIntegration_ConfiguredEditorLaunchSkipsReconfigure(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
@@ -1158,6 +1555,67 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) {
}
}
func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
binDir := t.TempDir()
writeFakeBinary(t, binDir, "claude")
t.Setenv("PATH", binDir)
if err := config.SaveIntegration("claude", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed config: %v", err)
}
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
return "missing-model", nil
}
DefaultConfirmPrompt = func(prompt string) (bool, error) {
if prompt == "Download missing-model?" {
return false, nil
}
t.Fatalf("unexpected prompt: %q", prompt)
return false, nil
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
var req apiShowRequest
_ = json.NewDecoder(r.Body).Decode(&req)
if req.Model == "missing-model" {
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
return
}
fmt.Fprintf(w, `{"model":%q}`, req.Model)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "claude",
ForceConfigure: true,
})
if err == nil {
t.Fatal("expected missing selected model to abort launch")
}
saved, loadErr := config.LoadIntegration("claude")
if loadErr != nil {
t.Fatalf("failed to reload saved config: %v", loadErr)
}
if diff := compareStrings(saved.Models, []string{"llama3.2"}); diff != "" {
t.Fatalf("unexpected saved models (-want +got):\n%s", diff)
}
}
func TestLaunchIntegration_ClaudeModelOverrideSkipsSelector(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)