mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 15:55:46 +02:00
Compare commits
11 Commits
hoyyeva/op
...
launch-cop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a67e30cf4e | ||
|
|
283b393ed9 | ||
|
|
1b3a200c25 | ||
|
|
f4438d8215 | ||
|
|
5d920cc6bc | ||
|
|
cdddea0592 | ||
|
|
43f90def04 | ||
|
|
06ae6367bd | ||
|
|
48ad7085c4 | ||
|
|
e1e3cec8d0 | ||
|
|
d3e67e305c |
@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
ollama
|
ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
|
||||||
|
|
||||||
### Coding
|
### Coding
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ To launch a specific integration:
|
|||||||
ollama launch claude
|
ollama launch claude
|
||||||
```
|
```
|
||||||
|
|
||||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||||
|
|
||||||
### AI assistant
|
### AI assistant
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ func TestLaunchCmd(t *testing.T) {
|
|||||||
if cmd.Long == "" {
|
if cmd.Long == "" {
|
||||||
t.Error("Long description should not be empty")
|
t.Error("Long description should not be empty")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(cmd.Long, "hermes") {
|
||||||
|
t.Error("Long description should mention hermes")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("flags exist", func(t *testing.T) {
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
|||||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||||
|
type Copilot struct{}
|
||||||
|
|
||||||
|
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||||
|
|
||||||
|
func (c *Copilot) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("copilot"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".local", "bin", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) Run(model string, args []string) error {
|
||||||
|
copilotPath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||||
|
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// envVars returns the environment variables that configure Copilot CLI
|
||||||
|
// to use Ollama as its model provider.
|
||||||
|
func (c *Copilot) envVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||||
|
"COPILOT_PROVIDER_API_KEY=",
|
||||||
|
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||||
|
}
|
||||||
|
|
||||||
|
if model != "" {
|
||||||
|
env = append(env, "COPILOT_MODEL="+model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopilotIntegration(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Copilot CLI" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotFindPath(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fakeBin := filepath.Join(tmpDir, name)
|
||||||
|
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fakeBin {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||||
|
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||||
|
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fallback {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotArgs(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, nil},
|
||||||
|
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||||
|
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotEnvVars(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
envMap := func(envs []string) map[string]string {
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, e := range envs {
|
||||||
|
k, v, _ := strings.Cut(e, "=")
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars("llama3.2"))
|
||||||
|
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||||
|
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||||
|
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||||
|
}
|
||||||
|
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||||
|
}
|
||||||
|
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars(""))
|
||||||
|
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||||
|
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||||
|
got := envMap(c.envVars("test"))
|
||||||
|
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
962
cmd/launch/hermes.go
Normal file
962
cmd/launch/hermes.go
Normal file
@@ -0,0 +1,962 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
pathpkg "path"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
|
||||||
|
hermesProviderName = "Ollama"
|
||||||
|
hermesProviderKey = "ollama-launch"
|
||||||
|
hermesLegacyKey = "ollama"
|
||||||
|
hermesPlaceholderKey = "ollama"
|
||||||
|
hermesGatewaySetupHint = "hermes gateway setup"
|
||||||
|
hermesGatewaySetupTitle = "Connect a messaging app now?"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hermesGOOS = runtime.GOOS
|
||||||
|
hermesLookPath = exec.LookPath
|
||||||
|
hermesCommand = exec.Command
|
||||||
|
hermesUserHome = os.UserHomeDir
|
||||||
|
hermesOllamaURL = envconfig.ConnectableHost
|
||||||
|
)
|
||||||
|
|
||||||
|
var hermesMessagingEnvGroups = [][]string{
|
||||||
|
{"TELEGRAM_BOT_TOKEN"},
|
||||||
|
{"DISCORD_BOT_TOKEN"},
|
||||||
|
{"SLACK_BOT_TOKEN"},
|
||||||
|
{"SIGNAL_ACCOUNT"},
|
||||||
|
{"EMAIL_ADDRESS"},
|
||||||
|
{"TWILIO_ACCOUNT_SID"},
|
||||||
|
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
|
||||||
|
{"MATTERMOST_TOKEN"},
|
||||||
|
{"WHATSAPP_PHONE_NUMBER_ID"},
|
||||||
|
{"DINGTALK_CLIENT_ID"},
|
||||||
|
{"FEISHU_APP_ID"},
|
||||||
|
{"WECOM_BOT_ID"},
|
||||||
|
{"WEIXIN_ACCOUNT_ID"},
|
||||||
|
{"BLUEBUBBLES_SERVER_URL"},
|
||||||
|
{"WEBHOOK_ENABLED"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hermes is intentionally not an Editor integration: launch owns one primary
|
||||||
|
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
|
||||||
|
// switching UX after startup.
|
||||||
|
type Hermes struct{}
|
||||||
|
|
||||||
|
type hermesConfigBackend struct {
|
||||||
|
displayPath string
|
||||||
|
read func() ([]byte, error)
|
||||||
|
write func([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) String() string { return "Hermes Agent" }
|
||||||
|
|
||||||
|
func (h *Hermes) Run(_ string, args []string) error {
|
||||||
|
// Hermes reads its primary model from config.yaml. launch configures that
|
||||||
|
// default model ahead of time so we can keep runtime invocation simple and
|
||||||
|
// still let Hermes discover additional models later via its own UX.
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
return h.runWindows(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := h.findUnixBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||||
|
return hermesAttachedCommand(bin, "gateway", "setup").Run()
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return hermesAttachedCommand(bin, args...).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Paths() []string {
|
||||||
|
backend, err := h.configBackend()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{backend.displayPath}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Configure(model string) error {
|
||||||
|
backend, err := h.configBackend()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := map[string]any{}
|
||||||
|
if data, err := backend.read(); err == nil {
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("parse hermes config: %w", err)
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelSection, _ := cfg["model"].(map[string]any)
|
||||||
|
if modelSection == nil {
|
||||||
|
modelSection = make(map[string]any)
|
||||||
|
}
|
||||||
|
models := h.listModels(model)
|
||||||
|
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
|
||||||
|
|
||||||
|
// launch writes the minimum provider/default-model settings needed to
|
||||||
|
// bootstrap Hermes against Ollama. The active provider stays on a
|
||||||
|
// launch-owned key so /model stays aligned with the launcher-managed entry,
|
||||||
|
// and the Ollama endpoint lives in providers: so the picker shows one row.
|
||||||
|
modelSection["provider"] = hermesProviderKey
|
||||||
|
modelSection["default"] = model
|
||||||
|
modelSection["base_url"] = hermesBaseURL()
|
||||||
|
modelSection["api_key"] = hermesPlaceholderKey
|
||||||
|
cfg["model"] = modelSection
|
||||||
|
|
||||||
|
// use Hermes' built-in web toolset for now.
|
||||||
|
// TODO(parthsareen): move this to using Ollama web search
|
||||||
|
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return backend.write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) CurrentModel() string {
|
||||||
|
backend, err := h.configBackend()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data, err := backend.read()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := map[string]any{}
|
||||||
|
if yaml.Unmarshal(data, &cfg) != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hermesManagedCurrentModel(cfg, hermesBaseURL())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Onboard() error {
|
||||||
|
return config.MarkIntegrationOnboarded("hermes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) RequiresInteractiveOnboarding() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
|
||||||
|
running, err := h.gatewayRunning()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check Hermes gateway status: %w", err)
|
||||||
|
}
|
||||||
|
if !running {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
|
||||||
|
if err := h.restartGateway(); err != nil {
|
||||||
|
return fmt.Errorf("restart Hermes gateway: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) installed() bool {
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
if _, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return h.wslHasHermes()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := h.findUnixBinary()
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) ensureInstalled() error {
|
||||||
|
if h.installed() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
return h.ensureInstalledWindows()
|
||||||
|
}
|
||||||
|
|
||||||
|
var missing []string
|
||||||
|
for _, dep := range []string{"bash", "curl", "git"} {
|
||||||
|
if _, err := hermesLookPath(dep); err != nil {
|
||||||
|
missing = append(missing, dep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("hermes installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
|
||||||
|
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to install hermes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.installed() {
|
||||||
|
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) ensureInstalledWindows() error {
|
||||||
|
// Hermes upstream support is WSL-oriented, so Windows launch uses a hybrid
|
||||||
|
// WSL handoff that stays on the same install path as upstream Hermes.
|
||||||
|
if _, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||||
|
}
|
||||||
|
if h.wslHasHermes() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPromptWithOptions("Hermes runs through WSL2 on Windows. Install it in WSL now?", ConfirmOptions{
|
||||||
|
YesLabel: "Use WSL",
|
||||||
|
NoLabel: "Show manual steps",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling Hermes in WSL...\n")
|
||||||
|
if err := h.runWSL("bash", "-lc", hermesInstallScript); err != nil {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("failed to install hermes in WSL: %w", err))
|
||||||
|
}
|
||||||
|
if !h.wslHasHermes() {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("hermes install finished but the WSL binary was not found"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sHermes installed successfully in WSL%s\n\n", ansiGreen, ansiReset)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) listModels(defaultModel string) []string {
|
||||||
|
client := hermesOllamaClient()
|
||||||
|
resp, err := client.List(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return []string{defaultModel}
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, 0, len(resp.Models)+1)
|
||||||
|
seen := make(map[string]struct{}, len(resp.Models)+1)
|
||||||
|
add := func(name string) {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[name]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[name] = struct{}{}
|
||||||
|
models = append(models, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
add(defaultModel)
|
||||||
|
for _, entry := range resp.Models {
|
||||||
|
add(entry.Name)
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return []string{defaultModel}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) findUnixBinary() (string, error) {
|
||||||
|
if path, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".local", "bin", "hermes")
|
||||||
|
if _, err := os.Stat(fallback); err == nil {
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("hermes is not installed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) runWindows(args []string) error {
|
||||||
|
if path, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||||
|
return hermesAttachedCommand(path, "gateway", "setup").Run()
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return hermesAttachedCommand(path, args...).Run()
|
||||||
|
}
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||||
|
}
|
||||||
|
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||||
|
return h.runWSL("hermes", "gateway", "setup")
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.runWSL(append([]string{"hermes"}, args...)...); err != nil {
|
||||||
|
return hermesWindowsHint(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) runWSL(args ...string) error {
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return fmt.Errorf("wsl.exe is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return hermesAttachedCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) runWSLCombinedOutput(args ...string) ([]byte, error) {
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return nil, fmt.Errorf("wsl.exe is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return hermesCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).CombinedOutput()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) wslAvailable() bool {
|
||||||
|
_, err := hermesLookPath("wsl.exe")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) wslHasHermes() bool {
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cmd := hermesCommand("wsl.exe", "bash", "-lc", "command -v hermes >/dev/null 2>&1")
|
||||||
|
return cmd.Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) configBackend() (*hermesConfigBackend, error) {
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
if _, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return hermesLocalConfigBackend()
|
||||||
|
}
|
||||||
|
if h.wslAvailable() {
|
||||||
|
return h.wslConfigBackend()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hermesLocalConfigBackend()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesConfigPath() (string, error) {
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".hermes", "config.yaml"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesLocalConfigBackend() (*hermesConfigBackend, error) {
|
||||||
|
configPath, err := hermesConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &hermesConfigBackend{
|
||||||
|
displayPath: configPath,
|
||||||
|
read: func() ([]byte, error) {
|
||||||
|
return os.ReadFile(configPath)
|
||||||
|
},
|
||||||
|
write: func(data []byte) error {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) wslConfigBackend() (*hermesConfigBackend, error) {
|
||||||
|
home, err := h.wslHome()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
configPath := pathpkg.Join(home, ".hermes", "config.yaml")
|
||||||
|
return &hermesConfigBackend{
|
||||||
|
displayPath: configPath,
|
||||||
|
read: func() ([]byte, error) {
|
||||||
|
return h.readWSLFile(configPath)
|
||||||
|
},
|
||||||
|
write: func(data []byte) error {
|
||||||
|
return h.writeWSLConfig(configPath, data)
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) wslHome() (string, error) {
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return "", fmt.Errorf("wsl.exe is not available")
|
||||||
|
}
|
||||||
|
cmd := hermesCommand("wsl.exe", "bash", "-lc", `printf %s "$HOME"`)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
home := strings.TrimSpace(string(out))
|
||||||
|
if home == "" {
|
||||||
|
return "", fmt.Errorf("could not resolve WSL home directory")
|
||||||
|
}
|
||||||
|
return home, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) readWSLFile(path string) ([]byte, error) {
|
||||||
|
pathArg := shellQuoteArgs([]string{path})
|
||||||
|
cmd := hermesCommand("wsl.exe", "bash", "-lc", fmt.Sprintf("if [ -f %s ]; then cat %s; else exit 42; fi", pathArg, pathArg))
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err == nil {
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
if errors.As(err, &exitErr) && exitErr.ExitCode() == 42 {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) writeWSLConfig(path string, data []byte) error {
|
||||||
|
if existing, err := h.readWSLFile(path); err == nil {
|
||||||
|
if !bytes.Equal(existing, data) {
|
||||||
|
if err := hermesBackupData(path, existing); err != nil {
|
||||||
|
return fmt.Errorf("backup failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("read existing file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := pathpkg.Dir(path)
|
||||||
|
dirArg := shellQuoteArgs([]string{dir})
|
||||||
|
pathArg := shellQuoteArgs([]string{path})
|
||||||
|
script := fmt.Sprintf(
|
||||||
|
"dir=%s; path=%s; mkdir -p \"$dir\" && tmp=$(mktemp \"$dir/.tmp-XXXXXX\") && cat > \"$tmp\" && mv \"$tmp\" \"$path\"",
|
||||||
|
dirArg,
|
||||||
|
pathArg,
|
||||||
|
)
|
||||||
|
cmd := hermesCommand("wsl.exe", "bash", "-lc", script)
|
||||||
|
cmd.Stdin = bytes.NewReader(data)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
if msg := strings.TrimSpace(string(out)); msg != "" {
|
||||||
|
return fmt.Errorf("%w: %s", err, msg)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesBackupData(path string, data []byte) error {
|
||||||
|
if err := os.MkdirAll(fileutil.BackupDir(), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
backupPath := filepath.Join(fileutil.BackupDir(), fmt.Sprintf("%s.%d", filepath.Base(path), time.Now().Unix()))
|
||||||
|
return os.WriteFile(backupPath, data, 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesBaseURL() string {
|
||||||
|
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesEnvPath() (string, error) {
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".hermes", ".env"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
|
||||||
|
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if h.messagingConfigured() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
|
||||||
|
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
|
||||||
|
YesLabel: "Yes",
|
||||||
|
NoLabel: "Set up later",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := runSetup(); err != nil {
|
||||||
|
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) messagingConfigured() bool {
|
||||||
|
envVars, err := h.gatewayEnvVars()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, group := range hermesMessagingEnvGroups {
|
||||||
|
for _, key := range group {
|
||||||
|
if strings.TrimSpace(envVars[key]) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||||
|
envVars := make(map[string]string)
|
||||||
|
|
||||||
|
data, err := h.readGatewayEnvFile()
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
for key, value := range hermesParseEnvFile(data) {
|
||||||
|
envVars[key] = value
|
||||||
|
}
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
// nothing persisted yet
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.usesLocalRuntimeEnv() {
|
||||||
|
for _, group := range hermesMessagingEnvGroups {
|
||||||
|
for _, key := range group {
|
||||||
|
if value, ok := os.LookupEnv(key); ok {
|
||||||
|
envVars[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return envVars, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) readGatewayEnvFile() ([]byte, error) {
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
if _, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
path, err := hermesEnvPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.ReadFile(path)
|
||||||
|
}
|
||||||
|
if h.wslAvailable() {
|
||||||
|
home, err := h.wslHome()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.readWSLFile(pathpkg.Join(home, ".hermes", ".env"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := hermesEnvPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.ReadFile(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) usesLocalRuntimeEnv() bool {
|
||||||
|
if hermesGOOS != "windows" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, err := hermesLookPath("hermes")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayRunning() (bool, error) {
|
||||||
|
status, err := h.gatewayStatusOutput()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return hermesGatewayStatusRunning(status), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
if path, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
out, err := hermesCommand(path, "gateway", "status").CombinedOutput()
|
||||||
|
return string(out), err
|
||||||
|
}
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return "", hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||||
|
}
|
||||||
|
out, err := h.runWSLCombinedOutput("hermes", "gateway", "status")
|
||||||
|
return string(out), err
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := h.findUnixBinary()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
|
||||||
|
return string(out), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) restartGateway() error {
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
if path, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return hermesAttachedCommand(path, "gateway", "restart").Run()
|
||||||
|
}
|
||||||
|
if !h.wslAvailable() {
|
||||||
|
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||||
|
}
|
||||||
|
if err := h.runWSL("hermes", "gateway", "restart"); err != nil {
|
||||||
|
return hermesWindowsHint(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := h.findUnixBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return hermesAttachedCommand(bin, "gateway", "restart").Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesGatewayStatusRunning(output string) bool {
|
||||||
|
status := strings.ToLower(output)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(status, "gateway is not running"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway service is stopped"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway service is not loaded"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway is running"):
|
||||||
|
return true
|
||||||
|
case strings.Contains(status, "gateway service is running"):
|
||||||
|
return true
|
||||||
|
case strings.Contains(status, "gateway service is loaded"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesParseEnvFile(data []byte) map[string]string {
|
||||||
|
out := make(map[string]string)
|
||||||
|
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "export ") {
|
||||||
|
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
|
||||||
|
}
|
||||||
|
|
||||||
|
key, value, ok := strings.Cut(line, "=")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if len(value) >= 2 {
|
||||||
|
switch {
|
||||||
|
case value[0] == '"' && value[len(value)-1] == '"':
|
||||||
|
if unquoted, err := strconv.Unquote(value); err == nil {
|
||||||
|
value = unquoted
|
||||||
|
}
|
||||||
|
case value[0] == '\'' && value[len(value)-1] == '\'':
|
||||||
|
value = value[1 : len(value)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesOllamaClient() *api.Client {
|
||||||
|
// Hermes queries the same launch-resolved Ollama host that launch writes
|
||||||
|
// into config, so model discovery follows the configured endpoint.
|
||||||
|
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
|
||||||
|
providers := hermesUserProviders(cfg["providers"])
|
||||||
|
entry := hermesManagedProviderEntry(providers)
|
||||||
|
if entry == nil {
|
||||||
|
entry = make(map[string]any)
|
||||||
|
}
|
||||||
|
entry["name"] = hermesProviderName
|
||||||
|
entry["api"] = baseURL
|
||||||
|
entry["default_model"] = model
|
||||||
|
entry["models"] = hermesStringListAny(models)
|
||||||
|
providers[hermesProviderKey] = entry
|
||||||
|
delete(providers, hermesLegacyKey)
|
||||||
|
cfg["providers"] = providers
|
||||||
|
|
||||||
|
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
|
||||||
|
if len(customProviders) == 0 {
|
||||||
|
delete(cfg, "custom_providers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg["custom_providers"] = customProviders
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
|
||||||
|
modelCfg, _ := cfg["model"].(map[string]any)
|
||||||
|
if modelCfg == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, _ := modelCfg["provider"].(string)
|
||||||
|
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
configBaseURL, _ := modelCfg["base_url"].(string)
|
||||||
|
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
current, _ := modelCfg["default"].(string)
|
||||||
|
current = strings.TrimSpace(current)
|
||||||
|
if current == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := hermesUserProviders(cfg["providers"])
|
||||||
|
entry, _ := providers[hermesProviderKey].(map[string]any)
|
||||||
|
if entry == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL, _ := entry["api"].(string)
|
||||||
|
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultModel, _ := entry["default_model"].(string)
|
||||||
|
if strings.TrimSpace(defaultModel) != current {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesUserProviders(current any) map[string]any {
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
out := make(map[string]any, len(existing))
|
||||||
|
for key, value := range existing {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case map[any]any:
|
||||||
|
out := make(map[string]any, len(existing))
|
||||||
|
for key, value := range existing {
|
||||||
|
if s, ok := key.(string); ok {
|
||||||
|
out[s] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return make(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesCustomProviders(current any) []any {
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case []any:
|
||||||
|
return append([]any(nil), existing...)
|
||||||
|
case []map[string]any:
|
||||||
|
out := make([]any, 0, len(existing))
|
||||||
|
for _, entry := range existing {
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
|
||||||
|
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
|
||||||
|
if entry, _ := providers[key].(map[string]any); entry != nil {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesWithoutManagedCustomProviders(current any) []any {
|
||||||
|
customProviders := hermesCustomProviders(current)
|
||||||
|
preserved := make([]any, 0, len(customProviders))
|
||||||
|
|
||||||
|
for _, item := range customProviders {
|
||||||
|
entry, _ := item.(map[string]any)
|
||||||
|
if entry == nil {
|
||||||
|
preserved = append(preserved, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hermesManagedCustomProvider(entry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
preserved = append(preserved, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesHasManagedCustomProvider(current any) bool {
|
||||||
|
for _, item := range hermesCustomProviders(current) {
|
||||||
|
entry, _ := item.(map[string]any)
|
||||||
|
if entry != nil && hermesManagedCustomProvider(entry) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedCustomProvider(entry map[string]any) bool {
|
||||||
|
name, _ := entry["name"].(string)
|
||||||
|
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesNormalizeURL(raw string) string {
|
||||||
|
return strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesStringListAny(models []string) []any {
|
||||||
|
out := make([]any, 0, len(models))
|
||||||
|
for _, model := range dedupeModelList(models) {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, model)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeHermesToolsets(current any) any {
|
||||||
|
added := false
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case []any:
|
||||||
|
out := make([]any, 0, len(existing)+1)
|
||||||
|
for _, item := range existing {
|
||||||
|
out = append(out, item)
|
||||||
|
if s, _ := item.(string); s == "web" {
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !added {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []string:
|
||||||
|
out := append([]string(nil), existing...)
|
||||||
|
if !slices.Contains(out, "web") {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
asAny := make([]any, 0, len(out))
|
||||||
|
for _, item := range out {
|
||||||
|
asAny = append(asAny, item)
|
||||||
|
}
|
||||||
|
return asAny
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(existing) == "" {
|
||||||
|
return []any{"hermes-cli", "web"}
|
||||||
|
}
|
||||||
|
parts := strings.Split(existing, ",")
|
||||||
|
out := make([]any, 0, len(parts)+1)
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if part == "web" {
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
out = append(out, part)
|
||||||
|
}
|
||||||
|
if !added {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return []any{"hermes-cli", "web"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shellQuoteArgs(args []string) string {
|
||||||
|
quoted := make([]string, 0, len(args))
|
||||||
|
for _, arg := range args {
|
||||||
|
quoted = append(quoted, "'"+strings.ReplaceAll(arg, "'", `'\''`)+"'")
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||||
|
cmd := hermesCommand(name, args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesWindowsHint(err error) error {
|
||||||
|
if hermesGOOS != "windows" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w\n\nHermes runs on Windows through WSL2.\nQuick setup: wsl --install\nInstaller docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/", err)
|
||||||
|
}
|
||||||
1236
cmd/launch/hermes_test.go
Normal file
1236
cmd/launch/hermes_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -74,7 +74,7 @@ func TestIntegrationLookup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationRegistry(t *testing.T) {
|
func TestIntegrationRegistry(t *testing.T) {
|
||||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
|
||||||
|
|
||||||
for _, name := range expectedIntegrations {
|
for _, name := range expectedIntegrations {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
@@ -329,7 +329,7 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
|
||||||
existing := []modelInfo{
|
existing := []modelInfo{
|
||||||
{Name: "llama3.2:latest", Remote: false},
|
{Name: "llama3.2:latest", Remote: false},
|
||||||
{Name: "qwen2.5:latest", Remote: false},
|
{Name: "qwen2.5:latest", Remote: false},
|
||||||
@@ -338,10 +338,11 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
|||||||
items, _, _, _ := buildModelList(existing, nil, "")
|
items, _, _, _ := buildModelList(existing, nil, "")
|
||||||
got := names(items)
|
got := names(items)
|
||||||
|
|
||||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
// Cloud recs always come first among recommended, regardless of installed inventory.
|
||||||
want := []string{"gemma4", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
|
// Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
|
||||||
|
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"}
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -588,7 +589,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
func TestBuildModelList_OnlyLocal_CloudRecsStillFirst(t *testing.T) {
|
||||||
existing := []modelInfo{
|
existing := []modelInfo{
|
||||||
{Name: "llama3.2:latest", Remote: false},
|
{Name: "llama3.2:latest", Remote: false},
|
||||||
}
|
}
|
||||||
@@ -596,11 +597,11 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
|||||||
items, _, _, _ := buildModelList(existing, nil, "")
|
items, _, _, _ := buildModelList(existing, nil, "")
|
||||||
got := names(items)
|
got := names(items)
|
||||||
|
|
||||||
// Local recs should sort before cloud recs in only-local case
|
// Cloud recs sort before local recs regardless of installed inventory.
|
||||||
localIdx := slices.Index(got, "gemma4")
|
localIdx := slices.Index(got, "gemma4")
|
||||||
cloudIdx := slices.Index(got, "glm-5.1:cloud")
|
cloudIdx := slices.Index(got, "glm-5.1:cloud")
|
||||||
if localIdx > cloudIdx {
|
if cloudIdx > localIdx {
|
||||||
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
|
t.Errorf("cloud recs should be before local recs even when only local models installed, got %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1509,27 +1510,13 @@ func TestListIntegrationInfos(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
t.Run("follows launcher order", func(t *testing.T) {
|
||||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
got := make([]string, 0, len(infos))
|
||||||
// All other entries should be sorted alphabetically before them.
|
for _, info := range infos {
|
||||||
orderRank := make(map[string]int)
|
got = append(got, info.Name)
|
||||||
for i, name := range integrationOrder {
|
|
||||||
orderRank[name] = i + 1
|
|
||||||
}
|
}
|
||||||
for i := 1; i < len(infos); i++ {
|
if diff := compareStrings(got, integrationOrder); diff != "" {
|
||||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
t.Fatalf("launcher integration order mismatch: %s", diff)
|
||||||
switch {
|
|
||||||
case aRank == 0 && bRank == 0:
|
|
||||||
if infos[i-1].Name >= infos[i].Name {
|
|
||||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
|
||||||
}
|
|
||||||
case aRank > 0 && bRank == 0:
|
|
||||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
|
||||||
case aRank > 0 && bRank > 0:
|
|
||||||
if aRank >= bRank {
|
|
||||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1557,6 +1544,28 @@ func TestListIntegrationInfos(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("includes hermes", func(t *testing.T) {
|
||||||
|
for _, info := range infos {
|
||||||
|
if info.Name == "hermes" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatal("expected hermes to be included in ListIntegrationInfos")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("hermes still resolves explicitly", func(t *testing.T) {
|
||||||
|
name, runner, err := LookupIntegration("hermes")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected explicit hermes integration lookup to work, got %v", err)
|
||||||
|
}
|
||||||
|
if name != "hermes" {
|
||||||
|
t.Fatalf("expected canonical name hermes, got %q", name)
|
||||||
|
}
|
||||||
|
if runner.String() == "" {
|
||||||
|
t.Fatal("expected hermes integration runner to be present")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_Descriptions(t *testing.T) {
|
func TestBuildModelList_Descriptions(t *testing.T) {
|
||||||
@@ -1645,6 +1654,7 @@ func TestIntegration_AutoInstallable(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"openclaw", true},
|
{"openclaw", true},
|
||||||
{"pi", true},
|
{"pi", true},
|
||||||
|
{"hermes", true},
|
||||||
{"claude", false},
|
{"claude", false},
|
||||||
{"codex", false},
|
{"codex", false},
|
||||||
{"opencode", false},
|
{"opencode", false},
|
||||||
|
|||||||
@@ -141,6 +141,36 @@ type Editor interface {
|
|||||||
Models() []string
|
Models() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ManagedSingleModel is the narrow launch-owned config path for integrations
|
||||||
|
// like Hermes that have one primary model selected by launcher, need launcher
|
||||||
|
// to persist minimal config, and still keep their own model discovery and
|
||||||
|
// onboarding UX. This stays separate from Runner-only integrations and the
|
||||||
|
// multi-model Editor flow so Hermes-specific behavior stays scoped to one path.
|
||||||
|
type ManagedSingleModel interface {
|
||||||
|
Paths() []string
|
||||||
|
Configure(model string) error
|
||||||
|
CurrentModel() string
|
||||||
|
Onboard() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedRuntimeRefresher lets managed integrations refresh any long-lived
|
||||||
|
// background runtime after launch rewrites their config.
|
||||||
|
type ManagedRuntimeRefresher interface {
|
||||||
|
RefreshRuntimeAfterConfigure() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedOnboardingValidator lets managed integrations re-check saved
|
||||||
|
// onboarding state when launcher needs a stronger live readiness signal.
|
||||||
|
type ManagedOnboardingValidator interface {
|
||||||
|
OnboardingComplete() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedInteractiveOnboarding lets a managed integration declare whether its
|
||||||
|
// onboarding step really requires an interactive terminal. Hermes does not.
|
||||||
|
type ManagedInteractiveOnboarding interface {
|
||||||
|
RequiresInteractiveOnboarding() bool
|
||||||
|
}
|
||||||
|
|
||||||
type modelInfo struct {
|
type modelInfo struct {
|
||||||
Name string
|
Name string
|
||||||
Remote bool
|
Remote bool
|
||||||
@@ -176,7 +206,9 @@ Supported integrations:
|
|||||||
claude Claude Code
|
claude Claude Code
|
||||||
cline Cline
|
cline Cline
|
||||||
codex Codex
|
codex Codex
|
||||||
|
copilot Copilot CLI (aliases: copilot-cli)
|
||||||
droid Droid
|
droid Droid
|
||||||
|
hermes Hermes Agent
|
||||||
opencode OpenCode
|
opencode OpenCode
|
||||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
pi Pi
|
pi Pi
|
||||||
@@ -186,6 +218,7 @@ Examples:
|
|||||||
ollama launch
|
ollama launch
|
||||||
ollama launch claude
|
ollama launch claude
|
||||||
ollama launch claude --model <model>
|
ollama launch claude --model <model>
|
||||||
|
ollama launch hermes
|
||||||
ollama launch droid --config (does not auto-launch)
|
ollama launch droid --config (does not auto-launch)
|
||||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
ollama launch codex -- --sandbox workspace-write`,
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
@@ -308,36 +341,54 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy := launchIntegrationPolicy(req)
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, saved, err := prepareIntegrationLaunch(name, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if managed, ok := runner.(ManagedSingleModel); ok {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return launchClient.launchManagedSingleIntegration(ctx, name, runner, managed, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
if !req.ConfigureOnly {
|
if !req.ConfigureOnly {
|
||||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var policy LaunchPolicy
|
|
||||||
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
|
||||||
if req.Policy == nil {
|
|
||||||
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
|
||||||
} else {
|
|
||||||
policy = *req.Policy
|
|
||||||
}
|
|
||||||
|
|
||||||
launchClient, err := newLauncherClient(policy)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
saved, _ := loadStoredIntegrationConfig(name)
|
|
||||||
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
|
||||||
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
|
||||||
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if editor, ok := runner.(Editor); ok {
|
if editor, ok := runner.(Editor); ok {
|
||||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
}
|
}
|
||||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func launchIntegrationPolicy(req IntegrationLaunchRequest) LaunchPolicy {
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can
|
||||||
|
// have flags which change the behavior.
|
||||||
|
if req.Policy != nil {
|
||||||
|
return *req.Policy
|
||||||
|
}
|
||||||
|
return defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareIntegrationLaunch(name string, policy LaunchPolicy) (*launcherClient, *config.IntegrationConfig, error) {
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
return launchClient, saved, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
_ = c.loadModelInventoryOnce(ctx)
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
@@ -368,9 +419,18 @@ func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return LauncherIntegrationState{}, err
|
return LauncherIntegrationState{}, err
|
||||||
}
|
}
|
||||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
var currentModel string
|
||||||
if err != nil {
|
var usable bool
|
||||||
return LauncherIntegrationState{}, err
|
if managed, ok := integration.spec.Runner.(ManagedSingleModel); ok {
|
||||||
|
currentModel, usable, err = c.launcherManagedModelState(ctx, info.Name, managed)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currentModel, usable, err = c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return LauncherIntegrationState{
|
return LauncherIntegrationState{
|
||||||
@@ -408,6 +468,28 @@ func (c *launcherClient) launcherModelState(ctx context.Context, name string, is
|
|||||||
return model, usableErr == nil && usable, nil
|
return model, usableErr == nil && usable, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherManagedModelState(ctx context.Context, name string, managed ManagedSingleModel) (string, bool, error) {
|
||||||
|
current := managed.CurrentModel()
|
||||||
|
if current == "" {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
if loadErr == nil {
|
||||||
|
current = primaryModelFromConfig(cfg)
|
||||||
|
}
|
||||||
|
if current != "" {
|
||||||
|
return current, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if current == "" {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return current, false, err
|
||||||
|
}
|
||||||
|
return current, usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
current := config.LastModel()
|
current := config.LastModel()
|
||||||
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
@@ -444,35 +526,15 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
current := primaryModelFromConfig(saved)
|
target, _, err := c.resolveSingleIntegrationTarget(ctx, runner, primaryModelFromConfig(saved), req)
|
||||||
target := req.ModelOverride
|
if err != nil {
|
||||||
needsConfigure := req.ForceConfigure
|
|
||||||
|
|
||||||
if target == "" {
|
|
||||||
target = current
|
|
||||||
usable, err := c.savedModelUsable(ctx, target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !usable {
|
|
||||||
needsConfigure = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needsConfigure {
|
|
||||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
target = selected
|
|
||||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if target == "" {
|
if target == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
if target != current {
|
if target != current {
|
||||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
@@ -510,6 +572,102 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
|||||||
return launchAfterConfiguration(name, runner, models[0], req)
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, name string, runner Runner, managed ManagedSingleModel, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := managed.CurrentModel()
|
||||||
|
selectionCurrent := current
|
||||||
|
if selectionCurrent == "" {
|
||||||
|
selectionCurrent = primaryModelFromConfig(saved)
|
||||||
|
}
|
||||||
|
|
||||||
|
target, needsConfigure, err := c.resolveSingleIntegrationTarget(ctx, runner, selectionCurrent, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if current == "" || needsConfigure || req.ModelOverride != "" || target != current {
|
||||||
|
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if refresher, ok := managed.(ManagedRuntimeRefresher); ok {
|
||||||
|
if err := refresher.RefreshRuntimeAfterConfigure(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !managedIntegrationOnboarded(saved, managed) {
|
||||||
|
if !isInteractiveSession() && managedRequiresInteractiveOnboarding(managed) {
|
||||||
|
return fmt.Errorf("%s still needs interactive gateway setup; run 'ollama launch %s' in a terminal to finish onboarding", runner, name)
|
||||||
|
}
|
||||||
|
if err := managed.Onboard(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return runIntegration(runner, target, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveSingleIntegrationTarget(ctx context.Context, runner Runner, current string, req IntegrationLaunchRequest) (string, bool, error) {
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return target, needsConfigure, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func savedIntegrationOnboarded(saved *config.IntegrationConfig) bool {
|
||||||
|
return saved != nil && saved.Onboarded
|
||||||
|
}
|
||||||
|
|
||||||
|
func managedIntegrationOnboarded(saved *config.IntegrationConfig, managed ManagedSingleModel) bool {
|
||||||
|
if !savedIntegrationOnboarded(saved) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
validator, ok := managed.(ManagedOnboardingValidator)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return validator.OnboardingComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most managed integrations treat onboarding as an interactive terminal step.
|
||||||
|
// Hermes opts out because its launch-owned onboarding is just bookkeeping, so
|
||||||
|
// headless launches should not be blocked once config is already prepared.
|
||||||
|
func managedRequiresInteractiveOnboarding(managed ManagedSingleModel) bool {
|
||||||
|
onboarding, ok := managed.(ManagedInteractiveOnboarding)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return onboarding.RequiresInteractiveOnboarding()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
if selector == nil {
|
if selector == nil {
|
||||||
return "", fmt.Errorf("no selector configured")
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
|||||||
@@ -49,6 +49,55 @@ func (r *launcherSingleRunner) Run(model string, args []string) error {
|
|||||||
|
|
||||||
func (r *launcherSingleRunner) String() string { return "StubSingle" }
|
func (r *launcherSingleRunner) String() string { return "StubSingle" }
|
||||||
|
|
||||||
|
type launcherManagedRunner struct {
|
||||||
|
paths []string
|
||||||
|
currentModel string
|
||||||
|
configured []string
|
||||||
|
ranModel string
|
||||||
|
onboarded bool
|
||||||
|
onboardCalls int
|
||||||
|
onboardingComplete bool
|
||||||
|
refreshCalls int
|
||||||
|
refreshErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Run(model string, args []string) error {
|
||||||
|
r.ranModel = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) String() string { return "StubManaged" }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Paths() []string { return r.paths }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Configure(model string) error {
|
||||||
|
r.configured = append(r.configured, model)
|
||||||
|
r.currentModel = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) CurrentModel() string { return r.currentModel }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Onboard() error {
|
||||||
|
r.onboardCalls++
|
||||||
|
r.onboarded = true
|
||||||
|
r.onboardingComplete = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) OnboardingComplete() bool { return r.onboardingComplete }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) RefreshRuntimeAfterConfigure() error {
|
||||||
|
r.refreshCalls++
|
||||||
|
return r.refreshErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherHeadlessManagedRunner struct {
|
||||||
|
launcherManagedRunner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherHeadlessManagedRunner) RequiresInteractiveOnboarding() bool { return false }
|
||||||
|
|
||||||
func setLaunchTestHome(t *testing.T, dir string) {
|
func setLaunchTestHome(t *testing.T, dir string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
@@ -141,6 +190,448 @@ func TestDefaultLaunchPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{currentModel: "gemma4"}
|
||||||
|
withIntegrationOverride(t, "pi", runner)
|
||||||
|
|
||||||
|
state, err := BuildLauncherState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||||
|
t.Fatalf("expected managed current model from integration config, got %q", state.Integrations["pi"].CurrentModel)
|
||||||
|
}
|
||||||
|
if !state.Integrations["pi"].ModelUsable {
|
||||||
|
t.Fatal("expected managed current model to be usable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfigMissing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("pi", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "pi", runner)
|
||||||
|
|
||||||
|
state, err := BuildLauncherState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||||
|
t.Fatalf("expected saved model to remain visible, got %q", state.Integrations["pi"].CurrentModel)
|
||||||
|
}
|
||||||
|
if state.Integrations["pi"].ModelUsable {
|
||||||
|
t.Fatal("expected missing live config to mark managed model unusable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
return "gemma4", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("configured models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubmanaged")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := compareStrings(saved.Models, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStale(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
currentModel: "gemma4",
|
||||||
|
onboardingComplete: false,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
if err := config.MarkIntegrationOnboarded("stubmanaged"); err != nil {
|
||||||
|
t.Fatalf("failed to mark managed integration onboarded: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected stale onboarded flag to trigger onboarding, got %d calls", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 0 {
|
||||||
|
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run saved model after onboarding repair, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
ConfigureOnly: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected configure-only flow to refresh runtime once, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected configure-only flow to onboard once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsingSavedModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called when saved model is reused for repair")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("expected missing live config to be rewritten from saved model: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected repaired config to refresh runtime once, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to use repaired saved model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLiveConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called when saved model is reused for repair")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ConfigureOnly: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("expected configure-only flow to rewrite missing live config: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected configure-only repair to refresh runtime once, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationStopsWhenRuntimeRefreshFails(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
refreshErr: fmt.Errorf("boom"),
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||||
|
t.Fatalf("expected runtime refresh error, got %v", err)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected final launch to stop on runtime refresh failure, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected one runtime refresh attempt, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 0 {
|
||||||
|
t.Fatalf("expected onboarding to stop after refresh failure, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessNeedsInteractiveOnboarding(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected headless onboarding requirement to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "interactive gateway setup") {
|
||||||
|
t.Fatalf("expected interactive onboarding guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected no final launch when onboarding is still required, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 0 {
|
||||||
|
t.Fatalf("expected no onboarding attempts in headless mode, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessAllowsNonInteractiveOnboarding(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherHeadlessManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected non-interactive onboarding to succeed headlessly, got %v", err)
|
||||||
|
}
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("configured models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setLaunchTestHome(t, tmpDir)
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ func pullMissingModel(ctx context.Context, client *api.Client, model string) err
|
|||||||
|
|
||||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
return errCancelled
|
return errCancelled
|
||||||
@@ -244,8 +244,22 @@ func prepareEditorIntegration(name string, runner Runner, editor Editor, models
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
|
||||||
paths := editor.Paths()
|
if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := managed.Configure(model); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, []string{model}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
|
||||||
if len(paths) == 0 {
|
if len(paths) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -345,8 +359,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||||||
recRank[rec.Name] = i + 1
|
recRank[rec.Name] = i + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
onlyLocal := hasLocalModel && !hasCloudModel
|
|
||||||
|
|
||||||
if hasLocalModel || hasCloudModel {
|
if hasLocalModel || hasCloudModel {
|
||||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
ac, bc := checked[a.Name], checked[b.Name]
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
@@ -368,12 +380,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||||||
}
|
}
|
||||||
if aRec && bRec {
|
if aRec && bRec {
|
||||||
if aCloud != bCloud {
|
if aCloud != bCloud {
|
||||||
if onlyLocal {
|
|
||||||
if aCloud {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if aCloud {
|
if aCloud {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
|||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
||||||
|
|
||||||
var integrationSpecs = []*IntegrationSpec{
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
{
|
{
|
||||||
@@ -74,6 +74,19 @@ var integrationSpecs = []*IntegrationSpec{
|
|||||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "copilot",
|
||||||
|
Runner: &Copilot{},
|
||||||
|
Aliases: []string{"copilot-cli"},
|
||||||
|
Description: "GitHub's AI coding agent for the terminal",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Copilot{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://github.com/features/copilot/cli/",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "droid",
|
Name: "droid",
|
||||||
Runner: &Droid{},
|
Runner: &Droid{},
|
||||||
@@ -136,6 +149,20 @@ var integrationSpecs = []*IntegrationSpec{
|
|||||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "hermes",
|
||||||
|
Runner: &Hermes{},
|
||||||
|
Description: "Self-improving AI agent built by Nous Research",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
return (&Hermes{}).installed()
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
return (&Hermes{}).ensureInstalled()
|
||||||
|
},
|
||||||
|
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "vscode",
|
Name: "vscode",
|
||||||
Runner: &VSCode{},
|
Runner: &VSCode{},
|
||||||
@@ -255,10 +282,10 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
|||||||
return aRank - bRank
|
return aRank - bRank
|
||||||
}
|
}
|
||||||
if aRank > 0 {
|
if aRank > 0 {
|
||||||
return 1
|
return -1
|
||||||
}
|
}
|
||||||
if bRank > 0 {
|
if bRank > 0 {
|
||||||
return -1
|
return 1
|
||||||
}
|
}
|
||||||
return strings.Compare(a.Name, b.Name)
|
return strings.Compare(a.Name, b.Name)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -45,21 +45,12 @@ type menuItem struct {
|
|||||||
isOthers bool
|
isOthers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var mainMenuItems = []menuItem{
|
const pinnedIntegrationCount = 3
|
||||||
{
|
|
||||||
title: "Chat with a model",
|
var runModelMenuItem = menuItem{
|
||||||
description: "Start an interactive chat with a model",
|
title: "Chat with a model",
|
||||||
isRunModel: true,
|
description: "Start an interactive chat with a model",
|
||||||
},
|
isRunModel: true,
|
||||||
{
|
|
||||||
integration: "openclaw",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
integration: "claude",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
integration: "opencode",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var othersMenuItem = menuItem{
|
var othersMenuItem = menuItem{
|
||||||
@@ -102,20 +93,14 @@ func shouldExpandOthers(state *launch.LauncherState) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
items := []menuItem{runModelMenuItem}
|
||||||
for _, item := range mainMenuItems {
|
items = append(items, pinnedIntegrationItems(state)...)
|
||||||
if item.integration == "" {
|
|
||||||
items = append(items, item)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
|
||||||
items = append(items, integrationMenuItem(integrationState))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if showOthers {
|
otherItems := otherIntegrationItems(state)
|
||||||
items = append(items, otherIntegrationItems(state)...)
|
switch {
|
||||||
} else {
|
case showOthers:
|
||||||
|
items = append(items, otherItems...)
|
||||||
|
case len(otherItems) > 0:
|
||||||
items = append(items, othersMenuItem)
|
items = append(items, othersMenuItem)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,17 +120,28 @@ func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
pinned := map[string]bool{
|
ordered := orderedIntegrationItems(state)
|
||||||
"openclaw": true,
|
if len(ordered) <= pinnedIntegrationCount {
|
||||||
"claude": true,
|
return nil
|
||||||
"opencode": true,
|
}
|
||||||
|
return ordered[pinnedIntegrationCount:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
ordered := orderedIntegrationItems(state)
|
||||||
|
if len(ordered) <= pinnedIntegrationCount {
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
return ordered[:pinnedIntegrationCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var items []menuItem
|
items := make([]menuItem, 0, len(state.Integrations))
|
||||||
for _, info := range launch.ListIntegrationInfos() {
|
for _, info := range launch.ListIntegrationInfos() {
|
||||||
if pinned[info.Name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
integrationState, ok := state.Integrations[info.Name]
|
integrationState, ok := state.Integrations[info.Name]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@@ -155,6 +151,10 @@ func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
|||||||
return items
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func primaryMenuItemCount(state *launch.LauncherState) int {
|
||||||
|
return 1 + len(pinnedIntegrationItems(state))
|
||||||
|
}
|
||||||
|
|
||||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||||
if state == nil || state.LastSelection == "" {
|
if state == nil || state.LastSelection == "" {
|
||||||
return 0
|
return 0
|
||||||
@@ -190,7 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
if m.cursor > 0 {
|
if m.cursor > 0 {
|
||||||
m.cursor--
|
m.cursor--
|
||||||
}
|
}
|
||||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
|
||||||
m.showOthers = false
|
m.showOthers = false
|
||||||
m.items = buildMenuItems(m.state, false)
|
m.items = buildMenuItems(m.state, false)
|
||||||
m.cursor = min(m.cursor, len(m.items)-1)
|
m.cursor = min(m.cursor, len(m.items)-1)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/cmd/launch"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +44,13 @@ func launcherTestState() *launch.LauncherState {
|
|||||||
Selectable: true,
|
Selectable: true,
|
||||||
Changeable: true,
|
Changeable: true,
|
||||||
},
|
},
|
||||||
|
"hermes": {
|
||||||
|
Name: "hermes",
|
||||||
|
DisplayName: "Hermes Agent",
|
||||||
|
Description: "Self-improving AI agent built by Nous Research",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
"droid": {
|
"droid": {
|
||||||
Name: "droid",
|
Name: "droid",
|
||||||
DisplayName: "Droid",
|
DisplayName: "Droid",
|
||||||
@@ -70,8 +78,28 @@ func findMenuCursorByIntegration(items []menuItem, name string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func integrationSequence(items []menuItem) []string {
|
||||||
|
sequence := make([]string, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
switch {
|
||||||
|
case item.isRunModel:
|
||||||
|
sequence = append(sequence, "run")
|
||||||
|
case item.isOthers:
|
||||||
|
sequence = append(sequence, "more")
|
||||||
|
case item.integration != "":
|
||||||
|
sequence = append(sequence, item.integration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sequence
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareStrings(got, want []string) string {
|
||||||
|
return cmp.Diff(want, got)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||||
view := newModel(launcherTestState()).View()
|
menu := newModel(launcherTestState())
|
||||||
|
view := menu.View()
|
||||||
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
|
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
|
||||||
if !strings.Contains(view, want) {
|
if !strings.Contains(view, want) {
|
||||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||||
@@ -80,23 +108,31 @@ func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
|||||||
if strings.Contains(view, "Launch Codex") {
|
if strings.Contains(view, "Launch Codex") {
|
||||||
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
|
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
|
||||||
}
|
}
|
||||||
|
wantOrder := []string{"run", "openclaw", "claude", "opencode", "more"}
|
||||||
|
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||||
|
t.Fatalf("unexpected pinned order: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||||
state := launcherTestState()
|
state := launcherTestState()
|
||||||
state.LastSelection = "pi"
|
state.LastSelection = "codex"
|
||||||
|
|
||||||
menu := newModel(state)
|
menu := newModel(state)
|
||||||
if !menu.showOthers {
|
if !menu.showOthers {
|
||||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||||
}
|
}
|
||||||
view := menu.View()
|
view := menu.View()
|
||||||
if !strings.Contains(view, "Launch Pi") {
|
if !strings.Contains(view, "Launch Codex") {
|
||||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||||
}
|
}
|
||||||
if strings.Contains(view, "More...") {
|
if strings.Contains(view, "More...") {
|
||||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||||
}
|
}
|
||||||
|
wantOrder := []string{"run", "openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
||||||
|
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||||
|
t.Fatalf("unexpected expanded order: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||||
|
|||||||
@@ -120,6 +120,7 @@
|
|||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/claude-code",
|
"/integrations/claude-code",
|
||||||
"/integrations/codex",
|
"/integrations/codex",
|
||||||
|
"/integrations/copilot-cli",
|
||||||
"/integrations/opencode",
|
"/integrations/opencode",
|
||||||
"/integrations/droid",
|
"/integrations/droid",
|
||||||
"/integrations/goose",
|
"/integrations/goose",
|
||||||
|
|||||||
93
docs/integrations/copilot-cli.mdx
Normal file
93
docs/integrations/copilot-cli.mdx
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
title: Copilot CLI
|
||||||
|
---
|
||||||
|
|
||||||
|
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
|
||||||
|
|
||||||
|
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [Copilot CLI](https://github.com/features/copilot/cli/):
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```shell macOS / Linux (Homebrew)
|
||||||
|
brew install copilot-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell npm (all platforms)
|
||||||
|
npm install -g @github/copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell macOS / Linux (script)
|
||||||
|
curl -fsSL https://gh.io/copilot-install | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell Windows (WinGet)
|
||||||
|
winget install GitHub.Copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
### Quick setup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run directly with a model
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud`
|
||||||
|
- `glm-5:cloud`
|
||||||
|
- `minimax-m2.7:cloud`
|
||||||
|
- `qwen3.5:cloud`
|
||||||
|
- `glm-4.7-flash`
|
||||||
|
- `qwen3.5`
|
||||||
|
|
||||||
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
|
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||||
|
|
||||||
|
1. Set the environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
|
||||||
|
export COPILOT_PROVIDER_API_KEY=
|
||||||
|
export COPILOT_PROVIDER_WIRE_API=responses
|
||||||
|
export COPILOT_MODEL=qwen3.5
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Run Copilot CLI:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
Or run with environment variables inline:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||||
@@ -6,6 +6,10 @@ Hermes Agent is a self-improving AI agent built by Nous Research. It features au
|
|||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch hermes
|
||||||
|
```
|
||||||
|
|
||||||
### Pull a model
|
### Pull a model
|
||||||
|
|
||||||
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
|
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
|||||||
|
|
||||||
- [Claude Code](/integrations/claude-code)
|
- [Claude Code](/integrations/claude-code)
|
||||||
- [Codex](/integrations/codex)
|
- [Codex](/integrations/codex)
|
||||||
|
- [Copilot CLI](/integrations/copilot-cli)
|
||||||
- [OpenCode](/integrations/opencode)
|
- [OpenCode](/integrations/opencode)
|
||||||
- [Droid](/integrations/droid)
|
- [Droid](/integrations/droid)
|
||||||
- [Goose](/integrations/goose)
|
- [Goose](/integrations/goose)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -106,5 +106,5 @@ require (
|
|||||||
golang.org/x/term v0.36.0
|
golang.org/x/term v0.36.0
|
||||||
golang.org/x/text v0.30.0
|
golang.org/x/text v0.30.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,6 +93,13 @@ func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quan
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MoE router logits choose the top-k expert set. Quantization noise here
|
||||||
|
// can flip expert selection, after which downstream activations diverge
|
||||||
|
// sharply. The tensor is small, so leave it in source precision.
|
||||||
|
if isGemma4RouterProjection(name) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Mixed-precision quantization: sensitive tensors get higher precision.
|
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||||
//
|
//
|
||||||
// Value projections (v_proj) directly determine attention output quality.
|
// Value projections (v_proj) directly determine attention output quality.
|
||||||
@@ -170,6 +177,12 @@ func isEmbedTokensWeight(name string) bool {
|
|||||||
!strings.Contains(name, "per_layer")
|
!strings.Contains(name, "per_layer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isGemma4RouterProjection(name string) bool {
|
||||||
|
return strings.HasSuffix(name, ".router.proj.weight") &&
|
||||||
|
!strings.Contains(name, "audio_tower") &&
|
||||||
|
!strings.Contains(name, "vision_tower")
|
||||||
|
}
|
||||||
|
|
||||||
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||||
if td == nil {
|
if td == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -68,6 +68,11 @@ func TestGemma4QuantizationType(t *testing.T) {
|
|||||||
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
|
||||||
|
// === Router projection: expert selection is sensitive; keep source precision ===
|
||||||
|
{"router proj int4", transform26B, "model.layers.0.router.proj.weight", aligned, "int4", ""},
|
||||||
|
{"router proj nvfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "nvfp4", ""},
|
||||||
|
{"router proj mxfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "mxfp4", ""},
|
||||||
|
|
||||||
// === k_proj: promoted only for 8-expert models ===
|
// === k_proj: promoted only for 8-expert models ===
|
||||||
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||||
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||||
|
|||||||
19
x/mlxrunner/cache/cache.go
vendored
19
x/mlxrunner/cache/cache.go
vendored
@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
|||||||
mlx.Pin(c.keys, c.values)
|
mlx.Pin(c.keys, c.values)
|
||||||
} else {
|
} else {
|
||||||
if c.idx < c.keys.Dim(2) {
|
if c.idx < c.keys.Dim(2) {
|
||||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
if c.offset <= c.maxSize {
|
||||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||||
|
// or stale post-rewind data, not live window content.
|
||||||
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
} else {
|
||||||
|
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||||
|
// Linearize so the trim + concat below operate on contiguous
|
||||||
|
// positions and preserve the last (maxSize - 1) old tokens.
|
||||||
|
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||||
|
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||||
|
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
c.keys.Set(tailK.Concatenate(2, headK))
|
||||||
|
c.values.Set(tailV.Concatenate(2, headV))
|
||||||
|
c.idx = c.keys.Dim(2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trim to max_size to maintain sliding window
|
// Trim to max_size to maintain sliding window
|
||||||
|
|||||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||||
|
// tensors whose channel value is the token id, so stateIDs can recover
|
||||||
|
// which ids survived in the cache.
|
||||||
|
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
data := make([]float32, 0, 2*len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
data = append(data, id, id)
|
||||||
|
}
|
||||||
|
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||||
|
// after a concat, physical/rotated after a single-token update).
|
||||||
|
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
state := c.State()
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mlx.Eval(state[0])
|
||||||
|
flat := state[0].Floats()
|
||||||
|
n := state[0].Dim(2)
|
||||||
|
out := make([]float32, n)
|
||||||
|
for i := range n {
|
||||||
|
out[i] = flat[i*2]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalSlice(a, b []float32) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||||
|
ids := make([]float32, n)
|
||||||
|
for i := range ids {
|
||||||
|
ids[i] = startID + float32(i)
|
||||||
|
}
|
||||||
|
k, v := multiTokenKV(ids)
|
||||||
|
c.Update(k, v)
|
||||||
|
return startID + float32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedSingle(c *RotatingKVCache, id float32) {
|
||||||
|
k, v := singleTokenKV(id)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||||
|
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||||
|
// pre-existing tokens in logical order so the first Q of the new batch
|
||||||
|
// has a full sliding window.
|
||||||
|
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
for range 6 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 9 {
|
||||||
|
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||||
|
}
|
||||||
|
if c.idx >= c.maxSize {
|
||||||
|
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 10, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{7, 8, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 11 {
|
||||||
|
t.Fatalf("offset=%d want 11", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||||
|
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||||
|
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||||
|
// x/mlxrunner/pipeline.go relies on.
|
||||||
|
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||||
|
feedMulti(c, 1, 6)
|
||||||
|
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||||
|
// so the first new Q has its full window in scope for this forward.
|
||||||
|
feedMulti(c, 7, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{4, 5, 6, 7, 8, 9}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next decode trims oversize back to maxSize; order may be
|
||||||
|
// physical (rotated), so check as a set.
|
||||||
|
feedSingle(c, 10)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{7, 8, 9, 10} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||||
|
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||||
|
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||||
|
// has not wrapped. Those trailing slots are zero padding and must not
|
||||||
|
// be pulled back into the live window on the next concat.
|
||||||
|
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 512
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 3)
|
||||||
|
feedSingle(c, 4)
|
||||||
|
feedMulti(c, 5, 3)
|
||||||
|
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||||
|
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||||
|
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||||
|
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||||
|
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||||
|
// window content.
|
||||||
|
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 8
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Grow the buffer to exactly maxSize without wrapping.
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
for id := float32(3); id <= 8; id++ {
|
||||||
|
feedSingle(c, id)
|
||||||
|
}
|
||||||
|
if c.Offset() != window {
|
||||||
|
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Restore(nil, 2) {
|
||||||
|
t.Fatalf("live rewind to 2 failed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 2 {
|
||||||
|
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 5 {
|
||||||
|
t.Fatalf("offset=%d want 5", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||||
|
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||||
|
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
feedMulti(c, 3, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||||
|
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||||
|
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||||
|
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||||
|
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 8)
|
||||||
|
if c.Offset() != 8 {
|
||||||
|
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 8)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 17, 4)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode trims oversize back to maxSize; order may be physical.
|
||||||
|
feedSingle(c, 21)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{18, 19, 20, 21} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||||
|
// prefill sequence and checks that each new prefill retains the last
|
||||||
|
// (maxSize-1) pre-existing tokens in logical order.
|
||||||
|
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 2)
|
||||||
|
for range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 7 {
|
||||||
|
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 3)
|
||||||
|
nextID += 3
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{5, 6, 7, 8, 9, 10}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range 4 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 14 {
|
||||||
|
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 2)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||||
|
// token count through any mix of Update() calls — Gemma 4 uses
|
||||||
|
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||||
|
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
if c.Offset() != 3 {
|
||||||
|
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||||
|
}
|
||||||
|
for i := range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
if c.Offset() != 3+i+1 {
|
||||||
|
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextID = feedMulti(c, nextID, 2)
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||||
|
}
|
||||||
|
// L > maxSize concat.
|
||||||
|
feedMulti(c, nextID, 7)
|
||||||
|
if c.Offset() != 17 {
|
||||||
|
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,62 +1,64 @@
|
|||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
// #include "generated.h"
|
|
||||||
import "C"
|
|
||||||
import "math"
|
import "math"
|
||||||
|
|
||||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||||
|
|
||||||
// GELUApprox matches mlx.nn.gelu_approx:
|
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
//
|
// as a fused kernel.
|
||||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
var GELUApprox = Compile1(
|
||||||
func GELUApprox(x *Array) *Array {
|
"GELUApprox",
|
||||||
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
func(x *Array) *Array {
|
||||||
half := scalarWithDtype(0.5, x)
|
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||||
defer C.mlx_array_free(half)
|
dt := x.DType()
|
||||||
coeff := scalarWithDtype(geluCoeff, x)
|
half := FromValue[float32](0.5).AsType(dt)
|
||||||
defer C.mlx_array_free(coeff)
|
coeff := FromValue(geluCoeff).AsType(dt)
|
||||||
c := scalarWithDtype(0.044715, x)
|
c := FromValue[float32](0.044715).AsType(dt)
|
||||||
defer C.mlx_array_free(c)
|
one := FromValue[float32](1.0).AsType(dt)
|
||||||
|
|
||||||
// x^3 via x*x*x (avoids general Power which is slower)
|
// x^3 via x*x*x (avoids general Power which is slower).
|
||||||
x3 := New("GELU_X3")
|
x3 := x.Multiply(x).Multiply(x)
|
||||||
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
inner := x.Add(c.Multiply(x3))
|
||||||
tmp := New("GELU_X3b")
|
tanh := coeff.Multiply(inner).Tanh()
|
||||||
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||||
x3 = tmp
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// 0.044715 * x^3
|
// SiLU returns a * sigmoid(a) as a fused kernel.
|
||||||
cx3 := New("GELU_CX3")
|
var SiLU = Compile1(
|
||||||
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
"SiLU",
|
||||||
|
func(a *Array) *Array {
|
||||||
|
return a.Multiply(a.Sigmoid())
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// x + 0.044715 * x^3
|
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||||
inner := New("GELU_INNER")
|
var SwiGLU = Compile2(
|
||||||
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
"SwiGLU",
|
||||||
|
func(gate, up *Array) *Array {
|
||||||
|
return SiLU(gate).Multiply(up)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||||
scaled := New("GELU_SCALED")
|
// geglu, used by Gemma-family MLP and MoE paths.
|
||||||
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
var GeGLU = Compile2(
|
||||||
|
"GeGLU",
|
||||||
|
func(gate, up *Array) *Array {
|
||||||
|
return GELUApprox(gate).Multiply(up)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// tanh(...)
|
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||||
th := New("GELU_TANH")
|
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||||
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
var LogitSoftcap = Compile2(
|
||||||
|
"LogitSoftcap",
|
||||||
// 1 + tanh(...)
|
func(x, cap *Array) *Array {
|
||||||
one := scalarWithDtype(1.0, x)
|
return x.Divide(cap).Tanh().Multiply(cap)
|
||||||
defer C.mlx_array_free(one)
|
},
|
||||||
onePlusTanh := New("GELU_1PT")
|
Shapeless(),
|
||||||
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
)
|
||||||
|
|
||||||
// 0.5 * x
|
|
||||||
halfX := New("GELU_HALFX")
|
|
||||||
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
|
||||||
|
|
||||||
// 0.5 * x * (1 + tanh(...))
|
|
||||||
out := New("GELU_APPROX")
|
|
||||||
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func SILU(t *Array) *Array {
|
|
||||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -27,7 +27,11 @@ var arrays []*Array
|
|||||||
|
|
||||||
func New(name string) *Array {
|
func New(name string) *Array {
|
||||||
t := &Array{name: name}
|
t := &Array{name: name}
|
||||||
arrays = append(arrays, t)
|
if tracing {
|
||||||
|
traceScratch = append(traceScratch, t)
|
||||||
|
} else {
|
||||||
|
arrays = append(arrays, t)
|
||||||
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "generated.h"
|
||||||
|
//
|
||||||
|
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||||
|
// extern void closureDestructor(void* payload);
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"runtime/cgo"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompileFunc is the signature of a function that can be compiled.
|
||||||
|
type CompileFunc func(inputs ...*Array) []*Array
|
||||||
|
|
||||||
|
// CompileOption configures Compile behavior.
|
||||||
|
type CompileOption func(*compileConfig)
|
||||||
|
|
||||||
|
type compileConfig struct {
|
||||||
|
shapeless bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||||
|
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||||
|
// on each new (shape, dtype) combination and caches each specialization.
|
||||||
|
func Shapeless() CompileOption {
|
||||||
|
return func(c *compileConfig) { c.shapeless = true }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile returns a compiled version of fn. When called during another
|
||||||
|
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||||
|
// inner ones.
|
||||||
|
//
|
||||||
|
// Compiled functions must not have side effects outside of the function. Do
|
||||||
|
// not access data other than the arguments passed in (either Go data or MLX
|
||||||
|
// arrays) unless it is a constant.
|
||||||
|
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||||
|
var cfg compileConfig
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var closure C.mlx_closure
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
return func(inputs ...*Array) []*Array {
|
||||||
|
if tracing {
|
||||||
|
return fn(inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
once.Do(func() {
|
||||||
|
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||||
|
*payload = cgo.NewHandle(fn)
|
||||||
|
src := C.mlx_closure_new_func_payload(
|
||||||
|
(*[0]byte)(C.closureCallback),
|
||||||
|
unsafe.Pointer(payload),
|
||||||
|
(*[0]byte)(C.closureDestructor),
|
||||||
|
)
|
||||||
|
defer C.mlx_closure_free(src)
|
||||||
|
|
||||||
|
closure = C.mlx_closure_new()
|
||||||
|
mlxCheck(name+": compile failed", func() C.int {
|
||||||
|
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
inVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
for _, in := range inputs {
|
||||||
|
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
mlxCheck(name+": closure apply failed", func() C.int {
|
||||||
|
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||||
|
})
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(outVec))
|
||||||
|
outputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
outputs[i] = New(name)
|
||||||
|
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile1 compiles a unary function. See Compile.
|
||||||
|
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a *Array) *Array {
|
||||||
|
return cf(a)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile2 compiles a binary function. See Compile.
|
||||||
|
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b *Array) *Array {
|
||||||
|
return cf(a, b)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile3 compiles a ternary function. See Compile.
|
||||||
|
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1], in[2])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b, c *Array) *Array {
|
||||||
|
return cf(a, b, c)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tracing is true while a compile callback is running. Since MLX is
|
||||||
|
// single-threaded at this level a plain Go bool suffices.
|
||||||
|
var tracing bool
|
||||||
|
|
||||||
|
// traceScratch collects arrays created during a compile trace so they can be
|
||||||
|
// freed as a group when the callback returns.
|
||||||
|
var traceScratch []*Array
|
||||||
|
|
||||||
|
//export closureCallback
|
||||||
|
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("mlx closure callback panicked", "panic", r)
|
||||||
|
rc = 1
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
fn := handle.Value().(CompileFunc)
|
||||||
|
|
||||||
|
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||||
|
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||||
|
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||||
|
if tracing {
|
||||||
|
panic("mlx: nested compile trace")
|
||||||
|
}
|
||||||
|
tracing = true
|
||||||
|
traceScratch = nil
|
||||||
|
defer func() {
|
||||||
|
for _, a := range traceScratch {
|
||||||
|
if a.pinned > 0 {
|
||||||
|
panic("mlx: traced array was pinned during compilation")
|
||||||
|
}
|
||||||
|
if a.Valid() {
|
||||||
|
C.mlx_array_free(a.ctx)
|
||||||
|
a.ctx.ctx = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing = false
|
||||||
|
traceScratch = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(input))
|
||||||
|
inputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
a := New("")
|
||||||
|
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||||
|
inputs[i] = a
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs := fn(inputs...)
|
||||||
|
|
||||||
|
var arrPtr *C.mlx_array
|
||||||
|
if len(outputs) > 0 {
|
||||||
|
handles := make([]C.mlx_array, len(outputs))
|
||||||
|
for i, out := range outputs {
|
||||||
|
handles[i] = out.ctx
|
||||||
|
}
|
||||||
|
arrPtr = &handles[0]
|
||||||
|
}
|
||||||
|
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
//export closureDestructor
|
||||||
|
func closureDestructor(payload unsafe.Pointer) {
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
handle.Delete()
|
||||||
|
C.free(payload)
|
||||||
|
}
|
||||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompileFusion(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Compile fuses the ops inside a function body into a single kernel,
|
||||||
|
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||||
|
// two branches must be materialized simultaneously without fusion,
|
||||||
|
// then compare peak memory against the compiled version which fuses
|
||||||
|
// everything into one kernel with no intermediates.
|
||||||
|
const n = 1024 * 1024 // 4MB per float32 array
|
||||||
|
data := make([]float32, n)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = float32(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||||
|
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||||
|
// With fusion: single kernel, no intermediates.
|
||||||
|
body := func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Multiply(a.Add(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
a := FromValues(data, n)
|
||||||
|
b := FromValues(data, n)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
// Compiled: ops fused into a single kernel.
|
||||||
|
EnableCompile()
|
||||||
|
fn := Compile2("diamond", body, Shapeless())
|
||||||
|
warm := fn(a, b)
|
||||||
|
Eval(warm)
|
||||||
|
Sweep()
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
y := fn(a, b)
|
||||||
|
Eval(y)
|
||||||
|
compiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
z := body(a, b)
|
||||||
|
Eval(z)
|
||||||
|
uncompiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||||
|
t.Skip("peak memory tracking not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
|
||||||
|
if compiledPeak >= uncompiledPeak {
|
||||||
|
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNested(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// A compiled function that calls another compiled function should
|
||||||
|
// produce correct results. The inner function inlines via isTracing()
|
||||||
|
// during the outer's trace.
|
||||||
|
inner := Compile1("silu", func(a *Array) *Array {
|
||||||
|
return a.Multiply(a.Sigmoid())
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||||
|
return inner(gate).Multiply(up)
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||||
|
up := FromValues([]float32{1, 1, 1}, 3)
|
||||||
|
Pin(gate, up)
|
||||||
|
defer Unpin(gate, up)
|
||||||
|
|
||||||
|
y := outer(gate, up)
|
||||||
|
Eval(y)
|
||||||
|
|
||||||
|
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||||
|
got := y.Floats()
|
||||||
|
want := []float32{0, 0.7310586, 1.7615942}
|
||||||
|
for i, v := range got {
|
||||||
|
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||||
|
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
boom := Compile1("boom", func(a *Array) *Array {
|
||||||
|
panic("intentional test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
x := FromValues([]float32{1}, 1)
|
||||||
|
Pin(x)
|
||||||
|
defer Unpin(x)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected panic from Call, got none")
|
||||||
|
}
|
||||||
|
if _, ok := r.(string); !ok {
|
||||||
|
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
boom(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Repeated invocations of a compiled kernel should not grow the
|
||||||
|
// tracked-arrays list — the callback's traceScratch collects
|
||||||
|
// intermediates during tracing and frees them when the callback returns.
|
||||||
|
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Add(b)
|
||||||
|
})
|
||||||
|
|
||||||
|
a := FromValues([]float32{1, 2}, 2)
|
||||||
|
b := FromValues([]float32{3, 4}, 2)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
Sweep()
|
||||||
|
before := len(arrays)
|
||||||
|
|
||||||
|
for range 100 {
|
||||||
|
_ = fn(a, b)
|
||||||
|
Sweep()
|
||||||
|
}
|
||||||
|
|
||||||
|
after := len(arrays)
|
||||||
|
if after > before+2 {
|
||||||
|
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,8 +9,8 @@ package mlx
|
|||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
// #include <string.h>
|
// #include <string.h>
|
||||||
//
|
//
|
||||||
// static char _mlx_last_error_msg[1024] = {0};
|
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||||
// static int _mlx_last_error_flag = 0;
|
// static __thread int _mlx_last_error_flag = 0;
|
||||||
//
|
//
|
||||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
// (void)data;
|
// (void)data;
|
||||||
@@ -30,15 +30,13 @@ package mlx
|
|||||||
// _mlx_last_error_msg[0] = '\0';
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// static int mlx_had_last_error(void) {
|
|
||||||
// return _mlx_last_error_flag;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// static const char* mlx_get_last_error(void) {
|
// static const char* mlx_get_last_error(void) {
|
||||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||||
// }
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
import "runtime"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Replace the default exit(-1) error handler with one that captures
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
// the error message so we can surface it in Go.
|
// the error message so we can surface it in Go.
|
||||||
@@ -53,6 +51,24 @@ func Version() string {
|
|||||||
return C.GoString(C.mlx_string_data(str))
|
return C.GoString(C.mlx_string_data(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||||
|
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||||
|
// The thread lock ensures the thread-local error state is read from the same
|
||||||
|
// thread that executed the call.
|
||||||
|
func mlxCheck(fallback string, fn func() C.int) {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
if fn() != 0 {
|
||||||
|
msg := C.GoString(C.mlx_get_last_error())
|
||||||
|
if msg == "" {
|
||||||
|
msg = fallback
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func doEval(outputs []*Array, async bool) {
|
func doEval(outputs []*Array, async bool) {
|
||||||
if len(outputs) == 0 {
|
if len(outputs) == 0 {
|
||||||
return
|
return
|
||||||
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.mlx_clear_last_error()
|
mlxCheck("eval failed", func() C.int {
|
||||||
var rc C.int
|
if async {
|
||||||
if async {
|
return C.mlx_async_eval(vector)
|
||||||
rc = C.mlx_async_eval(vector)
|
|
||||||
} else {
|
|
||||||
rc = C.mlx_eval(vector)
|
|
||||||
}
|
|
||||||
if rc != 0 {
|
|
||||||
msg := "mlx eval failed"
|
|
||||||
if C.mlx_had_last_error() != 0 {
|
|
||||||
msg = C.GoString(C.mlx_get_last_error())
|
|
||||||
}
|
}
|
||||||
panic("mlx: " + msg)
|
return C.mlx_eval(vector)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func AsyncEval(outputs ...*Array) {
|
func AsyncEval(outputs ...*Array) {
|
||||||
|
|||||||
@@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
|
|||||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SiLU(a *Array) *Array {
|
|
||||||
sig := a.Sigmoid()
|
|
||||||
return a.Multiply(sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
enableCompile := true
|
|
||||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
||||||
enableCompile = modelCompile.EnableCompile()
|
|
||||||
}
|
|
||||||
if enableCompile {
|
|
||||||
mlx.EnableCompile()
|
|
||||||
} else {
|
|
||||||
mlx.DisableCompile()
|
|
||||||
}
|
|
||||||
mlx.ResetPeakMemory()
|
mlx.ResetPeakMemory()
|
||||||
ctx := request.Ctx
|
ctx := request.Ctx
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
r.contextLength = m.MaxContextLength()
|
r.contextLength = m.MaxContextLength()
|
||||||
|
|
||||||
|
mlx.EnableCompile()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ type TextConfig struct {
|
|||||||
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
||||||
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
|
|
||||||
|
|
||||||
// KV sharing: maps shared layer index -> donor layer index.
|
// KV sharing: maps shared layer index -> donor layer index.
|
||||||
KVShareMap map[int32]int32 `json:"-"`
|
KVShareMap map[int32]int32 `json:"-"`
|
||||||
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
|
|||||||
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
||||||
}
|
}
|
||||||
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
||||||
if cfg.FinalLogitSoftcapping > 0 {
|
|
||||||
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute KV sharing map.
|
// Compute KV sharing map.
|
||||||
cfg.KVShareMap = make(map[int32]int32)
|
cfg.KVShareMap = make(map[int32]int32)
|
||||||
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
logits := m.LMHead.Forward(x)
|
logits := m.LMHead.Forward(x)
|
||||||
|
|
||||||
if m.FinalLogitSoftcapping > 0 {
|
if m.FinalLogitSoftcapping > 0 {
|
||||||
logits = mlx.MulScalar(logits, m.SoftcapInv)
|
cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
|
||||||
logits = logits.Tanh()
|
logits = mlx.LogitSoftcap(logits, cap)
|
||||||
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||||||
// PLE injection (after MLP residual).
|
// PLE injection (after MLP residual).
|
||||||
if l.PLE != nil && pleInput != nil {
|
if l.PLE != nil && pleInput != nil {
|
||||||
residual := h
|
residual := h
|
||||||
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
|
gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
|
||||||
gated := mlx.Mul(gate, pleInput)
|
|
||||||
projected := l.PLE.Projection.Forward(gated)
|
projected := l.PLE.Projection.Forward(gated)
|
||||||
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
||||||
h = mlx.Add(residual, projected)
|
h = mlx.Add(residual, projected)
|
||||||
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
gate := m.GateProj.Forward(x)
|
||||||
up := m.UpProj.Forward(x)
|
up := m.UpProj.Forward(x)
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
return m.DownProj.Forward(mlx.GeGLU(gate, up))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward runs the router to select top-k experts per token.
|
// Forward runs the router to select top-k experts per token.
|
||||||
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
||||||
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
||||||
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
||||||
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
downMode := m.DownQuantMode
|
downMode := m.DownQuantMode
|
||||||
if downMode == "" {
|
if downMode == "" {
|
||||||
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
||||||
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,9 +148,7 @@ type DenseMLP struct {
|
|||||||
|
|
||||||
// Forward applies the SwiGLU MLP
|
// Forward applies the SwiGLU MLP
|
||||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
up := m.UpProj.Forward(x)
|
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoEGate implements the expert gating mechanism
|
// MoEGate implements the expert gating mechanism
|
||||||
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
|
|
||||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||||
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
|
|
||||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
@@ -273,9 +271,7 @@ type SharedExperts struct {
|
|||||||
|
|
||||||
// Forward applies the shared expert MLP
|
// Forward applies the shared expert MLP
|
||||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
|
||||||
up := s.UpProj.Forward(x)
|
|
||||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoE implements the full Mixture of Experts layer
|
// MoE implements the full Mixture of Experts layer
|
||||||
|
|||||||
@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||||
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||||
} else {
|
} else {
|
||||||
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
||||||
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user