Compare commits

...

13 Commits

Author SHA1 Message Date
Patrick Devine
67ce53b9b5 wip sampling 2026-02-28 23:39:34 -08:00
Patrick Devine
dd497534c4 allow think/nothink in mlxrunner 2026-02-28 23:35:54 -08:00
Patrick Devine
560626fb43 cleanup 2026-02-28 23:35:53 -08:00
Patrick Devine
1a23c1a810 add qwen3.5 2026-02-28 23:35:53 -08:00
Patrick Devine
a6c1aa4da5 smaller recurrent cache 2026-02-28 23:35:53 -08:00
Jeffrey Morgan
8da09b1e7e qwen3next: add compatibility with imported GGUF models (#14517) 2026-02-28 14:21:42 -08:00
Jesse Gross
a60b9adcce mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-27 17:29:47 -08:00
Jesse Gross
a16f96658b mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-27 17:29:47 -08:00
Jesse Gross
18ab09b431 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-27 17:29:47 -08:00
Jesse Gross
638faeac54 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-27 17:29:47 -08:00
Jesse Gross
dd5eb6337d mlxrunner: Fix panic on full KV cache hit
When the entire prompt was already cached (e.g. repeated prompt),
findRemaining returned an empty slice, causing FromValues to panic
on an index-out-of-range accessing a zero-length byte slice.

Fix by always keeping at least one token to re-evaluate so the
pipeline can seed token generation. Also reject empty prompts
early rather than panicking.
2026-02-27 11:07:03 -08:00
Patrick Devine
79917cf80b show peak memory usage (#14485) 2026-02-26 18:38:27 -08:00
Parth Sareen
cc90a035a0 model/parsers: add stable tool call indexing for glm47 and qwen3 parsers (#14484) 2026-02-26 18:14:29 -08:00
47 changed files with 4033 additions and 302 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -569,6 +570,7 @@ type DebugInfo struct {
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
} }
if m.PeakMemory > 0 {
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
}
if m.LoadDuration > 0 { if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
} }
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
} }
} }
func formatPeakMemory(b uint64) string {
if b >= format.GibiByte {
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
}
return format.HumanBytes2(b)
}
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct

View File

@@ -74,8 +74,7 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
VRAMSize() uint64 // Total VRAM across all GPUs MemorySize() (total, vram uint64)
TotalSize() uint64
VRAMByGPU(id ml.DeviceID) uint64 VRAMByGPU(id ml.DeviceID) uint64
Pid() int Pid() int
GetPort() int GetPort() int
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache // For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) || (len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) { (s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1457,6 +1457,8 @@ type CompletionRequest struct {
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
Think *api.ThinkValue
ExplicitOptions map[string]struct{}
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
Shift bool Shift bool
@@ -1518,6 +1520,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"` EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"` EvalDuration time.Duration `json:"eval_duration"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
@@ -1848,17 +1851,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil return nil
} }
func (s *llmServer) VRAMSize() uint64 { func (s *llmServer) MemorySize() (total, vram uint64) {
if s.mem == nil { if s.mem == nil {
return 0 return 0, 0
} }
var mem uint64
for _, g := range s.mem.GPUs { for _, g := range s.mem.GPUs {
mem += g.Size() vram += g.Size()
} }
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers // Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading. // on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true noCPULayers := true
@@ -1869,25 +1872,11 @@ func (s *llmServer) VRAMSize() uint64 {
} }
} }
if noCPULayers { if noCPULayers {
mem += s.mem.InputWeights vram += s.mem.InputWeights
mem += s.mem.CPU.Graph vram += s.mem.CPU.Graph
} }
return mem return total, vram
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
} }
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 { func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35) SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35) SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"` SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default: default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections") return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
} }
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A // Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT) alphaBiased := alpha.Add(ctx, gdn.SSMDT)

View File

@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil return m.Output.Forward(ctx, hiddenStates), nil
} }
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil m.positionCache = nil
if len(m.mropeSections) > 0 { if len(m.mropeSections) > 0 {
@@ -450,6 +490,64 @@ var (
_ model.MultimodalProcessor = (*Model)(nil) _ model.MultimodalProcessor = (*Model)(nil)
) )
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count")) numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers) layers := make([]Layer, numLayers)
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64 HeadCountKV() []uint64
} }
var isRecurrent []bool
var headCountKV []uint64 var headCountKV []uint64
if hc, ok := c.(headCounts); ok { if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV() headCountKV = hc.HeadCountKV()
} }
isRecurrent = make([]bool, numLayers) isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
hasZero := false if err != nil {
hasFull := false return nil, err
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
} }
// Determine if MoE // Determine if MoE
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")), ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")), ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")), convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false), vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
isRecurrent: isRecurrent, isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) { mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections { for _, section := range mropeSections {
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)), mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
} }
if opts.numKVHeads == 0 { if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value") return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
} }
// Calculate cache dimensions // Calculate cache dimensions

View File

@@ -0,0 +1,65 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -0,0 +1,45 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}

View File

