Compare commits

...

10 Commits

Author SHA1 Message Date
Jeffrey Morgan
86513cb697 runner: add token history sampling parameters to ollama runner (#14537) 2026-03-01 19:16:07 -08:00
Jeffrey Morgan
3490e9590b model/qwen3next: avoid crash in in DeltaNet when offloading (#14541)
Co-authored-by: Yossi Ovadia <jabadia@gmail.com>
2026-03-01 18:44:04 -08:00
Jeffrey Morgan
8da09b1e7e qwen3next: add compatibility with imported GGUF models (#14517) 2026-02-28 14:21:42 -08:00
Jesse Gross
a60b9adcce mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-27 17:29:47 -08:00
Jesse Gross
a16f96658b mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-27 17:29:47 -08:00
Jesse Gross
18ab09b431 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-27 17:29:47 -08:00
Jesse Gross
638faeac54 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-27 17:29:47 -08:00
Jesse Gross
dd5eb6337d mlxrunner: Fix panic on full KV cache hit
When the entire prompt was already cached (e.g. repeated prompt),
findRemaining returned an empty slice, causing FromValues to panic
on an index-out-of-range accessing a zero-length byte slice.

Fix by always keeping at least one token to re-evaluate so the
pipeline can seed token generation. Also reject empty prompts
early rather than panicking.
2026-02-27 11:07:03 -08:00
Patrick Devine
79917cf80b show peak memory usage (#14485) 2026-02-26 18:38:27 -08:00
Parth Sareen
cc90a035a0 model/parsers: add stable tool call indexing for glm47 and qwen3 parsers (#14484) 2026-02-26 18:14:29 -08:00
37 changed files with 950 additions and 203 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -569,6 +570,7 @@ type DebugInfo struct {
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
} }
if m.PeakMemory > 0 {
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
}
if m.LoadDuration > 0 { if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
} }
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
} }
} }
func formatPeakMemory(b uint64) string {
if b >= format.GibiByte {
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
}
return format.HumanBytes2(b)
}
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@@ -1063,7 +1077,7 @@ func DefaultOptions() Options {
TopP: 0.9, TopP: 0.9,
TypicalP: 1.0, TypicalP: 1.0,
RepeatLastN: 64, RepeatLastN: 64,
RepeatPenalty: 1.1, RepeatPenalty: 1.0,
PresencePenalty: 0.0, PresencePenalty: 0.0,
FrequencyPenalty: 0.0, FrequencyPenalty: 0.0,
Seed: -1, Seed: -1,

View File

@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 | | num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 | | repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 | | repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 | | temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 | | seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" | | stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |

View File

@@ -74,8 +74,7 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
VRAMSize() uint64 // Total VRAM across all GPUs MemorySize() (total, vram uint64)
TotalSize() uint64
VRAMByGPU(id ml.DeviceID) uint64 VRAMByGPU(id ml.DeviceID) uint64
Pid() int Pid() int
GetPort() int GetPort() int
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache // For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) || (len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) { (s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"` EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"` EvalDuration time.Duration `json:"eval_duration"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
@@ -1848,17 +1849,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil return nil
} }
func (s *llmServer) VRAMSize() uint64 { func (s *llmServer) MemorySize() (total, vram uint64) {
if s.mem == nil { if s.mem == nil {
return 0 return 0, 0
} }
var mem uint64
for _, g := range s.mem.GPUs { for _, g := range s.mem.GPUs {
mem += g.Size() vram += g.Size()
} }
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers // Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading. // on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true noCPULayers := true
@@ -1869,25 +1870,11 @@ func (s *llmServer) VRAMSize() uint64 {
} }
} }
if noCPULayers { if noCPULayers {
mem += s.mem.InputWeights vram += s.mem.InputWeights
mem += s.mem.CPU.Graph vram += s.mem.CPU.Graph
} }
return mem return total, vram
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
} }
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 { func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35) SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35) SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"` SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default: default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections") return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
} }
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A // Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT) alphaBiased := alpha.Add(ctx, gdn.SSMDT)
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs) vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs) stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// Collect chunk outputs and concatenate at the end.
// Avoids SET on buffer-less intermediates under partial offload.
chunks := make([]ml.Tensor, nChunks)
for chunk := range nChunks { for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1) qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1) vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vAttn := vTNewChunk.Mulmat(ctx, attnChunk) vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn) coreAttnOutChunk := attnInter.Add(ctx, vAttn)
v = v.SetInplace( chunks[chunk] = coreAttnOutChunk
ctx,
coreAttnOutChunk,
v.Stride(1),
v.Stride(2),
v.Stride(3),
chunk*v.Stride(2),
)
// Update state for next chunk // Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
stateT = stateT.Add(ctx, kgdMulVNew) stateT = stateT.Add(ctx, kgdMulVNew)
} }
// Use a balanced concat tree so concat work does not balloon on long prompts.
for len(chunks) > 1 {
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
for i := 0; i < len(chunks); i += 2 {
if i+1 < len(chunks) {
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
} else {
merged = append(merged, chunks[i])
}
}
chunks = merged
}
v = chunks[0]
// Final reshape // Final reshape
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs) coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)

