mirror of
https://github.com/ollama/ollama.git
synced 2026-04-27 03:05:43 +02:00
Compare commits
6 Commits
launch-cop
...
v0.21.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9cb535407 | ||
|
|
031baef094 | ||
|
|
7d271e6dc9 | ||
|
|
c88dae2d6b | ||
|
|
9e3618d663 | ||
|
|
e585ecd11f |
@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
ollama
|
||||
```
|
||||
|
||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
|
||||
|
||||
### Coding
|
||||
|
||||
@@ -65,7 +65,7 @@ To launch a specific integration:
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||
|
||||
### AI assistant
|
||||
|
||||
|
||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||
type Copilot struct{}
|
||||
|
||||
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||
|
||||
func (c *Copilot) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Copilot) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("copilot"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".local", "bin", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Copilot) Run(model string, args []string) error {
|
||||
copilotPath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||
}
|
||||
|
||||
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// envVars returns the environment variables that configure Copilot CLI
|
||||
// to use Ollama as its model provider.
|
||||
func (c *Copilot) envVars(model string) []string {
|
||||
env := []string{
|
||||
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||
"COPILOT_PROVIDER_API_KEY=",
|
||||
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
env = append(env, "COPILOT_MODEL="+model)
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopilotIntegration(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Copilot CLI" {
|
||||
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotFindPath(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotArgs(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotEnvVars(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||
got := envMap(c.envVars("llama3.2"))
|
||||
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||
}
|
||||
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||
}
|
||||
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||
}
|
||||
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||
got := envMap(c.envVars(""))
|
||||
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||
got := envMap(c.envVars("test"))
|
||||
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -206,6 +206,7 @@ Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
copilot Copilot CLI (aliases: copilot-cli)
|
||||
droid Droid
|
||||
hermes Hermes Agent
|
||||
opencode OpenCode
|
||||
|
||||
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
@@ -74,6 +74,19 @@ var integrationSpecs = []*IntegrationSpec{
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "copilot",
|
||||
Runner: &Copilot{},
|
||||
Aliases: []string{"copilot-cli"},
|
||||
Description: "GitHub's AI coding agent for the terminal",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Copilot{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://github.com/features/copilot/cli/",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
|
||||
@@ -120,6 +120,7 @@
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/codex",
|
||||
"/integrations/copilot-cli",
|
||||
"/integrations/opencode",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
|
||||
93
docs/integrations/copilot-cli.mdx
Normal file
93
docs/integrations/copilot-cli.mdx
Normal file
@@ -0,0 +1,93 @@
|
||||
---
|
||||
title: Copilot CLI
|
||||
---
|
||||
|
||||
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
|
||||
|
||||
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
|
||||
|
||||
## Install
|
||||
|
||||
Install [Copilot CLI](https://github.com/features/copilot/cli/):
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```shell macOS / Linux (Homebrew)
|
||||
brew install copilot-cli
|
||||
```
|
||||
|
||||
```shell npm (all platforms)
|
||||
npm install -g @github/copilot
|
||||
```
|
||||
|
||||
```shell macOS / Linux (script)
|
||||
curl -fsSL https://gh.io/copilot-install | bash
|
||||
```
|
||||
|
||||
```powershell Windows (WinGet)
|
||||
winget install GitHub.Copilot
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch copilot
|
||||
```
|
||||
|
||||
### Run directly with a model
|
||||
|
||||
```shell
|
||||
ollama launch copilot --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `kimi-k2.5:cloud`
|
||||
- `glm-5:cloud`
|
||||
- `minimax-m2.7:cloud`
|
||||
- `qwen3.5:cloud`
|
||||
- `glm-4.7-flash`
|
||||
- `qwen3.5`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```shell
|
||||
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
|
||||
|
||||
## Manual setup
|
||||
|
||||
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
|
||||
export COPILOT_PROVIDER_API_KEY=
|
||||
export COPILOT_PROVIDER_WIRE_API=responses
|
||||
export COPILOT_MODEL=qwen3.5
|
||||
```
|
||||
|
||||
1. Run Copilot CLI:
|
||||
|
||||
```shell
|
||||
copilot
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
|
||||
```
|
||||
|
||||
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
||||
|
||||
- [Claude Code](/integrations/claude-code)
|
||||
- [Codex](/integrations/codex)
|
||||
- [Copilot CLI](/integrations/copilot-cli)
|
||||
- [OpenCode](/integrations/opencode)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||
// <|tool_call>/<|tool_response> tags for function calling.
|
||||
type Gemma4Renderer struct {
|
||||
useImgTags bool
|
||||
useImgTags bool
|
||||
emptyBlockOnNothink bool
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
// Generation prompt.
|
||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||
sb.WriteString("<|turn>model\n")
|
||||
if r.emptyBlockOnNothink && !hasThink {
|
||||
sb.WriteString("<|channel>thought\n<channel|>")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
||||
@@ -3,9 +3,9 @@ package renderers
|
||||
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||
// Gemma 4 reference template.
|
||||
//
|
||||
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
|
||||
// reference intentionally uses the shared baseline without an empty generation-time
|
||||
// thought channel until renderer selection is split by size.
|
||||
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
|
||||
// reference cases below use the small (e2b/e4b-style) baseline, with large
|
||||
// (26b/31b-style) checks covered separately in this file.
|
||||
//
|
||||
// To regenerate expected values, save the E2B template to
|
||||
// gemma4_e2b_chat_template.jinja2 and run:
|
||||
@@ -1474,6 +1474,47 @@ Hi<turn|>
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
|
||||
messages := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rendererName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "legacy_alias",
|
||||
rendererName: "gemma4",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "small",
|
||||
rendererName: "gemma4-small",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
rendererName: "gemma4-large",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
|
||||
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
|
||||
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
|
||||
}
|
||||
|
||||
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||
if os.Getenv("VERIFY_JINJA2") == "" {
|
||||
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
||||
@@ -1616,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
assert.NoError(t, err)
|
||||
variants := []struct {
|
||||
name string
|
||||
renderer *Gemma4Renderer
|
||||
templateRel string
|
||||
}{
|
||||
{
|
||||
name: "small",
|
||||
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
|
||||
templateRel: gemma4E2BTemplate,
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
|
||||
templateRel: gemma431BTemplate,
|
||||
},
|
||||
}
|
||||
|
||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
||||
assert.Equal(t, jinja2Output, got,
|
||||
"renderer output doesn't match Jinja2 template output")
|
||||
for _, variant := range variants {
|
||||
t.Run(variant.name, func(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
assert.NoError(t, err)
|
||||
|
||||
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
|
||||
assert.Equal(t, jinja2Output, got,
|
||||
"renderer output doesn't match Jinja2 template output")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "gemma4":
|
||||
case "gemma4", "gemma4-small":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
case "gemma4-large":
|
||||
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
|
||||
@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
||||
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
|
||||
78
server/gemma4_test.go
Normal file
78
server/gemma4_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveGemma4Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil model falls back to legacy alias",
|
||||
model: nil,
|
||||
want: gemma4RendererLegacy,
|
||||
},
|
||||
{
|
||||
name: "explicit small passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererSmall),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "explicit large passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLarge),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy e4b tag resolves small",
|
||||
model: &Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy 31b tag resolves large",
|
||||
model: &Model{
|
||||
Name: "gemma4:31b-cloud",
|
||||
ShortName: "gemma4:31b-cloud",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves large",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy unknown defaults small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveGemma4Renderer(tt.model); got != tt.want {
|
||||
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
||||
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
|
||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||
return c == model.CapabilityVision || c == "audio"
|
||||
})
|
||||
|
||||
@@ -118,6 +118,39 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "gemma4 small safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererSmall,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemma4 large safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLarge,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "legacy gemma4 safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLegacy,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
|
||||
@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
if m.Config.Renderer != "" {
|
||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
||||
rendererName := resolveRendererName(m)
|
||||
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func testConfigWithRenderer(renderer string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer}
|
||||
}
|
||||
|
||||
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
|
||||
}
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
type expect struct {
|
||||
prompt string
|
||||
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "small from name",
|
||||
model: Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "large from model type",
|
||||
model: Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderPrompt(&tt.model, msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
110
server/renderer_resolution.go
Normal file
110
server/renderer_resolution.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
gemma4RendererLegacy = "gemma4"
|
||||
gemma4RendererSmall = "gemma4-small"
|
||||
gemma4RendererLarge = "gemma4-large"
|
||||
|
||||
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
|
||||
// large template. Default to the small prompt unless the model is clearly in
|
||||
// the large range.
|
||||
gemma4LargeMinParameterCount = 16_000_000_000
|
||||
)
|
||||
|
||||
func resolveRendererName(m *Model) string {
|
||||
if m == nil || m.Config.Renderer == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch m.Config.Renderer {
|
||||
case gemma4RendererLegacy:
|
||||
return resolveGemma4Renderer(m)
|
||||
default:
|
||||
return m.Config.Renderer
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGemma4Renderer(m *Model) string {
|
||||
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
|
||||
if m == nil {
|
||||
return gemma4RendererLegacy
|
||||
}
|
||||
return m.Config.Renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.Name); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
|
||||
return gemma4RendererForParameterCount(parameterCount)
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererForParameterCount(parameterCount uint64) string {
|
||||
if parameterCount >= gemma4LargeMinParameterCount {
|
||||
return gemma4RendererLarge
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererFromName(name string) (string, bool) {
|
||||
lower := strings.ToLower(name)
|
||||
switch {
|
||||
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
|
||||
return gemma4RendererSmall, true
|
||||
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
|
||||
return gemma4RendererLarge, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseHumanParameterCount(s string) (uint64, bool) {
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
unit := strings.ToUpper(s[len(s)-1:])
|
||||
var multiplier float64
|
||||
switch unit {
|
||||
case "B":
|
||||
multiplier = float64(format.Billion)
|
||||
case "M":
|
||||
multiplier = float64(format.Million)
|
||||
case "K":
|
||||
multiplier = float64(format.Thousand)
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
|
||||
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return uint64(value * multiplier), true
|
||||
}
|
||||
|
||||
func isGemma4Renderer(renderer string) bool {
|
||||
switch renderer {
|
||||
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "gemma4",
|
||||
"general.parameter_count": uint64(25_200_000_000),
|
||||
}, nil)
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for manifest")
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open config blob: %v", err)
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||
t.Fatalf("decode config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Renderer != gemma4RendererLegacy {
|
||||
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != "gemma4" {
|
||||
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||
t.Run("gguf file", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
|
||||
@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
|
||||
|
||||
s.cmd = cmd
|
||||
|
||||
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mlxLibraryPathEnv() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "PATH"
|
||||
case "darwin":
|
||||
return "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
return "LD_LIBRARY_PATH"
|
||||
}
|
||||
}
|
||||
|
||||
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
|
||||
if len(libraryPaths) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Search order for the imagegen runner is:
|
||||
// 1. bundled lib/ollama root
|
||||
// 2. backend-specific library dirs selected during GPU discovery
|
||||
// 3. any existing caller-provided library path values
|
||||
pathEnv := mlxLibraryPathEnv()
|
||||
pathEnvPaths := append([]string{}, libraryPaths...)
|
||||
if existingPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||
|
||||
ollamaLibraryPaths := append([]string{}, libraryPaths...)
|
||||
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||
}
|
||||
|
||||
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
|
||||
for i := range cmd.Env {
|
||||
name, _, ok := strings.Cut(cmd.Env[i], "=")
|
||||
if ok && strings.EqualFold(name, key) {
|
||||
cmd.Env[i] = key + "=" + value
|
||||
return
|
||||
}
|
||||
}
|
||||
cmd.Env = append(cmd.Env, key+"="+value)
|
||||
}
|
||||
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
|
||||
@@ -1061,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
}
|
||||
}
|
||||
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
var donorKV *sharedKVEntry
|
||||
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
|
||||
// If this layer is a donor, store its cached KV for later shared layers.
|
||||
if layer.IsDonor && c != nil {
|
||||
state := c.State()
|
||||
if len(state) >= 2 && state[0] != nil && state[1] != nil {
|
||||
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
|
||||
}
|
||||
if layer.IsDonor && donorKV != nil {
|
||||
sharedKV[layer.LayerIdx] = *donorKV
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1190,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
|
||||
return mlx.Squeeze(sliced, 2)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
@@ -1237,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
||||
h = mlx.Mul(h, l.LayerScalar)
|
||||
}
|
||||
|
||||
return h
|
||||
return h, donorKV
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
// Determine head dim and scale based on layer type.
|
||||
headDim := cfg.HeadDim
|
||||
scale := cfg.SlidingScale
|
||||
@@ -1274,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||
|
||||
var k, v *mlx.Array
|
||||
var donorKV *sharedKVEntry
|
||||
|
||||
if donorEntry != nil {
|
||||
// Shared layer: use donor's cached K/V.
|
||||
@@ -1312,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
// Update cache.
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1365,7 +1365,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
// strided views differently. Metal handles them natively.
|
||||
out = mlx.Contiguous(out, false)
|
||||
}
|
||||
return a.OProj.Forward(out)
|
||||
return a.OProj.Forward(out), donorKV
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
Reference in New Issue
Block a user