@@ -35,6 +35,7 @@ type GLM46Parser struct {
state glm46ParserState state glm46ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *GLM46Parser) HasToolSupport() bool { func (p *GLM46Parser) HasToolSupport() bool {
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { // func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools return tools
} }
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("glm-4.6 tool call parsing failed", "error", err) slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent: case glm46EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -11,6 +11,7 @@ type GLM47Parser struct {
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
// When thinking is enabled (nil or true), the prompt ends with <think>, // When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag). // so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() { if thinkValue == nil || thinkValue.Bool() {

View File

@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
t.Fatalf("expected %#v, got %#v", expected, toolCall) t.Fatalf("expected %#v, got %#v", expected, toolCall)
} }
} }
func TestGLM47ParserToolCallIndexing(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
input := `plan</think>
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -38,6 +38,7 @@ type Qwen3Parser struct {
state qwen3ParserState state qwen3ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
hasThinkingSupport bool hasThinkingSupport bool
defaultThinking bool defaultThinking bool
maybeThinkingOpenAtBOL bool maybeThinkingOpenAtBOL bool
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.buffer.Reset() p.buffer.Reset()
p.callIndex = 0
thinkingEnabled := thinkValue != nil && thinkValue.Bool() thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil { if thinkValue == nil {
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("qwen3 tool call parsing failed", "error", err) slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
calls = append(calls, toolCall) calls = append(calls, toolCall)
case qwen3EventThinkingContent: case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
t.Fatalf("expected no tool calls, got %d", len(calls)) t.Fatalf("expected no tool calls, got %d", len(calls))
} }
} }
func TestQwen3ParserToolCallIndexing(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
var all []api.ToolCall
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -32,6 +32,7 @@ type Qwen3CoderParser struct {
state qwenParserState state qwenParserState
acc strings.Builder acc strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *Qwen3CoderParser) HasToolSupport() bool { func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools // Qwen doesn't modify tools return tools // Qwen doesn't modify tools
} }
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
slog.Warn("qwen tool call parsing failed", "error", err) slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case qwenEventContent: case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content // TODO(drifkin): if the same turn contains multiple interleaved content

View File

@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
} }
} }
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}
func TestQwenXMLTransform(t *testing.T) { func TestQwenXMLTransform(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string

View File

@@ -71,6 +71,10 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports // Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability { func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}

View File

@@ -30,6 +30,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1 lastMsgIdx := len(msgs) - 1
currMsgIdx := 0 currMsgIdx := 0
if truncate {
// Start with all messages and remove from the front until it fits in context // Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ { for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip // Collect system messages from the portion we're about to skip
@@ -57,7 +58,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
if !truncate || ctxLen <= opts.NumCtx { if ctxLen <= opts.NumCtx {
currMsgIdx = i currMsgIdx = i
break break
} }
@@ -68,6 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
break break
} }
} }
}
if currMsgIdx > 0 { if currMsgIdx > 0 {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:])) slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))

View File

@@ -130,6 +130,35 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
return opts, nil return opts, nil
} }
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
keys := []string{
"temperature",
"top_p",
"min_p",
"top_k",
"repeat_last_n",
"repeat_penalty",
"presence_penalty",
"frequency_penalty",
}
explicit := make(map[string]struct{}, len(keys))
for _, key := range keys {
if optionSpecified(modelOpts, requestOpts, key) {
explicit[key] = struct{}{}
}
}
return explicit
}
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
if _, ok := requestOpts[key]; ok {
return true
}
_, ok := modelOpts[key]
return ok
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
@@ -484,7 +513,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer // the real chat handler, but doing this as a stopgap to get renderer
// support for generate // support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" { if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -542,6 +572,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
Think: req.Think,
ExplicitOptions: explicitOptions(m.Options, req.Options),
Shift: req.Shift == nil || *req.Shift, Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate, Truncate: req.Truncate == nil || *req.Truncate,
Logprobs: req.Logprobs, Logprobs: req.Logprobs,
@@ -557,6 +589,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
PeakMemory: cr.PeakMemory,
}, },
Logprobs: toAPILogprobs(cr.Logprobs), Logprobs: toAPILogprobs(cr.Logprobs),
} }
@@ -1951,6 +1984,9 @@ func (s *Server) PsHandler(c *gin.Context) {
} }
if v.llama != nil { if v.llama != nil {
mr.ContextLength = v.llama.ContextLength() mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
} }
// The scheduler waits to set expiresAt, so if a model is loading it's // The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just // possible that it will be set to the unix epoch. For those cases, just
@@ -2213,6 +2249,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
truncate := req.Truncate == nil || *req.Truncate truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
@@ -2294,6 +2333,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
Images: images, Images: images,
Format: currentFormat, Format: currentFormat,
Options: opts, Options: opts,
Think: req.Think,
ExplicitOptions: explicitOptions(m.Options, req.Options),
Shift: req.Shift == nil || *req.Shift, Shift: req.Shift == nil || *req.Shift,
Truncate: truncate, Truncate: truncate,
Logprobs: req.Logprobs, Logprobs: req.Logprobs,
@@ -2309,6 +2350,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
PeakMemory: r.PeakMemory,
}, },
Logprobs: toAPILogprobs(r.Logprobs), Logprobs: toAPILogprobs(r.Logprobs),
} }

View File

@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
// Check for experimental safetensors LLM models // Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" { if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") { if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner // LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) { if s.loadMLX(pending) {
@@ -536,6 +536,7 @@ iGPUScan:
} }
} }
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -545,8 +546,8 @@ iGPUScan:
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
gpus: gpuIDs, gpus: gpuIDs,
discreteGPUs: discreteGPUs, discreteGPUs: discreteGPUs,
vramSize: llama.VRAMSize(), totalSize: totalSize,
totalSize: llama.TotalSize(), vramSize: vramSize,
loading: true, loading: true,
pid: llama.Pid(), pid: llama.Pid(),
} }
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
loading: false, loading: false,
isImagegen: isImagegen, isImagegen: isImagegen,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
totalSize: server.TotalSize(), totalSize: totalSize,
vramSize: server.VRAMSize(), vramSize: vramSize,
} }
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? (!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
runner.llama.Ping(ctx) != nil { runner.llama.Ping(ctx) != nil {
return true return true
} }

View File

@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
s.closeCalled = true s.closeCalled = true
return s.closeResp return s.closeResp
} }
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 } func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 } func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
} }
} }
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
return false
}
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor. // GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized. // Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization: // This implements mixed-precision quantization:
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel) // - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization // - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string { func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
// Use basic name-based check first // Use basic name-based check first
if !ShouldQuantize(name, "") { if !stackedExpert && !ShouldQuantize(name, "") {
return "" return ""
} }
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) // Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
if len(shape) != 2 { // e.g. qwen switch_mlp / experts combined tensors.
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
return "" return ""
} }
// Skip small tensors (less than 1024 elements) - not worth quantizing // Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { var elems int64 = 1
for _, d := range shape {
elems *= int64(d)
}
if elems < 1024 {
return "" return ""
} }

View File

@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized // 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false}, {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false}, {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
// Embeddings should not be quantized regardless of shape // Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false}, {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
} }
} }
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
gateUp := GetTensorQuantization(
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
[]int32{64, 22016, 4096},
"int4",
)
if gateUp != "int4" {
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
}
down := GetTensorQuantization(
"model.layers.1.mlp.experts.down_proj.weight",
[]int32{64, 4096, 14336},
"int4",
)
if down != "int8" {
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
}
combinedGateUp := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.gate_up_proj",
[]int32{256, 1024, 2048},
"int8",
)
if combinedGateUp != "int8" {
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
}
combinedDown := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.down_proj",
[]int32{256, 2048, 512},
"int4",
)
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()

View File

@@ -374,14 +374,9 @@ func (s *Server) Close() error {
return nil return nil
} }
// VRAMSize returns the estimated VRAM usage. // MemorySize returns the total and VRAM memory usage.
func (s *Server) VRAMSize() uint64 { func (s *Server) MemorySize() (total, vram uint64) {
return s.vramSize return s.vramSize, s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
} }
// VRAMByGPU returns VRAM usage for a specific GPU. // VRAMByGPU returns VRAM usage for a specific GPU.

View File

@@ -30,21 +30,64 @@ type cacheSession struct {
remaining []int32 remaining []int32
} }
func (c *kvCache) free() {
for i, kv := range c.caches {
if kv == nil {
continue
}
kv.Free()
c.caches[i] = nil
}
c.caches = nil
c.tokens = nil
}
func (c *kvCache) cachesCanTrim() bool {
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.CanTrim() {
return false
}
}
return true
}
func (c *kvCache) trimToPrefix(prefix int) {
for _, kv := range c.caches {
if kv == nil || !kv.CanTrim() {
continue
}
if trim := kv.Offset() - prefix; trim > 0 {
kv.Trim(trim)
}
}
if prefix < len(c.tokens) {
c.tokens = c.tokens[:prefix]
}
}
// begin prepares caches for a new request. It finds the nearest // begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match. // matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
if len(c.caches) == 0 { ensureCaches := func() {
if len(c.caches) != 0 {
return
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches() c.caches = cacheFactory.NewCaches()
} else { return
}
c.caches = make([]cache.Cache, m.NumLayers()) c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches { for i := range c.caches {
c.caches[i] = cache.NewKVCache() c.caches[i] = cache.NewKVCache()
} }
} }
} ensureCaches()
remaining := c.findRemaining(inputs) remaining := c.findRemaining(inputs)
ensureCaches()
return &cacheSession{ return &cacheSession{
cache: c, cache: c,
@@ -56,18 +99,34 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// close saves the token state if the forward pass ran. // close saves the token state if the forward pass ran.
func (s *cacheSession) close() { func (s *cacheSession) close() {
if offset := s.caches[0].Offset(); offset > 0 { if len(s.caches) == 0 {
// Ensure that if we have run the forward pass and set the metadata return
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
} }
offset := -1
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
}
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
arrays = append(arrays, kv.Materialize()...)
}
if offset <= 0 {
return
}
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...) mlx.AsyncEval(arrays...)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset] stored := append(s.inputs, s.outputs...)
if offset > len(stored) {
offset = len(stored)
} }
s.cache.tokens = stored[:offset]
} }
// findRemaining finds the longest common prefix between tokens and the cached // findRemaining finds the longest common prefix between tokens and the cached
@@ -78,12 +137,20 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix++ prefix++
} }
if prefix < len(c.tokens) { // Always keep at least one token to re-evaluate so the
trim := len(c.tokens) - prefix // pipeline can seed token generation from it.
for _, kv := range c.caches { if prefix == len(tokens) && prefix > 0 {
kv.Trim(trim) prefix--
}
if prefix < len(c.tokens) {
if c.cachesCanTrim() {
c.trimToPrefix(prefix)
} else {
c.free()
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return tokens
} }
c.tokens = c.tokens[:prefix]
} }
if prefix == 0 { if prefix == 0 {
@@ -98,10 +165,21 @@ func (c *kvCache) log() {
if len(c.caches) == 0 { if len(c.caches) == 0 {
return return
} }
offset := -1
var totalBytes int var totalBytes int
for _, kv := range c.caches { for _, kv := range c.caches {
k, v := kv.State() if kv == nil {
totalBytes += k.NumBytes() + v.NumBytes() continue
} }
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes))) if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
for _, a := range kv.Materialize() {
totalBytes += a.NumBytes()
}
}
if offset < 0 {
return
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
} }

View File

@@ -10,6 +10,8 @@ import (
type Cache interface { type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array) State() (keys, values *mlx.Array)
Materialize() []*mlx.Array
CanTrim() bool
Trim(int) int Trim(int) int
Clone() Cache Clone() Cache
Free() Free()
@@ -67,6 +69,20 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
} }
// Materialize returns the backing key/value buffers currently held by the cache.
func (c *KVCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.keys != nil && c.keys.Valid() {
out = append(out, c.keys)
}
if c.values != nil && c.values.Valid() {
out = append(out, c.values)
}
return out
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int { func (c *KVCache) Trim(n int) int {
n = min(c.offset, n) n = min(c.offset, n)
c.offset -= n c.offset -= n
@@ -190,6 +206,8 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
return c.keys, c.values return c.keys, c.values
} }
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int { func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n) n = min(c.offset, n)
c.offset -= n c.offset -= n

220
x/mlxrunner/cache/recurrent.go vendored Normal file
View File

@@ -0,0 +1,220 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
}
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
// Break dependency chains so recurrent state does not retain the full
// per-token compute graph over time.
snap := mlx.Snapshot(v)
mlx.Eval(snap)
old := *dst
*dst = snap
mlx.Pin(snap)
// Drop references to the previous cached state root and transient incoming
// graph root now that a detached snapshot is retained in cache. Actual
// cleanup happens at the runner's normal sweep points.
if old != nil && old != snap {
mlx.Unpin(old)
}
if v != snap && v != old {
mlx.Unpin(v)
}
}
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
old := *dst
*dst = v
mlx.Pin(v)
if old != nil && old != v {
mlx.Unpin(old)
}
}
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := mlx.Detach(root)
old := *dst
*dst = detached
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Unpin(old)
}
// Intentionally do not force-release root/v here. In the fast path, the detached
// handle aliases the same MLX value and may still be lazily computed. Releasing the
// source handles can invalidate the cached state before the next eval/sweep point.
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Snapshot(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
if needConv {
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
c.setStateMaterialized(&c.convState, v)
}
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point. The conv-state input is usually a slice view, so request a
// compact contiguous copy to avoid pinning the whole source buffer.
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
c.setStateDetached(&c.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
c.setStateMaterialized(&c.deltaState, v)
}
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point.
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
c.setStateDetached(&c.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
return c.convState, c.deltaState
}
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
func (c *RecurrentCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.convState != nil && c.convState.Valid() {
out = append(out, c.convState)
}
if c.deltaState != nil && c.deltaState.Valid() {
out = append(out, c.deltaState)
}
return out
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
convState: snapshotPinned(c.convState),
deltaState: snapshotPinned(c.deltaState),
}
return clone
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

View File

@@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@@ -19,19 +18,21 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
) )
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct { type Client struct {
port int port int
modelName string modelName string
vramSize uint64 contextLength atomic.Int64
memory atomic.Uint64
done chan error done chan error
client *http.Client client *http.Client
lastErr string lastErr string
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
} }
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{ c := &Client{
port: port, port: port,
modelName: modelName, modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1), done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute}, client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd, cmd: cmd,
@@ -190,17 +182,36 @@ func (c *Client) waitUntilRunning() error {
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. // completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct { type completionRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options *completionOpts `json:"options,omitempty"` Options *completionOpts `json:"options,omitempty"`
} }
type completionOpts struct { type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"` Temperature *float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP *float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"` MinP *float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK *int `json:"top_k,omitempty"`
RepeatLastN *int `json:"repeat_last_n,omitempty"`
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"` NumPredict int `json:"num_predict,omitempty"`
} }
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
}
// Close terminates the subprocess. // Close terminates the subprocess.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock() c.mu.Lock()
@@ -222,15 +233,26 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer. // Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var think *bool
if req.Think != nil {
enabled := req.Think.Bool()
think = &enabled
}
creq := completionRequest{ creq := completionRequest{
Prompt: req.Prompt, Prompt: req.Prompt,
Think: think,
} }
if req.Options != nil { if req.Options != nil {
creq.Options = &completionOpts{ creq.Options = &completionOpts{
Temperature: req.Options.Temperature, Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")),
TopP: req.Options.TopP, TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")),
MinP: req.Options.MinP, MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")),
TopK: req.Options.TopK, TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")),
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")),
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
NumPredict: req.Options.NumPredict, NumPredict: req.Options.NumPredict,
} }
} }
@@ -260,28 +282,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
var raw struct { var raw CompletionResponse
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue continue
} }
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{ cresp := llm.CompletionResponse{
Content: raw.Content, Content: raw.Content,
Done: raw.Done, Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason), DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount, PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration), PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration), EvalDuration: raw.EvalDuration,
PeakMemory: raw.PeakMemory,
} }
fn(cresp) fn(cresp)
@@ -293,8 +312,27 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return scanner.Err() return scanner.Err()
} }
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
_, ok := explicit[key]
return ok
}
func float32Ptr(v float32, ok bool) *float32 {
if !ok {
return nil
}
return &v
}
func intPtr(v int, ok bool) *int {
if !ok {
return nil
}
return &v
}
func (c *Client) ContextLength() int { func (c *Client) ContextLength() int {
return math.MaxInt return int(c.contextLength.Load())
} }
// Detokenize implements llm.LlamaServer. // Detokenize implements llm.LlamaServer.
@@ -347,9 +385,16 @@ func (c *Client) Pid() int {
return -1 return -1
} }
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint64
}
// Ping implements llm.LlamaServer. // Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error { func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port) reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil { if err != nil {
return err return err
@@ -362,6 +407,15 @@ func (c *Client) Ping(ctx context.Context) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode) return fmt.Errorf("health check failed: %d", resp.StatusCode)
} }
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength.Store(int64(status.ContextLength))
c.memory.Store(status.Memory)
return nil return nil
} }
@@ -388,19 +442,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
return tokens, nil return tokens, nil
} }
// TotalSize implements llm.LlamaServer. func (c *Client) currentMemory() uint64 {
func (c *Client) TotalSize() uint64 { ctx, cancel := context.WithTimeout(context.Background(), time.Second)
return c.vramSize defer cancel()
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return c.memory.Load()
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
} }
// VRAMByGPU implements llm.LlamaServer. // VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.vramSize return c.currentMemory()
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
return c.vramSize
} }
// WaitUntilRunning implements llm.LlamaServer. // WaitUntilRunning implements llm.LlamaServer.

167
x/mlxrunner/client_test.go Normal file
View File

@@ -0,0 +1,167 @@
package mlxrunner
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestCompletionForwardsThink(t *testing.T) {
boolPtr := func(v bool) *bool { return &v }
testCases := []struct {
name string
think *api.ThinkValue
want *bool
}{
{name: "unset", think: nil, want: nil},
{name: "enabled", think: &api.ThinkValue{Value: true}, want: boolPtr(true)},
{name: "disabled", think: &api.ThinkValue{Value: false}, want: boolPtr(false)},
{name: "level maps to enabled", think: &api.ThinkValue{Value: "high"}, want: boolPtr(true)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/completion" {
t.Fatalf("request path = %q, want %q", r.URL.Path, "/completion")
}
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Think: tc.think,
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Prompt != "hello" {
t.Fatalf("prompt = %q, want %q", got.Prompt, "hello")
}
switch {
case tc.want == nil && got.Think != nil:
t.Fatalf("think = %v, want nil", *got.Think)
case tc.want != nil && got.Think == nil:
t.Fatalf("think = nil, want %v", *tc.want)
case tc.want != nil && got.Think != nil && *tc.want != *got.Think:
t.Fatalf("think = %v, want %v", *got.Think, *tc.want)
}
})
}
}
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
opts := &api.Options{
Temperature: 1.0,
TopP: 0.95,
MinP: 0.1,
TopK: 20,
RepeatLastN: 128,
RepeatPenalty: 1.2,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
NumPredict: 64,
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Options: opts,
ExplicitOptions: map[string]struct{}{
"temperature": {},
"top_k": {},
"repeat_penalty": {},
"presence_penalty": {},
},
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Options == nil {
t.Fatal("options = nil, want serialized options")
}
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
}
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
}
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
}
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
}
if got.Options.TopP != nil {
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
}
if got.Options.MinP != nil {
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
}
if got.Options.RepeatLastN != nil {
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
}
if got.Options.FrequencyPenalty != nil {
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
}
if got.Options.NumPredict != opts.NumPredict {
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

View File

@@ -7,4 +7,6 @@ import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama" _ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3" _ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
) )

View File

@@ -0,0 +1,275 @@
//go:build mlx
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"sync/atomic"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled atomic.Bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}

View File

@@ -64,6 +64,10 @@ func PeakMemory() int {
return int(peak) return int(peak)
} }
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{} type Memory struct{}
func (Memory) LogValue() slog.Value { func (Memory) LogValue() slog.Value {

View File

@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
defer C.mlx_vector_array_free(vector) defer C.mlx_vector_array_free(vector)
for _, output := range outputs { for _, output := range outputs {
if output.Valid() { if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx) C.mlx_vector_array_append_value(vector, output.ctx)
} }
} }

View File

@@ -93,6 +93,12 @@ func (t *Array) Divide(other *Array) *Array {
return out return out
} }
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM")
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array { func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS") out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
@@ -123,12 +129,30 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out return out
} }
func (t *Array) GreaterEqual(other *Array) *Array {
out := New("GREATER_EQUAL")
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array { func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP") out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out return out
} }
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LogicalOr(other *Array) *Array {
out := New("LOGICAL_OR")
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array { func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL") out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View File

@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
return out return out
} }
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Convenience wrappers (function-style for the model code) // Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array { func Stack(arrays []*Array, axis int) *Array {
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
return a.Sigmoid() return a.Sigmoid()
} }
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array { func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("") mask := New("")
sinks := New("") sinks := New("")
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
func RMSNormFn(x, weight *Array, eps float32) *Array { func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM") out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
return out return out
} }
@@ -378,6 +429,27 @@ func Collect(v any) []*Array {
return arrays return arrays
} }
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
func Snapshot(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("SNAPSHOT")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Detach returns a new Array handle that shares the same MLX value but does
// not retain Go-side graph input references.
func Detach(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("DETACH")
C.mlx_array_set(&out.ctx, a.ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() { if !v.IsValid() {
return return

View File

@@ -20,6 +20,7 @@ type Model interface {
Unembed(x *mlx.Array) *mlx.Array Unembed(x *mlx.Array) *mlx.Array
NumLayers() int NumLayers() int
Tokenizer() *tokenizer.Tokenizer Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns // LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert // them to model fields. Model-specific logic (MLA absorption, expert

View File

@@ -6,18 +6,30 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net/http"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error { func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil { if r.Model == nil {
return errors.New("model not loaded") return errors.New("model not loaded")
} }
ctx := request.Ctx
if ctx == nil {
ctx = context.Background()
}
var ( var (
sample, logprobs *mlx.Array sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array nextSample, nextLogprobs *mlx.Array
@@ -44,43 +56,72 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} else { } else {
mlx.DisableCompile() mlx.DisableCompile()
} }
mlx.ResetPeakMemory()
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
defer session.close() defer session.close()
caches := session.caches caches := session.caches
tokens := session.remaining tokens := session.remaining
history := append([]int32(nil), session.inputs...)
prefillChunk := prefillChunkSize()
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
if c == nil {
continue
}
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0 total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 { for total-processed > 1 {
if err := request.Ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
n := min(2<<10, total-processed-1) n := min(prefillChunk, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep() mlx.Sweep()
mlx.Eval(func() []*mlx.Array { materializeCaches()
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total) slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache() mlx.ClearCache()
} }
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true)) logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sample(logprobs) sample := request.Sample(logprobs, history)
mlx.Pin(sample, logprobs) mlx.Pin(sample, logprobs)
mlx.Sweep() mlx.Sweep()
@@ -89,45 +130,42 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs return sample, logprobs
} }
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed)) sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history)
var b bytes.Buffer var b bytes.Buffer
now := time.Now() final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
nextSample, nextLogprobs = step(sample)
if i == 0 { if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample) mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now) final.PromptEvalDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
output := int32(sample.Int()) output := int32(sample.Int())
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
history = append(history, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0 final.DoneReason = 0
final.CompletionTokens = i final.EvalCount = i
break break
} }
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()
case request.Responses <- Response{ case request.Responses <- CompletionResponse{
Text: r.Decode(output, &b), Content: r.Decode(output, &b),
Token: int(output),
}: }:
} }
nextSample, nextLogprobs = step(sample, history)
mlx.Unpin(sample, logprobs) mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil nextSample, nextLogprobs = nil, nil
@@ -137,10 +175,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
final.CompletionTokensDuration = time.Since(now) final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select { select {
case <-request.Ctx.Done(): case <-ctx.Done():
return request.Ctx.Err() return ctx.Err()
case request.Responses <- final: case request.Responses <- final:
return nil return nil
} }

View File

@@ -4,14 +4,15 @@ package mlxrunner
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -21,7 +22,7 @@ import (
type Request struct { type Request struct {
TextCompletionsRequest TextCompletionsRequest
Responses chan Response Responses chan CompletionResponse
Pipeline func(Request) error Pipeline func(Request) error
Ctx context.Context Ctx context.Context
@@ -31,11 +32,16 @@ type Request struct {
type TextCompletionsRequest struct { type TextCompletionsRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options struct { Options struct {
Temperature float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP *float32 `json:"top_p"`
MinP float32 `json:"min_p"` MinP *float32 `json:"min_p"`
TopK int `json:"top_k"` TopK *int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead // Deprecated: use MaxTokens instead
@@ -43,25 +49,12 @@ type TextCompletionsRequest struct {
} `json:"options"` } `json:"options"`
} }
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
Requests chan Request Requests chan Request
cache kvCache cache kvCache
contextLength int
} }
func (r *Runner) Load(modelName string) error { func (r *Runner) Load(modelName string) error {
@@ -90,6 +83,7 @@ func (r *Runner) Load(modelName string) error {
r.Model = m r.Model = m
r.Tokenizer = m.Tokenizer() r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
return nil return nil
} }
@@ -158,6 +152,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests: case request := <-r.Requests:
if err := request.Pipeline(request); err != nil { if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err) slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
} }
close(request.Responses) close(request.Responses)

View File

@@ -9,69 +9,204 @@ import (
) )
type Sampler interface { type Sampler interface {
Sample(*mlx.Array) *mlx.Array Sample(*mlx.Array, []int32) *mlx.Array
} }
func New(temp, top_p, min_p float32, top_k int) Sampler { func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler var samplers []Sampler
if top_p > 0 && top_p < 1 { if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) {
samplers = append(samplers, TopP(top_p)) samplers = append(samplers, Penalty{
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
})
} }
if min_p != 0 { if temp == 0 {
samplers = append(samplers, MinP(min_p)) samplers = append(samplers, greedy{})
} else {
samplers = append(samplers, Distribution{
Temperature: temp,
TopK: top_k,
TopP: top_p,
MinP: min_p,
})
} }
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers) return chain(samplers)
} }
type greedy struct{} type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array { func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
return logits.Argmax(-1, false) return logits.Argmax(-1, false)
} }
type chain []Sampler type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array { func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array {
for _, sampler := range c { for _, sampler := range c {
logits = sampler.Sample(logits) logits = sampler.Sample(logits, history)
} }
return logits return logits
} }
type Temperature float32 type Distribution struct {
Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { TopK int
return mlx.DivScalar(logits, float32(t)).Categorical(-1) TopP float32
MinP float32
} }
type TopP float32 func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
filtered, indices := d.filter(logits)
sample := filtered.Categorical(-1)
if indices == nil {
return sample
}
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array { positions := sample.ExpandDims(1)
// TODO: implement return indices.TakeAlongAxis(positions, -1).Squeeze(1)
}
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) {
candidates := logits
var candidateIndices *mlx.Array
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) {
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1)
switch logits.NumDims() {
case 1:
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
default:
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
}
candidates = logits.TakeAlongAxis(candidateIndices, -1)
}
if d.Temperature != 1 {
candidates = mlx.DivScalar(candidates, d.Temperature)
}
if !d.needsProbabilityFilters() {
return candidates, candidateIndices
}
order := candidates.Negative().ArgsortAxis(-1)
sortedLogits := candidates.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
remove := d.topPRemovalMask(sortedProbs)
if d.MinP > 0 {
minPRemove := d.minPRemovalMask(sortedProbs)
if remove == nil {
remove = minPRemove
} else {
remove = remove.LogicalOr(minPRemove)
}
}
if remove == nil {
return candidates, candidateIndices
}
negInf := mlx.FromValue(float32(math.Inf(-1)))
filtered := mlx.Where(remove, negInf, sortedLogits)
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
}
func (d Distribution) needsProbabilityFilters() bool {
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
}
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
if d.TopP <= 0 || d.TopP >= 1 {
return nil
}
threshold := mlx.NewScalarArray(d.TopP)
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
return prevCum.GreaterEqual(threshold)
}
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
if d.MinP <= 0 {
return nil
}
var maxProb *mlx.Array
switch sortedProbs.NumDims() {
case 1:
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
default:
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
}
threshold := mlx.MulScalar(maxProb, d.MinP)
return sortedProbs.Less(threshold)
}
type Penalty struct {
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
}
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
if len(history) == 0 {
return logprobs return logprobs
} }
type MinP float32 window := p.RepeatLastN
if window <= 0 || window > len(history) {
window = len(history)
}
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array { counts := make(map[int32]int, window)
// TODO: implement order := make([]int32, 0, window)
for _, token := range history[len(history)-window:] {
if token < 0 {
continue
}
if counts[token] == 0 {
order = append(order, token)
}
counts[token]++
}
if len(order) == 0 {
return logprobs return logprobs
} }
type TopK int indexShape := []int32{int32(len(order))}
valueShape := []int{len(order)}
if logprobs.NumDims() > 1 {
indexShape = []int32{1, int32(len(order))}
valueShape = []int{1, len(order)}
}
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array { indices := mlx.NewArrayInt32(order, indexShape)
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0)) selected := logprobs.TakeAlongAxis(indices, -1)
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) mlx.Eval(selected)
values := selected.Floats()
for i, token := range order {
v := values[i]
if p.RepeatPenalty != 1 {
if v < 0 {
v *= p.RepeatPenalty
} else {
v /= p.RepeatPenalty
}
}
if p.PresencePenalty != 0 {
v -= p.PresencePenalty
}
if p.FrequencyPenalty != 0 {
v -= p.FrequencyPenalty * float32(counts[token])
}
values[i] = v
}
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
} }