View File

@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil return m.Output.Forward(ctx, hiddenStates), nil
} }
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil m.positionCache = nil
if len(m.mropeSections) > 0 { if len(m.mropeSections) > 0 {
@@ -450,6 +490,64 @@ var (
_ model.MultimodalProcessor = (*Model)(nil) _ model.MultimodalProcessor = (*Model)(nil)
) )
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count")) numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers) layers := make([]Layer, numLayers)
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64 HeadCountKV() []uint64
} }
var isRecurrent []bool
var headCountKV []uint64 var headCountKV []uint64
if hc, ok := c.(headCounts); ok { if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV() headCountKV = hc.HeadCountKV()
} }
isRecurrent = make([]bool, numLayers) isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
hasZero := false if err != nil {
hasFull := false return nil, err
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
} }
// Determine if MoE // Determine if MoE
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")), ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")), ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")), convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false), vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
isRecurrent: isRecurrent, isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) { mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections { for _, section := range mropeSections {
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)), mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
} }
if opts.numKVHeads == 0 { if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value") return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
} }
// Calculate cache dimensions // Calculate cache dimensions

View File

@@ -0,0 +1,65 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -0,0 +1,45 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}

View File

@@ -35,6 +35,7 @@ type GLM46Parser struct {
state glm46ParserState state glm46ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *GLM46Parser) HasToolSupport() bool { func (p *GLM46Parser) HasToolSupport() bool {
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { // func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools return tools
} }
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("glm-4.6 tool call parsing failed", "error", err) slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent: case glm46EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -11,6 +11,7 @@ type GLM47Parser struct {
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
// When thinking is enabled (nil or true), the prompt ends with <think>, // When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag). // so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() { if thinkValue == nil || thinkValue.Bool() {

View File

@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
t.Fatalf("expected %#v, got %#v", expected, toolCall) t.Fatalf("expected %#v, got %#v", expected, toolCall)
} }
} }
func TestGLM47ParserToolCallIndexing(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
input := `plan</think>
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -38,6 +38,7 @@ type Qwen3Parser struct {
state qwen3ParserState state qwen3ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
hasThinkingSupport bool hasThinkingSupport bool
defaultThinking bool defaultThinking bool
maybeThinkingOpenAtBOL bool maybeThinkingOpenAtBOL bool
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.buffer.Reset() p.buffer.Reset()
p.callIndex = 0
thinkingEnabled := thinkValue != nil && thinkValue.Bool() thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil { if thinkValue == nil {
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("qwen3 tool call parsing failed", "error", err) slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
calls = append(calls, toolCall) calls = append(calls, toolCall)
case qwen3EventThinkingContent: case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
t.Fatalf("expected no tool calls, got %d", len(calls)) t.Fatalf("expected no tool calls, got %d", len(calls))
} }
} }
func TestQwen3ParserToolCallIndexing(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
var all []api.ToolCall
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -32,6 +32,7 @@ type Qwen3CoderParser struct {
state qwenParserState state qwenParserState
acc strings.Builder acc strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *Qwen3CoderParser) HasToolSupport() bool { func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools // Qwen doesn't modify tools return tools // Qwen doesn't modify tools
} }
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
slog.Warn("qwen tool call parsing failed", "error", err) slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case qwenEventContent: case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content // TODO(drifkin): if the same turn contains multiple interleaved content

View File

@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
} }
} }
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}
func TestQwenXMLTransform(t *testing.T) { func TestQwenXMLTransform(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string

View File

@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
if errors.As(err, &reprocess) { if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing // Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...) seq.inputs = append(reprocess.Inputs, seq.inputs...)
seq.sampler.Reset()
// Skip this sequence but continue processing the rest // Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil err = nil
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
// (unless we take down the whole runner). // (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
for _, inp := range seq.pendingInputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
seq.pendingInputs = []*input.Input{} seq.pendingInputs = []*input.Input{}
} }
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.TopK, req.Options.TopK,
req.Options.TopP, req.Options.TopP,
req.Options.MinP, req.Options.MinP,
req.Options.RepeatPenalty,
req.Options.PresencePenalty,
req.Options.FrequencyPenalty,
req.Options.Seed, req.Options.Seed,
grammar, grammar,
) )
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
seq.sampler.Reset()
for _, inp := range seq.cache.Inputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
s.seqs[i] = seq s.seqs[i] = seq
s.cond.Signal() s.cond.Signal()
found = true found = true

View File

@@ -16,24 +16,49 @@ type token struct {
value float32 // The raw logit or probability from the model value float32 // The raw logit or probability from the model
} }
const DefaultPenaltyLookback = 64
type Sampler struct { type Sampler struct {
rng *rand.Rand rng *rand.Rand
topK int topK int
topP float32 topP float32
minP float32 minP float32
temperature float32 temperature float32
repeat float32
presence float32
frequency float32
history []int32
grammar *GrammarSampler grammar *GrammarSampler
} }
func (s *Sampler) Reset() {
s.history = s.history[:0]
}
func (s *Sampler) Accept(token int32) {
s.history = append(s.history, token)
if len(s.history) > DefaultPenaltyLookback {
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
s.history = s.history[:DefaultPenaltyLookback]
}
}
func (s *Sampler) Sample(logits []float32) (int32, error) { func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 { if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample") return -1, errors.New("sample: no logits provided to sample")
} }
counts := tokenCounts(s.history, len(logits))
tokens := make([]token, len(logits)) tokens := make([]token, len(logits))
for i := range logits { for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i) tokens[i].id = int32(i)
tokens[i].value = logits[i] tokens[i].value = value
} }
t, err := s.sample(tokens) t, err := s.sample(tokens)
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
// we need to reset them before applying the grammar and // we need to reset them before applying the grammar and
// sampling again // sampling again
for i := range logits { for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i) tokens[i].id = int32(i)
tokens[i].value = logits[i] tokens[i].value = value
} }
s.grammar.Apply(tokens) s.grammar.Apply(tokens)
t, err = s.sample(tokens) t, err = s.sample(tokens)
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler { func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand var rng *rand.Rand
if seed != -1 { if seed != -1 {
// PCG requires two parameters: sequence and stream // PCG requires two parameters: sequence and stream
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
minP = 1.0 minP = 1.0
} }
if repeatPenalty <= 0 {
repeatPenalty = 1.0
}
return Sampler{ return Sampler{
rng: rng, rng: rng,
topK: topK, topK: topK,
topP: topP, topP: topP,
minP: minP, minP: minP,
temperature: temperature, temperature: temperature,
repeat: repeatPenalty,
presence: presencePenalty,
frequency: frequencyPenalty,
grammar: grammar, grammar: grammar,
} }
} }

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5) logits[i] = float32(rand.Float64()*10 - 5)
} }
sampler := NewSampler(0.8, 0, 0, 0, 42, nil) sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
sampler.Sample(logits) sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs { for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) { b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil) sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
sampler.Sample(logits) sampler.Sample(logits)
b.ResetTimer() b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly // Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) { b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil) sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5) logits[i] = float32(rand.Float64()*10 - 5)
} }
sampler := NewSampler(0, -1, 0, 0, -1, nil) sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {

View File

@@ -13,7 +13,7 @@ import (
func TestWeighted(t *testing.T) { func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10} logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil) sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err := sampler.Sample(logits) got, err := sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
} }
logits = []float32{-100, -10, 0, 10} logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil) sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
// Test very high p // Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1} logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens // Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
} }
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())} logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil) sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits) got, err = sampler.Sample(logits)
if err == nil { if err == nil {
t.Errorf("expected error, got %d", got) t.Errorf("expected error, got %d", got)
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
func BenchmarkSample(b *testing.B) { func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{ samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy "Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), "Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
} }
// Generate random logits for benchmarking // Generate random logits for benchmarking

View File

@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
return x return x
} }
func tokenCounts(history []int32, vocabSize int) map[int32]int {
if len(history) == 0 {
return nil
}
start := 0
if len(history) > DefaultPenaltyLookback {
start = len(history) - DefaultPenaltyLookback
}
counts := make(map[int32]int, len(history)-start)
for _, token := range history[start:] {
if token < 0 || int(token) >= vocabSize {
continue
}
counts[token]++
}
return counts
}
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
if repeatPenalty != 1.0 {
// Preserve ordering for negative logits when applying repeat penalty.
if logit < 0 {
logit *= repeatPenalty
} else {
logit /= repeatPenalty
}
}
if frequencyPenalty != 0 {
logit -= float32(count) * frequencyPenalty
}
if presencePenalty != 0 {
logit -= presencePenalty
}
return logit
}
// temperature applies scaling to the logits // temperature applies scaling to the logits
func temperature(ts []token, temp float32) { func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability // Ensure temperature clipping near 0 to avoid numerical instability

View File

@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
} }
} }
func TestTokenCounts(t *testing.T) {
history := make([]int32, 70)
history[0] = 7
history[69] = 7
counts := tokenCounts(history, 8)
if got := counts[7]; got != 1 {
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
}
}
func TestApplyPenalty(t *testing.T) {
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
}
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
}
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
}
}
func TestSamplerPresencePenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 0.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
presence.Accept(1)
got, err = presence.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got == 1 {
t.Fatalf("presence penalty did not change repeated token selection")
}
}
func TestSamplerFrequencyPenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 4.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
baseline.Accept(1)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
frequency.Accept(1)
frequency.Accept(1)
frequency.Accept(1)
got, err = frequency.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 2 {
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
}
}
func BenchmarkTransforms(b *testing.B) { func BenchmarkTransforms(b *testing.B) {
// Generate random logits // Generate random logits
tokens := make([]token, 1<<16) tokens := make([]token, 1<<16)

View File

@@ -71,6 +71,10 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports // Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability { func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}

View File

@@ -30,6 +30,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1 lastMsgIdx := len(msgs) - 1
currMsgIdx := 0 currMsgIdx := 0
if truncate {
// Start with all messages and remove from the front until it fits in context // Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ { for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip // Collect system messages from the portion we're about to skip
@@ -57,7 +58,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
if !truncate || ctxLen <= opts.NumCtx { if ctxLen <= opts.NumCtx {
currMsgIdx = i currMsgIdx = i
break break
} }
@@ -68,6 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
break break
} }
} }
}
if currMsgIdx > 0 { if currMsgIdx > 0 {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:])) slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))

View File

@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer // the real chat handler, but doing this as a stopgap to get renderer
// support for generate // support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" { if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -557,6 +558,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
PeakMemory: cr.PeakMemory,
}, },
Logprobs: toAPILogprobs(cr.Logprobs), Logprobs: toAPILogprobs(cr.Logprobs),
} }
@@ -1951,6 +1953,9 @@ func (s *Server) PsHandler(c *gin.Context) {
} }
if v.llama != nil { if v.llama != nil {
mr.ContextLength = v.llama.ContextLength() mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
} }
// The scheduler waits to set expiresAt, so if a model is loading it's // The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just // possible that it will be set to the unix epoch. For those cases, just
@@ -2213,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
truncate := req.Truncate == nil || *req.Truncate truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
@@ -2309,6 +2317,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
PeakMemory: r.PeakMemory,
}, },
Logprobs: toAPILogprobs(r.Logprobs), Logprobs: toAPILogprobs(r.Logprobs),
} }

View File

@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
// Check for experimental safetensors LLM models // Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" { if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") { if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner // LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) { if s.loadMLX(pending) {
@@ -536,6 +536,7 @@ iGPUScan:
} }
} }
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -545,8 +546,8 @@ iGPUScan:
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
gpus: gpuIDs, gpus: gpuIDs,
discreteGPUs: discreteGPUs, discreteGPUs: discreteGPUs,
vramSize: llama.VRAMSize(), totalSize: totalSize,
totalSize: llama.TotalSize(), vramSize: vramSize,
loading: true, loading: true,
pid: llama.Pid(), pid: llama.Pid(),
} }
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
loading: false, loading: false,
isImagegen: isImagegen, isImagegen: isImagegen,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
totalSize: server.TotalSize(), totalSize: totalSize,
vramSize: server.VRAMSize(), vramSize: vramSize,
} }
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? (!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
runner.llama.Ping(ctx) != nil { runner.llama.Ping(ctx) != nil {
return true return true
} }

View File

@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
s.closeCalled = true s.closeCalled = true
return s.closeResp return s.closeResp
} }
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 } func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 } func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -374,14 +374,9 @@ func (s *Server) Close() error {
return nil return nil
} }
// VRAMSize returns the estimated VRAM usage. // MemorySize returns the total and VRAM memory usage.
func (s *Server) VRAMSize() uint64 { func (s *Server) MemorySize() (total, vram uint64) {
return s.vramSize return s.vramSize, s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
} }
// VRAMByGPU returns VRAM usage for a specific GPU. // VRAMByGPU returns VRAM usage for a specific GPU.

View File

@@ -78,6 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix++ prefix++
} }
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if prefix == len(tokens) && prefix > 0 {
prefix--
}
if prefix < len(c.tokens) { if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix trim := len(c.tokens) - prefix
for _, kv := range c.caches { for _, kv := range c.caches {

View File

@@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@@ -19,19 +18,21 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
) )
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct { type Client struct {
port int port int
modelName string modelName string
vramSize uint64 contextLength atomic.Int64
memory atomic.Uint64
done chan error done chan error
client *http.Client client *http.Client
lastErr string lastErr string
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
} }
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{ c := &Client{
port: port, port: port,
modelName: modelName, modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1), done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute}, client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd, cmd: cmd,
@@ -201,6 +193,20 @@ type completionOpts struct {
NumPredict int `json:"num_predict,omitempty"` NumPredict int `json:"num_predict,omitempty"`
} }
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
}
// Close terminates the subprocess. // Close terminates the subprocess.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock() c.mu.Lock()
@@ -260,28 +266,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
var raw struct { var raw CompletionResponse
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue continue
} }
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{ cresp := llm.CompletionResponse{
Content: raw.Content, Content: raw.Content,
Done: raw.Done, Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason), DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount, PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration), PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration), EvalDuration: raw.EvalDuration,
PeakMemory: raw.PeakMemory,
} }
fn(cresp) fn(cresp)
@@ -294,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
} }
func (c *Client) ContextLength() int { func (c *Client) ContextLength() int {
return math.MaxInt return int(c.contextLength.Load())
} }
// Detokenize implements llm.LlamaServer. // Detokenize implements llm.LlamaServer.
@@ -347,9 +350,16 @@ func (c *Client) Pid() int {
return -1 return -1
} }
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint64
}
// Ping implements llm.LlamaServer. // Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error { func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port) reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil { if err != nil {
return err return err
@@ -362,6 +372,15 @@ func (c *Client) Ping(ctx context.Context) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode) return fmt.Errorf("health check failed: %d", resp.StatusCode)
} }
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength.Store(int64(status.ContextLength))
c.memory.Store(status.Memory)
return nil return nil
} }
@@ -388,19 +407,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
return tokens, nil return tokens, nil
} }
// TotalSize implements llm.LlamaServer. func (c *Client) currentMemory() uint64 {
func (c *Client) TotalSize() uint64 { ctx, cancel := context.WithTimeout(context.Background(), time.Second)
return c.vramSize defer cancel()
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return c.memory.Load()
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
} }
// VRAMByGPU implements llm.LlamaServer. // VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.vramSize return c.currentMemory()
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
return c.vramSize
} }
// WaitUntilRunning implements llm.LlamaServer. // WaitUntilRunning implements llm.LlamaServer.

View File

@@ -64,6 +64,10 @@ func PeakMemory() int {
return int(peak) return int(peak)
} }
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{} type Memory struct{}
func (Memory) LogValue() slog.Value { func (Memory) LogValue() slog.Value {

View File

@@ -20,6 +20,7 @@ type Model interface {
Unembed(x *mlx.Array) *mlx.Array Unembed(x *mlx.Array) *mlx.Array
NumLayers() int NumLayers() int
Tokenizer() *tokenizer.Tokenizer Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns // LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert // them to model fields. Model-specific logic (MLA absorption, expert

View File

@@ -6,9 +6,12 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net/http"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
@@ -44,16 +47,35 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} else { } else {
mlx.DisableCompile() mlx.DisableCompile()
} }
mlx.ResetPeakMemory()
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
defer session.close() defer session.close()
caches := session.caches caches := session.caches
tokens := session.remaining tokens := session.remaining
now := time.Now()
total, processed := len(tokens), 0 total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 { for total-processed > 1 {
if err := request.Ctx.Err(); err != nil { if err := request.Ctx.Err(); err != nil {
return err return err
@@ -93,8 +115,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer var b bytes.Buffer
now := time.Now() final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil { if err := request.Ctx.Err(); err != nil {
return err return err
@@ -103,9 +124,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
nextSample, nextLogprobs = step(sample) nextSample, nextLogprobs = step(sample)
if i == 0 { if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample) mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now) final.PromptEvalDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
@@ -113,18 +133,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0 final.DoneReason = 0
final.CompletionTokens = i final.EvalCount = i
break break
} }
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()
case request.Responses <- Response{ case request.Responses <- CompletionResponse{
Text: r.Decode(output, &b), Content: r.Decode(output, &b),
Token: int(output),
}: }:
} }
@@ -137,7 +155,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
final.CompletionTokensDuration = time.Since(now) final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()

View File

@@ -4,14 +4,15 @@ package mlxrunner
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -21,7 +22,7 @@ import (
type Request struct { type Request struct {
TextCompletionsRequest TextCompletionsRequest
Responses chan Response Responses chan CompletionResponse
Pipeline func(Request) error Pipeline func(Request) error
Ctx context.Context Ctx context.Context
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
} `json:"options"` } `json:"options"`
} }
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
Requests chan Request Requests chan Request
cache kvCache cache kvCache
contextLength int
} }
func (r *Runner) Load(modelName string) error { func (r *Runner) Load(modelName string) error {
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
r.Model = m r.Model = m
r.Tokenizer = m.Tokenizer() r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
return nil return nil
} }
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests: case request := <-r.Requests:
if err := request.Pipeline(request); err != nil { if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err) slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
} }
close(request.Responses) close(request.Responses)

View File

@@ -50,9 +50,11 @@ func Execute(args []string) error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{ if err := json.NewEncoder(w).Encode(statusResponse{
"status": 0, Status: 0,
"progress": 100, Progress: 100,
ContextLength: runner.contextLength,
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil { }); err != nil {
slog.Error("Failed to encode response", "error", err) slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -78,7 +80,7 @@ func Execute(args []string) error {
}) })
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)} request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)
@@ -87,9 +89,6 @@ func Execute(args []string) error {
} }
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New( request.Sampler = sample.New(

View File

@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
func (m *Model) NumLayers() int { return len(m.Layers) } func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length // MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
// VocabSize returns the vocabulary size // VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize } func (m *Model) VocabSize() int32 { return m.Config.VocabSize }

View File

@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }