mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 08:15:42 +02:00
Compare commits
9 Commits
brucemacd/
...
v0.15.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2eda97f1c3 | ||
|
|
66831dcf70 | ||
|
|
1044b0419a | ||
|
|
771d9280ec | ||
|
|
862bc0a3bf | ||
|
|
c01608b6a1 | ||
|
|
199c41e16e | ||
|
|
3b3bf6c217 | ||
|
|
f52c21f457 |
@@ -35,6 +35,7 @@ import (
|
|||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
@@ -1018,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.ModelInfo != nil {
|
if resp.ModelInfo != nil {
|
||||||
arch := resp.ModelInfo["general.architecture"].(string)
|
arch, _ := resp.ModelInfo["general.architecture"].(string)
|
||||||
|
if arch != "" {
|
||||||
rows = append(rows, []string{"", "architecture", arch})
|
rows = append(rows, []string{"", "architecture", arch})
|
||||||
|
}
|
||||||
|
|
||||||
var paramStr string
|
var paramStr string
|
||||||
if resp.Details.ParameterSize != "" {
|
if resp.Details.ParameterSize != "" {
|
||||||
@@ -1029,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
paramStr = format.HumanNumber(uint64(f))
|
paramStr = format.HumanNumber(uint64(f))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if paramStr != "" {
|
||||||
rows = append(rows, []string{"", "parameters", paramStr})
|
rows = append(rows, []string{"", "parameters", paramStr})
|
||||||
|
}
|
||||||
|
|
||||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||||
if f, ok := v.(float64); ok {
|
if f, ok := v.(float64); ok {
|
||||||
@@ -2026,6 +2031,7 @@ func NewCLI() *cobra.Command {
|
|||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
runnerCmd,
|
runnerCmd,
|
||||||
|
config.ConfigCmd(checkServerHeartbeat),
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|||||||
36
cmd/config/claude.go
Normal file
36
cmd/config/claude.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude implements Runner for Claude Code integration
|
||||||
|
type Claude struct{}
|
||||||
|
|
||||||
|
func (c *Claude) String() string { return "Claude Code" }
|
||||||
|
|
||||||
|
func (c *Claude) args(model string) []string {
|
||||||
|
if model != "" {
|
||||||
|
return []string{"--model", model}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) Run(model string) error {
|
||||||
|
if _, err := exec.LookPath("claude"); err != nil {
|
||||||
|
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("claude", c.args(model)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||||
|
"ANTHROPIC_API_KEY=",
|
||||||
|
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||||
|
)
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
42
cmd/config/claude_test.go
Normal file
42
cmd/config/claude_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClaudeIntegration(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Claude Code" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeArgs(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||||
|
{"empty model", "", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
61
cmd/config/codex.go
Normal file
61
cmd/config/codex.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Codex implements Runner for Codex integration
|
||||||
|
type Codex struct{}
|
||||||
|
|
||||||
|
func (c *Codex) String() string { return "Codex" }
|
||||||
|
|
||||||
|
func (c *Codex) args(model string) []string {
|
||||||
|
args := []string{"--oss"}
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "-m", model)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Codex) Run(model string) error {
|
||||||
|
if err := checkCodexVersion(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("codex", c.args(model)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkCodexVersion() error {
|
||||||
|
if _, err := exec.LookPath("codex"); err != nil {
|
||||||
|
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("codex", "--version").Output()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get codex version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse output like "codex-cli 0.87.0"
|
||||||
|
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||||
|
if len(fields) < 2 {
|
||||||
|
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
version := "v" + fields[len(fields)-1]
|
||||||
|
minVersion := "v0.81.0"
|
||||||
|
|
||||||
|
if semver.Compare(version, minVersion) < 0 {
|
||||||
|
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
28
cmd/config/codex_test.go
Normal file
28
cmd/config/codex_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexArgs(t *testing.T) {
|
||||||
|
c := &Codex{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||||
|
{"empty model", "", []string{"--oss"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
115
cmd/config/config.go
Normal file
115
cmd/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
// Package config provides integration configuration for external coding tools
|
||||||
|
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type integration struct {
|
||||||
|
Models []string `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
Integrations map[string]*integration `json:"integrations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func configPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func load() (*config, error) {
|
||||||
|
path, err := configPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return &config{Integrations: make(map[string]*integration)}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg config
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||||
|
}
|
||||||
|
if cfg.Integrations == nil {
|
||||||
|
cfg.Integrations = make(map[string]*integration)
|
||||||
|
}
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func save(cfg *config) error {
|
||||||
|
path, err := configPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return writeWithBackup(path, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveIntegration(appName string, models []string) error {
|
||||||
|
if appName == "" {
|
||||||
|
return errors.New("app name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||||
|
Models: models,
|
||||||
|
}
|
||||||
|
|
||||||
|
return save(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadIntegration(appName string) (*integration, error) {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||||
|
if !ok {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
return ic, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func listIntegrations() ([]integration, error) {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]integration, 0, len(cfg.Integrations))
|
||||||
|
for _, ic := range cfg.Integrations {
|
||||||
|
result = append(result, *ic)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
373
cmd/config/config_test.go
Normal file
373
cmd/config/config_test.go
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||||
|
func setTestHome(t *testing.T, dir string) {
|
||||||
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("USERPROFILE", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||||
|
func editorPaths(r Runner) []string {
|
||||||
|
if editor, ok := r.(Editor); ok {
|
||||||
|
return editor.Paths()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("save and load round-trip", func(t *testing.T) {
|
||||||
|
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||||
|
if err := saveIntegration("claude", models); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := loadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Models) != len(models) {
|
||||||
|
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||||
|
}
|
||||||
|
for i, m := range models {
|
||||||
|
if config.Models[i] != m {
|
||||||
|
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||||
|
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||||
|
|
||||||
|
config, _ := loadIntegration("codex")
|
||||||
|
defaultModel := ""
|
||||||
|
if len(config.Models) > 0 {
|
||||||
|
defaultModel = config.Models[0]
|
||||||
|
}
|
||||||
|
if defaultModel != "model-a" {
|
||||||
|
t.Errorf("expected model-a, got %s", defaultModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||||
|
config := &integration{Models: []string{}}
|
||||||
|
defaultModel := ""
|
||||||
|
if len(config.Models) > 0 {
|
||||||
|
defaultModel = config.Models[0]
|
||||||
|
}
|
||||||
|
if defaultModel != "" {
|
||||||
|
t.Errorf("expected empty string, got %s", defaultModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||||
|
saveIntegration("Claude", []string{"model-x"})
|
||||||
|
|
||||||
|
config, err := loadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defaultModel := ""
|
||||||
|
if len(config.Models) > 0 {
|
||||||
|
defaultModel = config.Models[0]
|
||||||
|
}
|
||||||
|
if defaultModel != "model-x" {
|
||||||
|
t.Errorf("expected model-x, got %s", defaultModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||||
|
saveIntegration("app1", []string{"model-1"})
|
||||||
|
saveIntegration("app2", []string{"model-2"})
|
||||||
|
|
||||||
|
config1, _ := loadIntegration("app1")
|
||||||
|
config2, _ := loadIntegration("app2")
|
||||||
|
|
||||||
|
defaultModel1 := ""
|
||||||
|
if len(config1.Models) > 0 {
|
||||||
|
defaultModel1 = config1.Models[0]
|
||||||
|
}
|
||||||
|
defaultModel2 := ""
|
||||||
|
if len(config2.Models) > 0 {
|
||||||
|
defaultModel2 = config2.Models[0]
|
||||||
|
}
|
||||||
|
if defaultModel1 != "model-1" {
|
||||||
|
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||||
|
}
|
||||||
|
if defaultModel2 != "model-2" {
|
||||||
|
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListIntegrations(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||||
|
configs, err := listIntegrations()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(configs) != 0 {
|
||||||
|
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||||
|
saveIntegration("claude", []string{"model-1"})
|
||||||
|
saveIntegration("droid", []string{"model-2"})
|
||||||
|
|
||||||
|
configs, err := listIntegrations()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(configs) != 2 {
|
||||||
|
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEditorPaths(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||||
|
r := integrations["claude"]
|
||||||
|
paths := editorPaths(r)
|
||||||
|
if len(paths) != 0 {
|
||||||
|
t.Errorf("expected no paths for claude, got %v", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||||
|
r := integrations["codex"]
|
||||||
|
paths := editorPaths(r)
|
||||||
|
if len(paths) != 0 {
|
||||||
|
t.Errorf("expected no paths for codex, got %v", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||||
|
r := integrations["droid"]
|
||||||
|
paths := editorPaths(r)
|
||||||
|
if len(paths) != 0 {
|
||||||
|
t.Errorf("expected no paths, got %v", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||||
|
settingsDir, _ := os.UserHomeDir()
|
||||||
|
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||||
|
os.MkdirAll(settingsDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||||
|
|
||||||
|
r := integrations["droid"]
|
||||||
|
paths := editorPaths(r)
|
||||||
|
if len(paths) != 1 {
|
||||||
|
t.Errorf("expected 1 path, got %d", len(paths))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
configDir := filepath.Join(home, ".config", "opencode")
|
||||||
|
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||||
|
|
||||||
|
r := integrations["opencode"]
|
||||||
|
paths := editorPaths(r)
|
||||||
|
if len(paths) != 2 {
|
||||||
|
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Create corrupted config.json file
|
||||||
|
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(dir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||||
|
|
||||||
|
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||||
|
_, err := loadIntegration("test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent integration in corrupted file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := saveIntegration("test", nil); err != nil {
|
||||||
|
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := loadIntegration("test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("loadIntegration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Models == nil {
|
||||||
|
// nil is acceptable
|
||||||
|
} else if len(config.Models) != 0 {
|
||||||
|
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
err := saveIntegration("", []string{"model"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty app name, got nil")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||||
|
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
_, err := loadIntegration("nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent integration, got nil")
|
||||||
|
}
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
path, err := configPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||||
|
if path != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatal("expected non-nil config")
|
||||||
|
}
|
||||||
|
if cfg.Integrations == nil {
|
||||||
|
t.Error("expected non-nil Integrations map")
|
||||||
|
}
|
||||||
|
if len(cfg.Integrations) != 0 {
|
||||||
|
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("loads existing config", func(t *testing.T) {
|
||||||
|
path, _ := configPath()
|
||||||
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||||
|
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||||
|
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cfg.Integrations["test"] == nil {
|
||||||
|
t.Fatal("expected test integration")
|
||||||
|
}
|
||||||
|
if len(cfg.Integrations["test"].Models) != 1 {
|
||||||
|
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||||
|
path, _ := configPath()
|
||||||
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||||
|
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||||
|
|
||||||
|
_, err := load()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for corrupted JSON")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSave(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("creates config file", func(t *testing.T) {
|
||||||
|
cfg := &config{
|
||||||
|
Integrations: map[string]*integration{
|
||||||
|
"test": {Models: []string{"model-a", "model-b"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := save(cfg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path, _ := configPath()
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
t.Error("config file was not created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||||
|
cfg := &config{
|
||||||
|
Integrations: map[string]*integration{
|
||||||
|
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||||
|
"codex": {Models: []string{"qwen2.5"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := save(cfg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(loaded.Integrations) != 2 {
|
||||||
|
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||||
|
}
|
||||||
|
if loaded.Integrations["claude"] == nil {
|
||||||
|
t.Error("missing claude integration")
|
||||||
|
}
|
||||||
|
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||||
|
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
184
cmd/config/droid.go
Normal file
184
cmd/config/droid.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Droid implements Runner and Editor for Droid integration
|
||||||
|
type Droid struct{}
|
||||||
|
|
||||||
|
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||||
|
type droidSettings struct {
|
||||||
|
CustomModels []modelEntry `json:"customModels"`
|
||||||
|
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionSettings struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
ReasoningEffort string `json:"reasoningEffort"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelEntry struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
BaseURL string `json:"baseUrl"`
|
||||||
|
APIKey string `json:"apiKey"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||||
|
SupportsImages bool `json:"supportsImages"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Droid) String() string { return "Droid" }
|
||||||
|
|
||||||
|
func (d *Droid) Run(model string) error {
|
||||||
|
if _, err := exec.LookPath("droid"); err != nil {
|
||||||
|
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Edit() to ensure config is up-to-date before launch
|
||||||
|
models := []string{model}
|
||||||
|
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||||
|
models = config.Models
|
||||||
|
}
|
||||||
|
if err := d.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("droid")
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Droid) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p := filepath.Join(home, ".factory", "settings.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Droid) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file once, unmarshal twice:
|
||||||
|
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||||
|
settingsMap := make(map[string]any)
|
||||||
|
var settings droidSettings
|
||||||
|
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||||
|
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||||
|
}
|
||||||
|
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||||
|
// Rebuild Ollama models
|
||||||
|
var nonOllamaModels []any
|
||||||
|
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||||
|
for _, raw := range rawModels {
|
||||||
|
if m, ok := raw.(map[string]any); ok {
|
||||||
|
if m["apiKey"] != "ollama" {
|
||||||
|
nonOllamaModels = append(nonOllamaModels, raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
|
var newModels []any
|
||||||
|
var defaultModelID string
|
||||||
|
for i, model := range models {
|
||||||
|
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||||
|
newModels = append(newModels, modelEntry{
|
||||||
|
Model: model,
|
||||||
|
DisplayName: model,
|
||||||
|
BaseURL: "http://localhost:11434/v1",
|
||||||
|
APIKey: "ollama",
|
||||||
|
Provider: "generic-chat-completion-api",
|
||||||
|
MaxOutputTokens: 64000,
|
||||||
|
SupportsImages: false,
|
||||||
|
ID: modelID,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
if i == 0 {
|
||||||
|
defaultModelID = modelID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||||
|
|
||||||
|
// Update session default settings (preserve unknown fields in the nested object)
|
||||||
|
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
sessionSettings = make(map[string]any)
|
||||||
|
}
|
||||||
|
sessionSettings["model"] = defaultModelID
|
||||||
|
|
||||||
|
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||||
|
sessionSettings["reasoningEffort"] = "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeWithBackup(settingsPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Droid) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings droidSettings
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, m := range settings.CustomModels {
|
||||||
|
if m.APIKey == "ollama" {
|
||||||
|
result = append(result, m.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||||
|
|
||||||
|
func isValidReasoningEffort(effort string) bool {
|
||||||
|
return slices.Contains(validReasoningEfforts, effort)
|
||||||
|
}
|
||||||
1302
cmd/config/droid_test.go
Normal file
1302
cmd/config/droid_test.go
Normal file
File diff suppressed because it is too large
Load Diff
99
cmd/config/files.go
Normal file
99
cmd/config/files.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readJSONFile(path string) (map[string]any, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
info, err := os.Stat(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupDir() string {
|
||||||
|
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupToTmp(srcPath string) (string, error) {
|
||||||
|
dir := backupDir()
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
|
||||||
|
if err := copyFile(srcPath, backupPath); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return backupPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||||
|
func writeWithBackup(path string, data []byte) error {
|
||||||
|
var backupPath string
|
||||||
|
// backup must be created before any writes to the target file
|
||||||
|
if existingContent, err := os.ReadFile(path); err == nil {
|
||||||
|
if !bytes.Equal(existingContent, data) {
|
||||||
|
backupPath, err = backupToTmp(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backup failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("read existing file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create temp failed: %w", err)
|
||||||
|
}
|
||||||
|
tmpPath := tmp.Name()
|
||||||
|
|
||||||
|
if _, err := tmp.Write(data); err != nil {
|
||||||
|
_ = tmp.Close()
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
return fmt.Errorf("write failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := tmp.Sync(); err != nil {
|
||||||
|
_ = tmp.Close()
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
return fmt.Errorf("sync failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := tmp.Close(); err != nil {
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
return fmt.Errorf("close failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(tmpPath, path); err != nil {
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
if backupPath != "" {
|
||||||
|
_ = copyFile(backupPath, path)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("rename failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
502
cmd/config/files_test.go
Normal file
502
cmd/config/files_test.go
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustMarshal(t *testing.T, v any) []byte {
|
||||||
|
t.Helper()
|
||||||
|
data, err := json.MarshalIndent(v, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteWithBackup(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
t.Run("creates file", func(t *testing.T) {
|
||||||
|
path := filepath.Join(tmpDir, "new.json")
|
||||||
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
if err := json.Unmarshal(content, &result); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result["key"] != "value" {
|
||||||
|
t.Errorf("expected value, got %s", result["key"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||||
|
path := filepath.Join(tmpDir, "backup.json")
|
||||||
|
|
||||||
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
|
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(backupDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("backup directory not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
var foundBackup bool
|
||||||
|
for _, entry := range entries {
|
||||||
|
if filepath.Ext(entry.Name()) != ".json" {
|
||||||
|
name := entry.Name()
|
||||||
|
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||||
|
backupPath := filepath.Join(backupDir(), name)
|
||||||
|
backup, err := os.ReadFile(backupPath)
|
||||||
|
if err == nil {
|
||||||
|
var backupData map[string]bool
|
||||||
|
json.Unmarshal(backup, &backupData)
|
||||||
|
if backupData["original"] {
|
||||||
|
foundBackup = true
|
||||||
|
os.Remove(backupPath)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundBackup {
|
||||||
|
t.Error("backup file not created in /tmp/ollama-backups")
|
||||||
|
}
|
||||||
|
|
||||||
|
current, _ := os.ReadFile(path)
|
||||||
|
var currentData map[string]bool
|
||||||
|
json.Unmarshal(current, ¤tData)
|
||||||
|
if !currentData["updated"] {
|
||||||
|
t.Error("file doesn't contain updated data")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no backup for new file", func(t *testing.T) {
|
||||||
|
path := filepath.Join(tmpDir, "nobak.json")
|
||||||
|
|
||||||
|
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, _ := os.ReadDir(backupDir())
|
||||||
|
for _, entry := range entries {
|
||||||
|
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||||
|
t.Error("backup should not exist for new file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||||
|
path := filepath.Join(tmpDir, "unchanged.json")
|
||||||
|
|
||||||
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries1, _ := os.ReadDir(backupDir())
|
||||||
|
countBefore := 0
|
||||||
|
for _, e := range entries1 {
|
||||||
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
|
countBefore++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries2, _ := os.ReadDir(backupDir())
|
||||||
|
countAfter := 0
|
||||||
|
for _, e := range entries2 {
|
||||||
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
|
countAfter++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if countAfter != countBefore {
|
||||||
|
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||||
|
path := filepath.Join(tmpDir, "timestamped.json")
|
||||||
|
|
||||||
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
|
data := mustMarshal(t, map[string]int{"v": 2})
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, _ := os.ReadDir(backupDir())
|
||||||
|
var found bool
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := entry.Name()
|
||||||
|
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||||
|
timestamp := name[len("timestamped.json."):]
|
||||||
|
for _, c := range timestamp {
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
os.Remove(filepath.Join(backupDir(), name))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("backup file with timestamp not found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge case tests for files.go
|
||||||
|
|
||||||
|
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||||
|
// User could lose their config with no way to recover.
|
||||||
|
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("permission tests unreliable on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
|
// Create original file
|
||||||
|
originalContent := []byte(`{"original": true}`)
|
||||||
|
os.WriteFile(path, originalContent, 0o644)
|
||||||
|
|
||||||
|
// Make backup directory read-only to force backup failure
|
||||||
|
backupDir := backupDir()
|
||||||
|
os.MkdirAll(backupDir, 0o755)
|
||||||
|
os.Chmod(backupDir, 0o444) // Read-only
|
||||||
|
defer os.Chmod(backupDir, 0o755)
|
||||||
|
|
||||||
|
newContent := []byte(`{"updated": true}`)
|
||||||
|
err := writeWithBackup(path, newContent)
|
||||||
|
|
||||||
|
// Should fail because backup couldn't be created
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when backup fails, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original file should be preserved
|
||||||
|
current, _ := os.ReadFile(path)
|
||||||
|
if string(current) != string(originalContent) {
|
||||||
|
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||||
|
// Common issue when config owned by root or wrong perms.
|
||||||
|
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("permission tests unreliable on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a read-only directory
|
||||||
|
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||||
|
os.MkdirAll(readOnlyDir, 0o755)
|
||||||
|
os.Chmod(readOnlyDir, 0o444)
|
||||||
|
defer os.Chmod(readOnlyDir, 0o755)
|
||||||
|
|
||||||
|
path := filepath.Join(readOnlyDir, "config.json")
|
||||||
|
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected permission error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||||
|
// writeWithBackup doesn't create directories - caller is responsible.
|
||||||
|
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||||
|
|
||||||
|
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
|
// Should fail because directory doesn't exist
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent directory, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||||
|
// Documents what happens if user symlinks their config file.
|
||||||
|
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("symlink tests may require admin on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
realFile := filepath.Join(tmpDir, "real.json")
|
||||||
|
symlink := filepath.Join(tmpDir, "link.json")
|
||||||
|
|
||||||
|
// Create real file and symlink
|
||||||
|
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||||
|
os.Symlink(realFile, symlink)
|
||||||
|
|
||||||
|
// Write through symlink
|
||||||
|
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The real file should be updated (symlink followed for temp file creation)
|
||||||
|
content, _ := os.ReadFile(symlink)
|
||||||
|
if string(content) != `{"v": 2}` {
|
||||||
|
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||||
|
// User may have config files with unusual names.
|
||||||
|
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// File with spaces and special chars
|
||||||
|
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||||
|
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||||
|
|
||||||
|
backupPath, err := backupToTmp(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backupToTmp with special chars failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify backup exists and has correct content
|
||||||
|
content, err := os.ReadFile(backupPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not read backup: %v", err)
|
||||||
|
}
|
||||||
|
if string(content) != `{"test": true}` {
|
||||||
|
t.Errorf("backup content mismatch: got %s", string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Remove(backupPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||||
|
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("permission preservation tests unreliable on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
src := filepath.Join(tmpDir, "src.json")
|
||||||
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
|
// Create source with specific permissions
|
||||||
|
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||||
|
|
||||||
|
err := copyFile(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copyFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcInfo, _ := os.Stat(src)
|
||||||
|
dstInfo, _ := os.Stat(dst)
|
||||||
|
|
||||||
|
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||||
|
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||||
|
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||||
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
|
err := copyFile(src, dst)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent source, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||||
|
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||||
|
os.MkdirAll(dirPath, 0o755)
|
||||||
|
|
||||||
|
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when target is a directory, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||||
|
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "empty.json")
|
||||||
|
|
||||||
|
err := writeWithBackup(path, []byte{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not read file: %v", err)
|
||||||
|
}
|
||||||
|
if len(content) != 0 {
|
||||||
|
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||||
|
// cannot be read (for backup comparison) but directory is writable.
|
||||||
|
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("permission tests unreliable on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "unreadable.json")
|
||||||
|
|
||||||
|
// Create file and make it unreadable
|
||||||
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
os.Chmod(path, 0o000)
|
||||||
|
defer os.Chmod(path, 0o644)
|
||||||
|
|
||||||
|
// Should fail because we can't read the file to compare/backup
|
||||||
|
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when file is unreadable, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||||
|
// within the same second (timestamp collision scenario).
|
||||||
|
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "rapid.json")
|
||||||
|
|
||||||
|
// Create initial file
|
||||||
|
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||||
|
|
||||||
|
// Rapid successive writes
|
||||||
|
for i := 1; i <= 3; i++ {
|
||||||
|
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||||
|
if err := writeWithBackup(path, data); err != nil {
|
||||||
|
t.Fatalf("write %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify final content
|
||||||
|
content, _ := os.ReadFile(path)
|
||||||
|
if string(content) != `{"v": 3}` {
|
||||||
|
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify at least one backup exists
|
||||||
|
entries, _ := os.ReadDir(backupDir())
|
||||||
|
var backupCount int
|
||||||
|
for _, e := range entries {
|
||||||
|
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||||
|
backupCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if backupCount == 0 {
|
||||||
|
t.Error("expected at least one backup file from rapid writes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||||
|
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("test modifies system temp directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a file at the backup directory path
|
||||||
|
backupPath := backupDir()
|
||||||
|
// Clean up any existing directory first
|
||||||
|
os.RemoveAll(backupPath)
|
||||||
|
// Create a file instead of directory
|
||||||
|
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||||
|
defer func() {
|
||||||
|
os.Remove(backupPath)
|
||||||
|
os.MkdirAll(backupPath, 0o755)
|
||||||
|
}()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "test.json")
|
||||||
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
|
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when backup dir is a file, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||||
|
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("permission tests unreliable on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Count existing temp files
|
||||||
|
countTempFiles := func() int {
|
||||||
|
entries, _ := os.ReadDir(tmpDir)
|
||||||
|
count := 0
|
||||||
|
for _, e := range entries {
|
||||||
|
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
before := countTempFiles()
|
||||||
|
|
||||||
|
// Create a file, then make directory read-only to cause rename failure
|
||||||
|
path := filepath.Join(tmpDir, "orphan.json")
|
||||||
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
|
|
||||||
|
// Make a subdirectory and try to write there after making parent read-only
|
||||||
|
subDir := filepath.Join(tmpDir, "subdir")
|
||||||
|
os.MkdirAll(subDir, 0o755)
|
||||||
|
subPath := filepath.Join(subDir, "config.json")
|
||||||
|
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||||
|
|
||||||
|
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||||
|
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||||
|
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||||
|
|
||||||
|
// Force a failure by making the target a directory
|
||||||
|
badPath := filepath.Join(tmpDir, "isdir")
|
||||||
|
os.MkdirAll(badPath, 0o755)
|
||||||
|
|
||||||
|
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||||
|
|
||||||
|
after := countTempFiles()
|
||||||
|
if after > before {
|
||||||
|
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
362
cmd/config/integrations.go
Normal file
362
cmd/config/integrations.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Runners execute the launching of a model with the integration - claude, codex
|
||||||
|
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
||||||
|
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
||||||
|
// Runner can run an integration with a model.
|
||||||
|
|
||||||
|
type Runner interface {
|
||||||
|
Run(model string) error
|
||||||
|
// String returns the human-readable name of the integration
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Editor can edit config files (supports multi-model selection)
|
||||||
|
type Editor interface {
|
||||||
|
// Paths returns the paths to the config files for the integration
|
||||||
|
Paths() []string
|
||||||
|
// Edit updates the config files for the integration with the given models
|
||||||
|
Edit(models []string) error
|
||||||
|
// Models returns the models currently configured for the integration
|
||||||
|
// TODO(parthsareen): add error return to Models()
|
||||||
|
Models() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// integrations is the registry of available integrations.
|
||||||
|
var integrations = map[string]Runner{
|
||||||
|
"claude": &Claude{},
|
||||||
|
"codex": &Codex{},
|
||||||
|
"droid": &Droid{},
|
||||||
|
"opencode": &OpenCode{},
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectIntegration() (string, error) {
|
||||||
|
if len(integrations) == 0 {
|
||||||
|
return "", fmt.Errorf("no integrations available")
|
||||||
|
}
|
||||||
|
|
||||||
|
names := slices.Sorted(maps.Keys(integrations))
|
||||||
|
var items []selectItem
|
||||||
|
for _, name := range names {
|
||||||
|
r := integrations[name]
|
||||||
|
description := r.String()
|
||||||
|
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||||
|
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||||
|
}
|
||||||
|
items = append(items, selectItem{Name: name, Description: description})
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectPrompt("Select integration:", items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModels lets the user select models for an integration
|
||||||
|
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||||
|
r, ok := integrations[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models.Models) == 0 {
|
||||||
|
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []selectItem
|
||||||
|
cloudModels := make(map[string]bool)
|
||||||
|
for _, m := range models.Models {
|
||||||
|
if m.RemoteModel != "" {
|
||||||
|
cloudModels[m.Name] = true
|
||||||
|
}
|
||||||
|
items = append(items, selectItem{Name: m.Name})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get previously configured models (saved config takes precedence)
|
||||||
|
var preChecked []string
|
||||||
|
if saved, err := loadIntegration(name); err == nil {
|
||||||
|
preChecked = saved.Models
|
||||||
|
} else if editor, ok := r.(Editor); ok {
|
||||||
|
preChecked = editor.Models()
|
||||||
|
}
|
||||||
|
checked := make(map[string]bool, len(preChecked))
|
||||||
|
for _, n := range preChecked {
|
||||||
|
checked[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||||
|
current = item.Name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If current model is configured, move to front of preChecked
|
||||||
|
if checked[current] {
|
||||||
|
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort: checked first, then alphabetical
|
||||||
|
slices.SortFunc(items, func(a, b selectItem) int {
|
||||||
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
|
if ac != bc {
|
||||||
|
if ac {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||||
|
})
|
||||||
|
|
||||||
|
var selected []string
|
||||||
|
// only editors support multi-model selection
|
||||||
|
if _, ok := r.(Editor); ok {
|
||||||
|
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
selected = []string{model}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if any model in selected is a cloud model, ensure signed in
|
||||||
|
var selectedCloudModels []string
|
||||||
|
for _, m := range selected {
|
||||||
|
if cloudModels[m] {
|
||||||
|
selectedCloudModels = append(selectedCloudModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(selectedCloudModels) > 0 {
|
||||||
|
// ensure user is signed in
|
||||||
|
user, err := client.Whoami(ctx)
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList := strings.Join(selectedCloudModels, ", ")
|
||||||
|
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||||
|
if err != nil || !yes {
|
||||||
|
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||||
|
|
||||||
|
// TODO(parthsareen): extract into auth package for cmd
|
||||||
|
// Auto-open browser (best effort, fail silently)
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||||
|
case "linux":
|
||||||
|
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||||
|
case "windows":
|
||||||
|
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
// poll every 10th frame (~2 seconds)
|
||||||
|
if frame%10 == 0 {
|
||||||
|
u, err := client.Whoami(ctx)
|
||||||
|
if err == nil && u != nil && u.Name != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIntegration(name, modelName string) error {
|
||||||
|
r, ok := integrations[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||||
|
return r.Run(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigCmd returns the cobra command for configuring integrations.
|
||||||
|
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
||||||
|
var modelFlag string
|
||||||
|
var launchFlag bool
|
||||||
|
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "config [INTEGRATION]",
|
||||||
|
Short: "Configure an external integration to use Ollama",
|
||||||
|
Long: `Configure an external application to use Ollama models.
|
||||||
|
|
||||||
|
Supported integrations:
|
||||||
|
claude Claude Code
|
||||||
|
codex Codex
|
||||||
|
droid Droid
|
||||||
|
opencode OpenCode
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ollama config
|
||||||
|
ollama config claude
|
||||||
|
ollama config droid --launch`,
|
||||||
|
Args: cobra.MaximumNArgs(1),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
var name string
|
||||||
|
if len(args) > 0 {
|
||||||
|
name = args[0]
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
name, err = selectIntegration()
|
||||||
|
if errors.Is(err, errCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r, ok := integrations[strings.ToLower(name)]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If --launch without --model, use saved config if available
|
||||||
|
if launchFlag && modelFlag == "" {
|
||||||
|
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||||
|
return runIntegration(name, config.Models[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
if modelFlag != "" {
|
||||||
|
// When --model is specified, merge with existing models (new model becomes default)
|
||||||
|
models = []string{modelFlag}
|
||||||
|
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||||
|
for _, m := range existing.Models {
|
||||||
|
if m != modelFlag {
|
||||||
|
models = append(models, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
models, err = selectModels(cmd.Context(), name, "")
|
||||||
|
if errors.Is(err, errCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, isEditor := r.(Editor); isEditor {
|
||||||
|
paths := editor.Paths()
|
||||||
|
if len(paths) > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
||||||
|
for _, p := range paths {
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n", p)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
||||||
|
|
||||||
|
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveIntegration(name, models); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, isEditor := r.(Editor); isEditor {
|
||||||
|
if err := editor.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, isEditor := r.(Editor); isEditor {
|
||||||
|
if len(models) == 1 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.ContainsFunc(models, func(m string) bool {
|
||||||
|
return !strings.HasSuffix(m, "cloud")
|
||||||
|
}) {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
|
||||||
|
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
|
||||||
|
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
|
||||||
|
}
|
||||||
|
|
||||||
|
if launchFlag {
|
||||||
|
return runIntegration(name, models[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||||
|
return runIntegration(name, models[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||||
|
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
188
cmd/config/integrations_test.go
Normal file
188
cmd/config/integrations_test.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntegrationLookup(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantFound bool
|
||||||
|
wantName string
|
||||||
|
}{
|
||||||
|
{"claude lowercase", "claude", true, "Claude Code"},
|
||||||
|
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||||
|
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||||
|
{"codex", "codex", true, "Codex"},
|
||||||
|
{"droid", "droid", true, "Droid"},
|
||||||
|
{"opencode", "opencode", true, "OpenCode"},
|
||||||
|
{"unknown integration", "unknown", false, ""},
|
||||||
|
{"empty string", "", false, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, found := integrations[strings.ToLower(tt.input)]
|
||||||
|
if found != tt.wantFound {
|
||||||
|
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
||||||
|
}
|
||||||
|
if found && r.String() != tt.wantName {
|
||||||
|
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationRegistry(t *testing.T) {
|
||||||
|
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
||||||
|
|
||||||
|
for _, name := range expectedIntegrations {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
r, ok := integrations[name]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("integration %q not found in registry", name)
|
||||||
|
}
|
||||||
|
if r.String() == "" {
|
||||||
|
t.Error("integration.String() should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasLocalModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty list", []string{}, false},
|
||||||
|
{"single local model", []string{"llama3.2"}, true},
|
||||||
|
{"single cloud model", []string{"cloud-model"}, false},
|
||||||
|
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
||||||
|
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
||||||
|
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
||||||
|
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||||
|
return !strings.Contains(m, "cloud")
|
||||||
|
})
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCmd(t *testing.T) {
|
||||||
|
// Mock checkServerHeartbeat that always succeeds
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := ConfigCmd(mockCheck)
|
||||||
|
|
||||||
|
t.Run("command structure", func(t *testing.T) {
|
||||||
|
if cmd.Use != "config [INTEGRATION]" {
|
||||||
|
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
|
||||||
|
}
|
||||||
|
if cmd.Short == "" {
|
||||||
|
t.Error("Short description should not be empty")
|
||||||
|
}
|
||||||
|
if cmd.Long == "" {
|
||||||
|
t.Error("Long description should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
modelFlag := cmd.Flags().Lookup("model")
|
||||||
|
if modelFlag == nil {
|
||||||
|
t.Error("--model flag should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
launchFlag := cmd.Flags().Lookup("launch")
|
||||||
|
if launchFlag == nil {
|
||||||
|
t.Error("--launch flag should exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PreRunE is set", func(t *testing.T) {
|
||||||
|
if cmd.PreRunE == nil {
|
||||||
|
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||||
|
err := runIntegration("unknown-integration", "model")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for unknown integration, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "unknown integration") {
|
||||||
|
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
want bool
|
||||||
|
reason string
|
||||||
|
}{
|
||||||
|
{"empty list", []string{}, false, "empty list has no local models"},
|
||||||
|
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
||||||
|
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
||||||
|
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
||||||
|
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
||||||
|
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
||||||
|
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
||||||
|
return !strings.Contains(m, "cloud")
|
||||||
|
})
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCmd_NilHeartbeat(t *testing.T) {
|
||||||
|
// This should not panic - cmd creation should work even with nil
|
||||||
|
cmd := ConfigCmd(nil)
|
||||||
|
if cmd == nil {
|
||||||
|
t.Fatal("ConfigCmd returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreRunE should be nil when passed nil
|
||||||
|
if cmd.PreRunE != nil {
|
||||||
|
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||||
|
for name, r := range integrations {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
// Test String() doesn't panic and returns non-empty
|
||||||
|
displayName := r.String()
|
||||||
|
if displayName == "" {
|
||||||
|
t.Error("String() should not return empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Run() exists (we can't call it without actually running the command)
|
||||||
|
// Just verify the method is available
|
||||||
|
var _ func(string) error = r.Run
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
203
cmd/config/opencode.go
Normal file
203
cmd/config/opencode.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenCode implements Runner and Editor for OpenCode integration
|
||||||
|
type OpenCode struct{}
|
||||||
|
|
||||||
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
|
func (o *OpenCode) Run(model string) error {
|
||||||
|
if _, err := exec.LookPath("opencode"); err != nil {
|
||||||
|
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Edit() to ensure config is up-to-date before launch
|
||||||
|
models := []string{model}
|
||||||
|
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||||
|
models = config.Models
|
||||||
|
}
|
||||||
|
if err := o.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("opencode")
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
paths = append(paths, p)
|
||||||
|
}
|
||||||
|
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||||
|
if _, err := os.Stat(sp); err == nil {
|
||||||
|
paths = append(paths, sp)
|
||||||
|
}
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Edit(modelList []string) error {
|
||||||
|
if len(modelList) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||||
|
}
|
||||||
|
|
||||||
|
config["$schema"] = "https://opencode.ai/config.json"
|
||||||
|
|
||||||
|
provider, ok := config["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
provider = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama, ok := provider["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
ollama = map[string]any{
|
||||||
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
|
"name": "Ollama (local)",
|
||||||
|
"options": map[string]any{
|
||||||
|
"baseURL": "http://localhost:11434/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
models, ok := ollama["models"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
models = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedSet := make(map[string]bool)
|
||||||
|
for _, m := range modelList {
|
||||||
|
selectedSet[m] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, cfg := range models {
|
||||||
|
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||||
|
if displayName, ok := cfgMap["name"].(string); ok {
|
||||||
|
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
|
||||||
|
delete(models, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range modelList {
|
||||||
|
models[model] = map[string]any{
|
||||||
|
"name": fmt.Sprintf("%s [Ollama]", model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama["models"] = models
|
||||||
|
provider["ollama"] = ollama
|
||||||
|
config["provider"] = provider
|
||||||
|
|
||||||
|
configData, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeWithBackup(configPath, configData); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{},
|
||||||
|
"favorite": []any{},
|
||||||
|
"variant": map[string]any{},
|
||||||
|
}
|
||||||
|
if data, err := os.ReadFile(statePath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
modelSet := make(map[string]bool)
|
||||||
|
for _, m := range modelList {
|
||||||
|
modelSet[m] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out existing Ollama models we're about to re-add
|
||||||
|
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||||
|
e, ok := entry.(map[string]any)
|
||||||
|
if !ok || e["providerID"] != "ollama" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
modelID, _ := e["modelID"].(string)
|
||||||
|
return modelSet[modelID]
|
||||||
|
})
|
||||||
|
|
||||||
|
// Prepend models in reverse order so first model ends up first
|
||||||
|
for _, model := range slices.Backward(modelList) {
|
||||||
|
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||||
|
"providerID": "ollama",
|
||||||
|
"modelID": model,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxRecentModels = 10
|
||||||
|
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||||
|
|
||||||
|
state["recent"] = newRecent
|
||||||
|
|
||||||
|
stateData, err := json.MarshalIndent(state, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeWithBackup(statePath, stateData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
provider, _ := config["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
keys := slices.Collect(maps.Keys(models))
|
||||||
|
slices.Sort(keys)
|
||||||
|
return keys
|
||||||
|
}
|
||||||
437
cmd/config/opencode_test.go
Normal file
437
cmd/config/opencode_test.go
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenCodeIntegration(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := o.String(); got != "OpenCode" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = o
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = o
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
statePath := filepath.Join(stateDir, "model.json")
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
os.RemoveAll(configDir)
|
||||||
|
os.RemoveAll(stateDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("fresh install", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve other providers", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
if provider["anthropic"] == nil {
|
||||||
|
t.Error("anthropic provider was removed")
|
||||||
|
}
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve other models", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update existing model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
if cfg["theme"] != "dark" {
|
||||||
|
t.Error("theme was removed")
|
||||||
|
}
|
||||||
|
if cfg["keybindings"] == nil {
|
||||||
|
t.Error("keybindings was removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model state - insert at index 0", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||||
|
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := os.ReadFile(statePath)
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(data, &state)
|
||||||
|
if len(state["favorite"].([]any)) != 1 {
|
||||||
|
t.Error("favorite was modified")
|
||||||
|
}
|
||||||
|
if state["variant"].(map[string]any)["a"] != "b" {
|
||||||
|
t.Error("variant was modified")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := os.ReadFile(statePath)
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(data, &state)
|
||||||
|
recent := state["recent"].([]any)
|
||||||
|
if len(recent) != 2 {
|
||||||
|
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
||||||
|
}
|
||||||
|
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
// First add two models
|
||||||
|
o.Edit([]string{"llama3.2", "mistral"})
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||||
|
|
||||||
|
// Then remove one by only selecting the other
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
// Add a non-Ollama model manually
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
||||||
|
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||||
|
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
||||||
|
t.Helper()
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
provider, ok := cfg["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("provider not found")
|
||||||
|
}
|
||||||
|
ollama, ok := provider["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ollama provider not found")
|
||||||
|
}
|
||||||
|
models, ok := ollama["models"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("models not found")
|
||||||
|
}
|
||||||
|
if models[model] == nil {
|
||||||
|
t.Errorf("model %s not found", model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||||
|
t.Helper()
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
provider, ok := cfg["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return // No provider means no model
|
||||||
|
}
|
||||||
|
ollama, ok := provider["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return // No ollama means no model
|
||||||
|
}
|
||||||
|
models, ok := ollama["models"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return // No models means no model
|
||||||
|
}
|
||||||
|
if models[model] != nil {
|
||||||
|
t.Errorf("model %s should not exist but was found", model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||||
|
t.Helper()
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var state map[string]any
|
||||||
|
if err := json.Unmarshal(data, &state); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
recent, ok := state["recent"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("recent not found")
|
||||||
|
}
|
||||||
|
if index >= len(recent) {
|
||||||
|
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
||||||
|
}
|
||||||
|
entry, ok := recent[index].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("entry is not a map")
|
||||||
|
}
|
||||||
|
if entry["providerID"] != providerID {
|
||||||
|
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
||||||
|
}
|
||||||
|
if entry["modelID"] != modelID {
|
||||||
|
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge case tests for opencode.go
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
||||||
|
|
||||||
|
// Should not panic - corrupted JSON should be treated as empty
|
||||||
|
err := o.Edit([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit failed with corrupted config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify valid JSON was created
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Errorf("resulting config is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
statePath := filepath.Join(stateDir, "model.json")
|
||||||
|
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
||||||
|
|
||||||
|
err := o.Edit([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit failed with corrupted state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify valid state was created
|
||||||
|
data, _ := os.ReadFile(statePath)
|
||||||
|
var state map[string]any
|
||||||
|
if err := json.Unmarshal(data, &state); err != nil {
|
||||||
|
t.Errorf("resulting state is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
||||||
|
|
||||||
|
err := o.Edit([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify provider is now correct type
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
|
||||||
|
provider, ok := cfg["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
||||||
|
}
|
||||||
|
if provider["ollama"] == nil {
|
||||||
|
t.Error("ollama provider should be created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
statePath := filepath.Join(stateDir, "model.json")
|
||||||
|
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
||||||
|
|
||||||
|
err := o.Edit([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The function should handle this gracefully
|
||||||
|
data, _ := os.ReadFile(statePath)
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(data, &state)
|
||||||
|
|
||||||
|
// recent should be properly set after setup
|
||||||
|
recent, ok := state["recent"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
||||||
|
} else if len(recent) == 0 {
|
||||||
|
t.Logf("Note: recent is empty (documenting behavior)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
||||||
|
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
||||||
|
|
||||||
|
// Empty models should be no-op
|
||||||
|
err := o.Edit([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit with empty models failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original content should be preserved (file not modified)
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
if string(data) != originalContent {
|
||||||
|
t.Errorf("empty models should not modify file, but content changed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Model name with special characters (though unusual)
|
||||||
|
specialModel := `model-with-"quotes"`
|
||||||
|
|
||||||
|
err := o.Edit([]string{specialModel})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit with special chars failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it was stored correctly
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model should be accessible
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
|
||||||
|
if models[specialModel] == nil {
|
||||||
|
t.Errorf("model with special chars not found in config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
models := o.Models()
|
||||||
|
if len(models) > 0 {
|
||||||
|
t.Errorf("expected nil/empty for missing config, got %v", models)
|
||||||
|
}
|
||||||
|
}
|
||||||
499
cmd/config/selector.go
Normal file
499
cmd/config/selector.go
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ANSI escape sequences for terminal formatting.
|
||||||
|
const (
|
||||||
|
ansiHideCursor = "\033[?25l"
|
||||||
|
ansiShowCursor = "\033[?25h"
|
||||||
|
ansiBold = "\033[1m"
|
||||||
|
ansiReset = "\033[0m"
|
||||||
|
ansiGray = "\033[37m"
|
||||||
|
ansiClearDown = "\033[J"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxDisplayedItems = 10
|
||||||
|
|
||||||
|
var errCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
type selectItem struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
type inputEvent int
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventNone inputEvent = iota
|
||||||
|
eventEnter
|
||||||
|
eventEscape
|
||||||
|
eventUp
|
||||||
|
eventDown
|
||||||
|
eventTab
|
||||||
|
eventBackspace
|
||||||
|
eventChar
|
||||||
|
)
|
||||||
|
|
||||||
|
type selectState struct {
|
||||||
|
items []selectItem
|
||||||
|
filter string
|
||||||
|
selected int
|
||||||
|
scrollOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSelectState(items []selectItem) *selectState {
|
||||||
|
return &selectState{items: items}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *selectState) filtered() []selectItem {
|
||||||
|
return filterItems(s.items, s.filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
||||||
|
filtered := s.filtered()
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case eventEnter:
|
||||||
|
if len(filtered) > 0 && s.selected < len(filtered) {
|
||||||
|
return true, filtered[s.selected].Name, nil
|
||||||
|
}
|
||||||
|
case eventEscape:
|
||||||
|
return true, "", errCancelled
|
||||||
|
case eventBackspace:
|
||||||
|
if len(s.filter) > 0 {
|
||||||
|
s.filter = s.filter[:len(s.filter)-1]
|
||||||
|
s.selected = 0
|
||||||
|
s.scrollOffset = 0
|
||||||
|
}
|
||||||
|
case eventUp:
|
||||||
|
if s.selected > 0 {
|
||||||
|
s.selected--
|
||||||
|
if s.selected < s.scrollOffset {
|
||||||
|
s.scrollOffset = s.selected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case eventDown:
|
||||||
|
if s.selected < len(filtered)-1 {
|
||||||
|
s.selected++
|
||||||
|
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
||||||
|
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case eventChar:
|
||||||
|
s.filter += string(char)
|
||||||
|
s.selected = 0
|
||||||
|
s.scrollOffset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiSelectState struct {
|
||||||
|
items []selectItem
|
||||||
|
itemIndex map[string]int
|
||||||
|
filter string
|
||||||
|
highlighted int
|
||||||
|
scrollOffset int
|
||||||
|
checked map[int]bool
|
||||||
|
checkOrder []int
|
||||||
|
focusOnButton bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
||||||
|
s := &multiSelectState{
|
||||||
|
items: items,
|
||||||
|
itemIndex: make(map[string]int, len(items)),
|
||||||
|
checked: make(map[int]bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range items {
|
||||||
|
s.itemIndex[item.Name] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range preChecked {
|
||||||
|
if idx, ok := s.itemIndex[name]; ok {
|
||||||
|
s.checked[idx] = true
|
||||||
|
s.checkOrder = append(s.checkOrder, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *multiSelectState) filtered() []selectItem {
|
||||||
|
return filterItems(s.items, s.filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *multiSelectState) toggleItem() {
|
||||||
|
filtered := s.filtered()
|
||||||
|
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
item := filtered[s.highlighted]
|
||||||
|
origIdx := s.itemIndex[item.Name]
|
||||||
|
|
||||||
|
if s.checked[origIdx] {
|
||||||
|
delete(s.checked, origIdx)
|
||||||
|
for i, idx := range s.checkOrder {
|
||||||
|
if idx == origIdx {
|
||||||
|
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.checked[origIdx] = true
|
||||||
|
s.checkOrder = append(s.checkOrder, origIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
||||||
|
filtered := s.filtered()
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case eventEnter:
|
||||||
|
if s.focusOnButton && len(s.checkOrder) > 0 {
|
||||||
|
var res []string
|
||||||
|
for _, idx := range s.checkOrder {
|
||||||
|
res = append(res, s.items[idx].Name)
|
||||||
|
}
|
||||||
|
return true, res, nil
|
||||||
|
} else if !s.focusOnButton {
|
||||||
|
s.toggleItem()
|
||||||
|
}
|
||||||
|
case eventTab:
|
||||||
|
if len(s.checkOrder) > 0 {
|
||||||
|
s.focusOnButton = !s.focusOnButton
|
||||||
|
}
|
||||||
|
case eventEscape:
|
||||||
|
return true, nil, errCancelled
|
||||||
|
case eventBackspace:
|
||||||
|
if len(s.filter) > 0 {
|
||||||
|
s.filter = s.filter[:len(s.filter)-1]
|
||||||
|
s.highlighted = 0
|
||||||
|
s.scrollOffset = 0
|
||||||
|
s.focusOnButton = false
|
||||||
|
}
|
||||||
|
case eventUp:
|
||||||
|
if s.focusOnButton {
|
||||||
|
s.focusOnButton = false
|
||||||
|
} else if s.highlighted > 0 {
|
||||||
|
s.highlighted--
|
||||||
|
if s.highlighted < s.scrollOffset {
|
||||||
|
s.scrollOffset = s.highlighted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case eventDown:
|
||||||
|
if s.focusOnButton {
|
||||||
|
s.focusOnButton = false
|
||||||
|
} else if s.highlighted < len(filtered)-1 {
|
||||||
|
s.highlighted++
|
||||||
|
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
||||||
|
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case eventChar:
|
||||||
|
s.filter += string(char)
|
||||||
|
s.highlighted = 0
|
||||||
|
s.scrollOffset = 0
|
||||||
|
s.focusOnButton = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *multiSelectState) selectedCount() int {
|
||||||
|
return len(s.checkOrder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminal I/O handling
|
||||||
|
|
||||||
|
type terminalState struct {
|
||||||
|
fd int
|
||||||
|
oldState *term.State
|
||||||
|
}
|
||||||
|
|
||||||
|
func enterRawMode() (*terminalState, error) {
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Fprint(os.Stderr, ansiHideCursor)
|
||||||
|
return &terminalState{fd: fd, oldState: oldState}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *terminalState) restore() {
|
||||||
|
fmt.Fprint(os.Stderr, ansiShowCursor)
|
||||||
|
term.Restore(t.fd, t.oldState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearLines(n int) {
|
||||||
|
if n > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
||||||
|
fmt.Fprint(os.Stderr, ansiClearDown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||||
|
buf := make([]byte, 3)
|
||||||
|
n, err := r.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case n == 1 && buf[0] == 13:
|
||||||
|
return eventEnter, 0, nil
|
||||||
|
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
||||||
|
return eventEscape, 0, nil
|
||||||
|
case n == 1 && buf[0] == 9:
|
||||||
|
return eventTab, 0, nil
|
||||||
|
case n == 1 && buf[0] == 127:
|
||||||
|
return eventBackspace, 0, nil
|
||||||
|
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
||||||
|
return eventUp, 0, nil
|
||||||
|
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
||||||
|
return eventDown, 0, nil
|
||||||
|
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
||||||
|
return eventChar, buf[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return eventNone, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rendering
|
||||||
|
|
||||||
|
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||||
|
filtered := s.filtered()
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||||
|
lineCount := 1
|
||||||
|
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||||
|
lineCount++
|
||||||
|
} else {
|
||||||
|
displayCount := min(len(filtered), maxDisplayedItems)
|
||||||
|
|
||||||
|
for i := range displayCount {
|
||||||
|
idx := s.scrollOffset + i
|
||||||
|
if idx >= len(filtered) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
item := filtered[idx]
|
||||||
|
prefix := " "
|
||||||
|
if idx == s.selected {
|
||||||
|
prefix = " " + ansiBold + "> "
|
||||||
|
}
|
||||||
|
if item.Description != "" {
|
||||||
|
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
||||||
|
}
|
||||||
|
lineCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||||
|
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||||
|
lineCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lineCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||||
|
filtered := s.filtered()
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||||
|
lineCount := 1
|
||||||
|
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
||||||
|
lineCount++
|
||||||
|
} else {
|
||||||
|
displayCount := min(len(filtered), maxDisplayedItems)
|
||||||
|
|
||||||
|
for i := range displayCount {
|
||||||
|
idx := s.scrollOffset + i
|
||||||
|
if idx >= len(filtered) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
item := filtered[idx]
|
||||||
|
origIdx := s.itemIndex[item.Name]
|
||||||
|
|
||||||
|
checkbox := "[ ]"
|
||||||
|
if s.checked[origIdx] {
|
||||||
|
checkbox = "[x]"
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := " "
|
||||||
|
suffix := ""
|
||||||
|
if idx == s.highlighted && !s.focusOnButton {
|
||||||
|
prefix = "> "
|
||||||
|
}
|
||||||
|
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
||||||
|
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx == s.highlighted && !s.focusOnButton {
|
||||||
|
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||||
|
}
|
||||||
|
lineCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
||||||
|
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
||||||
|
lineCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "\r\n")
|
||||||
|
lineCount++
|
||||||
|
count := s.selectedCount()
|
||||||
|
switch {
|
||||||
|
case count == 0:
|
||||||
|
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
||||||
|
case s.focusOnButton:
|
||||||
|
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
||||||
|
}
|
||||||
|
lineCount++
|
||||||
|
|
||||||
|
return lineCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectPrompt prompts the user to select a single item from a list.
|
||||||
|
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return "", fmt.Errorf("no items to select from")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := enterRawMode()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer ts.restore()
|
||||||
|
|
||||||
|
state := newSelectState(items)
|
||||||
|
var lastLineCount int
|
||||||
|
|
||||||
|
render := func() {
|
||||||
|
clearLines(lastLineCount)
|
||||||
|
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
render()
|
||||||
|
|
||||||
|
for {
|
||||||
|
event, char, err := parseInput(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
done, result, err := state.handleInput(event, char)
|
||||||
|
if done {
|
||||||
|
clearLines(lastLineCount)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
render()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiSelectPrompt prompts the user to select multiple items from a list.
|
||||||
|
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, fmt.Errorf("no items to select from")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := enterRawMode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer ts.restore()
|
||||||
|
|
||||||
|
state := newMultiSelectState(items, preChecked)
|
||||||
|
var lastLineCount int
|
||||||
|
|
||||||
|
render := func() {
|
||||||
|
clearLines(lastLineCount)
|
||||||
|
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
render()
|
||||||
|
|
||||||
|
for {
|
||||||
|
event, char, err := parseInput(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
done, result, err := state.handleInput(event, char)
|
||||||
|
if done {
|
||||||
|
clearLines(lastLineCount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
render()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmPrompt(prompt string) (bool, error) {
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stdin.Read(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 'Y', 'y', 13:
|
||||||
|
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||||
|
return true, nil
|
||||||
|
case 'N', 'n', 27, 3:
|
||||||
|
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterItems(items []selectItem, filter string) []selectItem {
|
||||||
|
if filter == "" {
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
var result []selectItem
|
||||||
|
filterLower := strings.ToLower(filter)
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||||
|
result = append(result, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
913
cmd/config/selector_test.go
Normal file
913
cmd/config/selector_test.go
Normal file
@@ -0,0 +1,913 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterItems(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "llama3.2:latest"},
|
||||||
|
{Name: "qwen2.5:7b"},
|
||||||
|
{Name: "deepseek-v3:cloud"},
|
||||||
|
{Name: "GPT-OSS:20b"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "")
|
||||||
|
if len(result) != len(items) {
|
||||||
|
t.Errorf("expected %d items, got %d", len(items), len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "LLAMA")
|
||||||
|
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
||||||
|
t.Errorf("expected llama3.2:latest, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "gpt")
|
||||||
|
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
||||||
|
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PartialMatch", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "deep")
|
||||||
|
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
||||||
|
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "nonexistent")
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("expected 0 items, got %d", len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectState(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "item1"},
|
||||||
|
{Name: "item2"},
|
||||||
|
{Name: "item3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("InitialState", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("expected selected=0, got %d", s.selected)
|
||||||
|
}
|
||||||
|
if s.filter != "" {
|
||||||
|
t.Errorf("expected empty filter, got %q", s.filter)
|
||||||
|
}
|
||||||
|
if s.scrollOffset != 0 {
|
||||||
|
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
if !done || result != "item1" || err != nil {
|
||||||
|
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.filter = "item3"
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
if !done || result != "item3" || err != nil {
|
||||||
|
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.filter = "nonexistent"
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
if done || result != "" || err != nil {
|
||||||
|
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
done, result, err := s.handleInput(eventEscape, 0)
|
||||||
|
if !done || result != "" || err != errCancelled {
|
||||||
|
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Down_MovesSelection", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.selected != 1 {
|
||||||
|
t.Errorf("expected selected=1, got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.selected = 2
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.selected != 2 {
|
||||||
|
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Up_MovesSelection", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.selected = 2
|
||||||
|
s.handleInput(eventUp, 0)
|
||||||
|
if s.selected != 1 {
|
||||||
|
t.Errorf("expected selected=1, got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.handleInput(eventUp, 0)
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.handleInput(eventChar, 'i')
|
||||||
|
s.handleInput(eventChar, 't')
|
||||||
|
s.handleInput(eventChar, 'e')
|
||||||
|
s.handleInput(eventChar, 'm')
|
||||||
|
s.handleInput(eventChar, '2')
|
||||||
|
if s.filter != "item2" {
|
||||||
|
t.Errorf("expected filter='item2', got %q", s.filter)
|
||||||
|
}
|
||||||
|
filtered := s.filtered()
|
||||||
|
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
||||||
|
t.Errorf("expected [item2], got %v", filtered)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.selected = 2
|
||||||
|
s.handleInput(eventChar, 'x')
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.filter = "test"
|
||||||
|
s.handleInput(eventBackspace, 0)
|
||||||
|
if s.filter != "tes" {
|
||||||
|
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.handleInput(eventBackspace, 0)
|
||||||
|
if s.filter != "" {
|
||||||
|
t.Errorf("expected filter='', got %q", s.filter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.filter = "test"
|
||||||
|
s.selected = 2
|
||||||
|
s.handleInput(eventBackspace, 0)
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
||||||
|
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
||||||
|
manyItems := make([]selectItem, 15)
|
||||||
|
for i := range manyItems {
|
||||||
|
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||||
|
}
|
||||||
|
s := newSelectState(manyItems)
|
||||||
|
|
||||||
|
// move down 12 times (past the 10-item viewport)
|
||||||
|
for range 12 {
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.selected != 12 {
|
||||||
|
t.Errorf("expected selected=12, got %d", s.selected)
|
||||||
|
}
|
||||||
|
if s.scrollOffset != 3 {
|
||||||
|
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
||||||
|
manyItems := make([]selectItem, 15)
|
||||||
|
for i := range manyItems {
|
||||||
|
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||||
|
}
|
||||||
|
s := newSelectState(manyItems)
|
||||||
|
s.selected = 5
|
||||||
|
s.scrollOffset = 5
|
||||||
|
|
||||||
|
s.handleInput(eventUp, 0)
|
||||||
|
|
||||||
|
if s.selected != 4 {
|
||||||
|
t.Errorf("expected selected=4, got %d", s.selected)
|
||||||
|
}
|
||||||
|
if s.scrollOffset != 4 {
|
||||||
|
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiSelectState(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "item1"},
|
||||||
|
{Name: "item2"},
|
||||||
|
{Name: "item3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
if s.highlighted != 0 {
|
||||||
|
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||||
|
}
|
||||||
|
if s.selectedCount() != 0 {
|
||||||
|
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
||||||
|
}
|
||||||
|
if s.focusOnButton {
|
||||||
|
t.Error("expected focusOnButton=false initially")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item2", "item3"})
|
||||||
|
if s.selectedCount() != 2 {
|
||||||
|
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
||||||
|
}
|
||||||
|
if !s.checked[1] || !s.checked[2] {
|
||||||
|
t.Error("expected item2 and item3 to be checked")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
||||||
|
// order matters: first checked = default model
|
||||||
|
s := newMultiSelectState(items, []string{"item3", "item1"})
|
||||||
|
if len(s.checkOrder) != 2 {
|
||||||
|
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||||
|
}
|
||||||
|
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
||||||
|
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
||||||
|
if s.selectedCount() != 1 {
|
||||||
|
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.toggleItem()
|
||||||
|
if !s.checked[0] {
|
||||||
|
t.Error("expected item1 to be checked after toggle")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
s.toggleItem()
|
||||||
|
if s.checked[0] {
|
||||||
|
t.Error("expected item1 to be unchecked after toggle")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
||||||
|
s.highlighted = 1 // toggle item2
|
||||||
|
s.toggleItem()
|
||||||
|
|
||||||
|
if len(s.checkOrder) != 2 {
|
||||||
|
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
||||||
|
}
|
||||||
|
// should be [0, 2] (item1, item3) with item2 removed
|
||||||
|
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
||||||
|
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.handleInput(eventEnter, 0)
|
||||||
|
if !s.checked[0] {
|
||||||
|
t.Error("expected item1 to be checked after enter")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||||
|
s.focusOnButton = true
|
||||||
|
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
|
||||||
|
if !done || err != nil {
|
||||||
|
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
||||||
|
}
|
||||||
|
// result should preserve selection order
|
||||||
|
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
||||||
|
t.Errorf("expected [item2, item1], got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.focusOnButton = true
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
if done || result != nil || err != nil {
|
||||||
|
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
s.handleInput(eventTab, 0)
|
||||||
|
if !s.focusOnButton {
|
||||||
|
t.Error("expected focus on button after tab")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.handleInput(eventTab, 0)
|
||||||
|
if s.focusOnButton {
|
||||||
|
t.Error("tab should not focus button when nothing selected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
s.handleInput(eventTab, 0)
|
||||||
|
if !s.focusOnButton {
|
||||||
|
t.Error("expected focus on button after first tab")
|
||||||
|
}
|
||||||
|
s.handleInput(eventTab, 0)
|
||||||
|
if s.focusOnButton {
|
||||||
|
t.Error("expected focus back on list after second tab")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
done, result, err := s.handleInput(eventEscape, 0)
|
||||||
|
if !done || result != nil || err != errCancelled {
|
||||||
|
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item2", "item1"})
|
||||||
|
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
||||||
|
t.Error("expected item2 (idx 1) to be default (first checked)")
|
||||||
|
}
|
||||||
|
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||||
|
t.Error("expected item1 (idx 0) to NOT be default")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
||||||
|
t.Error("expected isDefault=false when nothing checked")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.highlighted != 1 {
|
||||||
|
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.highlighted = 1
|
||||||
|
s.handleInput(eventUp, 0)
|
||||||
|
if s.highlighted != 0 {
|
||||||
|
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
s.focusOnButton = true
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.focusOnButton {
|
||||||
|
t.Error("expected focus to return to list on arrow key")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.handleInput(eventChar, 'x')
|
||||||
|
if s.filter != "x" {
|
||||||
|
t.Errorf("expected filter='x', got %q", s.filter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
||||||
|
manyItems := make([]selectItem, 15)
|
||||||
|
for i := range manyItems {
|
||||||
|
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||||
|
}
|
||||||
|
s := newMultiSelectState(manyItems, nil)
|
||||||
|
s.highlighted = 10
|
||||||
|
s.scrollOffset = 5
|
||||||
|
|
||||||
|
s.handleInput(eventChar, 'x')
|
||||||
|
|
||||||
|
if s.highlighted != 0 {
|
||||||
|
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
||||||
|
}
|
||||||
|
if s.scrollOffset != 0 {
|
||||||
|
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.filter = "test"
|
||||||
|
s.handleInput(eventBackspace, 0)
|
||||||
|
if s.filter != "tes" {
|
||||||
|
t.Errorf("expected filter='tes', got %q", s.filter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
s.filter = "x"
|
||||||
|
s.focusOnButton = true
|
||||||
|
s.handleInput(eventBackspace, 0)
|
||||||
|
if s.focusOnButton {
|
||||||
|
t.Error("expected focusOnButton=false after backspace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseInput(t *testing.T) {
|
||||||
|
t.Run("Enter", func(t *testing.T) {
|
||||||
|
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
||||||
|
if err != nil || event != eventEnter || char != 0 {
|
||||||
|
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Escape", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
||||||
|
if err != nil || event != eventEscape {
|
||||||
|
t.Errorf("expected eventEscape, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
||||||
|
if err != nil || event != eventEscape {
|
||||||
|
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Tab", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
||||||
|
if err != nil || event != eventTab {
|
||||||
|
t.Errorf("expected eventTab, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Backspace", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
||||||
|
if err != nil || event != eventBackspace {
|
||||||
|
t.Errorf("expected eventBackspace, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UpArrow", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
||||||
|
if err != nil || event != eventUp {
|
||||||
|
t.Errorf("expected eventUp, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DownArrow", func(t *testing.T) {
|
||||||
|
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
||||||
|
if err != nil || event != eventDown {
|
||||||
|
t.Errorf("expected eventDown, got %v", event)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PrintableChars", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
char byte
|
||||||
|
}{
|
||||||
|
{"lowercase", 'a'},
|
||||||
|
{"uppercase", 'Z'},
|
||||||
|
{"digit", '5'},
|
||||||
|
{"space", ' '},
|
||||||
|
{"tilde", '~'},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
||||||
|
if err != nil || event != eventChar || char != tt.char {
|
||||||
|
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderSelect(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "item1", Description: "first item"},
|
||||||
|
{Name: "item2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
lineCount := renderSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "Select:") {
|
||||||
|
t.Error("expected prompt in output")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "item1") {
|
||||||
|
t.Error("expected item1 in output")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "first item") {
|
||||||
|
t.Error("expected description in output")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "item2") {
|
||||||
|
t.Error("expected item2 in output")
|
||||||
|
}
|
||||||
|
if lineCount != 3 { // 1 prompt + 2 items
|
||||||
|
t.Errorf("expected 3 lines, got %d", lineCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.filter = "xyz"
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "no matches") {
|
||||||
|
t.Error("expected 'no matches' message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
||||||
|
manyItems := make([]selectItem, 15)
|
||||||
|
for i := range manyItems {
|
||||||
|
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
||||||
|
}
|
||||||
|
s := newSelectState(manyItems)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
// 15 items - 10 displayed = 5 more
|
||||||
|
if !strings.Contains(buf.String(), "5 more") {
|
||||||
|
t.Error("expected '5 more' indicator")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderMultiSelect(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "item1"},
|
||||||
|
{Name: "item2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderMultiSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "[x]") {
|
||||||
|
t.Error("expected checked checkbox [x]")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "[ ]") {
|
||||||
|
t.Error("expected unchecked checkbox [ ]")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1"})
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderMultiSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "(default)") {
|
||||||
|
t.Error("expected (default) marker for first checked item")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, []string{"item1", "item2"})
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderMultiSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "2 selected") {
|
||||||
|
t.Error("expected '2 selected' in output")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderMultiSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "Select at least one") {
|
||||||
|
t.Error("expected 'Select at least one' helper text")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrCancelled(t *testing.T) {
|
||||||
|
t.Run("NotNil", func(t *testing.T) {
|
||||||
|
if errCancelled == nil {
|
||||||
|
t.Error("errCancelled should not be nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Message", func(t *testing.T) {
|
||||||
|
if errCancelled.Error() != "cancelled" {
|
||||||
|
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge case tests for selector.go
|
||||||
|
|
||||||
|
// TestSelectState_SingleItem verifies that single item list works without crash.
|
||||||
|
// List with only one item should still work.
|
||||||
|
func TestSelectState_SingleItem(t *testing.T) {
|
||||||
|
items := []selectItem{{Name: "only-one"}}
|
||||||
|
|
||||||
|
s := newSelectState(items)
|
||||||
|
|
||||||
|
// Down should do nothing (already at bottom)
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up should do nothing (already at top)
|
||||||
|
s.handleInput(eventUp, 0)
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enter should select the only item
|
||||||
|
done, result, err := s.handleInput(eventEnter, 0)
|
||||||
|
if !done || result != "only-one" || err != nil {
|
||||||
|
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
||||||
|
// List with exactly maxDisplayedItems items should not scroll.
|
||||||
|
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
||||||
|
items := make([]selectItem, maxDisplayedItems)
|
||||||
|
for i := range items {
|
||||||
|
items[i] = selectItem{Name: string(rune('a' + i))}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newSelectState(items)
|
||||||
|
|
||||||
|
// Move to last item
|
||||||
|
for range maxDisplayedItems - 1 {
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.selected != maxDisplayedItems-1 {
|
||||||
|
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not scroll when exactly at max
|
||||||
|
if s.scrollOffset != 0 {
|
||||||
|
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// One more down should do nothing
|
||||||
|
s.handleInput(eventDown, 0)
|
||||||
|
if s.selected != maxDisplayedItems-1 {
|
||||||
|
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
||||||
|
// User typing "model.v1" shouldn't match "modelsv1".
|
||||||
|
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "model.v1"},
|
||||||
|
{Name: "modelsv1"},
|
||||||
|
{Name: "model-v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter with dot should only match literal dot
|
||||||
|
result := filterItems(items, "model.v1")
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("expected 1 exact match, got %d", len(result))
|
||||||
|
}
|
||||||
|
if len(result) > 0 && result[0].Name != "model.v1" {
|
||||||
|
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other regex special chars should be literal too
|
||||||
|
items2 := []selectItem{
|
||||||
|
{Name: "test[0]"},
|
||||||
|
{Name: "test0"},
|
||||||
|
{Name: "test(1)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result2 := filterItems(items2, "test[0]")
|
||||||
|
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
||||||
|
t.Errorf("expected only 'test[0]', got %v", result2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
||||||
|
// itemIndex uses name as key - duplicates cause collision. This documents
|
||||||
|
// the current behavior: the last index for a duplicate name is stored
|
||||||
|
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
||||||
|
// Duplicate names - this is an edge case that shouldn't happen in practice
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "duplicate"},
|
||||||
|
{Name: "duplicate"},
|
||||||
|
{Name: "unique"},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
|
||||||
|
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
||||||
|
// When there are duplicates, only the last occurrence's index is stored
|
||||||
|
if s.itemIndex["duplicate"] != 1 {
|
||||||
|
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle item at highlighted=0 (first "duplicate")
|
||||||
|
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
||||||
|
// So it actually toggles the SECOND duplicate item, not the first
|
||||||
|
s.toggleItem()
|
||||||
|
|
||||||
|
// This documents the potentially surprising behavior:
|
||||||
|
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
||||||
|
if !s.checked[1] {
|
||||||
|
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
||||||
|
}
|
||||||
|
if s.checked[0] {
|
||||||
|
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
||||||
|
// Prevents index-out-of-bounds on next keystroke
|
||||||
|
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "apple"},
|
||||||
|
{Name: "banana"},
|
||||||
|
{Name: "cherry"},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newSelectState(items)
|
||||||
|
s.selected = 2 // Select "cherry"
|
||||||
|
|
||||||
|
// Type a filter that removes cherry from results
|
||||||
|
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
||||||
|
|
||||||
|
// Selection should reset to 0
|
||||||
|
if s.selected != 0 {
|
||||||
|
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := s.filtered()
|
||||||
|
if len(filtered) != 2 {
|
||||||
|
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
||||||
|
// Model names might contain unicode characters
|
||||||
|
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "llama-日本語"},
|
||||||
|
{Name: "模型-chinese"},
|
||||||
|
{Name: "émoji-🦙"},
|
||||||
|
{Name: "regular-model"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("filter japanese", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "日本")
|
||||||
|
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
||||||
|
t.Errorf("expected llama-日本語, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter chinese", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "模型")
|
||||||
|
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
||||||
|
t.Errorf("expected 模型-chinese, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter emoji", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "🦙")
|
||||||
|
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||||
|
t.Errorf("expected émoji-🦙, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter accented char", func(t *testing.T) {
|
||||||
|
result := filterItems(items, "émoji")
|
||||||
|
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
||||||
|
t.Errorf("expected émoji-🦙, got %v", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
||||||
|
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "apple"},
|
||||||
|
{Name: "banana"},
|
||||||
|
{Name: "cherry"},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newMultiSelectState(items, nil)
|
||||||
|
s.highlighted = 2 // Highlight "cherry"
|
||||||
|
|
||||||
|
// Type a filter that removes cherry
|
||||||
|
s.handleInput(eventChar, 'a')
|
||||||
|
|
||||||
|
if s.highlighted != 0 {
|
||||||
|
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
||||||
|
// Empty list should be handled gracefully.
|
||||||
|
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
||||||
|
s := newMultiSelectState([]selectItem{}, nil)
|
||||||
|
|
||||||
|
// Toggle should not panic on empty list
|
||||||
|
s.toggleItem()
|
||||||
|
|
||||||
|
if s.selectedCount() != 0 {
|
||||||
|
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render should handle empty list
|
||||||
|
var buf bytes.Buffer
|
||||||
|
lineCount := renderMultiSelect(&buf, "Select:", s)
|
||||||
|
if lineCount == 0 {
|
||||||
|
t.Error("renderMultiSelect should produce output even for empty list")
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "no matches") {
|
||||||
|
t.Error("expected 'no matches' for empty list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
||||||
|
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
||||||
|
items := []selectItem{
|
||||||
|
{Name: "item1", Description: "First item description"},
|
||||||
|
{Name: "item2", Description: ""},
|
||||||
|
{Name: "item3", Description: "Third item"},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := newSelectState(items)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
renderSelect(&buf, "Select:", s)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "First item description") {
|
||||||
|
t.Error("expected description to be rendered")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "item2") {
|
||||||
|
t.Error("expected item without description to be rendered")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -159,6 +159,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
sb.WriteString(before)
|
sb.WriteString(before)
|
||||||
if !ok {
|
if !ok {
|
||||||
fmt.Fprintln(&sb)
|
fmt.Fprintln(&sb)
|
||||||
|
scanner.Prompt.UseAlt = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -609,3 +609,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
|||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req openai.ImageEditRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Prompt == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Image == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
genReq, err := openai.FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &ImageWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1112,3 +1112,129 @@ func TestImageWriterResponse(t *testing.T) {
|
|||||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestImageEditsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
req api.GenerateRequest
|
||||||
|
err openai.ErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
// Base64-encoded test image (1x1 pixel PNG)
|
||||||
|
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
||||||
|
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "image edit basic",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Images: []api.ImageData{decodedImage},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit with size",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `",
|
||||||
|
"size": "512x768"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Images: []api.ImageData{decodedImage},
|
||||||
|
Width: 512,
|
||||||
|
Height: 768,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing prompt",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "prompt is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing model",
|
||||||
|
body: `{
|
||||||
|
"prompt": "make it blue",
|
||||||
|
"image": "` + testImage + `"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "model is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image edit missing image",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "make it blue"
|
||||||
|
}`,
|
||||||
|
err: openai.ErrorResponse{
|
||||||
|
Error: openai.Error{
|
||||||
|
Message: "image is required",
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() { capturedRequest = nil }()
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
if tc.err.Error.Message != "" {
|
||||||
|
var errResp openai.ErrorResponse
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
|
t.Fatalf("errors did not match:\n%s", diff)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
|
t.Fatalf("requests did not match:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||||
|
type ImageEditRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Image string `json:"image"` // Base64-encoded image data
|
||||||
|
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||||
|
Seed *int64 `json:"seed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||||
|
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the input image
|
||||||
|
if r.Image != "" {
|
||||||
|
imgData, err := decodeImageURL(r.Image)
|
||||||
|
if err != nil {
|
||||||
|
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||||
|
}
|
||||||
|
req.Images = append(req.Images, imgData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse size if provided (e.g., "1024x768")
|
||||||
|
if r.Size != "" {
|
||||||
|
var w, h int32
|
||||||
|
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||||
|
req.Width = w
|
||||||
|
req.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Seed != nil {
|
||||||
|
if req.Options == nil {
|
||||||
|
req.Options = map[string]any{}
|
||||||
|
}
|
||||||
|
req.Options["seed"] = *r.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Model != "test-model" {
|
||||||
|
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Prompt != "make it blue" {
|
||||||
|
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
Size: "512x768",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Width != 512 {
|
||||||
|
t.Errorf("expected width 512, got %d", result.Width)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Height != 768 {
|
||||||
|
t.Errorf("expected height 768, got %d", result.Height)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||||
|
seed := int64(12345)
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: prefix + image,
|
||||||
|
Seed: &seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromImageEditRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Options == nil {
|
||||||
|
t.Fatal("expected options to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Options["seed"] != seed {
|
||||||
|
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||||
|
req := ImageEditRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "make it blue",
|
||||||
|
Image: "not-valid-base64",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := FromImageEditRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid image")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -95,7 +95,21 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
|
|
||||||
var currentLineBuf []rune
|
var currentLineBuf []rune
|
||||||
|
|
||||||
|
// draining tracks if we're processing buffered input from cooked mode.
|
||||||
|
// In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
|
||||||
|
// We treat \n from cooked mode as submit, not multiline.
|
||||||
|
// We check Buffered() after the first read since the bufio buffer is
|
||||||
|
// empty until then. This is compatible with """ multiline mode in
|
||||||
|
// interactive.go since each Readline() call is independent.
|
||||||
|
var draining, stopDraining bool
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
// Apply deferred state change from previous iteration
|
||||||
|
if stopDraining {
|
||||||
|
draining = false
|
||||||
|
stopDraining = false
|
||||||
|
}
|
||||||
|
|
||||||
// don't show placeholder when pasting unless we're in multiline mode
|
// don't show placeholder when pasting unless we're in multiline mode
|
||||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||||
if buf.IsEmpty() && showPlaceholder {
|
if buf.IsEmpty() && showPlaceholder {
|
||||||
@@ -105,6 +119,15 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
|
|
||||||
r, err := i.Terminal.Read()
|
r, err := i.Terminal.Read()
|
||||||
|
|
||||||
|
// After reading, check if there's more buffered data. If so, we're
|
||||||
|
// processing cooked-mode input. Once buffer empties, the current
|
||||||
|
// char is the last buffered one (still drain it), then stop next iteration.
|
||||||
|
if i.Terminal.reader.Buffered() > 0 {
|
||||||
|
draining = true
|
||||||
|
} else if draining {
|
||||||
|
stopDraining = true
|
||||||
|
}
|
||||||
|
|
||||||
if buf.IsEmpty() {
|
if buf.IsEmpty() {
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL)
|
||||||
}
|
}
|
||||||
@@ -232,6 +255,8 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
fd := os.Stdin.Fd()
|
fd := os.Stdin.Fd()
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharCtrlJ:
|
case CharCtrlJ:
|
||||||
|
// If not draining cooked-mode input, treat as multiline
|
||||||
|
if !draining {
|
||||||
i.pastedLines = append(i.pastedLines, buf.String())
|
i.pastedLines = append(i.pastedLines, buf.String())
|
||||||
buf.Buf.Clear()
|
buf.Buf.Clear()
|
||||||
buf.Pos = 0
|
buf.Pos = 0
|
||||||
@@ -241,6 +266,9 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
fmt.Print(i.Prompt.AltPrompt)
|
fmt.Print(i.Prompt.AltPrompt)
|
||||||
i.Prompt.UseAlt = true
|
i.Prompt.UseAlt = true
|
||||||
continue
|
continue
|
||||||
|
}
|
||||||
|
// Draining cooked-mode input: treat \n as submit
|
||||||
|
fallthrough
|
||||||
case CharEnter:
|
case CharEnter:
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if len(i.pastedLines) > 0 {
|
if len(i.pastedLines) > 0 {
|
||||||
|
|||||||
@@ -75,12 +75,6 @@ type Model struct {
|
|||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
// Check for image generation model via config capabilities
|
|
||||||
if slices.Contains(m.Config.Capabilities, "image") {
|
|
||||||
return []model.Capability{model.CapabilityImage}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for completion capability
|
|
||||||
if m.ModelPath != "" {
|
if m.ModelPath != "" {
|
||||||
f, err := gguf.Open(m.ModelPath)
|
f, err := gguf.Open(m.ModelPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "model with image and vision capability (image editing)",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image", "vision"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "model with completion capability",
|
name: "model with completion capability",
|
||||||
model: Model{
|
model: Model{
|
||||||
|
|||||||
@@ -1604,8 +1604,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
// OpenAI-compatible image generation endpoint
|
// OpenAI-compatible image generation endpoints
|
||||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||||
@@ -2507,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers for streaming response
|
// Check streaming preference
|
||||||
c.Header("Content-Type", "application/x-ndjson")
|
isStreaming := req.Stream == nil || *req.Stream
|
||||||
|
|
||||||
|
contentType := "application/x-ndjson"
|
||||||
|
if !isStreaming {
|
||||||
|
contentType = "application/json; charset=utf-8"
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", contentType)
|
||||||
|
|
||||||
// Get seed from options if provided
|
// Get seed from options if provided
|
||||||
var seed int64
|
var seed int64
|
||||||
@@ -2523,13 +2530,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var images []llm.ImageData
|
||||||
|
for i, imgData := range req.Images {
|
||||||
|
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||||
|
}
|
||||||
|
|
||||||
var streamStarted bool
|
var streamStarted bool
|
||||||
|
var finalResponse api.GenerateResponse
|
||||||
|
|
||||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Width: req.Width,
|
Width: req.Width,
|
||||||
Height: req.Height,
|
Height: req.Height,
|
||||||
Steps: req.Steps,
|
Steps: req.Steps,
|
||||||
Seed: seed,
|
Seed: seed,
|
||||||
|
Images: images,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
streamStarted = true
|
streamStarted = true
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
@@ -2553,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !isStreaming {
|
||||||
|
finalResponse = res
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(res)
|
data, _ := json.Marshal(res)
|
||||||
c.Writer.Write(append(data, '\n'))
|
c.Writer.Write(append(data, '\n'))
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
@@ -2562,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||||||
if !streamStarted {
|
if !streamStarted {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isStreaming {
|
||||||
|
c.JSON(http.StatusOK, finalResponse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||||
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Ping(_ context.Context) error { return nil }
|
||||||
|
|
||||||
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||||
return mock, nil
|
return mock, nil
|
||||||
@@ -2193,3 +2197,246 @@ func TestGenerateUnload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateWithImages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: getGpuFn,
|
||||||
|
getSystemInfoFn: getSystemInfoFn,
|
||||||
|
waitForRecovery: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []*ggml.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Files: map[string]string{"file.gguf": digest},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("images passed to completion request", func(t *testing.T) {
|
||||||
|
testImage := []byte("test-image-data")
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Image processed"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Describe this image",
|
||||||
|
Images: []api.ImageData{testImage},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify images were passed to the completion request
|
||||||
|
if len(mock.CompletionRequest.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||||
|
t.Errorf("image data mismatch in completion request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||||
|
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||||
|
testImage1 := []byte("test-image-1")
|
||||||
|
testImage2 := []byte("test-image-2")
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Images processed"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Compare these images",
|
||||||
|
Images: []api.ImageData{testImage1, testImage2},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both images were passed
|
||||||
|
if len(mock.CompletionRequest.Images) != 2 {
|
||||||
|
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||||
|
t.Errorf("first image data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||||
|
t.Errorf("second image data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||||
|
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||||
|
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no images when none provided", func(t *testing.T) {
|
||||||
|
mock.CompletionResponse.Content = "No images"
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no images in completion request
|
||||||
|
if len(mock.CompletionRequest.Images) != 0 {
|
||||||
|
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestImageGenerateStreamFalse tests that image generation respects stream=false
|
||||||
|
// and returns a single JSON response instead of streaming ndjson.
|
||||||
|
func TestImageGenerateStreamFalse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
|
||||||
|
mock := mockRunner{}
|
||||||
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||||
|
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
|
||||||
|
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: map[string]*runnerRef{
|
||||||
|
"": {
|
||||||
|
llama: &mock,
|
||||||
|
Options: &opts,
|
||||||
|
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||||
|
numParallel: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: getGpuFn,
|
||||||
|
getSystemInfoFn: getSystemInfoFn,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
// Create model manifest with image capability
|
||||||
|
n := model.ParseName("test-image")
|
||||||
|
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
streamFalse := false
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-image",
|
||||||
|
Prompt: "test prompt",
|
||||||
|
Stream: &streamFalse,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
|
||||||
|
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||||
|
if len(lines) != 1 {
|
||||||
|
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.GenerateResponse
|
||||||
|
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Image != "base64image" {
|
||||||
|
t.Errorf("expected image 'base64image', got %q", resp.Image)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.Done {
|
||||||
|
t.Errorf("expected done=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/ollama/ollama/manifest"
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
@@ -209,10 +211,23 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
|||||||
return fmt.Errorf("invalid model name: %s", modelName)
|
return fmt.Errorf("invalid model name: %s", modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: find a better way to detect image input support
|
||||||
|
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
|
||||||
|
caps := capabilities
|
||||||
|
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
|
||||||
|
if data, err := os.ReadFile(modelIndex); err == nil {
|
||||||
|
var cfg struct {
|
||||||
|
ClassName string `json:"_class_name"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
|
||||||
|
caps = append(caps, "vision")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create config blob with version requirement
|
// Create config blob with version requirement
|
||||||
configData := model.ConfigV2{
|
configData := model.ConfigV2{
|
||||||
ModelFormat: "safetensors",
|
ModelFormat: "safetensors",
|
||||||
Capabilities: capabilities,
|
Capabilities: caps,
|
||||||
Requires: MinOllamaVersion,
|
Requires: MinOllamaVersion,
|
||||||
}
|
}
|
||||||
configJSON, err := json.Marshal(configData)
|
configJSON, err := json.Marshal(configData)
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
|
|||||||
// RunCLI handles the CLI for image generation models.
|
// RunCLI handles the CLI for image generation models.
|
||||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||||
|
// Image paths can be included in the prompt and will be extracted automatically.
|
||||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||||
// Get options from flags (with env var defaults)
|
// Get options from flags (with env var defaults)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract any image paths from the prompt
|
||||||
|
prompt, images, err := extractFileData(prompt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
|
Images: images,
|
||||||
Width: int32(opts.Width),
|
Width: int32(opts.Width),
|
||||||
Height: int32(opts.Height),
|
Height: int32(opts.Height),
|
||||||
Steps: int32(opts.Steps),
|
Steps: int32(opts.Steps),
|
||||||
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
|||||||
printCurrentSettings(opts)
|
printCurrentSettings(opts)
|
||||||
continue
|
continue
|
||||||
case strings.HasPrefix(line, "/"):
|
case strings.HasPrefix(line, "/"):
|
||||||
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
|
// Check if it's a file path, not a command
|
||||||
|
args := strings.Fields(line)
|
||||||
|
isFile := false
|
||||||
|
for _, f := range extractFileNames(line) {
|
||||||
|
if strings.HasPrefix(f, args[0]) {
|
||||||
|
isFile = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isFile {
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract any image paths from the input
|
||||||
|
prompt, images, err := extractFileData(line)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate image with current options
|
// Generate image with current options
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: line,
|
Prompt: prompt,
|
||||||
|
Images: images,
|
||||||
Width: int32(opts.Width),
|
Width: int32(opts.Width),
|
||||||
Height: int32(opts.Height),
|
Height: int32(opts.Height),
|
||||||
Steps: int32(opts.Steps),
|
Steps: int32(opts.Steps),
|
||||||
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractFileNames finds image file paths in the input string.
|
||||||
|
func extractFileNames(input string) []string {
|
||||||
|
// Regex to match file paths with image extensions
|
||||||
|
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||||
|
re := regexp.MustCompile(regexPattern)
|
||||||
|
return re.FindAllString(input, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFileData extracts image data from file paths found in the input.
|
||||||
|
// Returns the cleaned prompt (with file paths removed) and the image data.
|
||||||
|
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||||
|
filePaths := extractFileNames(input)
|
||||||
|
var imgs []api.ImageData
|
||||||
|
|
||||||
|
for _, fp := range filePaths {
|
||||||
|
// Normalize shell escapes
|
||||||
|
nfp := strings.ReplaceAll(fp, "\\ ", " ")
|
||||||
|
nfp = strings.ReplaceAll(nfp, "\\(", "(")
|
||||||
|
nfp = strings.ReplaceAll(nfp, "\\)", ")")
|
||||||
|
nfp = strings.ReplaceAll(nfp, "%20", " ")
|
||||||
|
|
||||||
|
data, err := getImageData(nfp)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||||
|
input = strings.ReplaceAll(input, fp, "")
|
||||||
|
imgs = append(imgs, data)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(input), imgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getImageData reads and validates image data from a file.
|
||||||
|
func getImageData(filePath string) ([]byte, error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 512)
|
||||||
|
_, err = file.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := http.DetectContentType(buf)
|
||||||
|
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||||
|
if !slices.Contains(allowedTypes, contentType) {
|
||||||
|
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-read the full file
|
||||||
|
return os.ReadFile(filePath)
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/draw"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
"image/png"
|
"image/png"
|
||||||
"os"
|
"os"
|
||||||
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||||
|
// Transparent images are composited onto a white background.
|
||||||
func DecodeImage(data []byte) (image.Image, error) {
|
func DecodeImage(data []byte) (image.Image, error) {
|
||||||
orientation := readJPEGOrientation(data)
|
orientation := readJPEGOrientation(data)
|
||||||
|
|
||||||
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
img = flattenAlpha(img)
|
||||||
return applyOrientation(img, orientation), nil
|
return applyOrientation(img, orientation), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// flattenAlpha composites an image onto a white background,
|
||||||
|
// removing any transparency. This is needed because image
|
||||||
|
// generation models don't handle alpha channels well.
|
||||||
|
func flattenAlpha(img image.Image) image.Image {
|
||||||
|
if _, ok := img.(*image.RGBA); !ok {
|
||||||
|
if _, ok := img.(*image.NRGBA); !ok {
|
||||||
|
// No alpha channel, return as-is
|
||||||
|
return img
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bounds := img.Bounds()
|
||||||
|
dst := image.NewRGBA(bounds)
|
||||||
|
|
||||||
|
// Fill with white background
|
||||||
|
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
|
||||||
|
|
||||||
|
// Composite the image on top
|
||||||
|
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||||
func readJPEGOrientation(data []byte) int {
|
func readJPEGOrientation(data []byte) int {
|
||||||
|
|||||||
@@ -161,6 +161,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TotalTensorSize returns the total size in bytes of all tensor layers.
|
||||||
|
func (m *ModelManifest) TotalTensorSize() int64 {
|
||||||
|
var total int64
|
||||||
|
for _, layer := range m.Manifest.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||||
|
total += layer.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
// ModelInfo contains metadata about an image generation model.
|
// ModelInfo contains metadata about an image generation model.
|
||||||
type ModelInfo struct {
|
type ModelInfo struct {
|
||||||
Architecture string
|
Architecture string
|
||||||
|
|||||||
@@ -5,6 +5,37 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestTotalTensorSize(t *testing.T) {
|
||||||
|
m := &ModelManifest{
|
||||||
|
Manifest: &Manifest{
|
||||||
|
Layers: []ManifestLayer{
|
||||||
|
{MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
|
||||||
|
{MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
|
||||||
|
{MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
|
||||||
|
{MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := m.TotalTensorSize()
|
||||||
|
want := int64(6000)
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("TotalTensorSize() = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTotalTensorSizeEmpty(t *testing.T) {
|
||||||
|
m := &ModelManifest{
|
||||||
|
Manifest: &Manifest{
|
||||||
|
Layers: []ManifestLayer{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := m.TotalTensorSize(); got != 0 {
|
||||||
|
t.Errorf("TotalTensorSize() = %d, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||||
|
|
||||||
|
|||||||
@@ -16,18 +16,9 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GB is a convenience constant for gigabytes.
|
|
||||||
const GB = 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
// SupportedBackends lists the backends that support image generation.
|
// SupportedBackends lists the backends that support image generation.
|
||||||
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||||
|
|
||||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
|
||||||
var modelVRAMEstimates = map[string]uint64{
|
|
||||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
|
||||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||||
// Returns nil if supported, or an error describing why it's not supported.
|
// Returns nil if supported, or an error describing why it's not supported.
|
||||||
func CheckPlatformSupport() error {
|
func CheckPlatformSupport() error {
|
||||||
@@ -47,17 +38,6 @@ func CheckPlatformSupport() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckMemoryRequirements validates that there's enough memory for image generation.
|
|
||||||
// Returns nil if memory is sufficient, or an error if not.
|
|
||||||
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
|
|
||||||
required := EstimateVRAM(modelName)
|
|
||||||
if availableMemory < required {
|
|
||||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
|
||||||
required/GB, availableMemory/GB)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResolveModelName checks if a model name is a known image generation model.
|
// ResolveModelName checks if a model name is a known image generation model.
|
||||||
// Returns the normalized model name if found, empty string otherwise.
|
// Returns the normalized model name if found, empty string otherwise.
|
||||||
func ResolveModelName(modelName string) string {
|
func ResolveModelName(modelName string) string {
|
||||||
@@ -68,16 +48,6 @@ func ResolveModelName(modelName string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
|
||||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
|
||||||
func EstimateVRAM(modelName string) uint64 {
|
|
||||||
className := DetectModelType(modelName)
|
|
||||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
|
||||||
return estimate
|
|
||||||
}
|
|
||||||
return 21 * GB
|
|
||||||
}
|
|
||||||
|
|
||||||
// DetectModelType reads model_index.json and returns the model type.
|
// DetectModelType reads model_index.json and returns the model type.
|
||||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||||
// Returns empty string if detection fails.
|
// Returns empty string if detection fails.
|
||||||
|
|||||||
@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckMemoryRequirements(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
availableMemory uint64
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "sufficient memory",
|
|
||||||
availableMemory: 32 * GB,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "exactly enough memory",
|
|
||||||
availableMemory: 21 * GB,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "insufficient memory",
|
|
||||||
availableMemory: 16 * GB,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero memory",
|
|
||||||
availableMemory: 0,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Use a non-existent model name which will default to 21GB estimate
|
|
||||||
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelVRAMEstimates(t *testing.T) {
|
|
||||||
// Verify the VRAM estimates map has expected entries
|
|
||||||
expected := map[string]uint64{
|
|
||||||
"ZImagePipeline": 21 * GB,
|
|
||||||
"FluxPipeline": 20 * GB,
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, expectedVRAM := range expected {
|
|
||||||
if actual, ok := modelVRAMEstimates[name]; !ok {
|
|
||||||
t.Errorf("Missing VRAM estimate for %s", name)
|
|
||||||
} else if actual != expectedVRAM {
|
|
||||||
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEstimateVRAMDefault(t *testing.T) {
|
|
||||||
// Non-existent model should return default 21GB
|
|
||||||
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
|
|
||||||
if vram != 21*GB {
|
|
||||||
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveModelName(t *testing.T) {
|
func TestResolveModelName(t *testing.T) {
|
||||||
// Non-existent model should return empty string
|
// Non-existent model should return empty string
|
||||||
result := ResolveModelName("nonexistent-model")
|
result := ResolveModelName("nonexistent-model")
|
||||||
|
|||||||
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||||
|
// It generates an image conditioned on the provided input images for image editing.
|
||||||
|
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||||
|
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||||
|
Prompt: prompt,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
Steps: steps,
|
||||||
|
Seed: seed,
|
||||||
|
InputImages: inputImages,
|
||||||
|
Progress: progress,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||||
const MaxOutputPixels = 2048 * 2048
|
const MaxOutputPixels = 2048 * 2048
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"image"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -30,6 +31,7 @@ type Request struct {
|
|||||||
Height int32 `json:"height,omitempty"`
|
Height int32 `json:"height,omitempty"`
|
||||||
Steps int `json:"steps,omitempty"`
|
Steps int `json:"steps,omitempty"`
|
||||||
Seed int64 `json:"seed,omitempty"`
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response is streamed back for each progress update
|
// Response is streamed back for each progress update
|
||||||
@@ -46,6 +48,13 @@ type ImageModel interface {
|
|||||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||||
|
// Models that support input images for editing should implement this interface.
|
||||||
|
type ImageEditModel interface {
|
||||||
|
ImageModel
|
||||||
|
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Server holds the model and handles requests
|
// Server holds the model and handles requests
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -78,14 +87,6 @@ func Execute(args []string) error {
|
|||||||
slog.Info("MLX library initialized")
|
slog.Info("MLX library initialized")
|
||||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||||
|
|
||||||
// Check memory requirements before loading
|
|
||||||
requiredMemory := imagegen.EstimateVRAM(*modelName)
|
|
||||||
availableMemory := mlx.GetMemoryLimit()
|
|
||||||
if availableMemory > 0 && availableMemory < requiredMemory {
|
|
||||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
|
||||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect model type and load appropriate model
|
// Detect model type and load appropriate model
|
||||||
modelType := imagegen.DetectModelType(*modelName)
|
modelType := imagegen.DetectModelType(*modelName)
|
||||||
slog.Info("detected model type", "type", modelType)
|
slog.Info("detected model type", "type", modelType)
|
||||||
@@ -161,6 +162,44 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate and decode input images
|
||||||
|
const maxInputImages = 2
|
||||||
|
if len(req.Images) > maxInputImages {
|
||||||
|
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputImages []image.Image
|
||||||
|
if len(req.Images) > 0 {
|
||||||
|
// TODO: add memory check for input images
|
||||||
|
|
||||||
|
inputImages = make([]image.Image, len(req.Images))
|
||||||
|
for i, imgBytes := range req.Images {
|
||||||
|
img, err := imagegen.DecodeImage(imgBytes)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputImages[i] = img
|
||||||
|
}
|
||||||
|
slog.Info("decoded input images", "count", len(inputImages))
|
||||||
|
|
||||||
|
// Default width/height to first input image dimensions, scaled to max 1024
|
||||||
|
bounds := inputImages[0].Bounds()
|
||||||
|
w, h := bounds.Dx(), bounds.Dy()
|
||||||
|
if w > 1024 || h > 1024 {
|
||||||
|
if w > h {
|
||||||
|
h = h * 1024 / w
|
||||||
|
w = 1024
|
||||||
|
} else {
|
||||||
|
w = w * 1024 / h
|
||||||
|
h = 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Width = int32(w)
|
||||||
|
req.Height = int32(h)
|
||||||
|
}
|
||||||
|
|
||||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -192,7 +231,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||||
|
var img *mlx.Array
|
||||||
|
var err error
|
||||||
|
if len(inputImages) > 0 {
|
||||||
|
editModel, ok := s.model.(ImageEditModel)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||||
|
} else {
|
||||||
|
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't send error for cancellation
|
// Don't send error for cancellation
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
@@ -104,11 +105,17 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get total weight size from manifest
|
||||||
|
var weightSize uint64
|
||||||
|
if manifest, err := LoadManifest(modelName); err == nil {
|
||||||
|
weightSize = uint64(manifest.TotalTensorSize())
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
vramSize: EstimateVRAM(modelName),
|
vramSize: weightSize,
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
}
|
}
|
||||||
@@ -226,6 +233,12 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
seed = time.Now().UnixNano()
|
seed = time.Now().UnixNano()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract raw image bytes from llm.ImageData slice
|
||||||
|
var images [][]byte
|
||||||
|
for _, img := range req.Images {
|
||||||
|
images = append(images, img.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// Build request for subprocess
|
// Build request for subprocess
|
||||||
creq := struct {
|
creq := struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
@@ -233,12 +246,14 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
Height int32 `json:"height,omitempty"`
|
Height int32 `json:"height,omitempty"`
|
||||||
Steps int32 `json:"steps,omitempty"`
|
Steps int32 `json:"steps,omitempty"`
|
||||||
Seed int64 `json:"seed,omitempty"`
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
Images [][]byte `json:"images,omitempty"`
|
||||||
}{
|
}{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Width: req.Width,
|
Width: req.Width,
|
||||||
Height: req.Height,
|
Height: req.Height,
|
||||||
Steps: req.Steps,
|
Steps: req.Steps,
|
||||||
Seed: seed,
|
Seed: seed,
|
||||||
|
Images: images,
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(creq)
|
body, err := json.Marshal(creq)
|
||||||
@@ -260,7 +275,8 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("%s", strings.TrimSpace(string(body)))
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|||||||
@@ -38,40 +38,6 @@ func TestPlatformSupport(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestMemoryRequirementsError verifies memory check returns clear error.
|
|
||||||
func TestMemoryRequirementsError(t *testing.T) {
|
|
||||||
// Test with insufficient memory
|
|
||||||
err := CheckMemoryRequirements("test-model", 8*GB)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with sufficient memory
|
|
||||||
err = CheckMemoryRequirements("test-model", 32*GB)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
|
|
||||||
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
|
|
||||||
// Unknown model should return default (21GB)
|
|
||||||
vram := EstimateVRAM("unknown-model")
|
|
||||||
if vram < 10*GB || vram > 100*GB {
|
|
||||||
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify known pipeline estimates exist and are reasonable
|
|
||||||
for name, estimate := range modelVRAMEstimates {
|
|
||||||
if estimate < 10*GB {
|
|
||||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
|
|
||||||
}
|
|
||||||
if estimate > 200*GB {
|
|
||||||
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||||
// This is a compile-time check but we document it as a test.
|
// This is a compile-time check but we document it as a test.
|
||||||
func TestServerInterfaceCompliance(t *testing.T) {
|
func TestServerInterfaceCompliance(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user