View File

@@ -0,0 +1,104 @@
//go:build mlx
package sample
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestPenaltySample(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, -2.0, 3.0, 4.0,
}, 1, 4)
got := Penalty{
RepeatLastN: 3,
RepeatPenalty: 2.0,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
}.Sample(logprobs, []int32{2, 1, 2})
mlx.Eval(got)
want := []float32{1.0, -5.75, -0.5, 4.0}
values := got.Floats()
if len(values) != len(want) {
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
}
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, 2.0, 3.0,
}, 1, 3)
got := Penalty{
RepeatLastN: 1,
PresencePenalty: 1.0,
}.Sample(logprobs, []int32{0, 1})
mlx.Eval(got)
want := []float32{1.0, 1.0, 3.0}
values := got.Floats()
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestDistributionFilterTopP(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logits := mlx.FromValues([]float32{
10.0, 9.0, 1.0, 0.0,
}, 1, 4)
filtered, indices := Distribution{
Temperature: 1.0,
TopK: 2,
TopP: 0.55,
}.filter(logits)
got := materializeFilteredLogits(filtered, indices, 4)
mlx.Eval(got)
values := got.Floats()
if values[0] != 10.0 {
t.Fatalf("values[0] = %v, want 10", values[0])
}
for i := 1; i < len(values); i++ {
if !math.IsInf(float64(values[i]), -1) {
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
}
}
}
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
if indices == nil {
return filtered
}
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
return base.PutAlongAxis(indices, filtered, -1)
}

