mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 01:54:17 +02:00
cmd: refactor tui and launch (#14609)
This commit is contained in:
@@ -1,187 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
return aliases, false, nil
|
||||
}
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0]
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0]
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,7 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
@@ -20,6 +19,9 @@ type integration struct {
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the persisted config for one integration.
|
||||
type IntegrationConfig = integration
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
@@ -124,7 +126,7 @@ func save(cfg *config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
return fileutil.WriteWithBackup(path, data)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||
func MarkIntegrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -228,31 +230,8 @@ func SetLastSelection(selection string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
return true
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
// LoadIntegration returns the saved config for one integration.
|
||||
func LoadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -266,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
// SaveAliases replaces the saved aliases for one integration.
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
if err := SaveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
if err := SaveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
if err := SaveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
loaded, err := LoadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
loaded, _ = LoadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
if err := SaveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
|
||||
@@ -10,17 +10,10 @@ import (
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("TMPDIR", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
config, _ := LoadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
config1, _ := LoadIntegration("app1")
|
||||
config2, _ := LoadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
_, err := loadIntegration("test")
|
||||
_, err := LoadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
config, err := LoadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
_, err := LoadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||
type droidSettings struct {
|
||||
CustomModels []modelEntry `json:"customModels"`
|
||||
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||
}
|
||||
|
||||
type sessionSettings struct {
|
||||
Model string `json:"model"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
type modelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read file once, unmarshal twice:
|
||||
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||
settingsMap := make(map[string]any)
|
||||
var settings droidSettings
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||
for _, raw := range rawModels {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
if m["apiKey"] != "ollama" {
|
||||
nonOllamaModels = append(nonOllamaModels, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: false,
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||
|
||||
// Update session default settings (preserve unknown fields in the nested object)
|
||||
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings droidSettings
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, m := range settings.CustomModels {
|
||||
if m.APIKey == "ollama" {
|
||||
result = append(result, m.Model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,99 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if !bytes.Equal(existingContent, data) {
|
||||
backupPath, err = backupToTmp(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,502 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := backupToTmp(path)
|
||||
if err != nil {
|
||||
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,801 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const defaultGatewayPort = 18789
|
||||
|
||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||
var openclawModelShowTimeout = 5 * time.Second
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin, err := ensureOpenclawInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
"--install-daemon",
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
|
||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
||||
// that Edit() wrote before Run() was called.
|
||||
if err := c.Edit([]string{model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
if len(args) > 0 {
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
if err := restart.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
if !waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// If the gateway isn't running, start it as a background child process.
|
||||
if !portOpen(addr) {
|
||||
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||
gw.Env = openclawEnv()
|
||||
if err := gw.Start(); err != nil {
|
||||
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||
}
|
||||
defer func() {
|
||||
if gw.Process != nil {
|
||||
_ = gw.Process.Kill()
|
||||
_ = gw.Wait()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
if !waitForPort(addr, 30*time.Second) {
|
||||
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||
}
|
||||
|
||||
printOpenclawReady(bin, token, port, firstLaunch)
|
||||
|
||||
tuiArgs := []string{"tui"}
|
||||
if firstLaunch {
|
||||
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||
}
|
||||
tui := exec.Command(bin, tuiArgs...)
|
||||
tui.Env = openclawEnv()
|
||||
tui.Stdin = os.Stdin
|
||||
tui.Stdout = os.Stdout
|
||||
tui.Stderr = os.Stderr
|
||||
if err := tui.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||
port = defaultGatewayPort
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", port
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
continue
|
||||
}
|
||||
gw, _ := config["gateway"].(map[string]any)
|
||||
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||
port = int(p)
|
||||
}
|
||||
auth, _ := gw["auth"].(map[string]any)
|
||||
if t, _ := auth["token"].(string); t != "" {
|
||||
token = t
|
||||
}
|
||||
return token, port
|
||||
}
|
||||
return "", port
|
||||
}
|
||||
|
||||
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||
u := fmt.Sprintf("http://localhost:%d", port)
|
||||
if token != "" {
|
||||
u += "/#token=" + url.QueryEscape(token)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// openclawEnv returns the current environment with provider API keys cleared
|
||||
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||
func openclawEnv() []string {
|
||||
clear := map[string]bool{
|
||||
"ANTHROPIC_API_KEY": true,
|
||||
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||
"OPENAI_API_KEY": true,
|
||||
"GEMINI_API_KEY": true,
|
||||
"MISTRAL_API_KEY": true,
|
||||
"GROQ_API_KEY": true,
|
||||
"XAI_API_KEY": true,
|
||||
"OPENROUTER_API_KEY": true,
|
||||
}
|
||||
var env []string
|
||||
for _, e := range os.Environ() {
|
||||
key, _, _ := strings.Cut(e, "=")
|
||||
if !clear[key] {
|
||||
env = append(env, e)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// portOpen checks if a TCP port is currently accepting connections.
|
||||
func portOpen(addr string) bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func waitForPort(addr string, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func windowsHint(err error) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\n"+
|
||||
"OpenClaw runs best on WSL2.\n"+
|
||||
"Quick setup: wsl --install\n"+
|
||||
"Guide: https://docs.openclaw.ai/windows", err)
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||
// operator.admin. Only patches the local device, not remote ones.
|
||||
// Best-effort: silently returns on any error.
|
||||
func patchDeviceScopes() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := readLocalDeviceID(home)
|
||||
if deviceID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var devices map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &devices); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, ok := devices[deviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"operator.read",
|
||||
"operator.admin",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
changed := patchScopes(dev, "scopes", required)
|
||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||
for _, tok := range tokens {
|
||||
if tokenMap, ok := tok.(map[string]any); ok {
|
||||
if patchScopes(tokenMap, "scopes", required) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(devices, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||
func readLocalDeviceID(home string) string {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var auth map[string]any
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return ""
|
||||
}
|
||||
id, _ := auth["deviceId"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||
// any scopes were added.
|
||||
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
existing, _ := obj[key].([]any)
|
||||
have := make(map[string]bool, len(existing))
|
||||
for _, s := range existing {
|
||||
if str, ok := s.(string); ok {
|
||||
have[str] = true
|
||||
}
|
||||
}
|
||||
added := false
|
||||
for _, s := range required {
|
||||
if !have[s] {
|
||||
existing = append(existing, s)
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if added {
|
||||
obj[key] = existing
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("openclaw installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return []string{legacy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String()
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
ollama["api"] = "ollama"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
for _, m := range models {
|
||||
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[m]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear any per-session model overrides so the new primary takes effect
|
||||
// immediately rather than being shadowed by a cached modelOverride.
|
||||
clearSessionModelOverride(models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearSessionModelOverride removes per-session model overrides from the main
|
||||
// agent session so the global primary model takes effect on the next TUI launch.
|
||||
func clearSessionModelOverride(primary string) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if json.Unmarshal(data, &sessions) != nil {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, sess := range sessions {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
out, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
|
||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||
// present. Returns true if the extension is available.
|
||||
func ensureWebSearchPlugin() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
||||
return true // already installed
|
||||
}
|
||||
|
||||
npmBin, err := exec.LookPath("npm")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||
out, err := pack.Output()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
tgzName := strings.TrimSpace(string(out))
|
||||
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||
defer os.Remove(tgzPath)
|
||||
|
||||
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||
if err := tar.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||
return true
|
||||
}
|
||||
|
||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||
// on any error.
|
||||
func registerWebSearchPlugin() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
plugins = make(map[string]any)
|
||||
}
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
return // already registered
|
||||
}
|
||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||
plugins["entries"] = entries
|
||||
config["plugins"] = plugins
|
||||
|
||||
// Disable the built-in web search since our plugin replaces it.
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
web["search"] = map[string]any{"enabled": false}
|
||||
tools["web"] = web
|
||||
config["tools"] = tools
|
||||
|
||||
out, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(configPath, out, 0o600)
|
||||
}
|
||||
|
||||
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||
// The second return value indicates whether the model is a cloud (remote) model.
|
||||
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||
entry := map[string]any{
|
||||
"id": modelID,
|
||||
"name": modelID,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
showCtx := ctx
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
entry["input"] = []any{"text", "image"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
entry["reasoning"] = true
|
||||
}
|
||||
|
||||
// Cloud models: use hardcoded limits for context/output tokens.
|
||||
// Capability detection above still applies (vision, thinking).
|
||||
if resp.RemoteModel != "" {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
entry["contextWindow"] = l.Context
|
||||
entry["maxTokens"] = l.Output
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo (local models only)
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
entry["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return entry, false
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,279 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,762 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
os.RemoveAll(stateDir)
|
||||
}
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
if provider["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("update existing model", func(t *testing.T) {
|
||||
cleanup()
|
||||
o.Edit([]string{"llama3.2"})
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["keybindings"] == nil {
|
||||
t.Error("keybindings was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||
})
|
||||
|
||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
if len(state["favorite"].([]any)) != 1 {
|
||||
t.Error("favorite was modified")
|
||||
}
|
||||
if state["variant"].(map[string]any)["a"] != "b" {
|
||||
t.Error("variant was modified")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
recent := state["recent"].([]any)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||
}
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
t.Run("remove model", func(t *testing.T) {
|
||||
cleanup()
|
||||
// First add two models
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add custom fields to the model entry (simulating user edits)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
entry["_myPref"] = "custom-value"
|
||||
entry["_myNum"] = 42
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit — should preserve custom fields
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider = cfg["provider"].(map[string]any)
|
||||
ollama = provider["ollama"].(map[string]any)
|
||||
models = ollama["models"].(map[string]any)
|
||||
entry = models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["_myPref"] != "custom-value" {
|
||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
||||
}
|
||||
if entry["_myNum"] != float64(42) {
|
||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
||||
}
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
|
||||
// _launch marker should be added
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
||||
}
|
||||
// [Ollama] suffix should be stripped
|
||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "My Custom Ollama" {
|
||||
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
// Add a non-Ollama model manually
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||
})
|
||||
}
|
||||
|
||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider not found")
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("ollama provider not found")
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("models not found")
|
||||
}
|
||||
if models[model] == nil {
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
return // No provider means no model
|
||||
}
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
return // No ollama means no model
|
||||
}
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
return // No models means no model
|
||||
}
|
||||
if models[model] != nil {
|
||||
t.Errorf("model %s should not exist but was found", model)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("recent not found")
|
||||
}
|
||||
if index >= len(recent) {
|
||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||
}
|
||||
entry, ok := recent[index].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("entry is not a map")
|
||||
}
|
||||
if entry["providerID"] != providerID {
|
||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||
}
|
||||
if entry["modelID"] != modelID {
|
||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case tests for opencode.go
|
||||
|
||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||
|
||||
// Should not panic - corrupted JSON should be treated as empty
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid JSON was created
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid state was created
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider is now correct type
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
provider, ok := cfg["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||
}
|
||||
if provider["ollama"] == nil {
|
||||
t.Error("ollama provider should be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
statePath := filepath.Join(stateDir, "model.json")
|
||||
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||
|
||||
err := o.Edit([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||
}
|
||||
|
||||
// The function should handle this gracefully
|
||||
data, _ := os.ReadFile(statePath)
|
||||
var state map[string]any
|
||||
json.Unmarshal(data, &state)
|
||||
|
||||
// recent should be properly set after setup
|
||||
recent, ok := state["recent"].([]any)
|
||||
if !ok {
|
||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||
} else if len(recent) == 0 {
|
||||
t.Logf("Note: recent is empty (documenting behavior)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||
|
||||
// Empty models should be no-op
|
||||
err := o.Edit([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with empty models failed: %v", err)
|
||||
}
|
||||
|
||||
// Original content should be preserved (file not modified)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != originalContent {
|
||||
t.Errorf("empty models should not modify file, but content changed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Model name with special characters (though unusual)
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit([]string{specialModel})
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored correctly
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Model should be accessible
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry, ok := models[model].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("model %s not found in config", model)
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
// Set up a model with a user-configured limit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"llama3.2": {
|
||||
"name": "llama3.2",
|
||||
"_launch": true,
|
||||
"limit": {"context": 8192, "output": 4096}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
// Re-edit should preserve the user's limit (not delete it)
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("user-configured limit was removed")
|
||||
}
|
||||
if limit["context"] != float64(8192) {
|
||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(4096) {
|
||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||
// the structure matches what opencode expects and re-edit preserves them.
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud",
|
||||
"_launch": true,
|
||||
"limit": {"context": %d, "output": %d}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, expected.Context, expected.Output)), 0o644)
|
||||
|
||||
// Re-edit should preserve the cloud model limit
|
||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was removed on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-5:cloud": {
|
||||
"name": "glm-5:cloud",
|
||||
"_launch": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was not added on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(202_752) {
|
||||
t.Errorf("context = %v, want 202752", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(131_072) {
|
||||
t.Errorf("output = %v, want 131072", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", false, 0, 0},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||
{"kimi-k2.5", false, 0, 0},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", false, 0, 0},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3.5", false, 0, 0},
|
||||
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||
{"qwen3-coder:480b", false, 0, 0},
|
||||
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := o.Models()
|
||||
if len(models) > 0 {
|
||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||
}
|
||||
}
|
||||
261
cmd/config/pi.go
261
cmd/config/pi.go
@@ -1,261 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if _, err := os.Stat(modelsPath); err == nil {
|
||||
paths = append(paths, modelsPath)
|
||||
}
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
if _, err := os.Stat(settingsPath); err == nil {
|
||||
paths = append(paths, settingsPath)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (p *Pi) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
providers, ok := config["providers"].(map[string]any)
|
||||
if !ok {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"baseUrl": envconfig.Host().String() + "/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
}
|
||||
}
|
||||
|
||||
existingModels, ok := ollama["models"].([]any)
|
||||
if !ok {
|
||||
existingModels = make([]any, 0)
|
||||
}
|
||||
|
||||
// Build set of selected models to track which need to be added
|
||||
selectedSet := make(map[string]bool, len(models))
|
||||
for _, m := range models {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||
// except stale cloud entries that should be rebuilt below
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
// User-managed model (no _launch marker) - always preserve
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||
// the whole entry instead of patching it in place.
|
||||
if !hasContextWindow(modelObj) {
|
||||
if _, ok := lookupCloudModelLimit(id); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add newly selected models that weren't already in the list
|
||||
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||
ctx := context.Background()
|
||||
for _, model := range models {
|
||||
if selectedSet[model] {
|
||||
newModels = append(newModels, createConfig(ctx, client, model))
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = newModels
|
||||
providers["ollama"] = ollama
|
||||
config["providers"] = providers
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update settings.json with default provider and model
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
_ = json.Unmarshal(data, &settings)
|
||||
}
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0]
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers, _ := config["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range models {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||
func isPiOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasContextWindow(cfg map[string]any) bool {
|
||||
switch v := cfg["contextWindow"].(type) {
|
||||
case float64:
|
||||
return v > 0
|
||||
case int:
|
||||
return v > 0
|
||||
case int64:
|
||||
return v > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection
|
||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": modelID,
|
||||
"_launch": true,
|
||||
}
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
cfg["input"] = []string{"text", "image"}
|
||||
} else {
|
||||
cfg["input"] = []string{"text"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo. For known cloud models, the
|
||||
// pre-filled shared limit remains unless the server provides a positive value.
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -1,926 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPiIntegration(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := pi.String(); got != "Pi" {
|
||||
t.Errorf("String() = %q, want %q", got, "Pi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = pi
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = pi
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiPaths(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns empty when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("Paths() = %v, want empty", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := pi.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiEdit(t *testing.T) {
|
||||
// Mock Ollama server for createConfig calls during Edit
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
pi := &Pi{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(configDir)
|
||||
}
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("returns nil for empty models", func(t *testing.T) {
|
||||
if err := pi.Edit([]string{}); err != nil {
|
||||
t.Errorf("Edit([]) error = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates config with models", func(t *testing.T) {
|
||||
cleanup()
|
||||
|
||||
models := []string{"llama3.2", "qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
|
||||
providers, ok := cfg["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Config missing providers")
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Providers missing ollama")
|
||||
}
|
||||
|
||||
modelsArray, ok := ollama["models"].([]any)
|
||||
if !ok || len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %v", modelsArray)
|
||||
}
|
||||
|
||||
if ollama["baseUrl"] == nil {
|
||||
t.Error("Missing baseUrl")
|
||||
}
|
||||
if ollama["api"] != "openai-completions" {
|
||||
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "ollama" {
|
||||
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://custom:8080/v1",
|
||||
"api": "custom-api",
|
||||
"apiKey": "custom-key",
|
||||
"models": [
|
||||
{"id": "old-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"new-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
|
||||
if ollama["baseUrl"] != "http://custom:8080/v1" {
|
||||
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
|
||||
}
|
||||
if ollama["api"] != "custom-api" {
|
||||
t.Errorf("Custom api not preserved, got %v", ollama["api"])
|
||||
}
|
||||
if ollama["apiKey"] != "custom-key" {
|
||||
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
|
||||
}
|
||||
|
||||
modelsArray := ollama["models"].([]any)
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
|
||||
} else {
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
if modelEntry["id"] != "new-model" {
|
||||
t.Errorf("Expected new-model, got %v", modelEntry["id"])
|
||||
}
|
||||
// Verify _launch marker is present
|
||||
if modelEntry["_launch"] != true {
|
||||
t.Errorf("Expected _launch marker to be true")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
|
||||
if modelEntry["contextWindow"] != float64(202_752) {
|
||||
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||
}
|
||||
input, ok := modelEntry["input"].([]any)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||
}
|
||||
if _, ok := modelEntry["legacyField"]; ok {
|
||||
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Old models must have _launch marker to be managed by us
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "old-model-1", "_launch": true},
|
||||
{"id": "old-model-2", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"new-model-1", "new-model-2"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
|
||||
t.Errorf("Expected new models, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
|
||||
t.Errorf("Old models should have been removed, got %v", modelIDs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles partial overlap in model list", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Models must have _launch marker to be managed
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "keep-model", "_launch": true},
|
||||
{"id": "remove-model", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newModels := []string{"keep-model", "add-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 2 {
|
||||
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = true
|
||||
}
|
||||
|
||||
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
|
||||
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
|
||||
}
|
||||
if modelIDs["remove-model"] {
|
||||
t.Errorf("remove-model should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read config: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
if len(modelsArray) != 1 {
|
||||
t.Errorf("Expected 1 model, got %d", len(modelsArray))
|
||||
}
|
||||
})
|
||||
|
||||
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
|
||||
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// User has manually configured models in ollama provider (no _launch marker)
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "user-model-1"},
|
||||
{"id": "user-model-2", "customField": "preserved"},
|
||||
{"id": "ollama-managed", "_launch": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add a new ollama-managed model
|
||||
newModels := []string{"new-ollama-model"}
|
||||
if err := pi.Edit(newModels); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
|
||||
// Should have: new-ollama-model (managed) + 2 user models (preserved)
|
||||
if len(modelsArray) != 3 {
|
||||
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
|
||||
}
|
||||
|
||||
modelIDs := make(map[string]map[string]any)
|
||||
for _, m := range modelsArray {
|
||||
modelObj := m.(map[string]any)
|
||||
id := modelObj["id"].(string)
|
||||
modelIDs[id] = modelObj
|
||||
}
|
||||
|
||||
// Verify new model has _launch marker
|
||||
if m, ok := modelIDs["new-ollama-model"]; !ok {
|
||||
t.Errorf("new-ollama-model should be present")
|
||||
} else if m["_launch"] != true {
|
||||
t.Errorf("new-ollama-model should have _launch marker")
|
||||
}
|
||||
|
||||
// Verify user models are preserved
|
||||
if _, ok := modelIDs["user-model-1"]; !ok {
|
||||
t.Errorf("user-model-1 should be preserved")
|
||||
}
|
||||
if _, ok := modelIDs["user-model-2"]; !ok {
|
||||
t.Errorf("user-model-2 should be preserved")
|
||||
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
|
||||
t.Errorf("user-model-2 customField should be preserved")
|
||||
}
|
||||
|
||||
// Verify old ollama-managed model is removed (not in new list)
|
||||
if _, ok := modelIDs["ollama-managed"]; ok {
|
||||
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create existing settings with other fields
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
existingSettings := `{
|
||||
"theme": "dark",
|
||||
"customSetting": "value",
|
||||
"defaultProvider": "anthropic",
|
||||
"defaultModel": "claude-3"
|
||||
}`
|
||||
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"llama3.2"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
// Verify defaultProvider is set to ollama
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
|
||||
// Verify defaultModel is set to first model
|
||||
if settings["defaultModel"] != "llama3.2" {
|
||||
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||
}
|
||||
|
||||
// Verify other fields are preserved
|
||||
if settings["theme"] != "dark" {
|
||||
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||
}
|
||||
if settings["customSetting"] != "value" {
|
||||
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
models := []string{"qwen3:8b"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("settings.json should be created: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("Failed to parse settings: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "qwen3:8b" {
|
||||
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
// Create corrupt settings
|
||||
settingsPath := filepath.Join(configDir, "settings.json")
|
||||
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := []string{"test-model"}
|
||||
if err := pi.Edit(models); err != nil {
|
||||
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read settings: %v", err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
|
||||
}
|
||||
|
||||
if settings["defaultProvider"] != "ollama" {
|
||||
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||
}
|
||||
if settings["defaultModel"] != "test-model" {
|
||||
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiModels(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns models from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "llama3.2"},
|
||||
{"id": "qwen3:8b"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("Models() returned %d models, want 2", len(models))
|
||||
}
|
||||
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
|
||||
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns sorted models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"models": [
|
||||
{"id": "z-model"},
|
||||
{"id": "a-model"},
|
||||
{"id": "m-model"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
|
||||
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when models array is missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := `{
|
||||
"providers": {
|
||||
"ollama": {}
|
||||
}
|
||||
}`
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil when models array is missing", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "models.json")
|
||||
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := pi.Models()
|
||||
if models != nil {
|
||||
t.Errorf("Models() = %v, want nil for corrupt config", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsPiOllamaModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg map[string]any
|
||||
want bool
|
||||
}{
|
||||
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
|
||||
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
|
||||
{"without _launch", map[string]any{"id": "m"}, false},
|
||||
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
|
||||
{"empty map", map[string]any{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isPiOllamaModel(tt.cfg); got != tt.want {
|
||||
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
if cfg["id"] != "llava:7b" {
|
||||
t.Errorf("id = %v, want llava:7b", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]string)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input := cfg["input"].([]string)
|
||||
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Errorf("input = %v, want [text image]", input)
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns minimal config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
if cfg["_launch"] != true {
|
||||
t.Error("expected _launch = true")
|
||||
}
|
||||
// Should not have capability fields
|
||||
if _, ok := cfg["input"]; ok {
|
||||
t.Error("input should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 262_144 {
|
||||
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 202_752 {
|
||||
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||
|
||||
if cfg["contextWindow"] != 131_072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure Capability constants used in createConfig match expected values
|
||||
func TestPiCapabilityConstants(t *testing.T) {
|
||||
if model.CapabilityVision != "vision" {
|
||||
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
|
||||
}
|
||||
if model.CapabilityThinking != "thinking" {
|
||||
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user