mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 00:05:40 +02:00
Compare commits
20 Commits
pdevine/sa
...
v0.17.7-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b |
16
api/types.go
16
api/types.go
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -570,7 +569,6 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -936,10 +934,6 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.PeakMemory > 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -963,14 +957,6 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatPeakMemory(b uint64) string {
|
|
||||||
if b >= format.GibiByte {
|
|
||||||
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
|
||||||
}
|
|
||||||
|
|
||||||
return format.HumanBytes2(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
@@ -1077,7 +1063,7 @@ func DefaultOptions() Options {
|
|||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
TypicalP: 1.0,
|
TypicalP: 1.0,
|
||||||
RepeatLastN: 64,
|
RepeatLastN: 64,
|
||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|||||||
23
cmd/cmd.go
23
cmd/cmd.go
@@ -145,6 +145,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
if ip == nil || (!ip.IsLoopback() && !ip.IsUnspecified()) {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -214,6 +220,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if create.IsTensorModelDir(".") {
|
if create.IsTensorModelDir(".") {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
if ip == nil || (!ip.IsLoopback() && !ip.IsUnspecified()) {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
@@ -585,17 +597,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
useImagegen := false
|
|
||||||
if cmd.Flags().Lookup("imagegen") != nil {
|
|
||||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if useImagegen {
|
|
||||||
opts.Options["use_imagegen_runner"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill out the rest of the options based on information about the
|
// Fill out the rest of the options based on information about the
|
||||||
// model.
|
// model.
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
|
|||||||
@@ -1277,7 +1277,8 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
|||||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||||
// :cloud suffix stripping must also work since that's how users specify them.
|
// Cloud suffix normalization must also work since integrations may see either
|
||||||
|
// :cloud or -cloud model names.
|
||||||
for name, expected := range cloudModelLimits {
|
for name, expected := range cloudModelLimits {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
l, ok := lookupCloudModelLimit(name)
|
l, ok := lookupCloudModelLimit(name)
|
||||||
@@ -1296,6 +1297,15 @@ func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
|||||||
if l2.Output != expected.Output {
|
if l2.Output != expected.Output {
|
||||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||||
}
|
}
|
||||||
|
// Also verify -cloud suffix lookup
|
||||||
|
dashCloudName := name + "-cloud"
|
||||||
|
l3, ok := lookupCloudModelLimit(dashCloudName)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", dashCloudName)
|
||||||
|
}
|
||||||
|
if l3.Output != expected.Output {
|
||||||
|
t.Errorf("-cloud output = %d, want %d", l3.Output, expected.Output)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
|||||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
@@ -90,6 +91,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
|||||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
}
|
}
|
||||||
|
|
||||||
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
||||||
|
|||||||
@@ -502,7 +502,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
ollama = make(map[string]any)
|
ollama = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
// needed to register provider
|
// needed to register provider
|
||||||
ollama["apiKey"] = "ollama-local"
|
ollama["apiKey"] = "ollama-local"
|
||||||
ollama["api"] = "ollama"
|
ollama["api"] = "ollama"
|
||||||
|
|||||||
@@ -589,7 +589,7 @@ const testOpenclawFixture = `{
|
|||||||
"providers": {
|
"providers": {
|
||||||
"anthropic": {"apiKey": "xxx"},
|
"anthropic": {"apiKey": "xxx"},
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
"baseUrl": "http://127.0.0.1:11434",
|
||||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,17 +26,15 @@ type cloudModelLimit struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
// It normalizes common cloud suffixes before checking the shared limit map.
|
||||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
|
// TODO(parthsareen): migrate to using cloud check instead.
|
||||||
|
for _, suffix := range []string{"-cloud", ":cloud"} {
|
||||||
|
name = strings.TrimSuffix(name, suffix)
|
||||||
|
}
|
||||||
if l, ok := cloudModelLimits[name]; ok {
|
if l, ok := cloudModelLimits[name]; ok {
|
||||||
return l, true
|
return l, true
|
||||||
}
|
}
|
||||||
base := strings.TrimSuffix(name, ":cloud")
|
|
||||||
if base != name {
|
|
||||||
if l, ok := cloudModelLimits[base]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cloudModelLimit{}, false
|
return cloudModelLimit{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,13 +120,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
ollama = map[string]any{
|
ollama = map[string]any{
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
"name": "Ollama (local)",
|
"name": "Ollama",
|
||||||
"options": map[string]any{
|
"options": map[string]any{
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrate legacy provider name
|
||||||
|
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||||
|
ollama["name"] = "Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
models, ok := ollama["models"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
models = make(map[string]any)
|
models = make(map[string]any)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package config
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "My Custom Ollama" {
|
||||||
|
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-5:cloud": {
|
||||||
|
"name": "glm-5:cloud",
|
||||||
|
"_launch": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not added on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(202_752) {
|
||||||
|
t.Errorf("context = %v, want 202752", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(131_072) {
|
||||||
|
t.Errorf("output = %v, want 131072", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLookupCloudModelLimit(t *testing.T) {
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -628,6 +716,9 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"glm-4.7", true, 202_752, 131_072},
|
{"glm-4.7", true, 202_752, 131_072},
|
||||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
{"kimi-k2.5", true, 262_144, 262_144},
|
{"kimi-k2.5", true, 262_144, 262_144},
|
||||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||||
|
|||||||
@@ -107,7 +107,8 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
|
|
||||||
// Build new models list:
|
// Build new models list:
|
||||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
// 3. Add new ollama-managed models
|
// 3. Add new ollama-managed models
|
||||||
var newModels []any
|
var newModels []any
|
||||||
for _, m := range existingModels {
|
for _, m := range existingModels {
|
||||||
@@ -117,7 +118,13 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if !isPiOllamaModel(modelObj) {
|
if !isPiOllamaModel(modelObj) {
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
} else if selectedSet[id] {
|
} else if selectedSet[id] {
|
||||||
// Ollama-managed and still selected - keep it
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
selectedSet[id] = false
|
selectedSet[id] = false
|
||||||
}
|
}
|
||||||
@@ -199,12 +206,28 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createConfig builds Pi model config with capability detection
|
// createConfig builds Pi model config with capability detection
|
||||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
cfg := map[string]any{
|
cfg := map[string]any{
|
||||||
"id": modelID,
|
"id": modelID,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -223,7 +246,8 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
|||||||
cfg["reasoning"] = true
|
cfg["reasoning"] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract context window from ModelInfo
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
for key, val := range resp.ModelInfo {
|
for key, val := range resp.ModelInfo {
|
||||||
if strings.HasSuffix(key, ".context_length") {
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
|||||||
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -798,6 +840,60 @@ func TestCreateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("skips zero context length", func(t *testing.T) {
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
|
|||||||
@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||||
|
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||||
|
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||||
|
|||||||
@@ -1518,7 +1518,6 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
|
|||||||
@@ -454,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||||
|
|
||||||
|
// Collect chunk outputs and concatenate at the end.
|
||||||
|
// Avoids SET on buffer-less intermediates under partial offload.
|
||||||
|
chunks := make([]ml.Tensor, nChunks)
|
||||||
|
|
||||||
for chunk := range nChunks {
|
for chunk := range nChunks {
|
||||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -475,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||||
|
|
||||||
v = v.SetInplace(
|
chunks[chunk] = coreAttnOutChunk
|
||||||
ctx,
|
|
||||||
coreAttnOutChunk,
|
|
||||||
v.Stride(1),
|
|
||||||
v.Stride(2),
|
|
||||||
v.Stride(3),
|
|
||||||
chunk*v.Stride(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Update state for next chunk
|
// Update state for next chunk
|
||||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -495,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||||
|
for len(chunks) > 1 {
|
||||||
|
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||||
|
for i := 0; i < len(chunks); i += 2 {
|
||||||
|
if i+1 < len(chunks) {
|
||||||
|
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||||
|
} else {
|
||||||
|
merged = append(merged, chunks[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks = merged
|
||||||
|
}
|
||||||
|
v = chunks[0]
|
||||||
|
|
||||||
// Final reshape
|
// Final reshape
|
||||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func ParserForName(name string) Parser {
|
|||||||
case "qwen3-thinking":
|
case "qwen3-thinking":
|
||||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
case "qwen3.5":
|
case "qwen3.5":
|
||||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
p = &Qwen35Parser{}
|
||||||
case "qwen3-coder":
|
case "qwen3-coder":
|
||||||
p = &Qwen3CoderParser{}
|
p = &Qwen3CoderParser{}
|
||||||
case "qwen3-vl-instruct":
|
case "qwen3-vl-instruct":
|
||||||
|
|||||||
238
model/parsers/qwen35.go
Normal file
238
model/parsers/qwen35.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qwen35ParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ParserStateCollectingThinking qwen35ParserState = iota
|
||||||
|
qwen35ParserStateThinkingDoneEatingWhitespace
|
||||||
|
qwen35ParserStateCollectingContent
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ThinkingOpenTag = "<think>"
|
||||||
|
qwen35ThinkingCloseTag = "</think>"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Qwen35Parser handles qwen3.5 reasoning extraction and delegates post-thinking
|
||||||
|
// content (including XML tool calls) to Qwen3CoderParser.
|
||||||
|
type Qwen35Parser struct {
|
||||||
|
toolParser Qwen3CoderParser
|
||||||
|
|
||||||
|
state qwen35ParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
// Some checkpoints may emit an explicit leading <think> even when the
|
||||||
|
// prompt already opened thinking. Strip at most one such tag.
|
||||||
|
allowLeadingThinkOpenTag bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) HasThinkingSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.toolParser = Qwen3CoderParser{}
|
||||||
|
p.toolParser.Init(tools, nil, nil)
|
||||||
|
|
||||||
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
|
if thinkValue == nil {
|
||||||
|
thinkingEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantPrefill := lastMessage != nil && lastMessage.Role == "assistant" && lastMessage.Content != ""
|
||||||
|
if thinkingEnabled && !assistantPrefill {
|
||||||
|
p.state = qwen35ParserStateCollectingThinking
|
||||||
|
p.allowLeadingThinkOpenTag = true
|
||||||
|
} else {
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen35Event interface {
|
||||||
|
isQwen35Event()
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen35EventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen35EventContent) isQwen35Event() {}
|
||||||
|
|
||||||
|
type qwen35EventThinkingContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen35EventThinkingContent) isQwen35Event() {}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var thinkingSb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case qwen35EventContent:
|
||||||
|
parsedContent, _, parsedCalls, err := p.toolParser.Add(event.content, done)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("qwen3.5 tool call parsing failed", "error", err)
|
||||||
|
return "", "", nil, err
|
||||||
|
}
|
||||||
|
contentSb.WriteString(parsedContent)
|
||||||
|
calls = append(calls, parsedCalls...)
|
||||||
|
case qwen35EventThinkingContent:
|
||||||
|
thinkingSb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) parseEvents() []qwen35Event {
|
||||||
|
var all []qwen35Event
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []qwen35Event
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(all) > 0 {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3.5 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||||
|
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen35ParserState) ([]qwen35Event, bool) {
|
||||||
|
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
p.state = nextState
|
||||||
|
p.buffer.WriteString(trimmed)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeConsumeLeadingThinkOpenTag handles a single optional leading <think> tag.
|
||||||
|
// Returns (handled, shouldContinueParsingNow).
|
||||||
|
func (p *Qwen35Parser) maybeConsumeLeadingThinkOpenTag(acc string) (bool, bool) {
|
||||||
|
if !p.allowLeadingThinkOpenTag {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||||
|
if strings.HasPrefix(trimmed, qwen35ThinkingOpenTag) {
|
||||||
|
after := strings.TrimPrefix(trimmed, qwen35ThinkingOpenTag)
|
||||||
|
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
if after == "" {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(qwen35ThinkingOpenTag, trimmed) {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) eat() ([]qwen35Event, bool) {
|
||||||
|
var events []qwen35Event
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case qwen35ParserStateCollectingThinking:
|
||||||
|
acc := p.buffer.String()
|
||||||
|
|
||||||
|
if handled, continueNow := p.maybeConsumeLeadingThinkOpenTag(acc); handled {
|
||||||
|
return events, continueNow
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acc, qwen35ThinkingCloseTag) {
|
||||||
|
thinking, remaining := p.splitAtTag(qwen35ThinkingCloseTag, true)
|
||||||
|
if len(thinking) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: thinking})
|
||||||
|
}
|
||||||
|
if remaining == "" {
|
||||||
|
p.state = qwen35ParserStateThinkingDoneEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(acc, qwen35ThinkingCloseTag); overlapLen > 0 {
|
||||||
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
|
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||||
|
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
whitespaceLen := trailingWhitespaceLen(acc)
|
||||||
|
ambiguousStart := len(acc) - whitespaceLen
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case qwen35ParserStateThinkingDoneEatingWhitespace:
|
||||||
|
return p.eatLeadingWhitespaceAndTransitionTo(qwen35ParserStateCollectingContent)
|
||||||
|
|
||||||
|
case qwen35ParserStateCollectingContent:
|
||||||
|
if p.buffer.Len() == 0 {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
content := p.buffer.String()
|
||||||
|
p.buffer.Reset()
|
||||||
|
if len(content) > 0 {
|
||||||
|
events = append(events, qwen35EventContent{content: content})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
default:
|
||||||
|
slog.Warn("qwen3.5 parser entered unknown state; resetting to content mode", "state", p.state)
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
}
|
||||||
382
model/parsers/qwen35_test.go
Normal file
382
model/parsers/qwen35_test.go
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen35ParserXMLToolCall(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||||
|
input := "<tool_call><function=get_weather><parameter=location>\nSan Francisco\n</parameter><parameter=days>\n3\n</parameter></function></tool_call>"
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
location, ok := calls[0].Function.Arguments.Get("location")
|
||||||
|
if !ok || location != "San Francisco" {
|
||||||
|
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||||
|
}
|
||||||
|
|
||||||
|
days, ok := calls[0].Function.Arguments.Get("days")
|
||||||
|
if !ok || days != 3 {
|
||||||
|
t.Fatalf("expected days %d, got %v", 3, days)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingWithExplicitOpeningTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Let me think..." {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||||
|
}
|
||||||
|
if content != "Answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserAssistantPrefillStartsInContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
last := &api.Message{Role: "assistant", Content: "Prefilled response start"}
|
||||||
|
parser.Init(nil, last, nil)
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add(" and continued", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking for assistant prefill continuation, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != " and continued" {
|
||||||
|
t.Fatalf("expected content %q, got %q", " and continued", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserToolCallEmittedInThinkingIsNotParsed(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
input := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
expectedThinking := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
if thinking != expectedThinking {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", expectedThinking, thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls before </think>, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserToolCallAfterThinkingCloseIsParsed(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
input := `Need weather lookup</think><tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "Need weather lookup" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Need weather lookup", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call after </think>, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
location, ok := calls[0].Function.Arguments.Get("location")
|
||||||
|
if !ok || location != "SF" {
|
||||||
|
t.Fatalf("expected location %q, got %v", "SF", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingDisabledPassesContentThrough(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
content, thinking, calls, err := parser.Add("Plain answer without think close tag.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "Plain answer without think close tag." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Plain answer without think close tag.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingDisabledWithCloseTagTreatsAsContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
content, thinking, calls, err := parser.Add("</think>Some content after spurious tag.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "</think>Some content after spurious tag." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "</think>Some content after spurious tag.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserLeadingThinkCloseProducesContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("</think>The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserStreamingSplitThinkCloseTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning text</thi", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "Reasoning text" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning text", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("nk>The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on second chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserStreamingEatsWhitespaceAfterThinkClose(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning</think>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "Reasoning" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("\n \t", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on whitespace chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking on whitespace chunk, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected whitespace after </think> to be eaten, got content %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on content chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no additional thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingTruncatedWithoutCloseTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning that never closes", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Reasoning that never closes" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning that never closes", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,7 +8,21 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GlmOcrRenderer struct{}
|
type GlmOcrRenderer struct {
|
||||||
|
useImgTags bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for range message.Images {
|
||||||
|
if r.useImgTags {
|
||||||
|
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||||
|
imageOffset++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
return sb.String(), imageOffset
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
@@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
thinkingExplicitlySet = true
|
thinkingExplicitlySet = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case "user":
|
case "user":
|
||||||
sb.WriteString("<|user|>\n")
|
sb.WriteString("<|user|>\n")
|
||||||
sb.WriteString(message.Content)
|
content, nextOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextOffset
|
||||||
|
sb.WriteString(content)
|
||||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||||
sb.WriteString("/nothink")
|
sb.WriteString("/nothink")
|
||||||
}
|
}
|
||||||
|
|||||||
99
model/renderers/glmocr_test.go
Normal file
99
model/renderers/glmocr_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
renderer *GlmOcrRenderer
|
||||||
|
messages []api.Message
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "use_img_tags_single_image",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Describe this image.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use_img_tags_multiple_images",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Describe these images.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi_turn_increments_image_offset",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "First image",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Processed.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Second image",
|
||||||
|
Images: []api.ImageData{api.ImageData("img2")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default_no_img_tags",
|
||||||
|
renderer: &GlmOcrRenderer{},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "No image tags expected.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_images_content_unchanged",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Text only message.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.renderer.Render(tt.messages, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Render() error = %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||||
|
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
194
model/renderers/qwen35.go
Normal file
194
model/renderers/qwen35.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ThinkOpenTag = "<think>"
|
||||||
|
qwen35ThinkCloseTag = "</think>"
|
||||||
|
qwen35ToolPostamble = `
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=example_function_name>
|
||||||
|
<parameter=example_parameter_1>
|
||||||
|
value_1
|
||||||
|
</parameter>
|
||||||
|
<parameter=example_parameter_2>
|
||||||
|
This is the value for the second parameter
|
||||||
|
that can span
|
||||||
|
multiple lines
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
<IMPORTANT>
|
||||||
|
Reminder:
|
||||||
|
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||||
|
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||||
|
</IMPORTANT>`
|
||||||
|
)
|
||||||
|
|
||||||
|
type Qwen35Renderer struct {
|
||||||
|
isThinking bool
|
||||||
|
|
||||||
|
emitEmptyThinkOnNoThink bool
|
||||||
|
useImgTags bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||||
|
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||||
|
var subSb strings.Builder
|
||||||
|
for range content.Images {
|
||||||
|
if r.useImgTags {
|
||||||
|
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||||
|
imageOffset++
|
||||||
|
} else {
|
||||||
|
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: support videos
|
||||||
|
|
||||||
|
subSb.WriteString(content.Content)
|
||||||
|
return subSb.String(), imageOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitQwen35ReasoningContent(content, messageThinking string, isThinking bool) (reasoning string, remaining string) {
|
||||||
|
if isThinking && messageThinking != "" {
|
||||||
|
return strings.TrimSpace(messageThinking), content
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx := strings.Index(content, qwen35ThinkCloseTag); idx != -1 {
|
||||||
|
before := content[:idx]
|
||||||
|
if open := strings.LastIndex(before, qwen35ThinkOpenTag); open != -1 {
|
||||||
|
reasoning = before[open+len(qwen35ThinkOpenTag):]
|
||||||
|
} else {
|
||||||
|
reasoning = before
|
||||||
|
}
|
||||||
|
content = strings.TrimLeft(content[idx+len(qwen35ThinkCloseTag):], "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(reasoning), content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Qwen35Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
isThinking := r.isThinking
|
||||||
|
if think != nil {
|
||||||
|
isThinking = think.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
sb.WriteString(imStartTag + "system\n")
|
||||||
|
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
|
||||||
|
for _, tool := range tools {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
if b, err := marshalWithSpaces(tool); err == nil {
|
||||||
|
sb.Write(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(qwen35ToolPostamble)
|
||||||
|
if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
|
systemContent, _ := r.renderContent(messages[0], 0)
|
||||||
|
systemContent = strings.TrimSpace(systemContent)
|
||||||
|
if systemContent != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
sb.WriteString(systemContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
} else if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
|
systemContent, _ := r.renderContent(messages[0], 0)
|
||||||
|
sb.WriteString(imStartTag + "system\n" + strings.TrimSpace(systemContent) + imEndTag + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
multiStepTool := true
|
||||||
|
lastQueryIndex := len(messages) - 1 // so this is the last user message
|
||||||
|
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
message := messages[i]
|
||||||
|
if multiStepTool && message.Role == "user" {
|
||||||
|
content, _ := r.renderContent(message, 0)
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||||
|
multiStepTool = false
|
||||||
|
lastQueryIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
|
for i, message := range messages {
|
||||||
|
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextImageOffset
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
|
||||||
|
lastMessage := i == len(messages)-1
|
||||||
|
prefill := lastMessage && message.Role == "assistant"
|
||||||
|
|
||||||
|
if message.Role == "user" || (message.Role == "system" && i != 0) {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n" + content + imEndTag + "\n")
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
contentReasoning, content := splitQwen35ReasoningContent(content, message.Thinking, isThinking)
|
||||||
|
|
||||||
|
if isThinking && i > lastQueryIndex {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n<think>\n" + contentReasoning + "\n</think>\n\n" + content)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n" + content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
for j, toolCall := range message.ToolCalls {
|
||||||
|
if j == 0 {
|
||||||
|
if strings.TrimSpace(content) != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<tool_call>\n<function=" + toolCall.Function.Name + ">\n")
|
||||||
|
for name, value := range toolCall.Function.Arguments.All() {
|
||||||
|
sb.WriteString("<parameter=" + name + ">\n")
|
||||||
|
sb.WriteString(formatToolCallArgument(value))
|
||||||
|
sb.WriteString("\n</parameter>\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("</function>\n</tool_call>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !prefill {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
} else if message.Role == "tool" {
|
||||||
|
if i == 0 || messages[i-1].Role != "tool" {
|
||||||
|
sb.WriteString(imStartTag + "user")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||||
|
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefill at the end
|
||||||
|
if lastMessage && !prefill {
|
||||||
|
sb.WriteString(imStartTag + "assistant\n")
|
||||||
|
if isThinking {
|
||||||
|
sb.WriteString("<think>\n")
|
||||||
|
} else if r.emitEmptyThinkOnNoThink {
|
||||||
|
sb.WriteString("<think>\n\n</think>\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
389
model/renderers/qwen35_test.go
Normal file
389
model/renderers/qwen35_test.go
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen35RendererUsesXMLToolCallingFormat(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "What's the weather in Paris?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "I'll check.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "22C"},
|
||||||
|
{Role: "user", Content: "Thanks"},
|
||||||
|
}
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, tools, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(got, "<tools>") {
|
||||||
|
t.Fatalf("expected tools section in prompt, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "<function=example_function_name>") {
|
||||||
|
t.Fatalf("expected xml-style tool call instructions, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolCall := "<tool_call>\n<function=get_weather>\n<parameter=location>\nParis\n</parameter>\n</function>\n</tool_call>"
|
||||||
|
if !strings.Contains(got, wantToolCall) {
|
||||||
|
t.Fatalf("expected xml tool call payload, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsIdx := strings.Index(got, "# Tools")
|
||||||
|
systemIdx := strings.Index(got, "You are a helpful assistant.")
|
||||||
|
if toolsIdx == -1 || systemIdx == -1 || systemIdx < toolsIdx {
|
||||||
|
t.Fatalf("expected system prompt appended after tool instructions, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererNoThinkPrefill(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||||
|
t.Fatalf("expected explicit no-think prefill, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererBackToBackToolCallsAndResponses(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Run add and multiply."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "I'll run both now.",
|
||||||
|
Thinking: "Need to call add and multiply.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "add",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "a", Value: 2},
|
||||||
|
{Key: "b", Value: 3},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "multiply",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "x", Value: 4},
|
||||||
|
{Key: "y", Value: 5},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "5"},
|
||||||
|
{Role: "tool", Content: "20"},
|
||||||
|
{Role: "user", Content: "Summarize the results."},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, qwen35MathTools(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(got, "Need to call add and multiply.") {
|
||||||
|
t.Fatalf("did not expect historical reasoning block in this sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolCalls := `<tool_call>
|
||||||
|
<function=add>
|
||||||
|
<parameter=a>
|
||||||
|
2
|
||||||
|
</parameter>
|
||||||
|
<parameter=b>
|
||||||
|
3
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
<function=multiply>
|
||||||
|
<parameter=x>
|
||||||
|
4
|
||||||
|
</parameter>
|
||||||
|
<parameter=y>
|
||||||
|
5
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>`
|
||||||
|
if !strings.Contains(got, wantToolCalls) {
|
||||||
|
t.Fatalf("expected back-to-back tool calls, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolResponses := `<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
5
|
||||||
|
</tool_response>
|
||||||
|
<tool_response>
|
||||||
|
20
|
||||||
|
</tool_response><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantToolResponses) {
|
||||||
|
t.Fatalf("expected grouped back-to-back tool responses, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererInterleavedThinkingAndTools(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Plan a picnic in Paris."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Checking weather first.",
|
||||||
|
Thinking: "Need weather before giving advice.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "22C"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Checking UV too.",
|
||||||
|
Thinking: "Need UV index for sunscreen advice.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_uv",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "5"},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, qwen35WeatherUVTools(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantFirstTurn := `<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Need weather before giving advice.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Checking weather first.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=get_weather>
|
||||||
|
<parameter=location>
|
||||||
|
Paris
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantFirstTurn) {
|
||||||
|
t.Fatalf("expected first assistant thinking/tool sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantSecondTurn := `<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Need UV index for sunscreen advice.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Checking UV too.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=get_uv>
|
||||||
|
<parameter=location>
|
||||||
|
Paris
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantSecondTurn) {
|
||||||
|
t.Fatalf("expected second assistant thinking/tool sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererAssistantPrefillWithThinking(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "Write two words."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "Keep it short.",
|
||||||
|
Content: "Hello world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := `<|im_start|>user
|
||||||
|
Write two words.<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Keep it short.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Hello world`
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("unexpected prefill output\n--- got ---\n%s\n--- want ---\n%s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func qwen35MathTools() []api.Tool {
|
||||||
|
return []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "add",
|
||||||
|
Description: "Add two numbers",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "a",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "b",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "multiply",
|
||||||
|
Description: "Multiply two numbers",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "x",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "y",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"x", "y"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func qwen35WeatherUVTools() []api.Tool {
|
||||||
|
return []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather for a location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_uv",
|
||||||
|
Description: "Get UV index for a location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,7 +57,7 @@ func rendererForName(name string) Renderer {
|
|||||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||||
return renderer
|
return renderer
|
||||||
case "qwen3.5":
|
case "qwen3.5":
|
||||||
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||||
return renderer
|
return renderer
|
||||||
case "cogito":
|
case "cogito":
|
||||||
renderer := &CogitoRenderer{isThinking: true}
|
renderer := &CogitoRenderer{isThinking: true}
|
||||||
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
|
|||||||
case "glm-4.7":
|
case "glm-4.7":
|
||||||
return &GLM47Renderer{}
|
return &GLM47Renderer{}
|
||||||
case "glm-ocr":
|
case "glm-ocr":
|
||||||
return &GlmOcrRenderer{}
|
return &GlmOcrRenderer{useImgTags: RenderImgTags}
|
||||||
case "lfm2":
|
case "lfm2":
|
||||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||||
case "lfm2-thinking":
|
case "lfm2-thinking":
|
||||||
|
|||||||
@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
if errors.As(err, &reprocess) {
|
if errors.As(err, &reprocess) {
|
||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
seq.sampler.Reset()
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||||
err = nil
|
err = nil
|
||||||
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
// (unless we take down the whole runner).
|
// (unless we take down the whole runner).
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
|
for _, inp := range seq.pendingInputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
seq.pendingInputs = []*input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
|
req.Options.RepeatPenalty,
|
||||||
|
req.Options.PresencePenalty,
|
||||||
|
req.Options.FrequencyPenalty,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
)
|
)
|
||||||
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq.sampler.Reset()
|
||||||
|
for _, inp := range seq.cache.Inputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
|
|
||||||
s.seqs[i] = seq
|
s.seqs[i] = seq
|
||||||
s.cond.Signal()
|
s.cond.Signal()
|
||||||
found = true
|
found = true
|
||||||
|
|||||||
@@ -16,24 +16,49 @@ type token struct {
|
|||||||
value float32 // The raw logit or probability from the model
|
value float32 // The raw logit or probability from the model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DefaultPenaltyLookback = 64
|
||||||
|
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
topK int
|
topK int
|
||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
|
repeat float32
|
||||||
|
presence float32
|
||||||
|
frequency float32
|
||||||
|
history []int32
|
||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Reset() {
|
||||||
|
s.history = s.history[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Accept(token int32) {
|
||||||
|
s.history = append(s.history, token)
|
||||||
|
if len(s.history) > DefaultPenaltyLookback {
|
||||||
|
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||||
|
s.history = s.history[:DefaultPenaltyLookback]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
counts := tokenCounts(s.history, len(logits))
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
|
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := s.sample(tokens)
|
t, err := s.sample(tokens)
|
||||||
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
|||||||
// we need to reset them before applying the grammar and
|
// we need to reset them before applying the grammar and
|
||||||
// sampling again
|
// sampling again
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
s.grammar.Apply(tokens)
|
s.grammar.Apply(tokens)
|
||||||
t, err = s.sample(tokens)
|
t, err = s.sample(tokens)
|
||||||
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||||||
minP = 1.0
|
minP = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if repeatPenalty <= 0 {
|
||||||
|
repeatPenalty = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
topP: topP,
|
topP: topP,
|
||||||
minP: minP,
|
minP: minP,
|
||||||
temperature: temperature,
|
temperature: temperature,
|
||||||
|
repeat: repeatPenalty,
|
||||||
|
presence: presencePenalty,
|
||||||
|
frequency: frequencyPenalty,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range configs {
|
for _, tc := range configs {
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
logits := []float32{-10, 3, -10, -10}
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err := sampler.Sample(logits)
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
logits = []float32{-100, -10, 0, 10}
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
// Test very high p
|
// Test very high p
|
||||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
// Use extremely small topP to filter out all tokens
|
// Use extremely small topP to filter out all tokens
|
||||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error, got %d", got)
|
t.Errorf("expected error, got %d", got)
|
||||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
// Generate random logits for benchmarking
|
||||||
|
|||||||
@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||||
|
if len(history) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := 0
|
||||||
|
if len(history) > DefaultPenaltyLookback {
|
||||||
|
start = len(history) - DefaultPenaltyLookback
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make(map[int32]int, len(history)-start)
|
||||||
|
for _, token := range history[start:] {
|
||||||
|
if token < 0 || int(token) >= vocabSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[token]++
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||||
|
if repeatPenalty != 1.0 {
|
||||||
|
// Preserve ordering for negative logits when applying repeat penalty.
|
||||||
|
if logit < 0 {
|
||||||
|
logit *= repeatPenalty
|
||||||
|
} else {
|
||||||
|
logit /= repeatPenalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if frequencyPenalty != 0 {
|
||||||
|
logit -= float32(count) * frequencyPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
if presencePenalty != 0 {
|
||||||
|
logit -= presencePenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
return logit
|
||||||
|
}
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
// temperature applies scaling to the logits
|
||||||
func temperature(ts []token, temp float32) {
|
func temperature(ts []token, temp float32) {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
|
|||||||
@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenCounts(t *testing.T) {
|
||||||
|
history := make([]int32, 70)
|
||||||
|
history[0] = 7
|
||||||
|
history[69] = 7
|
||||||
|
|
||||||
|
counts := tokenCounts(history, 8)
|
||||||
|
if got := counts[7]; got != 1 {
|
||||||
|
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPenalty(t *testing.T) {
|
||||||
|
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerPresencePenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 0.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||||
|
presence.Accept(1)
|
||||||
|
got, err = presence.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got == 1 {
|
||||||
|
t.Fatalf("presence penalty did not change repeated token selection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 4.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
got, err = frequency.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 2 {
|
||||||
|
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
// Generate random logits
|
// Generate random logits
|
||||||
tokens := make([]token, 1<<16)
|
tokens := make([]token, 1<<16)
|
||||||
|
|||||||
@@ -65,11 +65,22 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
config.Parser = r.Parser
|
config.Parser = r.Parser
|
||||||
config.Requires = r.Requires
|
config.Requires = r.Requires
|
||||||
|
|
||||||
for v := range r.Files {
|
for v, digest := range r.Files {
|
||||||
if !fs.ValidPath(v) {
|
if !fs.ValidPath(v) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if digest == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, digest := range r.Adapters {
|
||||||
|
if digest == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
|||||||
t.Fatal("prompt is empty")
|
t.Fatal("prompt is empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "extract text",
|
||||||
|
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||||
|
ProjectorPaths: []string{"vision"},
|
||||||
|
}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||||
|
think := false
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(images), 2; got != want {
|
||||||
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||||
|
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
// Deprecated runner override option; ignore if present.
|
||||||
delete(requestOpts, "use_imagegen_runner")
|
delete(requestOpts, "use_imagegen_runner")
|
||||||
|
|
||||||
opts, err := s.modelOptions(model, requestOpts)
|
opts, err := s.modelOptions(model, requestOpts)
|
||||||
@@ -158,7 +158,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
case runner = <-runnerCh:
|
case runner = <-runnerCh:
|
||||||
@@ -370,12 +370,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
|
||||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
caps := []model.Capability{model.CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if req.Suffix != "" {
|
if req.Suffix != "" {
|
||||||
caps = append(caps, model.CapabilityInsert)
|
caps = append(caps, model.CapabilityInsert)
|
||||||
@@ -558,7 +552,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
PeakMemory: cr.PeakMemory,
|
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -2241,12 +2234,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
|
||||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var thinkingState *thinking.Parser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||||
@@ -2317,7 +2304,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
PeakMemory: r.PeakMemory,
|
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -144,6 +144,37 @@ func TestCreateFromBin(t *testing.T) {
|
|||||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("empty file digest", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "my-gguf-model",
|
||||||
|
Files: map[string]string{"0.gguf": ""},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||||
|
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty adapter digest", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "my-gguf-model",
|
||||||
|
Files: map[string]string{"0.gguf": digest},
|
||||||
|
Adapters: map[string]string{"adapter.gguf": ""},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||||
|
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ type LlmRequest struct {
|
|||||||
successCh chan *runnerRef
|
successCh chan *runnerRef
|
||||||
errCh chan error
|
errCh chan error
|
||||||
schedAttempts uint
|
schedAttempts uint
|
||||||
useImagegen bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Scheduler struct {
|
type Scheduler struct {
|
||||||
@@ -106,7 +105,7 @@ func schedulerModelKey(m *Model) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
@@ -123,7 +122,6 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
useImagegen: useImagegen,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
key := schedulerModelKey(req.model)
|
key := schedulerModelKey(req.model)
|
||||||
@@ -593,20 +591,15 @@ iGPUScan:
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
// loadMLX loads an experimental safetensors model using MLX runners.
|
||||||
// This supports both LLM (completion) and image generation models.
|
// Image models use x/imagegen; LLM models use x/mlxrunner.
|
||||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||||
modelName := req.model.ShortName
|
modelName := req.model.ShortName
|
||||||
var server llm.LlamaServer
|
var server llm.LlamaServer
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
isImagegen := false
|
|
||||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
server, err = imagegen.NewServer(modelName)
|
||||||
isImagegen = true
|
|
||||||
} else if req.useImagegen {
|
|
||||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
|
||||||
isImagegen = true
|
|
||||||
} else {
|
} else {
|
||||||
server, err = mlxrunner.NewClient(modelName)
|
server, err = mlxrunner.NewClient(modelName)
|
||||||
}
|
}
|
||||||
@@ -628,7 +621,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
llama: server,
|
llama: server,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: totalSize,
|
totalSize: totalSize,
|
||||||
vramSize: vramSize,
|
vramSize: vramSize,
|
||||||
@@ -737,8 +730,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
defer runner.refMu.Unlock()
|
defer runner.refMu.Unlock()
|
||||||
|
|
||||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
|
||||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
|
||||||
if runner.isImagegen != wantImagegen {
|
if runner.isImagegen != wantImagegen {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
|||||||
s.getSystemInfoFn = getSystemInfoFn
|
s.getSystemInfoFn = getSystemInfoFn
|
||||||
s.newServerFn = a.newServer
|
s.newServerFn = a.newServer
|
||||||
slog.Info("a")
|
slog.Info("a")
|
||||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("b")
|
slog.Info("b")
|
||||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
|||||||
|
|
||||||
c.req.model.ModelPath = "bad path"
|
c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("c")
|
slog.Info("c")
|
||||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processed to return an error
|
// Starts in pending channel, then should be quickly processed to return an error
|
||||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
@@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil)
|
||||||
|
|
||||||
require.Empty(t, successCh)
|
require.Empty(t, successCh)
|
||||||
require.Empty(t, errCh)
|
require.Empty(t, errCh)
|
||||||
@@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil)
|
||||||
cancelReq()
|
cancelReq()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
|||||||
s.getGpuFn = getGpuFn
|
s.getGpuFn = getGpuFn
|
||||||
s.getSystemInfoFn = getSystemInfoFn
|
s.getSystemInfoFn = getSystemInfoFn
|
||||||
s.newServerFn = scenario1a.newServer
|
s.newServerFn = scenario1a.newServer
|
||||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStackedExpertWeight(name string) bool {
|
||||||
|
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||||
|
// or "...proj" (pre-stacked packed tensor).
|
||||||
|
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||||
|
strings.Contains(name, ".mlp.experts.") ||
|
||||||
|
strings.Contains(name, ".mlp.shared_experts.")
|
||||||
|
}
|
||||||
|
|
||||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||||
// Returns "" if the tensor should not be quantized.
|
// Returns "" if the tensor should not be quantized.
|
||||||
// This implements mixed-precision quantization:
|
// This implements mixed-precision quantization:
|
||||||
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
|
|||||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||||
// - Norms, embeddings, biases, routing gates: no quantization
|
// - Norms, embeddings, biases, routing gates: no quantization
|
||||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||||
|
stackedExpert := isStackedExpertWeight(name)
|
||||||
|
|
||||||
// Use basic name-based check first
|
// Use basic name-based check first
|
||||||
if !ShouldQuantize(name, "") {
|
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||||
if len(shape) != 2 {
|
// e.g. qwen switch_mlp / experts combined tensors.
|
||||||
|
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
var elems int64 = 1
|
||||||
|
for _, d := range shape {
|
||||||
|
elems *= int64(d)
|
||||||
|
}
|
||||||
|
if elems < 1024 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
// 3D+ tensors should not be quantized
|
// 3D+ tensors should not be quantized
|
||||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||||
|
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||||
|
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||||
|
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||||
|
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||||
|
|
||||||
// Embeddings should not be quantized regardless of shape
|
// Embeddings should not be quantized regardless of shape
|
||||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||||
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||||
|
gateUp := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||||
|
[]int32{64, 22016, 4096},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if gateUp != "int4" {
|
||||||
|
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||||
|
}
|
||||||
|
|
||||||
|
down := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.experts.down_proj.weight",
|
||||||
|
[]int32{64, 4096, 14336},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if down != "int8" {
|
||||||
|
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedGateUp := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||||
|
[]int32{256, 1024, 2048},
|
||||||
|
"int8",
|
||||||
|
)
|
||||||
|
if combinedGateUp != "int8" {
|
||||||
|
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedDown := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||||
|
[]int32{256, 2048, 512},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if combinedDown != "int8" {
|
||||||
|
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
@@ -10,17 +10,7 @@ go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
|||||||
|
|
||||||
## Text Generation
|
## Text Generation
|
||||||
|
|
||||||
```bash
|
Text generation models are no longer supported by this engine.
|
||||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
|
||||||
```
|
|
||||||
|
|
||||||
Options:
|
|
||||||
|
|
||||||
- `-temperature` - sampling temperature (default 0.7)
|
|
||||||
- `-top-p` - nucleus sampling (default 0.9)
|
|
||||||
- `-top-k` - top-k sampling (default 40)
|
|
||||||
|
|
||||||
Supports: Llama, Gemma3, GPT-OSS
|
|
||||||
|
|
||||||
## Image Generation
|
## Image Generation
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,6 @@ import (
|
|||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
)
|
)
|
||||||
@@ -170,11 +167,11 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load image if provided and model supports it
|
// Load image if provided and model supports it.
|
||||||
var image *mlx.Array
|
var image *mlx.Array
|
||||||
if *imagePath != "" {
|
if *imagePath != "" {
|
||||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("load image:", err)
|
log.Fatal("load image:", err)
|
||||||
}
|
}
|
||||||
@@ -236,14 +233,8 @@ func load(modelPath string) (Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch kind {
|
switch kind {
|
||||||
case "gpt_oss":
|
|
||||||
return gpt_oss.Load(modelPath)
|
|
||||||
case "gemma3":
|
|
||||||
return gemma3.Load(modelPath)
|
|
||||||
case "gemma3_text":
|
|
||||||
return gemma3.LoadText(modelPath)
|
|
||||||
default:
|
default:
|
||||||
return llama.Load(modelPath)
|
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build mlx
|
//go:build mlx
|
||||||
|
|
||||||
package gemma3
|
package imagegen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,8 +13,8 @@ import (
|
|||||||
"golang.org/x/image/draw"
|
"golang.org/x/image/draw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessImage loads and preprocesses an image for the vision tower
|
// ProcessImage loads and preprocesses an image for multimodal vision towers.
|
||||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP.
|
||||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -30,20 +30,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
|||||||
return ProcessImageData(img, imageSize)
|
return ProcessImageData(img, imageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
// ProcessImageData preprocesses an image.Image for multimodal vision towers.
|
||||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||||
// Resize to target size using bilinear interpolation
|
// Resize to target size using bilinear interpolation.
|
||||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||||
|
|
||||||
// Convert to float32 array [H, W, C] and normalize
|
// Convert to float32 array [H, W, C] and normalize.
|
||||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0.
|
||||||
data := make([]float32, imageSize*imageSize*3)
|
data := make([]float32, imageSize*imageSize*3)
|
||||||
idx := 0
|
idx := 0
|
||||||
for y := int32(0); y < imageSize; y++ {
|
for y := int32(0); y < imageSize; y++ {
|
||||||
for x := int32(0); x < imageSize; x++ {
|
for x := int32(0); x < imageSize; x++ {
|
||||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||||
// RGBA returns 16-bit values, convert to 8-bit
|
// RGBA returns 16-bit values, convert to 8-bit.
|
||||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||||
@@ -51,8 +51,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create MLX array [1, H, W, C] for NHWC layout
|
// Create MLX array [1, H, W, C] for NHWC layout.
|
||||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
mlx.Eval(arr) // Materialize to prevent use-after-free.
|
||||||
return arr, nil
|
return arr, nil
|
||||||
}
|
}
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package imagegen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TextModel is the interface for LLM text generation models.
|
|
||||||
type TextModel interface {
|
|
||||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
||||||
NewCache(maxSeqLen int32) []cache.Cache
|
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
|
||||||
VocabSize() int32
|
|
||||||
MaxContextLength() int32
|
|
||||||
NumLayers() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// llmState holds the state for LLM generation
|
|
||||||
type llmState struct {
|
|
||||||
model TextModel
|
|
||||||
}
|
|
||||||
|
|
||||||
var llmMu sync.Mutex
|
|
||||||
|
|
||||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
|
||||||
var generationStream *mlx.Stream
|
|
||||||
|
|
||||||
// withStream runs fn with the generation stream as default
|
|
||||||
func withStream(fn func()) {
|
|
||||||
// Lazy initialization of generationStream
|
|
||||||
if generationStream == nil {
|
|
||||||
generationStream = mlx.NewStream()
|
|
||||||
}
|
|
||||||
orig := mlx.GetDefaultStream()
|
|
||||||
mlx.SetDefaultStream(generationStream)
|
|
||||||
fn()
|
|
||||||
mlx.SetDefaultStream(orig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decoder wraps model + cache for autoregressive generation.
|
|
||||||
// This matches the pattern from cmd/engine/generate.go
|
|
||||||
type Decoder struct {
|
|
||||||
model TextModel
|
|
||||||
caches []cache.Cache
|
|
||||||
vocabSize int32
|
|
||||||
temp float32
|
|
||||||
token *mlx.Array // Current token (kept across iterations)
|
|
||||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
|
||||||
caches := m.NewCache(0)
|
|
||||||
return &Decoder{
|
|
||||||
model: m,
|
|
||||||
caches: caches,
|
|
||||||
vocabSize: m.VocabSize(),
|
|
||||||
temp: temp,
|
|
||||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
|
||||||
processed := 0
|
|
||||||
|
|
||||||
// Track old cache state to free after each chunk
|
|
||||||
var oldCacheState []*mlx.Array
|
|
||||||
|
|
||||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
|
||||||
for len(inputIDs) > 1 {
|
|
||||||
chunkSize := min(2048, len(inputIDs)-1)
|
|
||||||
if chunkSize <= 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
chunk := inputIDs[:chunkSize]
|
|
||||||
|
|
||||||
// Save old cache state before forward
|
|
||||||
oldCacheState = oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
oldCacheState = append(oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cacheState []*mlx.Array
|
|
||||||
withStream(func() {
|
|
||||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
|
||||||
d.model.Forward(x, d.caches)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
cacheState = append(cacheState, c.State()...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
mlx.Eval(cacheState...)
|
|
||||||
|
|
||||||
// Free old cache state
|
|
||||||
for _, arr := range oldCacheState {
|
|
||||||
if arr != nil {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inputIDs = inputIDs[chunkSize:]
|
|
||||||
processed += chunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save old cache state before final step
|
|
||||||
oldCacheState = oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
oldCacheState = append(oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final token + sampling
|
|
||||||
withStream(func() {
|
|
||||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
|
||||||
mlx.Eval(x) // Materialize before any other evals
|
|
||||||
logits := d.model.Forward(x, d.caches)
|
|
||||||
d.token = sample(logits, d.temp, d.vocabSize)
|
|
||||||
})
|
|
||||||
// Keep cache state (token auto-kept by AsyncEval)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
mlx.Keep(c.State()...)
|
|
||||||
}
|
|
||||||
mlx.AsyncEval(d.token)
|
|
||||||
|
|
||||||
// Free old cache state from before final step
|
|
||||||
for _, arr := range oldCacheState {
|
|
||||||
if arr != nil {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.ClearCache()
|
|
||||||
|
|
||||||
return processed + len(inputIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Decoder) step() int32 {
|
|
||||||
prevToken := d.token
|
|
||||||
|
|
||||||
// Save old cache state (reuse preallocated slice)
|
|
||||||
d.oldCacheState = d.oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
withStream(func() {
|
|
||||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
|
||||||
d.token = sample(logits, d.temp, d.vocabSize)
|
|
||||||
})
|
|
||||||
// Keep token and new cache state so they survive cleanup
|
|
||||||
mlx.Keep(d.token)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
mlx.Keep(c.State()...)
|
|
||||||
}
|
|
||||||
mlx.AsyncEval(d.token)
|
|
||||||
|
|
||||||
// Sync on previous token (GPU already working on next step)
|
|
||||||
val := prevToken.ItemInt32()
|
|
||||||
|
|
||||||
// Free old token and old cache state
|
|
||||||
prevToken.Free()
|
|
||||||
for _, arr := range d.oldCacheState {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// sample samples from logits using temperature scaling
|
|
||||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
|
||||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
|
||||||
shape := logits.Shape()
|
|
||||||
seqLen := shape[1]
|
|
||||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
|
||||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
|
||||||
|
|
||||||
if temp <= 0 || temp < 0.01 {
|
|
||||||
// Greedy decoding
|
|
||||||
return mlx.Argmax(lastLogits, -1, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply temperature scaling
|
|
||||||
scaled := mlx.DivScalar(lastLogits, temp)
|
|
||||||
return mlx.RandomCategorical(scaled, -1, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
|
||||||
func (s *server) loadLLMModel() error {
|
|
||||||
// Load the manifest to get model information
|
|
||||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load manifest: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect model architecture from config.json
|
|
||||||
configData, err := modelManifest.ReadConfig("config.json")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read config.json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelConfig struct {
|
|
||||||
Architectures []string `json:"architectures"`
|
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
|
||||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
arch := ""
|
|
||||||
if len(modelConfig.Architectures) > 0 {
|
|
||||||
arch = modelConfig.Architectures[0]
|
|
||||||
}
|
|
||||||
if arch == "" {
|
|
||||||
arch = modelConfig.ModelType
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
|
||||||
|
|
||||||
// Load the appropriate model based on architecture
|
|
||||||
var model TextModel
|
|
||||||
archLower := strings.ToLower(arch)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.Contains(archLower, "glm4moelite"):
|
|
||||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
|
||||||
}
|
|
||||||
model = m
|
|
||||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
|
||||||
"Supported architectures: glm4-moe-lite. "+
|
|
||||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.llmModel = &llmState{
|
|
||||||
model: model,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleLLMCompletion handles LLM text generation requests.
|
|
||||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
|
||||||
if s.llmModel == nil {
|
|
||||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize generation requests
|
|
||||||
llmMu.Lock()
|
|
||||||
defer llmMu.Unlock()
|
|
||||||
|
|
||||||
if err := s.llmGenerate(w, r, req); err != nil {
|
|
||||||
slog.Error("LLM generation failed", "error", err)
|
|
||||||
// Don't send error if we've already started streaming
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
|
||||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
|
||||||
state := s.llmModel
|
|
||||||
|
|
||||||
// Set up streaming response
|
|
||||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("streaming not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
tok := state.model.Tokenizer()
|
|
||||||
|
|
||||||
// The prompt is already formatted by the server using the model's renderer
|
|
||||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
|
||||||
prompt := req.Prompt
|
|
||||||
|
|
||||||
// Tokenize the prompt
|
|
||||||
inputIDs := tok.Encode(prompt, true)
|
|
||||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
|
||||||
|
|
||||||
// Generation parameters
|
|
||||||
maxTokens := int(state.model.MaxContextLength())
|
|
||||||
if maxTokens <= 0 {
|
|
||||||
maxTokens = 4096
|
|
||||||
}
|
|
||||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
|
||||||
maxTokens = req.Options.NumPredict
|
|
||||||
}
|
|
||||||
|
|
||||||
temperature := float32(0.7)
|
|
||||||
if req.Options != nil && req.Options.Temperature > 0 {
|
|
||||||
temperature = float32(req.Options.Temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable MLX compilation for better performance
|
|
||||||
mlx.EnableCompile()
|
|
||||||
|
|
||||||
// Create decoder with fresh caches
|
|
||||||
dec := NewDecoder(state.model, temperature)
|
|
||||||
|
|
||||||
prefillStart := time.Now()
|
|
||||||
prefillTokens := dec.prefill(inputIDs)
|
|
||||||
// Prefill measurement includes time to first token
|
|
||||||
firstToken := dec.step()
|
|
||||||
prefillDuration := time.Since(prefillStart)
|
|
||||||
promptEvalDuration := prefillDuration
|
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
ctx := r.Context()
|
|
||||||
generated := 0
|
|
||||||
stopReason := "max_tokens"
|
|
||||||
|
|
||||||
// Handle first token
|
|
||||||
generated++
|
|
||||||
if tok.IsEOS(firstToken) {
|
|
||||||
resp := Response{
|
|
||||||
Done: true,
|
|
||||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
|
||||||
PromptEvalCount: prefillTokens,
|
|
||||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
|
||||||
}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
text := tok.Decode([]int32{firstToken})
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
genStart := time.Now()
|
|
||||||
|
|
||||||
// Generation loop
|
|
||||||
for n := 1; n < maxTokens; n++ {
|
|
||||||
// Check for cancellation
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if stopReason != "max_tokens" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
token := dec.step()
|
|
||||||
generated++
|
|
||||||
|
|
||||||
if tok.IsEOS(token) {
|
|
||||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
text := tok.Decode([]int32{token})
|
|
||||||
|
|
||||||
// Check for stop sequences
|
|
||||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
|
||||||
shouldStop := false
|
|
||||||
var matchedStop string
|
|
||||||
for _, stop := range req.Options.Stop {
|
|
||||||
if strings.Contains(text, stop) {
|
|
||||||
text = strings.Split(text, stop)[0]
|
|
||||||
shouldStop = true
|
|
||||||
matchedStop = stop
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if shouldStop {
|
|
||||||
if text != "" {
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
// Periodically clear MLX cache
|
|
||||||
if n%256 == 0 {
|
|
||||||
mlx.ClearCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
mlx.ClearCache()
|
|
||||||
|
|
||||||
// Send final response with stats
|
|
||||||
evalDuration := time.Since(genStart)
|
|
||||||
resp = Response{
|
|
||||||
Done: true,
|
|
||||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
|
||||||
PromptEvalCount: prefillTokens,
|
|
||||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
|
||||||
EvalCount: generated,
|
|
||||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
|
||||||
}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,614 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TextConfig holds configuration for the text model
|
|
||||||
type TextConfig struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
||||||
HeadDim int32 `json:"head_dim"`
|
|
||||||
VocabSize int32 `json:"vocab_size"`
|
|
||||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
|
||||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
||||||
SlidingWindow int32 `json:"sliding_window"`
|
|
||||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
|
||||||
|
|
||||||
// Computed fields
|
|
||||||
Scale float32 `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TextModel is the Gemma 3 text-only model
|
|
||||||
type TextModel struct {
|
|
||||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
||||||
Layers []*DecoderLayer `weight:"model.layers"`
|
|
||||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
|
||||||
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
|
|
||||||
|
|
||||||
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
|
|
||||||
NormScaled *mlx.Array `weight:"-"`
|
|
||||||
|
|
||||||
tok *tokenizer.Tokenizer
|
|
||||||
*TextConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecoderLayer is a single transformer block
|
|
||||||
type DecoderLayer struct {
|
|
||||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
||||||
Attention *Attention
|
|
||||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
||||||
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
|
|
||||||
MLP *MLP
|
|
||||||
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
|
|
||||||
|
|
||||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
|
||||||
InputNormScaled *mlx.Array `weight:"-"`
|
|
||||||
PostAttnNormScaled *mlx.Array `weight:"-"`
|
|
||||||
PreFFNormScaled *mlx.Array `weight:"-"`
|
|
||||||
PostFFNormScaled *mlx.Array `weight:"-"`
|
|
||||||
|
|
||||||
// Whether this layer uses sliding window attention
|
|
||||||
IsSliding bool
|
|
||||||
LayerIdx int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attention implements Gemma 3 attention with Q/K normalization
|
|
||||||
type Attention struct {
|
|
||||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
|
||||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
|
||||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
|
||||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
|
||||||
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
|
|
||||||
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
|
|
||||||
|
|
||||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
|
||||||
QNormScaled *mlx.Array `weight:"-"`
|
|
||||||
KNormScaled *mlx.Array `weight:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// MLP is the feed-forward network with GELU activation
|
|
||||||
type MLP struct {
|
|
||||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
|
||||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
|
||||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadText loads the text-only Gemma 3 model
|
|
||||||
func LoadText(modelPath string) (*TextModel, error) {
|
|
||||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
var cfg TextConfig
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute scale
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
||||||
|
|
||||||
// Set defaults if not specified
|
|
||||||
if cfg.RopeTheta == 0 {
|
|
||||||
cfg.RopeTheta = 1000000
|
|
||||||
}
|
|
||||||
if cfg.RopeLocalBaseFreq == 0 {
|
|
||||||
cfg.RopeLocalBaseFreq = 10000
|
|
||||||
}
|
|
||||||
if cfg.RMSNormEps == 0 {
|
|
||||||
cfg.RMSNormEps = 1e-6
|
|
||||||
}
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &TextModel{
|
|
||||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
|
||||||
TextConfig: &cfg,
|
|
||||||
tok: tok,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize layer metadata
|
|
||||||
for i := range m.Layers {
|
|
||||||
m.Layers[i] = &DecoderLayer{
|
|
||||||
LayerIdx: int32(i),
|
|
||||||
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tied embeddings for output
|
|
||||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
|
||||||
|
|
||||||
mlx.Eval(mlx.Collect(m)...)
|
|
||||||
weights.ReleaseAll()
|
|
||||||
|
|
||||||
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
|
|
||||||
precomputeGemmaScaledWeights(m)
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
|
|
||||||
// This avoids creating temporary arrays on every forward pass
|
|
||||||
func precomputeGemmaScaledWeights(m *TextModel) {
|
|
||||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
|
||||||
|
|
||||||
for _, layer := range m.Layers {
|
|
||||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
|
||||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
|
||||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
|
||||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
|
||||||
|
|
||||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
|
||||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Eval all the precomputed weights
|
|
||||||
var scaled []*mlx.Array
|
|
||||||
scaled = append(scaled, m.NormScaled)
|
|
||||||
for _, layer := range m.Layers {
|
|
||||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
|
||||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
|
||||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
|
||||||
}
|
|
||||||
mlx.Eval(scaled...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isLayerSliding determines if a layer uses sliding window attention
|
|
||||||
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
|
|
||||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
|
||||||
if pattern <= 0 {
|
|
||||||
return false // No sliding window
|
|
||||||
}
|
|
||||||
// Layer is full attention if (layerIdx + 1) % pattern == 0
|
|
||||||
return (layerIdx+1)%pattern != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the text model forward pass
|
|
||||||
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
||||||
|
|
||||||
// Get embeddings and scale by sqrt(hidden_size)
|
|
||||||
h := m.EmbedTokens.Forward(tokens)
|
|
||||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final norm and output projection
|
|
||||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs a decoder layer
|
|
||||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
|
||||||
// Pre-attention norm (use precomputed scaled weight)
|
|
||||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
|
||||||
|
|
||||||
// Attention
|
|
||||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
|
||||||
|
|
||||||
// Post-attention norm and residual
|
|
||||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
|
||||||
h := mlx.Add(x, attnOut)
|
|
||||||
|
|
||||||
// Pre-FFN norm
|
|
||||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
|
||||||
|
|
||||||
// MLP
|
|
||||||
mlpOut := l.MLP.Forward(normed)
|
|
||||||
|
|
||||||
// Post-FFN norm and residual
|
|
||||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
|
||||||
return mlx.Add(h, mlpOut)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs attention with Q/K normalization
|
|
||||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
|
||||||
q := a.QProj.Forward(x)
|
|
||||||
k := a.KProj.Forward(x)
|
|
||||||
v := a.VProj.Forward(x)
|
|
||||||
|
|
||||||
// Reshape to [B, num_heads, L, head_dim]
|
|
||||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
|
|
||||||
// Q/K normalization after reshaping (use precomputed scaled weight)
|
|
||||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
|
||||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
|
||||||
|
|
||||||
// Apply RoPE with appropriate theta
|
|
||||||
ropeTheta := cfg.RopeTheta
|
|
||||||
if isSliding {
|
|
||||||
ropeTheta = cfg.RopeLocalBaseFreq
|
|
||||||
}
|
|
||||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
|
||||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
|
||||||
|
|
||||||
// Update cache
|
|
||||||
k, v = c.Update(k, v, int(L))
|
|
||||||
|
|
||||||
// Repeat K/V for GQA if needed
|
|
||||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
|
||||||
if repeatFactor > 1 {
|
|
||||||
k = nn.RepeatKV(k, repeatFactor)
|
|
||||||
v = nn.RepeatKV(v, repeatFactor)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attention
|
|
||||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
|
||||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
||||||
return a.OProj.Forward(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
|
|
||||||
var compiledGeluApprox *mlx.CompiledFunc
|
|
||||||
|
|
||||||
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
|
|
||||||
func getCompiledGeluApprox() *mlx.CompiledFunc {
|
|
||||||
if compiledGeluApprox == nil {
|
|
||||||
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
|
||||||
return []*mlx.Array{geluApproxImpl(inputs[0])}
|
|
||||||
}, true)
|
|
||||||
}
|
|
||||||
return compiledGeluApprox
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the MLP with GELU approximation (tanh variant)
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
|
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
|
|
||||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
||||||
func geluApproxImpl(x *mlx.Array) *mlx.Array {
|
|
||||||
// Constants
|
|
||||||
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
|
|
||||||
const coeff = 0.044715
|
|
||||||
|
|
||||||
// x^3
|
|
||||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
|
||||||
// x + 0.044715 * x^3
|
|
||||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
|
||||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
|
||||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
|
||||||
// tanh(...)
|
|
||||||
tanh := mlx.Tanh(scaled)
|
|
||||||
// 1 + tanh(...)
|
|
||||||
onePlusTanh := mlx.AddScalar(tanh, 1.0)
|
|
||||||
// 0.5 * x * (1 + tanh(...))
|
|
||||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
|
|
||||||
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
|
|
||||||
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
|
|
||||||
// Gemma uses (1 + weight) instead of weight
|
|
||||||
scaledWeight := mlx.AddScalar(weight, 1.0)
|
|
||||||
return mlx.RMSNorm(x, scaledWeight, eps)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Interface methods
|
|
||||||
func (m *TextModel) NumLayers() int { return len(m.Layers) }
|
|
||||||
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
|
||||||
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
|
|
||||||
|
|
||||||
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
|
|
||||||
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
|
|
||||||
return m.tok
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatPrompt applies the Gemma 3 chat template to a prompt
|
|
||||||
func (m *TextModel) FormatPrompt(prompt string) string {
|
|
||||||
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
|
|
||||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
|
|
||||||
caches := make([]cache.Cache, len(m.Layers))
|
|
||||||
for i := range caches {
|
|
||||||
if m.Layers[i].IsSliding {
|
|
||||||
// Use rotating cache for sliding window layers
|
|
||||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
|
||||||
} else {
|
|
||||||
// Use regular cache for global attention layers
|
|
||||||
caches[i] = cache.NewKVCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return caches
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config holds config for the full multimodal model
|
|
||||||
type Config struct {
|
|
||||||
TextConfig TextConfig `json:"text_config"`
|
|
||||||
VisionConfig VisionConfig `json:"vision_config"`
|
|
||||||
|
|
||||||
// Image token config (from config.json)
|
|
||||||
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
|
|
||||||
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
|
|
||||||
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
|
|
||||||
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model is the full Gemma 3 multimodal model
|
|
||||||
type Model struct {
|
|
||||||
VisionTower *VisionTower `weight:"vision_tower"`
|
|
||||||
Projector *MultiModalProjector `weight:"multi_modal_projector"`
|
|
||||||
TextModel *TextModel `weight:"language_model"`
|
|
||||||
Config *Config
|
|
||||||
tok *tokenizer.Tokenizer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load loads the full multimodal Gemma 3 model
|
|
||||||
func Load(modelPath string) (*Model, error) {
|
|
||||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set defaults for text config (multimodal config often has incomplete text_config)
|
|
||||||
// These defaults match transformers.Gemma3TextConfig defaults
|
|
||||||
tc := &cfg.TextConfig
|
|
||||||
if tc.HeadDim == 0 {
|
|
||||||
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
|
|
||||||
}
|
|
||||||
if tc.NumAttentionHeads == 0 {
|
|
||||||
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
|
|
||||||
tc.NumAttentionHeads = 8
|
|
||||||
}
|
|
||||||
if tc.NumKeyValueHeads == 0 {
|
|
||||||
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
|
|
||||||
tc.NumKeyValueHeads = 4
|
|
||||||
}
|
|
||||||
if tc.VocabSize == 0 {
|
|
||||||
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
|
|
||||||
}
|
|
||||||
if tc.RopeTheta == 0 {
|
|
||||||
tc.RopeTheta = 1000000
|
|
||||||
}
|
|
||||||
if tc.RopeLocalBaseFreq == 0 {
|
|
||||||
tc.RopeLocalBaseFreq = 10000
|
|
||||||
}
|
|
||||||
if tc.RMSNormEps == 0 {
|
|
||||||
tc.RMSNormEps = 1e-6
|
|
||||||
}
|
|
||||||
if tc.SlidingWindowPattern == 0 {
|
|
||||||
tc.SlidingWindowPattern = 6
|
|
||||||
}
|
|
||||||
if tc.MaxPositionEmbeddings == 0 {
|
|
||||||
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute text model scale
|
|
||||||
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
|
|
||||||
|
|
||||||
// Set defaults for image token config
|
|
||||||
if cfg.BOITokenIndex == 0 {
|
|
||||||
cfg.BOITokenIndex = 255999 // <start_of_image>
|
|
||||||
}
|
|
||||||
if cfg.EOITokenIndex == 0 {
|
|
||||||
cfg.EOITokenIndex = 256000 // <end_of_image>
|
|
||||||
}
|
|
||||||
if cfg.ImageTokenIndex == 0 {
|
|
||||||
cfg.ImageTokenIndex = 262144 // <image_soft_token>
|
|
||||||
}
|
|
||||||
if cfg.MMTokensPerImage == 0 {
|
|
||||||
cfg.MMTokensPerImage = 256
|
|
||||||
}
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Model{
|
|
||||||
VisionTower: &VisionTower{
|
|
||||||
Embeddings: &VisionEmbeddings{},
|
|
||||||
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
|
|
||||||
Config: &cfg.VisionConfig,
|
|
||||||
},
|
|
||||||
Projector: &MultiModalProjector{},
|
|
||||||
TextModel: &TextModel{
|
|
||||||
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
|
|
||||||
TextConfig: &cfg.TextConfig,
|
|
||||||
},
|
|
||||||
Config: &cfg,
|
|
||||||
tok: tok,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize text layer metadata
|
|
||||||
for i := range m.TextModel.Layers {
|
|
||||||
m.TextModel.Layers[i] = &DecoderLayer{
|
|
||||||
LayerIdx: int32(i),
|
|
||||||
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize vision encoder layers
|
|
||||||
for i := range m.VisionTower.Encoder {
|
|
||||||
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tied embeddings for text output
|
|
||||||
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
|
|
||||||
m.TextModel.tok = tok
|
|
||||||
|
|
||||||
mlx.Eval(mlx.Collect(m)...)
|
|
||||||
weights.ReleaseAll()
|
|
||||||
|
|
||||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
|
||||||
precomputeGemmaScaledWeights(m.TextModel)
|
|
||||||
|
|
||||||
// Precompute projector's scaled weight
|
|
||||||
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
|
|
||||||
mlx.Eval(m.Projector.SoftEmbNormScaled)
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the text-only forward pass
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
return m.TextModel.Forward(tokens, caches)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardWithImage runs the multimodal forward pass
|
|
||||||
// tokens: [B, L] input token IDs (with image placeholder tokens)
|
|
||||||
// image: [B, H, W, C] preprocessed image tensor
|
|
||||||
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
||||||
cfg := m.Config.TextConfig
|
|
||||||
|
|
||||||
// Find image token position FIRST before any eval that might free tokens
|
|
||||||
imageStartPos := int32(-1)
|
|
||||||
if image != nil && B == 1 {
|
|
||||||
tokenData := tokens.DataInt32() // This evals tokens
|
|
||||||
for i, t := range tokenData {
|
|
||||||
if t == m.Config.ImageTokenIndex {
|
|
||||||
imageStartPos = int32(i)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get text embeddings and scale
|
|
||||||
h := m.TextModel.EmbedTokens.Forward(tokens)
|
|
||||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
|
|
||||||
|
|
||||||
// Process image if provided
|
|
||||||
if image != nil && imageStartPos >= 0 {
|
|
||||||
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
|
|
||||||
visionFeatures := m.VisionTower.Forward(image)
|
|
||||||
|
|
||||||
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
|
|
||||||
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
|
|
||||||
|
|
||||||
// Eval h and imageEmbeds together so neither gets freed
|
|
||||||
mlx.Eval(h, imageEmbeds)
|
|
||||||
|
|
||||||
// Cast imageEmbeds to match text embeddings dtype (bf16)
|
|
||||||
if imageEmbeds.Dtype() != h.Dtype() {
|
|
||||||
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
|
|
||||||
mlx.Eval(imageEmbeds)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert image embeddings at the known position
|
|
||||||
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run through text model layers
|
|
||||||
for i, layer := range m.TextModel.Layers {
|
|
||||||
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final norm and output projection
|
|
||||||
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
|
|
||||||
// at a known position (to avoid re-scanning tokens after eval)
|
|
||||||
// textEmbeds: [B, L, hidden_size] text embeddings
|
|
||||||
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
|
|
||||||
// startPos: starting position of image tokens in the sequence
|
|
||||||
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
|
|
||||||
numImageTokens := imageEmbeds.Shape()[1]
|
|
||||||
L := textEmbeds.Shape()[1]
|
|
||||||
|
|
||||||
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
|
|
||||||
afterStart := startPos + numImageTokens
|
|
||||||
|
|
||||||
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
|
|
||||||
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
|
|
||||||
|
|
||||||
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
|
|
||||||
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
|
|
||||||
|
|
||||||
// Concatenate: before + imageEmbeds + after along axis 1
|
|
||||||
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Interface methods for Model
|
|
||||||
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
|
|
||||||
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
|
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
||||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
|
|
||||||
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
|
|
||||||
|
|
||||||
// FormatPrompt applies the Gemma 3 multimodal chat template
|
|
||||||
func (m *Model) FormatPrompt(prompt string) string {
|
|
||||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
|
|
||||||
func (m *Model) FormatPromptWithImage(prompt string) string {
|
|
||||||
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
|
|
||||||
// Input tokens containing boi_token (255999) are expanded to:
|
|
||||||
// boi_token + 256 * image_token + eoi_token
|
|
||||||
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
|
|
||||||
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
|
|
||||||
|
|
||||||
for _, t := range tokens {
|
|
||||||
if t == m.Config.BOITokenIndex {
|
|
||||||
// Expand: boi + 256 * image_token + eoi
|
|
||||||
result = append(result, m.Config.BOITokenIndex)
|
|
||||||
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
|
|
||||||
result = append(result, m.Config.ImageTokenIndex)
|
|
||||||
}
|
|
||||||
result = append(result, m.Config.EOITokenIndex)
|
|
||||||
} else {
|
|
||||||
result = append(result, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MultiModalProjector projects vision features to text embedding space
|
|
||||||
type MultiModalProjector struct {
|
|
||||||
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
|
||||||
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
|
||||||
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
|
||||||
|
|
||||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
|
||||||
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward projects vision features to text space
|
|
||||||
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
|
||||||
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
|
||||||
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
|
||||||
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
|
||||||
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
|
||||||
B := visionFeatures.Shape()[0]
|
|
||||||
visionHidden := visionFeatures.Shape()[2]
|
|
||||||
|
|
||||||
// Reshape to [B, 64, 64, hidden]
|
|
||||||
gridSize := int32(64) // sqrt(4096)
|
|
||||||
pooledSize := int32(16) // 64/4
|
|
||||||
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
|
||||||
|
|
||||||
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
|
||||||
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
|
||||||
|
|
||||||
// Average over pooling dimensions (axes 2 and 4)
|
|
||||||
h = mlx.Mean(h, 4, false)
|
|
||||||
h = mlx.Mean(h, 2, false)
|
|
||||||
|
|
||||||
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
|
||||||
numTokens := pooledSize * pooledSize
|
|
||||||
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
|
||||||
|
|
||||||
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
|
||||||
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
|
||||||
|
|
||||||
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
|
||||||
return mlx.Linear(h, p.InputProjection)
|
|
||||||
}
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
// VisionConfig holds configuration for the SigLIP vision tower
|
|
||||||
type VisionConfig struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
ImageSize int32 `json:"image_size"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
PatchSize int32 `json:"patch_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionTower is the SigLIP vision encoder
|
|
||||||
type VisionTower struct {
|
|
||||||
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
|
||||||
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
|
||||||
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
|
||||||
Config *VisionConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionEmbeddings handles patch and position embeddings
|
|
||||||
type VisionEmbeddings struct {
|
|
||||||
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
|
||||||
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
|
||||||
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
|
||||||
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionEncoderLayer is a single transformer encoder layer
|
|
||||||
type VisionEncoderLayer struct {
|
|
||||||
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
|
||||||
Attention *VisionAttention `weight:"self_attn"`
|
|
||||||
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
|
||||||
MLP *VisionMLP `weight:"mlp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionAttention implements multi-head self-attention
|
|
||||||
type VisionAttention struct {
|
|
||||||
QProj *nn.Linear `weight:"q_proj"`
|
|
||||||
KProj *nn.Linear `weight:"k_proj"`
|
|
||||||
VProj *nn.Linear `weight:"v_proj"`
|
|
||||||
OutProj *nn.Linear `weight:"out_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionMLP is the feed-forward network
|
|
||||||
type VisionMLP struct {
|
|
||||||
FC1 *nn.Linear `weight:"fc1"`
|
|
||||||
FC2 *nn.Linear `weight:"fc2"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the vision tower on preprocessed images
|
|
||||||
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
|
||||||
// Output: [B, num_patches, hidden_size]
|
|
||||||
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
|
||||||
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
|
||||||
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
|
||||||
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
|
||||||
|
|
||||||
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
|
||||||
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
|
||||||
h = mlx.Add(h, bias)
|
|
||||||
|
|
||||||
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
|
||||||
B := h.Shape()[0]
|
|
||||||
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
|
||||||
hidden := h.Shape()[3]
|
|
||||||
numPatches := gridH * gridW
|
|
||||||
h = mlx.Reshape(h, B, numPatches, hidden)
|
|
||||||
|
|
||||||
// Add position embeddings
|
|
||||||
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
|
||||||
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
|
||||||
h = mlx.Add(h, posEmbed)
|
|
||||||
|
|
||||||
// Encoder layers
|
|
||||||
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
|
||||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
|
||||||
for _, layer := range v.Encoder {
|
|
||||||
h = layer.Forward(h, v.Config, scale)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final layer norm
|
|
||||||
h = v.PostLayerNorm.Forward(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs a vision encoder layer
|
|
||||||
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
|
||||||
// Pre-norm attention
|
|
||||||
h := l.LayerNorm1.Forward(x)
|
|
||||||
h = l.Attention.Forward(h, cfg, scale)
|
|
||||||
x = mlx.Add(x, h)
|
|
||||||
|
|
||||||
// Pre-norm MLP
|
|
||||||
h = l.LayerNorm2.Forward(x)
|
|
||||||
h = l.MLP.Forward(h)
|
|
||||||
return mlx.Add(x, h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs multi-head self-attention
|
|
||||||
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
|
||||||
B, L := x.Shape()[0], x.Shape()[1]
|
|
||||||
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
|
||||||
|
|
||||||
q := a.QProj.Forward(x)
|
|
||||||
k := a.KProj.Forward(x)
|
|
||||||
v := a.VProj.Forward(x)
|
|
||||||
|
|
||||||
// Reshape to [B, num_heads, L, head_dim]
|
|
||||||
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
||||||
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
||||||
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
||||||
|
|
||||||
// Scaled dot-product attention (no causal mask for vision)
|
|
||||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
||||||
|
|
||||||
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
|
||||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
|
||||||
|
|
||||||
return a.OutProj.Forward(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward runs the MLP with GELU activation
|
|
||||||
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
h := mlx.GELU(m.FC1.Forward(x))
|
|
||||||
return m.FC2.Forward(h)
|
|
||||||
}
|
|
||||||
@@ -1,840 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
|
||||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
|
||||||
package glm4_moe_lite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RopeScaling holds RoPE scaling configuration
|
|
||||||
type RopeScaling struct {
|
|
||||||
Factor float32 `json:"factor"`
|
|
||||||
MscaleAllDim float32 `json:"mscale_all_dim"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config holds GLM4-MoE-Lite model configuration
|
|
||||||
type Config struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
||||||
VocabSize int32 `json:"vocab_size"`
|
|
||||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
||||||
AttentionBias bool `json:"attention_bias"`
|
|
||||||
|
|
||||||
// MLA (Multi-head Latent Attention) parameters
|
|
||||||
QLoraRank int32 `json:"q_lora_rank"`
|
|
||||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
|
||||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
|
||||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
|
||||||
VHeadDim int32 `json:"v_head_dim"`
|
|
||||||
|
|
||||||
// MoE parameters
|
|
||||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
|
||||||
NSharedExperts int32 `json:"n_shared_experts"`
|
|
||||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
||||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
|
||||||
NormTopKProb bool `json:"norm_topk_prob"`
|
|
||||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
|
||||||
NGroup int32 `json:"n_group"`
|
|
||||||
TopKGroup int32 `json:"topk_group"`
|
|
||||||
|
|
||||||
// RoPE scaling
|
|
||||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
|
||||||
|
|
||||||
// Quantization parameters (set during load based on model quantization)
|
|
||||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
|
||||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
|
||||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
|
||||||
|
|
||||||
// Computed fields
|
|
||||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
|
||||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
|
||||||
}
|
|
||||||
|
|
||||||
// MLAAttention implements Multi-head Latent Attention with absorption.
|
|
||||||
// This uses absorbed MLA which operates in latent space for reduced KV cache.
|
|
||||||
type MLAAttention struct {
|
|
||||||
// Low-rank query projections
|
|
||||||
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
|
|
||||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
|
||||||
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
|
|
||||||
|
|
||||||
// Low-rank KV projections (with shared rope component)
|
|
||||||
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
|
|
||||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
|
||||||
|
|
||||||
// Absorbed MLA projections (derived from kv_b_proj)
|
|
||||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
|
||||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
|
||||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
|
||||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
|
||||||
|
|
||||||
// Output projection
|
|
||||||
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward computes absorbed MLA attention output.
|
|
||||||
// This operates in latent space for reduced KV cache memory.
|
|
||||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
||||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
|
||||||
q := a.QAProj.Forward(x)
|
|
||||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
|
||||||
q = a.QBProj.Forward(q)
|
|
||||||
|
|
||||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
|
||||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
|
||||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
||||||
|
|
||||||
// Split Q into nope and rope parts
|
|
||||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
|
||||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
|
||||||
|
|
||||||
// KV path: get compressed KV and k_pe
|
|
||||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
|
||||||
|
|
||||||
// Split into compressed_kv and k_pe (shared rope component)
|
|
||||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
|
||||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
|
||||||
|
|
||||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
|
||||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
|
||||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
|
||||||
|
|
||||||
// Apply layernorm to get kv latent representation
|
|
||||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
|
||||||
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
|
|
||||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
|
||||||
|
|
||||||
// Apply RoPE to the rope parts
|
|
||||||
offset := 0
|
|
||||||
if c != nil {
|
|
||||||
offset = c.Offset()
|
|
||||||
}
|
|
||||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
||||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
||||||
|
|
||||||
// ABSORBED MLA: project q_nope to latent space
|
|
||||||
// qNope: [B, num_heads, L, qk_nope_head_dim]
|
|
||||||
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
|
|
||||||
// Result: [B, num_heads, L, kv_lora_rank]
|
|
||||||
qLatent := a.EmbedQ.Forward(qNope)
|
|
||||||
|
|
||||||
// Keys = concat(kvLatent, kPE)
|
|
||||||
// kvLatent: [B, 1, L, kv_lora_rank]
|
|
||||||
// kPE: [B, 1, L, qk_rope_head_dim]
|
|
||||||
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
|
|
||||||
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
|
||||||
|
|
||||||
// Cache the smaller latent representation
|
|
||||||
// We cache keys (latent + rope) and use empty values since values are derived from keys
|
|
||||||
cachedL := L
|
|
||||||
if c != nil {
|
|
||||||
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
|
|
||||||
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
|
|
||||||
keys, _ = c.Update(keys, placeholderValues, int(L))
|
|
||||||
cachedL = int32(keys.Shape()[2])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Values are the first kv_lora_rank dims of keys (slice off rope part)
|
|
||||||
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
|
||||||
|
|
||||||
// Queries = concat(qLatent, qPE)
|
|
||||||
// qLatent: [B, num_heads, L, kv_lora_rank]
|
|
||||||
// qPE: [B, num_heads, L, qk_rope_head_dim]
|
|
||||||
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
|
|
||||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
|
||||||
|
|
||||||
// Attention in latent space
|
|
||||||
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
|
|
||||||
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
|
|
||||||
// values: [B, 1, cachedL, kv_lora_rank]
|
|
||||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
|
||||||
|
|
||||||
// ABSORBED MLA: unembed from latent space
|
|
||||||
// out: [B, num_heads, L, kv_lora_rank]
|
|
||||||
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
|
|
||||||
// Result: [B, num_heads, L, v_head_dim]
|
|
||||||
out = a.UnembedOut.Forward(out)
|
|
||||||
|
|
||||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
|
||||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
|
||||||
|
|
||||||
return a.OProj.Forward(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
|
||||||
type DenseMLP struct {
|
|
||||||
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
|
|
||||||
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
|
|
||||||
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the SwiGLU MLP
|
|
||||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
|
||||||
up := m.UpProj.Forward(x)
|
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoEGate implements the expert gating mechanism
|
|
||||||
type MoEGate struct {
|
|
||||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
|
||||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward computes expert selection indices and scores
|
|
||||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
||||||
// Compute gate logits through linear layer (handles both quantized and non-quantized)
|
|
||||||
gates := g.Gate.Forward(x)
|
|
||||||
|
|
||||||
// Sigmoid scoring
|
|
||||||
scores := mlx.Sigmoid(gates)
|
|
||||||
origScores := scores
|
|
||||||
|
|
||||||
// Add correction bias if present
|
|
||||||
if g.EScoreCorrectionBias != nil {
|
|
||||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group-wise expert selection (simplified for n_group=1)
|
|
||||||
// Select top-k experts
|
|
||||||
topK := cfg.NumExpertsPerTok
|
|
||||||
negScores := mlx.Neg(scores)
|
|
||||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
|
||||||
|
|
||||||
shape := inds.Shape()
|
|
||||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
|
||||||
|
|
||||||
// Get scores for selected experts
|
|
||||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
|
||||||
|
|
||||||
// Normalize if configured
|
|
||||||
if topK > 1 && cfg.NormTopKProb {
|
|
||||||
sumScores := mlx.Sum(scores, -1, true)
|
|
||||||
scores = mlx.Div(scores, sumScores)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply routing scaling factor
|
|
||||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
|
||||||
|
|
||||||
return inds, scores
|
|
||||||
}
|
|
||||||
|
|
||||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
|
||||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
|
||||||
type SwitchMLP struct {
|
|
||||||
// Dequantized weights (used when GatherQMM not available)
|
|
||||||
GateWeight *mlx.Array
|
|
||||||
UpWeight *mlx.Array
|
|
||||||
DownWeight *mlx.Array
|
|
||||||
|
|
||||||
// Quantized weights (used with GatherQMM for 4/8-bit affine)
|
|
||||||
GateWeightQ, GateScales, GateBiases *mlx.Array
|
|
||||||
UpWeightQ, UpScales, UpBiases *mlx.Array
|
|
||||||
DownWeightQ, DownScales, DownBiases *mlx.Array
|
|
||||||
|
|
||||||
// Quantization bits per projection (supports mixed precision Q4/Q8)
|
|
||||||
GateBits int
|
|
||||||
UpBits int
|
|
||||||
DownBits int
|
|
||||||
|
|
||||||
// Quantization group size per projection (detected from tensor shapes)
|
|
||||||
GateGroupSize int
|
|
||||||
UpGroupSize int
|
|
||||||
DownGroupSize int
|
|
||||||
|
|
||||||
// If true, use GatherQMM with quantized weights
|
|
||||||
UseQuantized bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the switched expert MLP
|
|
||||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B, L := shape[0], shape[1]
|
|
||||||
topK := cfg.NumExpertsPerTok
|
|
||||||
|
|
||||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
|
||||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
|
||||||
|
|
||||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
|
||||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
|
||||||
|
|
||||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
|
||||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
|
||||||
|
|
||||||
// Sort for efficient gather (when we have many tokens)
|
|
||||||
doSort := B*L >= 64
|
|
||||||
var invOrder *mlx.Array
|
|
||||||
n := B * L * topK
|
|
||||||
|
|
||||||
if doSort {
|
|
||||||
idxAll := mlx.Flatten(idxFlat)
|
|
||||||
order := mlx.Argsort(idxAll, 0)
|
|
||||||
invOrder = mlx.Argsort(order, 0)
|
|
||||||
// Reorder x based on sorted indices
|
|
||||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
|
||||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
var gate, up, hidden, down *mlx.Array
|
|
||||||
|
|
||||||
if s.UseQuantized {
|
|
||||||
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
|
|
||||||
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
|
|
||||||
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
|
||||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
|
||||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
|
||||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
||||||
|
|
||||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
|
||||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
|
||||||
} else {
|
|
||||||
// Use GatherMM for dequantized/non-quantized weights
|
|
||||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
||||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
||||||
|
|
||||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsort if we sorted
|
|
||||||
if doSort {
|
|
||||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
|
||||||
} else {
|
|
||||||
down = mlx.Squeeze(down, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SharedExperts implements the shared expert MLP
|
|
||||||
type SharedExperts struct {
|
|
||||||
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
|
|
||||||
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
|
|
||||||
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the shared expert MLP
|
|
||||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
|
||||||
up := s.UpProj.Forward(x)
|
|
||||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoE implements the full Mixture of Experts layer
|
|
||||||
type MoE struct {
|
|
||||||
Gate *MoEGate
|
|
||||||
SwitchMLP *SwitchMLP
|
|
||||||
SharedExperts *SharedExperts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the MoE layer
|
|
||||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|
||||||
shape := x.Shape()
|
|
||||||
B, L := shape[0], shape[1]
|
|
||||||
|
|
||||||
// Get expert indices and scores
|
|
||||||
inds, scores := m.Gate.Forward(x, cfg)
|
|
||||||
|
|
||||||
// Apply routed experts
|
|
||||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
|
||||||
|
|
||||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
|
||||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
|
||||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
|
||||||
|
|
||||||
// Add shared experts if present
|
|
||||||
if m.SharedExperts != nil {
|
|
||||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
|
||||||
}
|
|
||||||
|
|
||||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
|
||||||
type DenseBlock struct {
|
|
||||||
Attention *MLAAttention
|
|
||||||
MLP *DenseMLP
|
|
||||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
||||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the dense block
|
|
||||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
||||||
// Pre-norm attention with residual
|
|
||||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
||||||
h := mlx.Add(x, r)
|
|
||||||
|
|
||||||
// Pre-norm MLP with residual
|
|
||||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
|
||||||
return mlx.Add(h, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoEBlock represents a MoE transformer block
|
|
||||||
type MoEBlock struct {
|
|
||||||
Attention *MLAAttention
|
|
||||||
MoE *MoE
|
|
||||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
||||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward applies the MoE block
|
|
||||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
||||||
// Pre-norm attention with residual
|
|
||||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
||||||
h := mlx.Add(x, r)
|
|
||||||
|
|
||||||
// Pre-norm MoE with residual
|
|
||||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
|
||||||
return mlx.Add(h, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Block interface for both dense and MoE blocks
|
|
||||||
type Block interface {
|
|
||||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model represents the complete GLM4-MoE-Lite model
|
|
||||||
type Model struct {
|
|
||||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
||||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
|
||||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
|
||||||
LMHead nn.LinearLayer `weight:"lm_head"`
|
|
||||||
|
|
||||||
tok *tokenizer.Tokenizer
|
|
||||||
*Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// computeScale computes the attention scale.
|
|
||||||
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
|
|
||||||
func computeScale(cfg *Config) float32 {
|
|
||||||
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
||||||
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
|
||||||
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
|
||||||
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
|
||||||
scale *= s * s
|
|
||||||
}
|
|
||||||
return scale
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
|
||||||
// Currently only 4-bit and 8-bit affine quantization are supported.
|
|
||||||
func supportsGatherQMM(mode string, bits int) bool {
|
|
||||||
return mode == "affine" && (bits == 4 || bits == 8)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
|
||||||
type ExpertWeight struct {
|
|
||||||
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
|
|
||||||
Scales *mlx.Array // Quantization scales (nil if not quantized)
|
|
||||||
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
|
|
||||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
|
||||||
GroupSize int // Quantization group size, 0 if not quantized
|
|
||||||
}
|
|
||||||
|
|
||||||
// getQuantParams returns quantization parameters from model metadata.
|
|
||||||
// Returns groupSize, bits, and mode for the model's quantization type.
|
|
||||||
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
|
|
||||||
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
|
|
||||||
// Use metadata group_size if available (overrides default)
|
|
||||||
if gs := weights.GroupSize(); gs > 0 {
|
|
||||||
groupSize = gs
|
|
||||||
}
|
|
||||||
return groupSize, bits, mode
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadExpertWeight loads an expert weight.
|
|
||||||
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
|
|
||||||
// Otherwise dequantizes and returns only the weight.
|
|
||||||
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
|
||||||
w, _ := weights.GetTensor(path + ".weight")
|
|
||||||
if w == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is a quantized weight by looking for scales
|
|
||||||
scalePath := path + ".weight_scale"
|
|
||||||
if weights.HasTensor(scalePath) {
|
|
||||||
scales, _ := weights.GetTensor(scalePath)
|
|
||||||
var qbiases *mlx.Array
|
|
||||||
qbiasPath := path + ".weight_qbias"
|
|
||||||
if weights.HasTensor(qbiasPath) {
|
|
||||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get quantization params from metadata
|
|
||||||
groupSize, bits, mode := getQuantParams(weights)
|
|
||||||
|
|
||||||
// Update config with group size (for GatherQMM calls)
|
|
||||||
if cfg.QuantGroupSize == 0 {
|
|
||||||
cfg.QuantGroupSize = groupSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// If GatherQMM is supported and requested, return quantized components
|
|
||||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
||||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise dequantize
|
|
||||||
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ExpertWeight{Weight: w}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
|
||||||
// Returns embed_q and unembed_out weights for per-head projections.
|
|
||||||
//
|
|
||||||
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
|
||||||
// Output:
|
|
||||||
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
|
|
||||||
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
|
|
||||||
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
||||||
path := prefix + ".self_attn.kv_b_proj"
|
|
||||||
w, err := weights.GetTensor(path + ".weight")
|
|
||||||
if err != nil || w == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if quantized and dequantize
|
|
||||||
scalePath := path + ".weight_scale"
|
|
||||||
if weights.HasTensor(scalePath) {
|
|
||||||
scales, _ := weights.GetTensor(scalePath)
|
|
||||||
var qbiases *mlx.Array
|
|
||||||
qbiasPath := path + ".weight_qbias"
|
|
||||||
if weights.HasTensor(qbiasPath) {
|
|
||||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
groupSize, bits, mode := getQuantParams(weights)
|
|
||||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
|
||||||
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
|
||||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
|
||||||
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
|
||||||
|
|
||||||
// Split into wk and wv
|
|
||||||
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
|
||||||
// wv: [num_heads, v_head_dim, kv_lora_rank]
|
|
||||||
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
|
||||||
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
|
||||||
|
|
||||||
// Transform for absorbed MLA:
|
|
||||||
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
|
|
||||||
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
|
|
||||||
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
|
||||||
|
|
||||||
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
|
|
||||||
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
|
|
||||||
unembedOut := wv
|
|
||||||
|
|
||||||
return embedQ, unembedOut
|
|
||||||
}
|
|
||||||
|
|
||||||
// StackedExpertWeights holds stacked weights for all experts.
|
|
||||||
type StackedExpertWeights struct {
|
|
||||||
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
|
|
||||||
Scales *mlx.Array // Stacked scales (nil if not quantized)
|
|
||||||
Biases *mlx.Array // Stacked biases (nil if not quantized)
|
|
||||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
|
||||||
GroupSize int // Quantization group size, 0 if not quantized
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
|
||||||
func collectAndStackExpertWeights(
|
|
||||||
weights safetensors.WeightSource,
|
|
||||||
prefix string,
|
|
||||||
projName string,
|
|
||||||
numExperts int32,
|
|
||||||
useQuantized bool,
|
|
||||||
cfg *Config,
|
|
||||||
) *StackedExpertWeights {
|
|
||||||
var w, s, b []*mlx.Array
|
|
||||||
var bits, groupSize int
|
|
||||||
|
|
||||||
for e := int32(0); e < numExperts; e++ {
|
|
||||||
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
|
||||||
ew := loadExpertWeight(weights, path, useQuantized, cfg)
|
|
||||||
if ew == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
w = append(w, ew.Weight)
|
|
||||||
if ew.Scales != nil {
|
|
||||||
s = append(s, ew.Scales)
|
|
||||||
}
|
|
||||||
if ew.Biases != nil {
|
|
||||||
b = append(b, ew.Biases)
|
|
||||||
}
|
|
||||||
if e == 0 {
|
|
||||||
bits = ew.Bits
|
|
||||||
groupSize = ew.GroupSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
|
||||||
if len(w) > 0 {
|
|
||||||
result.Weight = mlx.Stack(w, 0)
|
|
||||||
if len(s) > 0 {
|
|
||||||
result.Scales = mlx.Stack(s, 0)
|
|
||||||
}
|
|
||||||
if len(b) > 0 {
|
|
||||||
result.Biases = mlx.Stack(b, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
|
||||||
// If useQuantized is true and weights support GatherQMM, returns quantized components.
|
|
||||||
// Otherwise returns dequantized weights with nil scales/biases.
|
|
||||||
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
|
|
||||||
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
|
||||||
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
|
||||||
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
|
|
||||||
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
|
|
||||||
return gate, up, down
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
|
||||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
|
||||||
// Read config from manifest
|
|
||||||
configData, err := modelManifest.ReadConfig("config.json")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute derived fields
|
|
||||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
||||||
cfg.Scale = computeScale(&cfg)
|
|
||||||
|
|
||||||
// Load weights from manifest blobs
|
|
||||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := weights.Load(0); err != nil {
|
|
||||||
return nil, fmt.Errorf("load weight data: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up quantization parameters (only if model is actually quantized)
|
|
||||||
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
|
|
||||||
quantization := weights.Quantization()
|
|
||||||
useQuantized := false
|
|
||||||
if quantization != "" {
|
|
||||||
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
|
|
||||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load tokenizer from manifest with config files for EOS token detection
|
|
||||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build tokenizer config with companion files for EOS/BOS token loading
|
|
||||||
tokConfig := &tokenizer.TokenizerConfig{
|
|
||||||
ConfigJSON: configData, // Already loaded above, contains eos_token_id
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to load generation_config.json if available (preferred source for EOS)
|
|
||||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
|
||||||
tokConfig.GenerationConfigJSON = genConfigData
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to load tokenizer_config.json if available
|
|
||||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
|
||||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Model{
|
|
||||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
|
||||||
Config: &cfg,
|
|
||||||
tok: tok,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load embedding, norm, and lm_head
|
|
||||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load layers manually due to different block types
|
|
||||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
||||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
|
||||||
|
|
||||||
// Load attention (same for both block types)
|
|
||||||
attn := &MLAAttention{}
|
|
||||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanitize MLA weights for absorbed attention
|
|
||||||
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
|
|
||||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
|
||||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
|
||||||
|
|
||||||
if i < cfg.FirstKDenseReplace {
|
|
||||||
// Dense block
|
|
||||||
block := &DenseBlock{Attention: attn}
|
|
||||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
|
||||||
}
|
|
||||||
m.Layers[i] = block
|
|
||||||
} else {
|
|
||||||
// MoE block
|
|
||||||
block := &MoEBlock{Attention: attn}
|
|
||||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stack expert weights (pass cfg so group sizes can be detected)
|
|
||||||
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
|
||||||
|
|
||||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
|
||||||
if useQuantized {
|
|
||||||
switchMLP.GateWeightQ = gate.Weight
|
|
||||||
switchMLP.GateScales = gate.Scales
|
|
||||||
switchMLP.GateBiases = gate.Biases
|
|
||||||
switchMLP.GateBits = gate.Bits
|
|
||||||
switchMLP.GateGroupSize = gate.GroupSize
|
|
||||||
switchMLP.UpWeightQ = up.Weight
|
|
||||||
switchMLP.UpScales = up.Scales
|
|
||||||
switchMLP.UpBiases = up.Biases
|
|
||||||
switchMLP.UpBits = up.Bits
|
|
||||||
switchMLP.UpGroupSize = up.GroupSize
|
|
||||||
switchMLP.DownWeightQ = down.Weight
|
|
||||||
switchMLP.DownScales = down.Scales
|
|
||||||
switchMLP.DownBiases = down.Biases
|
|
||||||
switchMLP.DownBits = down.Bits
|
|
||||||
switchMLP.DownGroupSize = down.GroupSize
|
|
||||||
} else {
|
|
||||||
switchMLP.GateWeight = gate.Weight
|
|
||||||
switchMLP.UpWeight = up.Weight
|
|
||||||
switchMLP.DownWeight = down.Weight
|
|
||||||
}
|
|
||||||
|
|
||||||
block.MoE = &MoE{
|
|
||||||
Gate: &MoEGate{},
|
|
||||||
SwitchMLP: switchMLP,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load gate weights
|
|
||||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load shared experts if present
|
|
||||||
if cfg.NSharedExperts > 0 {
|
|
||||||
block.MoE.SharedExperts = &SharedExperts{}
|
|
||||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Layers[i] = block
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.Eval(mlx.Collect(m)...)
|
|
||||||
weights.ReleaseAll()
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward computes the forward pass of the model
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
||||||
|
|
||||||
h := m.EmbedTokens.Forward(tokens)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
var c cache.Cache
|
|
||||||
if caches != nil {
|
|
||||||
c = caches[i]
|
|
||||||
}
|
|
||||||
h = layer.Forward(h, c, B, L, m.Config)
|
|
||||||
}
|
|
||||||
|
|
||||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
|
||||||
return m.LMHead.Forward(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Interface methods
|
|
||||||
|
|
||||||
// NumLayers returns the number of transformer layers
|
|
||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
||||||
|
|
||||||
// Tokenizer returns the model's tokenizer
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
||||||
|
|
||||||
// NewCache creates a new KV cache for the model
|
|
||||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
|
||||||
caches := make([]cache.Cache, len(m.Layers))
|
|
||||||
for i := range caches {
|
|
||||||
caches[i] = cache.NewKVCache()
|
|
||||||
}
|
|
||||||
return caches
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
|
||||||
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
|
|
||||||
func (m *Model) FormatPrompt(prompt string) string {
|
|
||||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
|
||||||
// When think is true, the prompt ends with <think> to enable reasoning mode.
|
|
||||||
// When think is false, the prompt ends with </think> to skip reasoning.
|
|
||||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
|
||||||
if think {
|
|
||||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
||||||
}
|
|
||||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
|
||||||
func (m *Model) NewRenderer() *Renderer {
|
|
||||||
return &Renderer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
|
||||||
func (m *Model) NewParser() *Parser {
|
|
||||||
return &Parser{}
|
|
||||||
}
|
|
||||||
@@ -1,479 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package glm4_moe_lite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/xml"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"strings"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/logutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
type parserState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
parserState_LookingForThinkingOpen parserState = iota
|
|
||||||
parserState_ThinkingStartedEatingWhitespace
|
|
||||||
parserState_CollectingThinking
|
|
||||||
parserState_ThinkingDoneEatingWhitespace
|
|
||||||
parserState_CollectingContent
|
|
||||||
parserState_ToolStartedEatingWhitespace
|
|
||||||
parserState_CollectingToolContent
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
thinkingOpenTag = "<think>"
|
|
||||||
thinkingCloseTag = "</think>"
|
|
||||||
toolOpenTag = "<tool_call>"
|
|
||||||
toolCloseTag = "</tool_call>"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
|
||||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
|
||||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
|
||||||
type Parser struct {
|
|
||||||
state parserState
|
|
||||||
buffer strings.Builder
|
|
||||||
tools []api.Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
|
||||||
func (p *Parser) HasToolSupport() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
|
||||||
func (p *Parser) HasThinkingSupport() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init initializes the parser with tools and thinking configuration.
|
|
||||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
|
||||||
p.tools = tools
|
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
|
||||||
p.state = parserState_CollectingThinking
|
|
||||||
}
|
|
||||||
return tools
|
|
||||||
}
|
|
||||||
|
|
||||||
type parserEvent interface {
|
|
||||||
isParserEvent()
|
|
||||||
}
|
|
||||||
|
|
||||||
type eventContent struct {
|
|
||||||
content string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eventContent) isParserEvent() {}
|
|
||||||
|
|
||||||
type eventRawToolCall struct {
|
|
||||||
raw string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eventRawToolCall) isParserEvent() {}
|
|
||||||
|
|
||||||
type eventThinkingContent struct {
|
|
||||||
content string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eventThinkingContent) isParserEvent() {}
|
|
||||||
|
|
||||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
|
||||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
|
||||||
p.buffer.WriteString(s)
|
|
||||||
events := p.parseEvents()
|
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
|
||||||
var contentSb strings.Builder
|
|
||||||
var thinkingSb strings.Builder
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
switch event := event.(type) {
|
|
||||||
case eventRawToolCall:
|
|
||||||
toolCall, err := parseToolCall(event, p.tools)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
|
||||||
return "", "", nil, err
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
|
||||||
case eventThinkingContent:
|
|
||||||
thinkingSb.WriteString(event.content)
|
|
||||||
case eventContent:
|
|
||||||
contentSb.WriteString(event.content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Parser) parseEvents() []parserEvent {
|
|
||||||
var all []parserEvent
|
|
||||||
|
|
||||||
keepLooping := true
|
|
||||||
for keepLooping {
|
|
||||||
var events []parserEvent
|
|
||||||
events, keepLooping = p.eat()
|
|
||||||
if len(events) > 0 {
|
|
||||||
all = append(all, events...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(all) > 0 {
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return all
|
|
||||||
}
|
|
||||||
|
|
||||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
|
||||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
|
||||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
|
||||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
|
||||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
|
||||||
p.buffer.Reset()
|
|
||||||
if trimmed == "" {
|
|
||||||
return nil, false // Still only whitespace, keep waiting for more input
|
|
||||||
}
|
|
||||||
p.state = nextState
|
|
||||||
p.buffer.WriteString(trimmed)
|
|
||||||
return nil, true // Successfully transitioned
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
|
||||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
|
||||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
|
||||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
|
||||||
before := split[0]
|
|
||||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
|
||||||
after := split[1]
|
|
||||||
if trimAfter {
|
|
||||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
|
||||||
}
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(after)
|
|
||||||
return before, after
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
|
||||||
var events []parserEvent
|
|
||||||
|
|
||||||
switch p.state {
|
|
||||||
case parserState_LookingForThinkingOpen:
|
|
||||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
|
||||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
|
||||||
// Found <think> opening tag
|
|
||||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
|
||||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(after)
|
|
||||||
if after == "" {
|
|
||||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
|
||||||
} else {
|
|
||||||
p.state = parserState_CollectingThinking
|
|
||||||
}
|
|
||||||
return events, true
|
|
||||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
|
||||||
// Partial opening tag seen, keep accumulating
|
|
||||||
return events, false
|
|
||||||
} else if trimmed == "" {
|
|
||||||
// Only whitespace, keep accumulating
|
|
||||||
return events, false
|
|
||||||
} else {
|
|
||||||
// No thinking tag found, skip to content collection
|
|
||||||
p.state = parserState_CollectingContent
|
|
||||||
// Don't trim - we want to keep the original content
|
|
||||||
return events, true
|
|
||||||
}
|
|
||||||
|
|
||||||
case parserState_ThinkingStartedEatingWhitespace:
|
|
||||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
|
||||||
|
|
||||||
case parserState_CollectingThinking:
|
|
||||||
acc := p.buffer.String()
|
|
||||||
if strings.Contains(acc, thinkingCloseTag) {
|
|
||||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
|
||||||
if len(thinking) > 0 {
|
|
||||||
events = append(events, eventThinkingContent{content: thinking})
|
|
||||||
}
|
|
||||||
if remaining == "" {
|
|
||||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
|
||||||
} else {
|
|
||||||
p.state = parserState_CollectingContent
|
|
||||||
}
|
|
||||||
return events, true
|
|
||||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
|
||||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
|
||||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
|
||||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
|
||||||
|
|
||||||
unambiguous := acc[:ambiguousStart]
|
|
||||||
ambiguous := acc[ambiguousStart:]
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(ambiguous)
|
|
||||||
if len(unambiguous) > 0 {
|
|
||||||
events = append(events, eventThinkingContent{content: unambiguous})
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
} else {
|
|
||||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
|
||||||
whitespaceLen := trailingWhitespaceLen(acc)
|
|
||||||
ambiguousStart := len(acc) - whitespaceLen
|
|
||||||
|
|
||||||
unambiguous := acc[:ambiguousStart]
|
|
||||||
ambiguous := acc[ambiguousStart:]
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(ambiguous)
|
|
||||||
if len(unambiguous) > 0 {
|
|
||||||
events = append(events, eventThinkingContent{content: unambiguous})
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
}
|
|
||||||
|
|
||||||
case parserState_ThinkingDoneEatingWhitespace:
|
|
||||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
|
||||||
|
|
||||||
case parserState_CollectingContent:
|
|
||||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
|
||||||
before, after := p.splitAtTag(toolOpenTag, true)
|
|
||||||
if len(before) > 0 {
|
|
||||||
events = append(events, eventContent{content: before})
|
|
||||||
}
|
|
||||||
if after == "" {
|
|
||||||
p.state = parserState_ToolStartedEatingWhitespace
|
|
||||||
} else {
|
|
||||||
p.state = parserState_CollectingToolContent
|
|
||||||
}
|
|
||||||
return events, true
|
|
||||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
|
||||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
|
||||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
|
||||||
|
|
||||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
|
||||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(ambiguous)
|
|
||||||
if len(unambiguous) > 0 {
|
|
||||||
events = append(events, eventContent{content: unambiguous})
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
} else {
|
|
||||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
|
||||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
|
||||||
|
|
||||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
|
||||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
|
||||||
p.buffer.Reset()
|
|
||||||
p.buffer.WriteString(ambiguous)
|
|
||||||
if len(unambiguous) > 0 {
|
|
||||||
events = append(events, eventContent{content: unambiguous})
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
}
|
|
||||||
|
|
||||||
case parserState_ToolStartedEatingWhitespace:
|
|
||||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
|
||||||
|
|
||||||
case parserState_CollectingToolContent:
|
|
||||||
acc := p.buffer.String()
|
|
||||||
if strings.Contains(acc, toolCloseTag) {
|
|
||||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
|
||||||
if len(toolContent) == 0 {
|
|
||||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
|
||||||
}
|
|
||||||
events = append(events, eventRawToolCall{raw: toolContent})
|
|
||||||
p.state = parserState_CollectingContent
|
|
||||||
return events, true
|
|
||||||
} else {
|
|
||||||
// Keep accumulating - tool calls are not streamed
|
|
||||||
// We just wait for the closing tag
|
|
||||||
return events, false
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
|
||||||
func overlap(s, tag string) int {
|
|
||||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
|
||||||
if strings.HasSuffix(s, tag[:i]) {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
|
||||||
func trailingWhitespaceLen(s string) int {
|
|
||||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
|
||||||
return len(s) - len(trimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
|
||||||
type ToolCallXML struct {
|
|
||||||
XMLName xml.Name `xml:"tool_call"`
|
|
||||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
|
||||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
|
||||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
|
||||||
}
|
|
||||||
|
|
||||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
|
||||||
func escapeContent(s string) string {
|
|
||||||
var result strings.Builder
|
|
||||||
inTag := false
|
|
||||||
|
|
||||||
for i := range len(s) {
|
|
||||||
ch := s[i]
|
|
||||||
|
|
||||||
if ch == '<' {
|
|
||||||
// Check if this is a known tag
|
|
||||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
|
||||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
|
||||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
|
||||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
|
||||||
inTag = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if inTag {
|
|
||||||
result.WriteByte(ch)
|
|
||||||
if ch == '>' {
|
|
||||||
inTag = false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Escape special characters in text content
|
|
||||||
switch ch {
|
|
||||||
case '&':
|
|
||||||
result.WriteString("&")
|
|
||||||
case '<':
|
|
||||||
result.WriteString("<")
|
|
||||||
case '>':
|
|
||||||
result.WriteString(">")
|
|
||||||
default:
|
|
||||||
result.WriteByte(ch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
|
||||||
// Escape any unescaped entities in text content
|
|
||||||
escaped := escapeContent(raw.raw)
|
|
||||||
|
|
||||||
// Wrap the content in a root element to make it valid XML
|
|
||||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
|
||||||
|
|
||||||
// Parse XML into struct
|
|
||||||
var parsed ToolCallXML
|
|
||||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
|
||||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and trim function name
|
|
||||||
functionName := strings.TrimSpace(parsed.Content)
|
|
||||||
if functionName == "" {
|
|
||||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify keys and values are paired correctly
|
|
||||||
if len(parsed.Keys) != len(parsed.Values) {
|
|
||||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the matching tool to get parameter types
|
|
||||||
var matchedTool *api.Tool
|
|
||||||
for i := range tools {
|
|
||||||
if tools[i].Function.Name == functionName {
|
|
||||||
matchedTool = &tools[i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build arguments map by pairing keys and values
|
|
||||||
toolCall := api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: functionName,
|
|
||||||
Arguments: api.NewToolCallFunctionArguments(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range parsed.Keys {
|
|
||||||
key := strings.TrimSpace(parsed.Keys[i])
|
|
||||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
|
||||||
|
|
||||||
// Look up parameter type
|
|
||||||
var paramType api.PropertyType
|
|
||||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
|
||||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
|
||||||
// Handle anyOf by collecting all types from the union
|
|
||||||
if len(prop.AnyOf) > 0 {
|
|
||||||
for _, anyOfProp := range prop.AnyOf {
|
|
||||||
paramType = append(paramType, anyOfProp.Type...)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
paramType = prop.Type
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse value with type coercion
|
|
||||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolCall, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
|
||||||
func parseValue(value string, paramType api.PropertyType) any {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
|
|
||||||
// If no type specified, return as string
|
|
||||||
if len(paramType) == 0 {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse based on specified types
|
|
||||||
for _, t := range paramType {
|
|
||||||
switch t {
|
|
||||||
case "boolean":
|
|
||||||
if value == "true" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if value == "false" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case "integer":
|
|
||||||
var i int64
|
|
||||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
case "number":
|
|
||||||
var f float64
|
|
||||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
case "array", "object":
|
|
||||||
// Try to parse as JSON
|
|
||||||
var result any
|
|
||||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to string
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package glm4_moe_lite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParserThinking(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
thinkEnabled bool
|
|
||||||
wantContent string
|
|
||||||
wantThinking string
|
|
||||||
wantToolCalls int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "thinking enabled - simple thinking then content",
|
|
||||||
input: "Let me think about this...</think>Here is my answer.",
|
|
||||||
thinkEnabled: true,
|
|
||||||
wantThinking: "Let me think about this...",
|
|
||||||
wantContent: "Here is my answer.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "thinking enabled - only thinking",
|
|
||||||
input: "I need to consider multiple factors...",
|
|
||||||
thinkEnabled: true,
|
|
||||||
wantThinking: "I need to consider multiple factors...",
|
|
||||||
wantContent: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "thinking disabled - direct content",
|
|
||||||
input: "Here is my direct answer.",
|
|
||||||
thinkEnabled: false,
|
|
||||||
wantThinking: "",
|
|
||||||
wantContent: "Here is my direct answer.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "thinking with tool call",
|
|
||||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
|
||||||
thinkEnabled: true,
|
|
||||||
wantThinking: "Let me search for that...",
|
|
||||||
wantContent: "I'll use a tool.",
|
|
||||||
wantToolCalls: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p := &Parser{}
|
|
||||||
|
|
||||||
var thinkValue *api.ThinkValue
|
|
||||||
if tt.thinkEnabled {
|
|
||||||
thinkValue = &api.ThinkValue{Value: true}
|
|
||||||
} else {
|
|
||||||
thinkValue = &api.ThinkValue{Value: false}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define tools for tool call tests
|
|
||||||
props := api.NewToolPropertiesMap()
|
|
||||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
||||||
tools := []api.Tool{
|
|
||||||
{
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "search",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Properties: props,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Init(tools, nil, thinkValue)
|
|
||||||
|
|
||||||
content, thinking, calls, err := p.Add(tt.input, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if thinking != tt.wantThinking {
|
|
||||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
|
||||||
}
|
|
||||||
if content != tt.wantContent {
|
|
||||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
|
||||||
}
|
|
||||||
if len(calls) != tt.wantToolCalls {
|
|
||||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParserToolCall(t *testing.T) {
|
|
||||||
p := &Parser{}
|
|
||||||
|
|
||||||
props := api.NewToolPropertiesMap()
|
|
||||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
||||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
||||||
tools := []api.Tool{
|
|
||||||
{
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Properties: props,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize with thinking disabled
|
|
||||||
tv := &api.ThinkValue{Value: false}
|
|
||||||
p.Init(tools, nil, tv)
|
|
||||||
|
|
||||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
|
||||||
|
|
||||||
_, _, calls, err := p.Add(input, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(calls) != 1 {
|
|
||||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
|
||||||
}
|
|
||||||
|
|
||||||
call := calls[0]
|
|
||||||
if call.Function.Name != "get_weather" {
|
|
||||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
|
||||||
}
|
|
||||||
|
|
||||||
location, ok := call.Function.Arguments.Get("location")
|
|
||||||
if !ok || location != "San Francisco" {
|
|
||||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
|
||||||
}
|
|
||||||
|
|
||||||
unit, ok := call.Function.Arguments.Get("unit")
|
|
||||||
if !ok || unit != "celsius" {
|
|
||||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOverlap(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
s string
|
|
||||||
tag string
|
|
||||||
want int
|
|
||||||
}{
|
|
||||||
{"hello<", "</think>", 1},
|
|
||||||
{"hello</", "</think>", 2},
|
|
||||||
{"hello</t", "</think>", 3},
|
|
||||||
{"hello</th", "</think>", 4},
|
|
||||||
{"hello</thi", "</think>", 5},
|
|
||||||
{"hello</thin", "</think>", 6},
|
|
||||||
{"hello</think", "</think>", 7},
|
|
||||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
|
||||||
{"hello", "</think>", 0},
|
|
||||||
{"", "</think>", 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
|
||||||
got := overlap(tt.s, tt.tag)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
s string
|
|
||||||
want int
|
|
||||||
}{
|
|
||||||
{"hello ", 3},
|
|
||||||
{"hello\n\t ", 3},
|
|
||||||
{"hello", 0},
|
|
||||||
{"", 0},
|
|
||||||
{" ", 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.s, func(t *testing.T) {
|
|
||||||
got := trailingWhitespaceLen(tt.s)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package glm4_moe_lite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
|
||||||
//
|
|
||||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
|
||||||
//
|
|
||||||
// 1. INTERLEAVED THINKING
|
|
||||||
// The model thinks between tool calls and after receiving tool results.
|
|
||||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
|
||||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
|
||||||
// with tool results to maintain reasoning continuity.
|
|
||||||
//
|
|
||||||
// 2. PRESERVED THINKING
|
|
||||||
// The model retains reasoning content from previous assistant turns in context.
|
|
||||||
// This preserves reasoning continuity across multi-turn conversations. The
|
|
||||||
// upstream API has a "clear_thinking" parameter to control this:
|
|
||||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
|
||||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
|
||||||
//
|
|
||||||
// 3. TURN-LEVEL THINKING
|
|
||||||
// Controls whether the model should reason on each turn. The upstream API
|
|
||||||
// uses "enable_thinking" parameter:
|
|
||||||
// - enable_thinking=true: outputs <think> to start reasoning
|
|
||||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
|
||||||
//
|
|
||||||
// OLLAMA DEFAULTS:
|
|
||||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
|
||||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
|
||||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
|
||||||
// - Users can disable thinking per-turn via thinkValue=false
|
|
||||||
type Renderer struct{}
|
|
||||||
|
|
||||||
// Render renders messages into the GLM4 chat format.
|
|
||||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
sb.WriteString("[gMASK]<sop>")
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
|
||||||
sb.WriteString("<|system|>\n")
|
|
||||||
sb.WriteString("# Tools\n\n")
|
|
||||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
|
||||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
|
||||||
sb.WriteString("<tools>\n")
|
|
||||||
for _, tool := range tools {
|
|
||||||
d, _ := json.Marshal(tool)
|
|
||||||
sb.WriteString(formatToolJSON(d))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
}
|
|
||||||
sb.WriteString("</tools>\n\n")
|
|
||||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
|
||||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
|
||||||
}
|
|
||||||
|
|
||||||
think := true
|
|
||||||
if thinkValue != nil && !thinkValue.Bool() {
|
|
||||||
think = false
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, message := range messages {
|
|
||||||
switch message.Role {
|
|
||||||
case "user":
|
|
||||||
sb.WriteString("<|user|>")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
case "assistant":
|
|
||||||
sb.WriteString("<|assistant|>")
|
|
||||||
if message.Thinking != "" {
|
|
||||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
|
||||||
} else {
|
|
||||||
sb.WriteString("</think>")
|
|
||||||
}
|
|
||||||
if message.Content != "" {
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
}
|
|
||||||
if len(message.ToolCalls) > 0 {
|
|
||||||
for _, toolCall := range message.ToolCalls {
|
|
||||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
|
||||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
|
||||||
sb.WriteString("</tool_call>")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "tool":
|
|
||||||
if i == 0 || messages[i-1].Role != "tool" {
|
|
||||||
sb.WriteString("<|observation|>")
|
|
||||||
}
|
|
||||||
sb.WriteString("<tool_response>")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
sb.WriteString("</tool_response>")
|
|
||||||
case "system":
|
|
||||||
sb.WriteString("<|system|>")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("<|assistant|>")
|
|
||||||
if think {
|
|
||||||
sb.WriteString("<think>")
|
|
||||||
} else {
|
|
||||||
sb.WriteString("</think>")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
|
||||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
for key, value := range args.All() {
|
|
||||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
|
||||||
var valueStr string
|
|
||||||
if str, ok := value.(string); ok {
|
|
||||||
valueStr = str
|
|
||||||
} else {
|
|
||||||
jsonBytes, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
valueStr = fmt.Sprintf("%v", value)
|
|
||||||
} else {
|
|
||||||
valueStr = string(jsonBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
|
||||||
func formatToolJSON(raw []byte) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.Grow(len(raw) + len(raw)/10)
|
|
||||||
|
|
||||||
inString := false
|
|
||||||
escaped := false
|
|
||||||
for i := range raw {
|
|
||||||
ch := raw[i]
|
|
||||||
sb.WriteByte(ch)
|
|
||||||
|
|
||||||
if inString {
|
|
||||||
if escaped {
|
|
||||||
escaped = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '\\' {
|
|
||||||
escaped = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '"' {
|
|
||||||
inString = false
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ch == '"' {
|
|
||||||
inString = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ch == ':' || ch == ',' {
|
|
||||||
sb.WriteByte(' ')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package glm4_moe_lite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRendererSimple(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "user", Content: "Hello"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Thinking enabled (default)
|
|
||||||
result, err := r.Render(messages, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
|
||||||
if result != expected {
|
|
||||||
t.Errorf("result = %q, want %q", result, expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRendererThinkingDisabled(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "user", Content: "Hello"},
|
|
||||||
}
|
|
||||||
|
|
||||||
tv := &api.ThinkValue{Value: false}
|
|
||||||
|
|
||||||
result, err := r.Render(messages, nil, tv)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
|
||||||
if result != expected {
|
|
||||||
t.Errorf("result = %q, want %q", result, expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRendererMultiTurn(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "user", Content: "What is 2+2?"},
|
|
||||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
|
||||||
{Role: "user", Content: "And 3+3?"},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := r.Render(messages, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check key parts
|
|
||||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
|
||||||
t.Error("missing [gMASK]<sop> prefix")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
|
||||||
t.Error("missing first user message")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
|
||||||
t.Error("missing assistant message with thinking")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
|
||||||
t.Error("missing second user message")
|
|
||||||
}
|
|
||||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
|
||||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRendererWithSystem(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "system", Content: "You are a helpful assistant."},
|
|
||||||
{Role: "user", Content: "Hello"},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := r.Render(messages, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
|
||||||
t.Error("missing system message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRendererWithTools(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "user", Content: "What's the weather?"},
|
|
||||||
}
|
|
||||||
|
|
||||||
props := api.NewToolPropertiesMap()
|
|
||||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
|
||||||
tools := []api.Tool{
|
|
||||||
{
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Description: "Get the weather for a location",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: props,
|
|
||||||
Required: []string{"location"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := r.Render(messages, tools, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for tool system prompt
|
|
||||||
if !strings.Contains(result, "<|system|>") {
|
|
||||||
t.Error("missing system tag for tools")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "# Tools") {
|
|
||||||
t.Error("missing tools header")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<tools>") {
|
|
||||||
t.Error("missing tools tag")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "get_weather") {
|
|
||||||
t.Error("missing tool name")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "</tools>") {
|
|
||||||
t.Error("missing closing tools tag")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRendererWithToolCalls(t *testing.T) {
|
|
||||||
r := &Renderer{}
|
|
||||||
|
|
||||||
args := api.NewToolCallFunctionArguments()
|
|
||||||
args.Set("location", "San Francisco")
|
|
||||||
|
|
||||||
messages := []api.Message{
|
|
||||||
{Role: "user", Content: "What's the weather in SF?"},
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
ToolCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{Role: "tool", Content: "Sunny, 72F"},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := r.Render(messages, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
|
||||||
t.Error("missing tool call")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
|
||||||
t.Error("missing arg_key")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
|
||||||
t.Error("missing arg_value")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "</tool_call>") {
|
|
||||||
t.Error("missing tool call closing tag")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<|observation|>") {
|
|
||||||
t.Error("missing observation tag")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
|
||||||
t.Error("missing tool response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFormatToolJSON(t *testing.T) {
|
|
||||||
input := []byte(`{"name":"test","value":123}`)
|
|
||||||
result := formatToolJSON(input)
|
|
||||||
|
|
||||||
// Should add spaces after : and ,
|
|
||||||
if !strings.Contains(result, ": ") {
|
|
||||||
t.Error("should add space after colon")
|
|
||||||
}
|
|
||||||
if !strings.Contains(result, ", ") {
|
|
||||||
t.Error("should add space after comma")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,487 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package gpt_oss
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RopeScaling holds YaRN or other RoPE scaling configuration
|
|
||||||
type RopeScaling struct {
|
|
||||||
RopeType string `json:"rope_type"`
|
|
||||||
Factor float32 `json:"factor"`
|
|
||||||
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
|
||||||
BetaFast float32 `json:"beta_fast"`
|
|
||||||
BetaSlow float32 `json:"beta_slow"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
||||||
VocabSize int32 `json:"vocab_size"`
|
|
||||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
HeadDim int32 `json:"head_dim"`
|
|
||||||
SlidingWindow int32 `json:"sliding_window"`
|
|
||||||
NumLocalExperts int32 `json:"num_local_experts"`
|
|
||||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
||||||
LayerTypes []string `json:"layer_types"`
|
|
||||||
SwiGLULimit float32 `json:"swiglu_limit"`
|
|
||||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
|
||||||
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Attention struct {
|
|
||||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
|
||||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
|
||||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
|
||||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
|
||||||
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
|
||||||
YarnFreqs *mlx.Array // computed
|
|
||||||
YarnMscale float32
|
|
||||||
}
|
|
||||||
|
|
||||||
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
|
||||||
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
|
||||||
// with clipping: gate to [None, limit], up to [-limit, limit]
|
|
||||||
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
|
||||||
// Clip gate to [None, limit]
|
|
||||||
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
|
||||||
|
|
||||||
// Clip up to [-limit, limit]
|
|
||||||
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
|
||||||
|
|
||||||
// glu_scaled = alpha * gate_clipped
|
|
||||||
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
|
||||||
|
|
||||||
// sig = sigmoid(glu_scaled)
|
|
||||||
sig := mlx.Sigmoid(gluScaled)
|
|
||||||
|
|
||||||
// out_glu = gate_clipped * sig
|
|
||||||
outGlu := mlx.Mul(gateClipped, sig)
|
|
||||||
|
|
||||||
// result = out_glu * (up_clipped + 1)
|
|
||||||
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
|
||||||
var compiledSwiGLU *mlx.CompiledFunc
|
|
||||||
|
|
||||||
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
|
||||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
|
||||||
if compiledSwiGLU == nil {
|
|
||||||
const alpha float32 = 1.702
|
|
||||||
const limit float32 = 7.0
|
|
||||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
|
||||||
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
|
||||||
}, true) // shapeless=true so it works for any input size
|
|
||||||
}
|
|
||||||
return compiledSwiGLU
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
|
||||||
// Based on mlx-lm's YarnRoPE implementation
|
|
||||||
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
|
||||||
// yarn_find_correction_dim
|
|
||||||
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
|
||||||
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// yarn_find_correction_range
|
|
||||||
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
|
||||||
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
|
||||||
if low < 0 {
|
|
||||||
low = 0
|
|
||||||
}
|
|
||||||
if high > int(dims)-1 {
|
|
||||||
high = int(dims) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// yarn_get_mscale
|
|
||||||
yarnGetMscale := func(scale, mscale float64) float64 {
|
|
||||||
if scale <= 1 {
|
|
||||||
return 1.0
|
|
||||||
}
|
|
||||||
return 0.1*mscale*math.Log(scale) + 1.0
|
|
||||||
}
|
|
||||||
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
|
||||||
|
|
||||||
// Compute frequencies
|
|
||||||
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
|
||||||
// freq_inter = scaling_factor * freq_extra
|
|
||||||
halfDims := dims / 2
|
|
||||||
freqData := make([]float32, halfDims)
|
|
||||||
for i := int32(0); i < halfDims; i++ {
|
|
||||||
exp := float64(2*i) / float64(dims)
|
|
||||||
freqExtra := math.Pow(float64(base), exp)
|
|
||||||
freqInter := float64(scalingFactor) * freqExtra
|
|
||||||
|
|
||||||
// linear ramp mask
|
|
||||||
var freqMask float64
|
|
||||||
if low == high {
|
|
||||||
freqMask = 0.0
|
|
||||||
} else {
|
|
||||||
t := (float64(i) - float64(low)) / float64(high-low)
|
|
||||||
if t < 0 {
|
|
||||||
t = 0
|
|
||||||
}
|
|
||||||
if t > 1 {
|
|
||||||
t = 1
|
|
||||||
}
|
|
||||||
freqMask = 1.0 - t
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
|
||||||
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
|
||||||
}
|
|
||||||
|
|
||||||
// initYarn initializes YaRN RoPE if configured
|
|
||||||
func (a *Attention) initYarn(cfg *Config) {
|
|
||||||
a.YarnMscale = 1.0
|
|
||||||
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
|
||||||
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
|
||||||
cfg.HeadDim,
|
|
||||||
cfg.RopeTheta,
|
|
||||||
cfg.RopeScaling.Factor,
|
|
||||||
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
|
||||||
cfg.RopeScaling.BetaFast,
|
|
||||||
cfg.RopeScaling.BetaSlow,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
|
||||||
q := a.QProj.Forward(x)
|
|
||||||
k := a.KProj.Forward(x)
|
|
||||||
v := a.VProj.Forward(x)
|
|
||||||
|
|
||||||
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
|
||||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
if c != nil {
|
|
||||||
offset = c.Offset()
|
|
||||||
}
|
|
||||||
if a.YarnFreqs != nil {
|
|
||||||
if a.YarnMscale != 1.0 {
|
|
||||||
q = mlx.MulScalar(q, a.YarnMscale)
|
|
||||||
}
|
|
||||||
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
|
||||||
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
|
||||||
} else {
|
|
||||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
|
||||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c != nil {
|
|
||||||
k, v = c.Update(k, v, int(L))
|
|
||||||
}
|
|
||||||
|
|
||||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
|
||||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
||||||
return a.OProj.Forward(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSlidingWindowMask creates a causal mask with sliding window
|
|
||||||
// Mirrors mlx-lm's create_causal_mask with window_size
|
|
||||||
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
|
||||||
// Build mask aligned to actual cache length (may be rotated)
|
|
||||||
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
|
||||||
// linds covers new queries: [queryStart, queryStart+seqLen)
|
|
||||||
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
|
||||||
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
|
||||||
|
|
||||||
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
|
||||||
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
|
||||||
|
|
||||||
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
|
||||||
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
|
||||||
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
|
||||||
|
|
||||||
return mlx.LogicalAnd(causalMask, windowMask)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
|
||||||
type MoE struct {
|
|
||||||
Router *nn.Linear `weight:"mlp.router"`
|
|
||||||
TopK int32
|
|
||||||
HiddenSize int32
|
|
||||||
GroupSize int
|
|
||||||
Bits int
|
|
||||||
// Expert weights (loaded manually via sanitizeExpertWeights)
|
|
||||||
GateBlocks, GateScales, GateBias *mlx.Array
|
|
||||||
UpBlocks, UpScales, UpBias *mlx.Array
|
|
||||||
DownBlocks, DownScales, DownBias *mlx.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
|
||||||
logits := moe.Router.Forward(x)
|
|
||||||
neg := mlx.Neg(logits)
|
|
||||||
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
|
||||||
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
|
||||||
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
|
||||||
weights := mlx.Softmax(topKVal, -1)
|
|
||||||
|
|
||||||
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
|
||||||
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
|
||||||
|
|
||||||
doSort := B*L >= 64
|
|
||||||
var invOrder *mlx.Array
|
|
||||||
sorted := false
|
|
||||||
n := B * L * moe.TopK
|
|
||||||
|
|
||||||
if doSort {
|
|
||||||
idxAll := mlx.Flatten(idxFlat)
|
|
||||||
order := mlx.Argsort(idxAll, 0)
|
|
||||||
invOrder = mlx.Argsort(order, 0)
|
|
||||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
|
||||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
||||||
sorted = true
|
|
||||||
}
|
|
||||||
|
|
||||||
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
||||||
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
||||||
|
|
||||||
if moe.GateBias != nil {
|
|
||||||
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
|
||||||
}
|
|
||||||
if moe.UpBias != nil {
|
|
||||||
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
|
||||||
}
|
|
||||||
|
|
||||||
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
|
||||||
|
|
||||||
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
||||||
if moe.DownBias != nil {
|
|
||||||
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
|
||||||
}
|
|
||||||
|
|
||||||
if doSort {
|
|
||||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
|
||||||
} else {
|
|
||||||
down = mlx.Squeeze(down, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
|
||||||
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Block struct {
|
|
||||||
Attention *Attention
|
|
||||||
MLP *MoE
|
|
||||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
||||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
||||||
LayerType string // "sliding_attention" or "full_attention"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
|
||||||
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
|
||||||
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
|
||||||
}
|
|
||||||
|
|
||||||
type Model struct {
|
|
||||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
||||||
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
|
||||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
|
||||||
LMHead *nn.Linear `weight:"lm_head"`
|
|
||||||
|
|
||||||
tok *tokenizer.Tokenizer
|
|
||||||
*Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
||||||
|
|
||||||
func (m *Model) NewCache(int32) []cache.Cache {
|
|
||||||
caches := make([]cache.Cache, len(m.Layers))
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
|
||||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
|
||||||
} else {
|
|
||||||
caches[i] = cache.NewKVCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return caches
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
||||||
x := m.EmbedTokens.Forward(tokens)
|
|
||||||
|
|
||||||
// Find representative cache indices for sliding window attention
|
|
||||||
var swaIdx int = -1
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
if layer.LayerType == "sliding_attention" {
|
|
||||||
swaIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create masks once at model level
|
|
||||||
var fullMask, swaMask *mlx.Array
|
|
||||||
var fullMaskMode, swaMaskMode string
|
|
||||||
|
|
||||||
if L > 1 {
|
|
||||||
fullMaskMode = "causal"
|
|
||||||
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
|
||||||
c := caches[swaIdx]
|
|
||||||
offset := c.Offset()
|
|
||||||
windowSize := int(m.SlidingWindow)
|
|
||||||
cacheLen := min(int(L), windowSize)
|
|
||||||
if offset > 0 {
|
|
||||||
cacheLen = min(c.Len()+int(L), windowSize)
|
|
||||||
}
|
|
||||||
if int(L) > windowSize {
|
|
||||||
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
|
||||||
} else {
|
|
||||||
swaMaskMode = "causal"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
swaMaskMode = "causal"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
var c cache.Cache
|
|
||||||
if caches != nil {
|
|
||||||
c = caches[i]
|
|
||||||
}
|
|
||||||
mask, maskMode := fullMask, fullMaskMode
|
|
||||||
if layer.LayerType == "sliding_attention" {
|
|
||||||
mask, maskMode = swaMask, swaMaskMode
|
|
||||||
}
|
|
||||||
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
|
||||||
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
|
||||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
|
||||||
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
|
||||||
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
|
||||||
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
|
||||||
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
|
||||||
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
|
||||||
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
|
||||||
|
|
||||||
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
|
||||||
|
|
||||||
if gateUpBlocks != nil {
|
|
||||||
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
|
||||||
s := gub.Shape()
|
|
||||||
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
||||||
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
||||||
}
|
|
||||||
if gateUpScales != nil {
|
|
||||||
s := gateUpScales.Shape()
|
|
||||||
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
||||||
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
||||||
}
|
|
||||||
if gateUpBias != nil {
|
|
||||||
s := gateUpBias.Shape()
|
|
||||||
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
|
||||||
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
|
||||||
}
|
|
||||||
if downBlocks != nil {
|
|
||||||
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
|
||||||
}
|
|
||||||
return moe
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load(modelPath string) (*Model, error) {
|
|
||||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
var cfg Config
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Model{
|
|
||||||
Layers: make([]*Block, cfg.NumHiddenLayers),
|
|
||||||
Config: &cfg,
|
|
||||||
tok: tok,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load simple weights via struct tags
|
|
||||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load layers with custom MoE handling
|
|
||||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
||||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
|
||||||
layer := &Block{}
|
|
||||||
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
|
||||||
return nil, fmt.Errorf("layer %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize attention YaRN
|
|
||||||
layer.Attention.initYarn(&cfg)
|
|
||||||
|
|
||||||
// Load MoE with weight sanitization
|
|
||||||
moe := sanitizeExpertWeights(weights, prefix)
|
|
||||||
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
|
||||||
moe.TopK = cfg.NumExpertsPerTok
|
|
||||||
moe.HiddenSize = cfg.HiddenSize
|
|
||||||
layer.MLP = moe
|
|
||||||
|
|
||||||
// Set layer type
|
|
||||||
layer.LayerType = "full_attention"
|
|
||||||
if int(i) < len(cfg.LayerTypes) {
|
|
||||||
layer.LayerType = cfg.LayerTypes[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Layers[i] = layer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
|
||||||
// this reduces peak memory by freeing mmap during materialization
|
|
||||||
weights.ReleaseAll()
|
|
||||||
mlx.Eval(mlx.Collect(m)...)
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) MaxContextLength() int32 {
|
|
||||||
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
|
||||||
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
|
||||||
}
|
|
||||||
return 131072
|
|
||||||
}
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package llama
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/nn"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
|
||||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
|
||||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
||||||
VocabSize int32 `json:"vocab_size"`
|
|
||||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
|
||||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
||||||
HeadDim int32 `json:"-"`
|
|
||||||
Scale float32 `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Model struct {
|
|
||||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
||||||
Layers []*Layer `weight:"model.layers"`
|
|
||||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
|
||||||
Output *nn.Linear `weight:"lm_head,optional"`
|
|
||||||
|
|
||||||
tok *tokenizer.Tokenizer
|
|
||||||
*Config
|
|
||||||
}
|
|
||||||
|
|
||||||
type Layer struct {
|
|
||||||
Attention *Attention
|
|
||||||
MLP *MLP
|
|
||||||
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
||||||
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Attention struct {
|
|
||||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
|
||||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
|
||||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
|
||||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MLP struct {
|
|
||||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
|
||||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
|
||||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load(modelPath string) (*Model, error) {
|
|
||||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
var cfg Config
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
||||||
|
|
||||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load weights: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Model{
|
|
||||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
|
||||||
Config: &cfg,
|
|
||||||
tok: tok,
|
|
||||||
}
|
|
||||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
|
||||||
|
|
||||||
mlx.Eval(mlx.Collect(m)...)
|
|
||||||
weights.ReleaseAll()
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
||||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
||||||
h := m.EmbedTokens.Forward(tokens)
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
h = layer.Forward(h, caches[i], B, L, m.Config)
|
|
||||||
}
|
|
||||||
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
||||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
|
||||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
||||||
q := a.QProj.Forward(x)
|
|
||||||
k := a.KProj.Forward(x)
|
|
||||||
v := a.VProj.Forward(x)
|
|
||||||
|
|
||||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
||||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
||||||
|
|
||||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
|
||||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
|
||||||
|
|
||||||
k, v = c.Update(k, v, int(L))
|
|
||||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
|
||||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
||||||
return a.OProj.Forward(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Interface methods
|
|
||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
||||||
|
|
||||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
|
||||||
caches := make([]cache.Cache, len(m.Layers))
|
|
||||||
for i := range caches {
|
|
||||||
caches[i] = cache.NewKVCache()
|
|
||||||
}
|
|
||||||
return caches
|
|
||||||
}
|
|
||||||
@@ -39,19 +39,23 @@ func Execute(args []string) error {
|
|||||||
return fmt.Errorf("--port is required")
|
return fmt.Errorf("--port is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize MLX
|
// Detect model type from capabilities
|
||||||
|
mode := detectModelMode(*modelName)
|
||||||
|
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||||
|
|
||||||
|
if mode != ModeImageGen {
|
||||||
|
return fmt.Errorf("imagegen runner only supports image generation models")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize MLX only for image generation mode.
|
||||||
if err := mlx.InitMLX(); err != nil {
|
if err := mlx.InitMLX(); err != nil {
|
||||||
slog.Error("unable to initialize MLX", "error", err)
|
slog.Error("unable to initialize MLX", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
slog.Info("MLX library initialized")
|
slog.Info("MLX library initialized")
|
||||||
|
|
||||||
// Detect model type from capabilities
|
|
||||||
mode := detectModelMode(*modelName)
|
|
||||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
|
||||||
|
|
||||||
// Create and start server
|
// Create and start server
|
||||||
server, err := newServer(*modelName, *port, mode)
|
server, err := newServer(*modelName, *port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create server: %w", err)
|
return fmt.Errorf("failed to create server: %w", err)
|
||||||
}
|
}
|
||||||
@@ -61,12 +65,6 @@ func Execute(args []string) error {
|
|||||||
mux.HandleFunc("/health", server.healthHandler)
|
mux.HandleFunc("/health", server.healthHandler)
|
||||||
mux.HandleFunc("/completion", server.completionHandler)
|
mux.HandleFunc("/completion", server.completionHandler)
|
||||||
|
|
||||||
// LLM-specific endpoints
|
|
||||||
if mode == ModeLLM {
|
|
||||||
mux.HandleFunc("/tokenize", server.tokenizeHandler)
|
|
||||||
mux.HandleFunc("/embedding", server.embeddingHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
@@ -112,34 +110,22 @@ func detectModelMode(modelName string) ModelMode {
|
|||||||
|
|
||||||
// server holds the model and handles HTTP requests.
|
// server holds the model and handles HTTP requests.
|
||||||
type server struct {
|
type server struct {
|
||||||
mode ModelMode
|
|
||||||
modelName string
|
modelName string
|
||||||
port int
|
port int
|
||||||
|
|
||||||
// Image generation model (when mode == ModeImageGen)
|
// Image generation model.
|
||||||
imageModel ImageModel
|
imageModel ImageModel
|
||||||
|
|
||||||
// LLM model (when mode == ModeLLM)
|
|
||||||
llmModel *llmState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServer creates a new server instance and loads the appropriate model.
|
// newServer creates a new server instance for image generation models.
|
||||||
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
|
func newServer(modelName string, port int) (*server, error) {
|
||||||
s := &server{
|
s := &server{
|
||||||
mode: mode,
|
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
port: port,
|
port: port,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch mode {
|
if err := s.loadImageModel(); err != nil {
|
||||||
case ModeImageGen:
|
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||||
if err := s.loadImageModel(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
|
||||||
}
|
|
||||||
case ModeLLM:
|
|
||||||
if err := s.loadLLMModel(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
@@ -163,41 +149,5 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch s.mode {
|
s.handleImageCompletion(w, r, req)
|
||||||
case ModeImageGen:
|
|
||||||
s.handleImageCompletion(w, r, req)
|
|
||||||
case ModeLLM:
|
|
||||||
s.handleLLMCompletion(w, r, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.llmModel == nil {
|
|
||||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tok := s.llmModel.model.Tokenizer()
|
|
||||||
tokens := tok.Encode(req.Content, false)
|
|
||||||
|
|
||||||
// Convert int32 to int for JSON response
|
|
||||||
intTokens := make([]int, len(tokens))
|
|
||||||
for i, t := range tokens {
|
|
||||||
intTokens[i] = int(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,13 +30,12 @@ import (
|
|||||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||||
//
|
//
|
||||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||||
// like any other model. It supports both LLM (safetensors) and image generation models.
|
// like any other model. It is used for image generation models.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
mode ModelMode
|
|
||||||
vramSize uint64
|
vramSize uint64
|
||||||
done chan error
|
done chan error
|
||||||
client *http.Client
|
client *http.Client
|
||||||
@@ -45,7 +44,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
func NewServer(modelName string) (*Server, error) {
|
||||||
// Validate platform support before attempting to start
|
// Validate platform support before attempting to start
|
||||||
if err := CheckPlatformSupport(); err != nil {
|
if err := CheckPlatformSupport(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -119,7 +118,6 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
|||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
mode: mode,
|
|
||||||
vramSize: vramSize,
|
vramSize: vramSize,
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
@@ -145,7 +143,7 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
|
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||||
}
|
}
|
||||||
@@ -396,36 +394,7 @@ func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, e
|
|||||||
|
|
||||||
// Tokenize tokenizes the input content.
|
// Tokenize tokenizes the input content.
|
||||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
body, err := json.Marshal(map[string]string{"content": content})
|
return nil, errors.New("tokenization not supported for image generation models")
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
resp, err := s.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result struct {
|
|
||||||
Tokens []int `json:"tokens"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.Tokens, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize converts tokens back to text.
|
// Detokenize converts tokens back to text.
|
||||||
|
|||||||
@@ -30,21 +30,80 @@ type cacheSession struct {
|
|||||||
remaining []int32
|
remaining []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
|
||||||
|
if c == nil {
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, values := c.State()
|
||||||
|
if keys != nil && keys.Valid() {
|
||||||
|
dst = append(dst, keys)
|
||||||
|
}
|
||||||
|
if values != nil && values.Valid() {
|
||||||
|
dst = append(dst, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) free() {
|
||||||
|
for i, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kv.Free()
|
||||||
|
c.caches[i] = nil
|
||||||
|
}
|
||||||
|
c.caches = nil
|
||||||
|
c.tokens = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) cachesCanTrim() bool {
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !kv.CanTrim() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) trimToPrefix(prefix int) {
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil || !kv.CanTrim() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if trim := kv.Offset() - prefix; trim > 0 {
|
||||||
|
kv.Trim(trim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if prefix < len(c.tokens) {
|
||||||
|
c.tokens = c.tokens[:prefix]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// begin prepares caches for a new request. It finds the nearest
|
// begin prepares caches for a new request. It finds the nearest
|
||||||
// matching cache or creates new caches if none match.
|
// matching cache or creates new caches if none match.
|
||||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||||
if len(c.caches) == 0 {
|
ensureCaches := func() {
|
||||||
|
if len(c.caches) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||||
c.caches = cacheFactory.NewCaches()
|
c.caches = cacheFactory.NewCaches()
|
||||||
} else {
|
return
|
||||||
c.caches = make([]cache.Cache, m.NumLayers())
|
}
|
||||||
for i := range c.caches {
|
c.caches = make([]cache.Cache, m.NumLayers())
|
||||||
c.caches[i] = cache.NewKVCache()
|
for i := range c.caches {
|
||||||
}
|
c.caches[i] = cache.NewKVCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ensureCaches()
|
||||||
|
|
||||||
remaining := c.findRemaining(inputs)
|
remaining := c.findRemaining(inputs)
|
||||||
|
ensureCaches()
|
||||||
|
|
||||||
return &cacheSession{
|
return &cacheSession{
|
||||||
cache: c,
|
cache: c,
|
||||||
@@ -56,18 +115,36 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
|
|
||||||
// close saves the token state if the forward pass ran.
|
// close saves the token state if the forward pass ran.
|
||||||
func (s *cacheSession) close() {
|
func (s *cacheSession) close() {
|
||||||
if offset := s.caches[0].Offset(); offset > 0 {
|
if len(s.caches) == 0 {
|
||||||
// Ensure that if we have run the forward pass and set the metadata
|
return
|
||||||
// that we also actually have the data
|
|
||||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
|
||||||
for _, c := range s.caches {
|
|
||||||
k, v := c.State()
|
|
||||||
arrays = append(arrays, k, v)
|
|
||||||
}
|
|
||||||
mlx.AsyncEval(arrays...)
|
|
||||||
|
|
||||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
offset := -1
|
||||||
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||||
|
for _, kv := range s.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Mixed cache types (e.g. recurrent + KV) can transiently report different
|
||||||
|
// offsets, so use the minimum as the safe reusable token prefix.
|
||||||
|
if off := kv.Offset(); offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
arrays = appendCacheState(arrays, kv)
|
||||||
|
}
|
||||||
|
if offset <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that if we have run the forward pass and set the metadata
|
||||||
|
// that we also actually have the data.
|
||||||
|
mlx.AsyncEval(arrays...)
|
||||||
|
|
||||||
|
stored := append(s.inputs, s.outputs...)
|
||||||
|
if offset > len(stored) {
|
||||||
|
offset = len(stored)
|
||||||
|
}
|
||||||
|
s.cache.tokens = stored[:offset]
|
||||||
}
|
}
|
||||||
|
|
||||||
// findRemaining finds the longest common prefix between tokens and the cached
|
// findRemaining finds the longest common prefix between tokens and the cached
|
||||||
@@ -85,11 +162,13 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
if prefix < len(c.tokens) {
|
||||||
trim := len(c.tokens) - prefix
|
if c.cachesCanTrim() {
|
||||||
for _, kv := range c.caches {
|
c.trimToPrefix(prefix)
|
||||||
kv.Trim(trim)
|
} else {
|
||||||
|
c.free()
|
||||||
|
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||||
|
return tokens
|
||||||
}
|
}
|
||||||
c.tokens = c.tokens[:prefix]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix == 0 {
|
if prefix == 0 {
|
||||||
@@ -104,10 +183,21 @@ func (c *kvCache) log() {
|
|||||||
if len(c.caches) == 0 {
|
if len(c.caches) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
offset := -1
|
||||||
var totalBytes int
|
var totalBytes int
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
k, v := kv.State()
|
if kv == nil {
|
||||||
totalBytes += k.NumBytes() + v.NumBytes()
|
continue
|
||||||
|
}
|
||||||
|
if off := kv.Offset(); offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
for _, a := range appendCacheState(nil, kv) {
|
||||||
|
totalBytes += a.NumBytes()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
if offset < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
||||||
}
|
}
|
||||||
|
|||||||
18
x/mlxrunner/cache/cache.go
vendored
18
x/mlxrunner/cache/cache.go
vendored
@@ -9,7 +9,9 @@ import (
|
|||||||
|
|
||||||
type Cache interface {
|
type Cache interface {
|
||||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||||
|
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||||
State() (keys, values *mlx.Array)
|
State() (keys, values *mlx.Array)
|
||||||
|
CanTrim() bool
|
||||||
Trim(int) int
|
Trim(int) int
|
||||||
Clone() Cache
|
Clone() Cache
|
||||||
Free()
|
Free()
|
||||||
@@ -60,13 +62,15 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||||
if c.offset == c.keys.Dim(2) {
|
if c.keys == nil || c.values == nil {
|
||||||
return c.keys, c.values
|
return nil, nil
|
||||||
}
|
}
|
||||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) CanTrim() bool { return true }
|
||||||
|
|
||||||
func (c *KVCache) Trim(n int) int {
|
func (c *KVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
@@ -183,13 +187,15 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||||
if c.offset < c.keys.Dim(2) {
|
if c.keys == nil || c.values == nil {
|
||||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
return nil, nil
|
||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
|
||||||
}
|
}
|
||||||
return c.keys, c.values
|
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||||
|
|
||||||
func (c *RotatingKVCache) Trim(n int) int {
|
func (c *RotatingKVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
|
|||||||
161
x/mlxrunner/cache/recurrent.go
vendored
Normal file
161
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
|
||||||
|
// RecurrentCache stores state for linear-recurrent layers.
|
||||||
|
//
|
||||||
|
// Conv state shape: [B, convTail, convDim]
|
||||||
|
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||||
|
type RecurrentCache struct {
|
||||||
|
convState *mlx.Array
|
||||||
|
deltaState *mlx.Array
|
||||||
|
offset int
|
||||||
|
|
||||||
|
convTail int
|
||||||
|
convDim int
|
||||||
|
numVHeads int
|
||||||
|
headVDim int
|
||||||
|
headKDim int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
||||||
|
if v == nil || !v.Valid() {
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
if old == v {
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
|
||||||
|
mlx.Pin(v)
|
||||||
|
if old != nil && old != v {
|
||||||
|
mlx.Unpin(old)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bool) *mlx.Array {
|
||||||
|
if v == nil || !v.Valid() {
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
if old == v {
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
|
||||||
|
root := v
|
||||||
|
if ensureContiguous {
|
||||||
|
root = mlx.Contiguous(v, false)
|
||||||
|
}
|
||||||
|
detached := root.Clone()
|
||||||
|
|
||||||
|
mlx.Pin(detached)
|
||||||
|
if old != nil && old != detached {
|
||||||
|
mlx.Unpin(old)
|
||||||
|
}
|
||||||
|
|
||||||
|
return detached
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||||
|
if a == nil || !a.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snap := mlx.Copy(a)
|
||||||
|
mlx.Eval(snap)
|
||||||
|
mlx.Pin(snap)
|
||||||
|
return snap
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||||
|
return &RecurrentCache{
|
||||||
|
convTail: int(convTail),
|
||||||
|
convDim: int(convDim),
|
||||||
|
numVHeads: int(numVHeads),
|
||||||
|
headVDim: int(headVDim),
|
||||||
|
headKDim: int(headKDim),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||||
|
if batch <= 0 {
|
||||||
|
batch = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||||
|
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||||
|
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
||||||
|
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||||
|
if !needConv && !needDelta {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if needConv {
|
||||||
|
c.convState = c.setStateRaw(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||||
|
}
|
||||||
|
if needDelta {
|
||||||
|
c.deltaState = c.setStateRaw(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||||
|
c.ensure(batch, dtype)
|
||||||
|
return c.convState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||||
|
c.convState = c.setStateDetached(c.convState, v, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||||
|
c.ensure(batch, dtype)
|
||||||
|
return c.deltaState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||||
|
c.deltaState = c.setStateDetached(c.deltaState, v, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Advance(n int) {
|
||||||
|
c.offset += n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
return keys, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||||
|
return c.convState, c.deltaState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Trim(n int) int {
|
||||||
|
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
||||||
|
_ = n
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Clone() Cache {
|
||||||
|
clone := &RecurrentCache{
|
||||||
|
offset: c.offset,
|
||||||
|
convTail: c.convTail,
|
||||||
|
convDim: c.convDim,
|
||||||
|
numVHeads: c.numVHeads,
|
||||||
|
headVDim: c.headVDim,
|
||||||
|
headKDim: c.headKDim,
|
||||||
|
convState: snapshotPinned(c.convState),
|
||||||
|
deltaState: snapshotPinned(c.deltaState),
|
||||||
|
}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Free() {
|
||||||
|
mlx.Unpin(c.convState, c.deltaState)
|
||||||
|
c.convState, c.deltaState = nil, nil
|
||||||
|
c.offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||||
|
func (c *RecurrentCache) Len() int { return c.offset }
|
||||||
@@ -202,7 +202,6 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration
|
PromptEvalDuration time.Duration
|
||||||
EvalCount int
|
EvalCount int
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
PeakMemory uint64
|
|
||||||
|
|
||||||
Error *api.StatusError
|
Error *api.StatusError
|
||||||
}
|
}
|
||||||
@@ -284,7 +283,6 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
PromptEvalDuration: raw.PromptEvalDuration,
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: raw.EvalDuration,
|
EvalDuration: raw.EvalDuration,
|
||||||
PeakMemory: raw.PeakMemory,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
|
|||||||
@@ -7,4 +7,6 @@ import (
|
|||||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||||
_ "github.com/ollama/ollama/x/models/llama"
|
_ "github.com/ollama/ollama/x/models/llama"
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||||
|
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
|
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
type Array struct {
|
type Array struct {
|
||||||
ctx C.mlx_array
|
ctx C.mlx_array
|
||||||
name string
|
name string
|
||||||
pinned bool
|
pinned int
|
||||||
}
|
}
|
||||||
|
|
||||||
var arrays []*Array
|
var arrays []*Array
|
||||||
@@ -129,7 +129,7 @@ func (t *Array) Clone() *Array {
|
|||||||
func Pin(s ...*Array) {
|
func Pin(s ...*Array) {
|
||||||
for _, t := range s {
|
for _, t := range s {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
t.pinned = true
|
t.pinned++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,7 +138,7 @@ func Pin(s ...*Array) {
|
|||||||
func Unpin(s ...*Array) {
|
func Unpin(s ...*Array) {
|
||||||
for _, t := range s {
|
for _, t := range s {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
t.pinned = false
|
t.pinned--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ func Unpin(s ...*Array) {
|
|||||||
func Sweep() {
|
func Sweep() {
|
||||||
n := 0
|
n := 0
|
||||||
for _, t := range arrays {
|
for _, t := range arrays {
|
||||||
if t.pinned && t.Valid() {
|
if t.pinned > 0 && t.Valid() {
|
||||||
arrays[n] = t
|
arrays[n] = t
|
||||||
n++
|
n++
|
||||||
} else if t.Valid() {
|
} else if t.Valid() {
|
||||||
@@ -175,7 +175,7 @@ func (t *Array) String() string {
|
|||||||
func (t *Array) LogValue() slog.Value {
|
func (t *Array) LogValue() slog.Value {
|
||||||
attrs := []slog.Attr{
|
attrs := []slog.Attr{
|
||||||
slog.String("name", t.name),
|
slog.String("name", t.name),
|
||||||
slog.Bool("pinned", t.pinned),
|
slog.Int("pinned", t.pinned),
|
||||||
}
|
}
|
||||||
if t.Valid() {
|
if t.Valid() {
|
||||||
attrs = append(attrs,
|
attrs = append(attrs,
|
||||||
|
|||||||
370
x/mlxrunner/mlx/gated_delta.go
Normal file
370
x/mlxrunner/mlx/gated_delta.go
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "generated.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
gatedDeltaMetalKernelOnce sync.Once
|
||||||
|
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||||
|
gatedDeltaMetalDisabled bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const gatedDeltaMetalKernelSource = `
|
||||||
|
auto n = thread_position_in_grid.z;
|
||||||
|
auto b_idx = n / Hv;
|
||||||
|
auto hv_idx = n % Hv;
|
||||||
|
auto hk_idx = hv_idx / (Hv / Hk);
|
||||||
|
constexpr int n_per_t = Dk / 32;
|
||||||
|
|
||||||
|
// q, k: [B, T, Hk, Dk]
|
||||||
|
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||||
|
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||||
|
|
||||||
|
// v, y: [B, T, Hv, Dv]
|
||||||
|
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||||
|
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||||
|
|
||||||
|
auto dk_idx = thread_position_in_threadgroup.x;
|
||||||
|
auto dv_idx = thread_position_in_grid.y;
|
||||||
|
|
||||||
|
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||||
|
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||||
|
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||||
|
|
||||||
|
float state[n_per_t];
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = static_cast<float>(i_state[s_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// g: [B, T, Hv]
|
||||||
|
auto g_ = g + b_idx * T * Hv;
|
||||||
|
auto beta_ = beta + b_idx * T * Hv;
|
||||||
|
|
||||||
|
for (int t = 0; t < T; ++t) {
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] * g_[hv_idx];
|
||||||
|
kv_mem += state[i] * k_[s_idx];
|
||||||
|
}
|
||||||
|
kv_mem = simd_sum(kv_mem);
|
||||||
|
|
||||||
|
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||||
|
|
||||||
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] + k_[s_idx] * delta;
|
||||||
|
out += state[i] * q_[s_idx];
|
||||||
|
}
|
||||||
|
out = simd_sum(out);
|
||||||
|
if (thread_index_in_simdgroup == 0) {
|
||||||
|
y[dv_idx] = static_cast<InT>(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
q_ += Hk * Dk;
|
||||||
|
k_ += Hk * Dk;
|
||||||
|
v_ += Hv * Dv;
|
||||||
|
y += Hv * Dv;
|
||||||
|
g_ += Hv;
|
||||||
|
beta_ += Hv;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||||
|
vec := C.mlx_vector_string_new()
|
||||||
|
ok := true
|
||||||
|
for _, s := range values {
|
||||||
|
cs := C.CString(s)
|
||||||
|
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
C.free(unsafe.Pointer(cs))
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cleanup := func() {
|
||||||
|
C.mlx_vector_string_free(vec)
|
||||||
|
}
|
||||||
|
return vec, cleanup, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func initGatedDeltaMetalKernel() {
|
||||||
|
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
freeInputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeInputs()
|
||||||
|
|
||||||
|
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
freeOutputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeOutputs()
|
||||||
|
|
||||||
|
cName := C.CString("gated_delta_step")
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||||
|
defer C.free(unsafe.Pointer(cSource))
|
||||||
|
cHeader := C.CString("")
|
||||||
|
defer C.free(unsafe.Pointer(cHeader))
|
||||||
|
|
||||||
|
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||||
|
cName,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
cSource,
|
||||||
|
cHeader,
|
||||||
|
C.bool(true),
|
||||||
|
C.bool(false),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||||
|
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||||
|
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||||
|
if gatedDeltaMetalDisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
Hv, Dv := vd[2], vd[3]
|
||||||
|
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype := q.DType()
|
||||||
|
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||||
|
if gatedDeltaMetalDisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||||
|
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||||
|
|
||||||
|
cInT := C.CString("InT")
|
||||||
|
defer C.free(unsafe.Pointer(cInT))
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
for _, tpl := range []struct {
|
||||||
|
name string
|
||||||
|
value int
|
||||||
|
}{
|
||||||
|
{name: "Dk", value: Dk},
|
||||||
|
{name: "Dv", value: Dv},
|
||||||
|
{name: "Hk", value: Hk},
|
||||||
|
{name: "Hv", value: Hv},
|
||||||
|
} {
|
||||||
|
cn := C.CString(tpl.name)
|
||||||
|
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||||
|
C.free(unsafe.Pointer(cn))
|
||||||
|
if rc != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||||
|
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
threadY := Dv
|
||||||
|
if threadY > 4 {
|
||||||
|
threadY = 4
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
tScalar := FromValue(T)
|
||||||
|
inputs := []C.mlx_array{
|
||||||
|
q.ctx,
|
||||||
|
k.ctx,
|
||||||
|
v.ctx,
|
||||||
|
g.ctx,
|
||||||
|
beta.ctx,
|
||||||
|
state.ctx,
|
||||||
|
tScalar.ctx,
|
||||||
|
}
|
||||||
|
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||||
|
gatedDeltaMetalDisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
y = New("GATED_DELTA_METAL_Y")
|
||||||
|
nextState = New("GATED_DELTA_METAL_STATE")
|
||||||
|
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||||
|
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||||
|
return y, nextState, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
|
||||||
|
if repeatFactor <= 1 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
shape := x.Dims()
|
||||||
|
x = ExpandDims(x, 3)
|
||||||
|
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
|
||||||
|
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
|
||||||
|
Hv, Dv := int32(vd[2]), int32(vd[3])
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if vd[0] != int(B) || vd[1] != int(T) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
repeatFactor := int(Hv / Hk)
|
||||||
|
q = repeatHeadsForGatedDelta(q, repeatFactor)
|
||||||
|
k = repeatHeadsForGatedDelta(k, repeatFactor)
|
||||||
|
|
||||||
|
nextState = state
|
||||||
|
if T == 1 {
|
||||||
|
qt := Squeeze(q, 1)
|
||||||
|
kt := Squeeze(k, 1)
|
||||||
|
vt := Squeeze(v, 1)
|
||||||
|
gt := Squeeze(g, 1)
|
||||||
|
bt := Squeeze(beta, 1)
|
||||||
|
|
||||||
|
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||||
|
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||||
|
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||||
|
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||||
|
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||||
|
return ExpandDims(yt, 1), nextState
|
||||||
|
}
|
||||||
|
|
||||||
|
outs := make([]*Array, 0, T)
|
||||||
|
for t := int32(0); t < T; t++ {
|
||||||
|
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||||
|
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||||
|
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
|
||||||
|
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||||
|
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||||
|
|
||||||
|
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||||
|
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||||
|
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||||
|
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||||
|
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||||
|
outs = append(outs, ExpandDims(yt, 1))
|
||||||
|
}
|
||||||
|
return Concatenate(outs, 1), nextState
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatedDelta runs the recurrent update operation.
|
||||||
|
//
|
||||||
|
// It uses the fused Metal kernel when available and otherwise falls back to a
|
||||||
|
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||||
|
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||||
|
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||||
|
return y, nextState
|
||||||
|
}
|
||||||
|
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
|
||||||
|
if y == nil || nextState == nil {
|
||||||
|
panic("mlx.GatedDelta: fallback failed (invalid inputs or unsupported shapes)")
|
||||||
|
}
|
||||||
|
return y, nextState
|
||||||
|
}
|
||||||
@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
defer C.mlx_vector_array_free(vector)
|
defer C.mlx_vector_array_free(vector)
|
||||||
|
|
||||||
for _, output := range outputs {
|
for _, output := range outputs {
|
||||||
if output.Valid() {
|
if output != nil && output.Valid() {
|
||||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||||
|
out := New("CONV1D")
|
||||||
|
C.mlx_conv1d(
|
||||||
|
&out.ctx,
|
||||||
|
x.ctx,
|
||||||
|
weight.ctx,
|
||||||
|
C.int(stride),
|
||||||
|
C.int(padding),
|
||||||
|
C.int(dilation),
|
||||||
|
C.int(groups),
|
||||||
|
DefaultStream().ctx,
|
||||||
|
)
|
||||||
|
if bias != nil && bias.Valid() {
|
||||||
|
out = Add(out, bias)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||||
|
out := New("CONTIGUOUS")
|
||||||
|
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||||
|
groups := int32(x.Dim(x.NumDims() - 1))
|
||||||
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||||
|
}
|
||||||
|
|
||||||
// Convenience wrappers (function-style for the model code)
|
// Convenience wrappers (function-style for the model code)
|
||||||
|
|
||||||
func Stack(arrays []*Array, axis int) *Array {
|
func Stack(arrays []*Array, axis int) *Array {
|
||||||
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
|
|||||||
return a.Sigmoid()
|
return a.Sigmoid()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Exp(a *Array) *Array {
|
||||||
|
out := New("EXP")
|
||||||
|
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Log(a *Array) *Array {
|
||||||
|
out := New("LOG")
|
||||||
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||||
|
out := New("SOFTMAX_AXIS")
|
||||||
|
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||||
mask := New("")
|
mask := New("")
|
||||||
sinks := New("")
|
sinks := New("")
|
||||||
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
|||||||
|
|
||||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||||
out := New("FAST_RMSNORM")
|
out := New("FAST_RMSNORM")
|
||||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
var w C.mlx_array
|
||||||
|
if weight != nil {
|
||||||
|
w = weight.ctx
|
||||||
|
}
|
||||||
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,6 +429,15 @@ func Collect(v any) []*Array {
|
|||||||
return arrays
|
return arrays
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Copy(a *Array) *Array {
|
||||||
|
if a == nil || !a.Valid() {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
out := New("COPY")
|
||||||
|
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -16,11 +16,26 @@ import (
|
|||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func prefillChunkSize() int {
|
||||||
|
return 2 << 10
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||||
if r.Model == nil {
|
if r.Model == nil {
|
||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enableCompile := true
|
||||||
|
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||||
|
enableCompile = modelCompile.EnableCompile()
|
||||||
|
}
|
||||||
|
if enableCompile {
|
||||||
|
mlx.EnableCompile()
|
||||||
|
} else {
|
||||||
|
mlx.DisableCompile()
|
||||||
|
}
|
||||||
|
mlx.ResetPeakMemory()
|
||||||
|
ctx := request.Ctx
|
||||||
var (
|
var (
|
||||||
sample, logprobs *mlx.Array
|
sample, logprobs *mlx.Array
|
||||||
nextSample, nextLogprobs *mlx.Array
|
nextSample, nextLogprobs *mlx.Array
|
||||||
@@ -36,19 +51,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
mlx.LogArrays()
|
mlx.LogArrays()
|
||||||
r.cache.log()
|
r.cache.log()
|
||||||
}
|
}
|
||||||
|
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
enableCompile := true
|
|
||||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
||||||
enableCompile = modelCompile.EnableCompile()
|
|
||||||
}
|
|
||||||
if enableCompile {
|
|
||||||
mlx.EnableCompile()
|
|
||||||
} else {
|
|
||||||
mlx.DisableCompile()
|
|
||||||
}
|
|
||||||
mlx.ResetPeakMemory()
|
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
if len(inputs) == 0 {
|
if len(inputs) == 0 {
|
||||||
return errors.New("empty prompt")
|
return errors.New("empty prompt")
|
||||||
@@ -73,24 +78,30 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
defer session.close()
|
defer session.close()
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
|
prefillChunk := prefillChunkSize()
|
||||||
|
|
||||||
|
materializeCaches := func() {
|
||||||
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||||
|
for _, c := range caches {
|
||||||
|
state = appendCacheState(state, c)
|
||||||
|
}
|
||||||
|
if len(state) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mlx.Eval(state...)
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n := min(2<<10, total-processed-1)
|
n := min(prefillChunk, total-processed-1)
|
||||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
mlx.Eval(func() []*mlx.Array {
|
materializeCaches()
|
||||||
s := make([]*mlx.Array, 2*len(caches))
|
|
||||||
for i, c := range caches {
|
|
||||||
s[2*i], s[2*i+1] = c.State()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}()...)
|
|
||||||
processed += n
|
processed += n
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
@@ -117,7 +128,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,8 +150,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-ctx.Done():
|
||||||
return request.Ctx.Err()
|
return ctx.Err()
|
||||||
case request.Responses <- CompletionResponse{
|
case request.Responses <- CompletionResponse{
|
||||||
Content: r.Decode(output, &b),
|
Content: r.Decode(output, &b),
|
||||||
}:
|
}:
|
||||||
@@ -156,10 +167,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
final.EvalDuration = time.Since(now)
|
final.EvalDuration = time.Since(now)
|
||||||
final.PeakMemory = uint64(mlx.PeakMemory())
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-ctx.Done():
|
||||||
return request.Ctx.Err()
|
return ctx.Err()
|
||||||
case request.Responses <- final:
|
case request.Responses <- final:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,40 @@ type LinearLayer interface {
|
|||||||
OutputDim() int32
|
OutputDim() int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conv1d applies 1D convolution over NLC input.
|
||||||
|
type Conv1d struct {
|
||||||
|
Weight *mlx.Array
|
||||||
|
Bias *mlx.Array
|
||||||
|
Stride int32
|
||||||
|
Padding int32
|
||||||
|
Dilation int32
|
||||||
|
Groups int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||||
|
if stride <= 0 {
|
||||||
|
stride = 1
|
||||||
|
}
|
||||||
|
if dilation <= 0 {
|
||||||
|
dilation = 1
|
||||||
|
}
|
||||||
|
if groups <= 0 {
|
||||||
|
groups = 1
|
||||||
|
}
|
||||||
|
return &Conv1d{
|
||||||
|
Weight: weight,
|
||||||
|
Bias: bias,
|
||||||
|
Stride: stride,
|
||||||
|
Padding: padding,
|
||||||
|
Dilation: dilation,
|
||||||
|
Groups: groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
// Linear applies an affine transformation: y = x @ W.T + b
|
// Linear applies an affine transformation: y = x @ W.T + b
|
||||||
type Linear struct {
|
type Linear struct {
|
||||||
Weight *mlx.Array
|
Weight *mlx.Array
|
||||||
|
|||||||
1387
x/models/qwen3_5/qwen3_5.go
Normal file
1387
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
159
x/models/qwen3_5/qwen3_5_test.go
Normal file
159
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 8,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"num_experts": 16,
|
||||||
|
"num_experts_per_tok": 4,
|
||||||
|
"moe_intermediate_size": 2048,
|
||||||
|
"shared_expert_intermediate_size": 4096,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 500000,
|
||||||
|
"partial_rotary_factor": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.RopeTheta != 500000 {
|
||||||
|
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.RopeDim != 64 {
|
||||||
|
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||||
|
}
|
||||||
|
if cfg.FullAttentionInterval != 4 {
|
||||||
|
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||||
|
}
|
||||||
|
if !cfg.NormTopKProb {
|
||||||
|
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLayerSelectionHelpers(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
NumHiddenLayers: 6,
|
||||||
|
FullAttentionInterval: 3,
|
||||||
|
NumExperts: 8,
|
||||||
|
DecoderSparseStep: 2,
|
||||||
|
MLPOnlyLayers: []int32{1},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !layerIsLinear(cfg, 0) {
|
||||||
|
t.Fatalf("layer 0 should be linear")
|
||||||
|
}
|
||||||
|
if layerIsLinear(cfg, 2) {
|
||||||
|
t.Fatalf("layer 2 should be full attention")
|
||||||
|
}
|
||||||
|
|
||||||
|
if layerUsesMoE(cfg, 1) {
|
||||||
|
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||||
|
}
|
||||||
|
if !layerUsesMoE(cfg, 3) {
|
||||||
|
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTensorPathLayout(t *testing.T) {
|
||||||
|
dummy := mlx.New("dummy")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
wantContainer string
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "standard",
|
||||||
|
key: "model.embed_tokens.weight",
|
||||||
|
wantContainer: "",
|
||||||
|
wantModel: "model.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested language model with inner model",
|
||||||
|
key: "model.language_model.model.embed_tokens.weight",
|
||||||
|
wantContainer: "model.language_model.",
|
||||||
|
wantModel: "model.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested language model without inner model",
|
||||||
|
key: "model.language_model.embed_tokens.weight",
|
||||||
|
wantContainer: "model.language_model.",
|
||||||
|
wantModel: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||||
|
tt.key: dummy,
|
||||||
|
})
|
||||||
|
|
||||||
|
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||||
|
t.Fatalf(
|
||||||
|
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||||
|
layout.containerPrefix,
|
||||||
|
layout.modelPrefix,
|
||||||
|
tt.wantContainer,
|
||||||
|
tt.wantModel,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCachesLayout(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Config: &Config{
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearKeyHeadDim: 8,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 16,
|
||||||
|
},
|
||||||
|
Layers: []*Layer{
|
||||||
|
{IsLinear: true},
|
||||||
|
{IsLinear: false},
|
||||||
|
{IsLinear: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
caches := m.NewCaches()
|
||||||
|
if len(caches) != len(m.Layers) {
|
||||||
|
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||||
|
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||||
|
}
|
||||||
|
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||||
|
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||||
|
}
|
||||||
|
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||||
|
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||||
|
package qwen3_5_moe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user