View File

@@ -16,12 +16,89 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample" "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/models/qwen3_5"
) )
type samplingConfig struct {
temperature float32
topP float32
minP float32
topK int
repeatLastN int
repeatPenalty float32
presencePenalty float32
frequencyPenalty float32
}
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
if _, ok := m.(*qwen3_5.Model); ok {
cfg := samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
}
if think != nil && !*think {
cfg.temperature = 0.7
cfg.topP = 0.8
}
return cfg
}
opts := api.DefaultOptions()
return samplingConfig{
temperature: opts.Temperature,
topP: opts.TopP,
minP: opts.MinP,
topK: opts.TopK,
repeatLastN: opts.RepeatLastN,
repeatPenalty: opts.RepeatPenalty,
presencePenalty: opts.PresencePenalty,
frequencyPenalty: opts.FrequencyPenalty,
}
}
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
cfg := defaultSamplingConfig(m, req.Think)
if req.Options.Temperature != nil {
cfg.temperature = *req.Options.Temperature
}
if req.Options.TopP != nil {
cfg.topP = *req.Options.TopP
}
if req.Options.MinP != nil {
cfg.minP = *req.Options.MinP
}
if req.Options.TopK != nil {
cfg.topK = *req.Options.TopK
}
if req.Options.RepeatLastN != nil {
cfg.repeatLastN = *req.Options.RepeatLastN
}
if req.Options.RepeatPenalty != nil {
cfg.repeatPenalty = *req.Options.RepeatPenalty
}
if req.Options.PresencePenalty != nil {
cfg.presencePenalty = *req.Options.PresencePenalty
}
if req.Options.FrequencyPenalty != nil {
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
}
return cfg
}
func Execute(args []string) error { func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
@@ -50,9 +127,11 @@ func Execute(args []string) error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{ if err := json.NewEncoder(w).Encode(statusResponse{
"status": 0, Status: 0,
"progress": 100, Progress: 100,
ContextLength: runner.contextLength,
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil { }); err != nil {
slog.Error("Failed to encode response", "error", err) slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -78,7 +157,7 @@ func Execute(args []string) error {
}) })
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)} request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)
@@ -87,16 +166,19 @@ func Execute(args []string) error {
} }
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10 sampling := resolveSamplingConfig(runner.Model, request)
}
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New( request.Sampler = sample.New(
request.Options.Temperature, sampling.temperature,
request.Options.TopP, sampling.topP,
request.Options.MinP, sampling.minP,
request.Options.TopK, sampling.topK,
sampling.repeatLastN,
sampling.repeatPenalty,
sampling.presencePenalty,
sampling.frequencyPenalty,
) )
var cancel context.CancelFunc var cancel context.CancelFunc

172
x/mlxrunner/server_test.go Normal file
View File

@@ -0,0 +1,172 @@
//go:build mlx
package mlxrunner
import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
"github.com/ollama/ollama/x/tokenizer"
)
type stubModel struct{}
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
func (stubModel) NumLayers() int { return 0 }
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
func TestResolveSamplingConfigDefaults(t *testing.T) {
trueValue := true
falseValue := false
tests := []struct {
name string
model base.Model
req Request
want samplingConfig
}{
{
name: "generic model uses api defaults",
model: stubModel{},
req: Request{},
want: samplingConfig{
temperature: 0.8,
topP: 0.9,
minP: 0.0,
topK: 40,
repeatLastN: 64,
repeatPenalty: 1.1,
presencePenalty: 0.0,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 defaults to thinking profile when think unset",
model: &qwen3_5.Model{},
req: Request{},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking disabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
want: samplingConfig{
temperature: 0.7,
topP: 0.8,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking enabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
}
})
}
}
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
trueValue := true
temperature := float32(0.4)
topP := float32(0.6)
minP := float32(0.05)
topK := 12
repeatLastN := 32
repeatPenalty := float32(1.1)
presencePenalty := float32(0.7)
frequencyPenalty := float32(0.2)
got := resolveSamplingConfig(stubModel{}, Request{
TextCompletionsRequest: TextCompletionsRequest{
Think: &trueValue,
Options: struct {
Temperature *float32 `json:"temperature"`
TopP *float32 `json:"top_p"`
MinP *float32 `json:"min_p"`
TopK *int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
NumPredict int `json:"num_predict"`
}{
Temperature: &temperature,
TopP: &topP,
MinP: &minP,
TopK: &topK,
RepeatLastN: &repeatLastN,
RepeatPenalty: &repeatPenalty,
PresencePenalty: &presencePenalty,
FrequencyPenalty: &frequencyPenalty,
},
},
})
want := samplingConfig{
temperature: temperature,
topP: topP,
minP: minP,
topK: topK,
repeatLastN: repeatLastN,
repeatPenalty: repeatPenalty,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
}
if got != want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
}
}
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
want := api.DefaultOptions()
got := defaultSamplingConfig(stubModel{}, nil)
if got.temperature != want.Temperature ||
got.topP != want.TopP ||
got.minP != want.MinP ||
got.topK != want.TopK ||
got.repeatLastN != want.RepeatLastN ||
got.repeatPenalty != want.RepeatPenalty ||
got.presencePenalty != want.PresencePenalty ||
got.frequencyPenalty != want.FrequencyPenalty {
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
}
}

View File

@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
func (m *Model) NumLayers() int { return len(m.Layers) } func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length // MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
// VocabSize returns the vocabulary size // VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize } func (m *Model) VocabSize() int32 { return m.Config.VocabSize }

View File

@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -15,6 +15,40 @@ type LinearLayer interface {
OutputDim() int32 OutputDim() int32
} }
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b // Linear applies an affine transformation: y = x @ W.T + b
type Linear struct { type Linear struct {
Weight *mlx.Array Weight *mlx.Array

View File

@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

1457
x/models/qwen3_5/qwen3_5.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,166 @@
//go:build mlx
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestModelRuntimeDefaults(t *testing.T) {
m := &Model{}
if m.DisablePromptCache() {
t.Fatal("DisablePromptCache() = true, want false")
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}

View File

@@ -0,0 +1,16 @@
//go:build mlx
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}