mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 06:54:09 +02:00
Compare commits
10 Commits
pdevine/sa
...
v0.17.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
16
api/types.go
16
api/types.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -569,6 +570,7 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.PeakMemory > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
||||||
|
}
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatPeakMemory(b uint64) string {
|
||||||
|
if b >= format.GibiByte {
|
||||||
|
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
||||||
|
}
|
||||||
|
|
||||||
|
return format.HumanBytes2(b)
|
||||||
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
@@ -1063,7 +1077,7 @@ func DefaultOptions() Options {
|
|||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
TypicalP: 1.0,
|
TypicalP: 1.0,
|
||||||
RepeatLastN: 64,
|
RepeatLastN: 64,
|
||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|||||||
@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||||
|
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||||
|
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
MemorySize() (total, vram uint64)
|
||||||
TotalSize() uint64
|
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
totalSize, _ := s.MemorySize()
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
@@ -1848,17 +1849,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMSize() uint64 {
|
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var mem uint64
|
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
mem += g.Size()
|
vram += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1869,25 +1870,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
mem += s.mem.InputWeights
|
vram += s.mem.InputWeights
|
||||||
mem += s.mem.CPU.Graph
|
vram += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return mem
|
return total, vram
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) TotalSize() uint64 {
|
|
||||||
if s.mem == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
mem := s.mem.InputWeights
|
|
||||||
mem += s.mem.CPU.Size()
|
|
||||||
for _, g := range s.mem.GPUs {
|
|
||||||
mem += g.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
|
|||||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
}
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||||
|
}
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||||
|
|
||||||
|
// Collect chunk outputs and concatenate at the end.
|
||||||
|
// Avoids SET on buffer-less intermediates under partial offload.
|
||||||
|
chunks := make([]ml.Tensor, nChunks)
|
||||||
|
|
||||||
for chunk := range nChunks {
|
for chunk := range nChunks {
|
||||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||||
|
|
||||||
v = v.SetInplace(
|
chunks[chunk] = coreAttnOutChunk
|
||||||
ctx,
|
|
||||||
coreAttnOutChunk,
|
|
||||||
v.Stride(1),
|
|
||||||
v.Stride(2),
|
|
||||||
v.Stride(3),
|
|
||||||
chunk*v.Stride(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Update state for next chunk
|
// Update state for next chunk
|
||||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||||
|
for len(chunks) > 1 {
|
||||||
|
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||||
|
for i := 0; i < len(chunks); i += 2 {
|
||||||
|
if i+1 < len(chunks) {
|
||||||
|
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||||
|
} else {
|
||||||
|
merged = append(merged, chunks[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks = merged
|
||||||
|
}
|
||||||
|
v = chunks[0]
|
||||||
|
|
||||||
// Final reshape
|
// Final reshape
|
||||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||||
|
|
||||||
|
|||||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Validate() error {
|
||||||
|
if m.Options == nil {
|
||||||
|
return fmt.Errorf("qwen3next: missing model options")
|
||||||
|
}
|
||||||
|
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||||
|
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if !m.Options.isRecurrent[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||||
|
if !ok || gdn == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
if len(m.mropeSections) > 0 {
|
if len(m.mropeSections) > 0 {
|
||||||
@@ -450,6 +490,64 @@ var (
|
|||||||
_ model.MultimodalProcessor = (*Model)(nil)
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func defaultVHeadReordered(arch string) bool {
|
||||||
|
return arch == "qwen35" || arch == "qwen35moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||||
|
isRecurrent := make([]bool, numLayers)
|
||||||
|
|
||||||
|
hasZero := false
|
||||||
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZero && hasFull {
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility path: older imports store a scalar KV head count and omit
|
||||||
|
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||||
|
interval := int(fullAttentionInterval)
|
||||||
|
if interval == 0 {
|
||||||
|
interval = min(4, numLayers)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||||
|
}
|
||||||
|
if interval > numLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasZero = false
|
||||||
|
hasFull = false
|
||||||
|
for i := range numLayers {
|
||||||
|
isRecurrent[i] = (i+1)%interval != 0
|
||||||
|
if isRecurrent[i] {
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
layers := make([]Layer, numLayers)
|
layers := make([]Layer, numLayers)
|
||||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var isRecurrent []bool
|
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent = make([]bool, numLayers)
|
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||||
hasZero := false
|
if err != nil {
|
||||||
hasFull := false
|
return nil, err
|
||||||
for i := range numLayers {
|
|
||||||
// If KV head count is 0, it's a recurrent layer
|
|
||||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
for _, section := range mropeSections {
|
for _, section := range mropeSections {
|
||||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
|
|||||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, false, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, true, false, true, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, false, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||||
|
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultVHeadReordered(t *testing.T) {
|
||||||
|
if !defaultVHeadReordered("qwen35") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||||
|
}
|
||||||
|
if !defaultVHeadReordered("qwen35moe") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||||
|
}
|
||||||
|
if defaultVHeadReordered("qwen3next") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{
|
||||||
|
Operator: &GatedDeltaNet{
|
||||||
|
SSMQKV: &nn.Linear{},
|
||||||
|
SSMQKVGate: &nn.Linear{},
|
||||||
|
SSMBeta: &nn.Linear{},
|
||||||
|
SSMAlpha: &nn.Linear{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Validate() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -35,6 +35,7 @@ type GLM46Parser struct {
|
|||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `plan</think>
|
||||||
|
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||||
|
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||||
|
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
|
p.callIndex = 0
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
|||||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||||
|
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||||
|
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type Qwen3CoderParser struct {
|
|||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||||
|
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||||
|
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
if errors.As(err, &reprocess) {
|
if errors.As(err, &reprocess) {
|
||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
seq.sampler.Reset()
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||||
err = nil
|
err = nil
|
||||||
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
// (unless we take down the whole runner).
|
// (unless we take down the whole runner).
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
|
for _, inp := range seq.pendingInputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
seq.pendingInputs = []*input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
|
req.Options.RepeatPenalty,
|
||||||
|
req.Options.PresencePenalty,
|
||||||
|
req.Options.FrequencyPenalty,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
)
|
)
|
||||||
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq.sampler.Reset()
|
||||||
|
for _, inp := range seq.cache.Inputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
|
|
||||||
s.seqs[i] = seq
|
s.seqs[i] = seq
|
||||||
s.cond.Signal()
|
s.cond.Signal()
|
||||||
found = true
|
found = true
|
||||||
|
|||||||
@@ -16,24 +16,49 @@ type token struct {
|
|||||||
value float32 // The raw logit or probability from the model
|
value float32 // The raw logit or probability from the model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DefaultPenaltyLookback = 64
|
||||||
|
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
topK int
|
topK int
|
||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
|
repeat float32
|
||||||
|
presence float32
|
||||||
|
frequency float32
|
||||||
|
history []int32
|
||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Reset() {
|
||||||
|
s.history = s.history[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Accept(token int32) {
|
||||||
|
s.history = append(s.history, token)
|
||||||
|
if len(s.history) > DefaultPenaltyLookback {
|
||||||
|
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||||
|
s.history = s.history[:DefaultPenaltyLookback]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
counts := tokenCounts(s.history, len(logits))
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
|
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := s.sample(tokens)
|
t, err := s.sample(tokens)
|
||||||
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
|||||||
// we need to reset them before applying the grammar and
|
// we need to reset them before applying the grammar and
|
||||||
// sampling again
|
// sampling again
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
s.grammar.Apply(tokens)
|
s.grammar.Apply(tokens)
|
||||||
t, err = s.sample(tokens)
|
t, err = s.sample(tokens)
|
||||||
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||||||
minP = 1.0
|
minP = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if repeatPenalty <= 0 {
|
||||||
|
repeatPenalty = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
topP: topP,
|
topP: topP,
|
||||||
minP: minP,
|
minP: minP,
|
||||||
temperature: temperature,
|
temperature: temperature,
|
||||||
|
repeat: repeatPenalty,
|
||||||
|
presence: presencePenalty,
|
||||||
|
frequency: frequencyPenalty,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range configs {
|
for _, tc := range configs {
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
logits := []float32{-10, 3, -10, -10}
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err := sampler.Sample(logits)
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
logits = []float32{-100, -10, 0, 10}
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
// Test very high p
|
// Test very high p
|
||||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
// Use extremely small topP to filter out all tokens
|
// Use extremely small topP to filter out all tokens
|
||||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error, got %d", got)
|
t.Errorf("expected error, got %d", got)
|
||||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
// Generate random logits for benchmarking
|
||||||
|
|||||||
@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||||
|
if len(history) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := 0
|
||||||
|
if len(history) > DefaultPenaltyLookback {
|
||||||
|
start = len(history) - DefaultPenaltyLookback
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make(map[int32]int, len(history)-start)
|
||||||
|
for _, token := range history[start:] {
|
||||||
|
if token < 0 || int(token) >= vocabSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[token]++
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||||
|
if repeatPenalty != 1.0 {
|
||||||
|
// Preserve ordering for negative logits when applying repeat penalty.
|
||||||
|
if logit < 0 {
|
||||||
|
logit *= repeatPenalty
|
||||||
|
} else {
|
||||||
|
logit /= repeatPenalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if frequencyPenalty != 0 {
|
||||||
|
logit -= float32(count) * frequencyPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
if presencePenalty != 0 {
|
||||||
|
logit -= presencePenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
return logit
|
||||||
|
}
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
// temperature applies scaling to the logits
|
||||||
func temperature(ts []token, temp float32) {
|
func temperature(ts []token, temp float32) {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
|
|||||||
@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenCounts(t *testing.T) {
|
||||||
|
history := make([]int32, 70)
|
||||||
|
history[0] = 7
|
||||||
|
history[69] = 7
|
||||||
|
|
||||||
|
counts := tokenCounts(history, 8)
|
||||||
|
if got := counts[7]; got != 1 {
|
||||||
|
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPenalty(t *testing.T) {
|
||||||
|
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerPresencePenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 0.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||||
|
presence.Accept(1)
|
||||||
|
got, err = presence.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got == 1 {
|
||||||
|
t.Fatalf("presence penalty did not change repeated token selection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 4.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
got, err = frequency.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 2 {
|
||||||
|
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
// Generate random logits
|
// Generate random logits
|
||||||
tokens := make([]token, 1<<16)
|
tokens := make([]token, 1<<16)
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) IsMLX() bool {
|
||||||
|
return m.Config.ModelFormat == "safetensors"
|
||||||
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
|
if truncate {
|
||||||
// Start with all messages and remove from the front until it fits in context
|
// Start with all messages and remove from the front until it fits in context
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
// Collect system messages from the portion we're about to skip
|
// Collect system messages from the portion we're about to skip
|
||||||
@@ -57,7 +58,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !truncate || ctxLen <= opts.NumCtx {
|
if ctxLen <= opts.NumCtx {
|
||||||
currMsgIdx = i
|
currMsgIdx = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -68,6 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if currMsgIdx > 0 {
|
if currMsgIdx > 0 {
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||||
|
|||||||
@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||||
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -557,6 +558,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
|
PeakMemory: cr.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -1951,6 +1953,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
|
total, vram := v.llama.MemorySize()
|
||||||
|
mr.Size = int64(total)
|
||||||
|
mr.SizeVRAM = int64(vram)
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -2213,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
|
if m.IsMLX() {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2309,6 +2317,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
|
PeakMemory: r.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.Config.ModelFormat == "safetensors" {
|
if pending.model.IsMLX() {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -536,6 +536,7 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := llama.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -545,8 +546,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
vramSize: llama.VRAMSize(),
|
totalSize: totalSize,
|
||||||
totalSize: llama.TotalSize(),
|
vramSize: vramSize,
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := server.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: isImagegen,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: server.TotalSize(),
|
totalSize: totalSize,
|
||||||
vramSize: server.VRAMSize(),
|
vramSize: vramSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMSize returns the estimated VRAM usage.
|
// MemorySize returns the total and VRAM memory usage.
|
||||||
func (s *Server) VRAMSize() uint64 {
|
func (s *Server) MemorySize() (total, vram uint64) {
|
||||||
return s.vramSize
|
return s.vramSize, s.vramSize
|
||||||
}
|
|
||||||
|
|
||||||
// TotalSize returns the total memory usage.
|
|
||||||
func (s *Server) TotalSize() uint64 {
|
|
||||||
return s.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|||||||
prefix++
|
prefix++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always keep at least one token to re-evaluate so the
|
||||||
|
// pipeline can seed token generation from it.
|
||||||
|
if prefix == len(tokens) && prefix > 0 {
|
||||||
|
prefix--
|
||||||
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
if prefix < len(c.tokens) {
|
||||||
trim := len(c.tokens) - prefix
|
trim := len(c.tokens) - prefix
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -19,19 +18,21 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
vramSize uint64
|
contextLength atomic.Int64
|
||||||
|
memory atomic.Uint64
|
||||||
done chan error
|
done chan error
|
||||||
client *http.Client
|
client *http.Client
|
||||||
lastErr string
|
lastErr string
|
||||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate VRAM based on tensor size from manifest
|
|
||||||
var vramSize uint64
|
|
||||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
|
||||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
|
||||||
} else {
|
|
||||||
vramSize = 8 * 1024 * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
vramSize: vramSize,
|
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
@@ -201,6 +193,20 @@ type completionOpts struct {
|
|||||||
NumPredict int `json:"num_predict,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string
|
||||||
|
Done bool
|
||||||
|
DoneReason int
|
||||||
|
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
PeakMemory uint64
|
||||||
|
|
||||||
|
Error *api.StatusError
|
||||||
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -260,28 +266,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw struct {
|
var raw CompletionResponse
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration int `json:"eval_duration,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if raw.Error != nil {
|
||||||
|
return *raw.Error
|
||||||
|
}
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: time.Duration(raw.EvalDuration),
|
EvalDuration: raw.EvalDuration,
|
||||||
|
PeakMemory: raw.PeakMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
@@ -294,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
return math.MaxInt
|
return int(c.contextLength.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize implements llm.LlamaServer.
|
// Detokenize implements llm.LlamaServer.
|
||||||
@@ -347,9 +350,16 @@ func (c *Client) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type statusResponse struct {
|
||||||
|
Status int
|
||||||
|
Progress int
|
||||||
|
ContextLength int
|
||||||
|
Memory uint64
|
||||||
|
}
|
||||||
|
|
||||||
// Ping implements llm.LlamaServer.
|
// Ping implements llm.LlamaServer.
|
||||||
func (c *Client) Ping(ctx context.Context) error {
|
func (c *Client) Ping(ctx context.Context) error {
|
||||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -362,6 +372,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var status statusResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.contextLength.Store(int64(status.ContextLength))
|
||||||
|
c.memory.Store(status.Memory)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,19 +407,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotalSize implements llm.LlamaServer.
|
func (c *Client) currentMemory() uint64 {
|
||||||
func (c *Client) TotalSize() uint64 {
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
return c.vramSize
|
defer cancel()
|
||||||
|
if err := c.Ping(ctx); err != nil {
|
||||||
|
slog.Warn("failed to get current memory", "error", err)
|
||||||
|
}
|
||||||
|
return c.memory.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemorySize implements llm.LlamaServer.
|
||||||
|
func (c *Client) MemorySize() (total, vram uint64) {
|
||||||
|
mem := c.currentMemory()
|
||||||
|
return mem, mem
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU implements llm.LlamaServer.
|
// VRAMByGPU implements llm.LlamaServer.
|
||||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
return c.vramSize
|
return c.currentMemory()
|
||||||
}
|
|
||||||
|
|
||||||
// VRAMSize implements llm.LlamaServer.
|
|
||||||
func (c *Client) VRAMSize() uint64 {
|
|
||||||
return c.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitUntilRunning implements llm.LlamaServer.
|
// WaitUntilRunning implements llm.LlamaServer.
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
|||||||
return int(peak)
|
return int(peak)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetPeakMemory() {
|
||||||
|
C.mlx_reset_peak_memory()
|
||||||
|
}
|
||||||
|
|
||||||
type Memory struct{}
|
type Memory struct{}
|
||||||
|
|
||||||
func (Memory) LogValue() slog.Value {
|
func (Memory) LogValue() slog.Value {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Model interface {
|
|||||||
Unembed(x *mlx.Array) *mlx.Array
|
Unembed(x *mlx.Array) *mlx.Array
|
||||||
NumLayers() int
|
NumLayers() int
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
|
MaxContextLength() int
|
||||||
|
|
||||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
@@ -44,16 +47,35 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
} else {
|
} else {
|
||||||
mlx.DisableCompile()
|
mlx.DisableCompile()
|
||||||
}
|
}
|
||||||
|
mlx.ResetPeakMemory()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return errors.New("empty prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputs) >= r.contextLength {
|
||||||
|
return api.StatusError{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap generation to stay within the model's context length
|
||||||
|
maxGenerate := r.contextLength - len(inputs)
|
||||||
|
if request.Options.MaxTokens <= 0 {
|
||||||
|
request.Options.MaxTokens = maxGenerate
|
||||||
|
} else {
|
||||||
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
|
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -93,8 +115,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
now := time.Now()
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,9 +124,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
nextSample, nextLogprobs = step(sample)
|
nextSample, nextLogprobs = step(sample)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptTokensDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,18 +133,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
final.Token = int(output)
|
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.CompletionTokens = i
|
final.EvalCount = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- Response{
|
case request.Responses <- CompletionResponse{
|
||||||
Text: r.Decode(output, &b),
|
Content: r.Decode(output, &b),
|
||||||
Token: int(output),
|
|
||||||
}:
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +155,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.CompletionTokensDuration = time.Since(now)
|
final.EvalDuration = time.Since(now)
|
||||||
|
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -21,7 +22,7 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan Response
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
|||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Text string `json:"content,omitempty"`
|
|
||||||
Token int `json:"token,omitempty"`
|
|
||||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
||||||
Done bool `json:"done,omitempty"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
CompletionTokens int `json:"eval_count,omitempty"`
|
|
||||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
TotalTokens int `json:"total_tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
cache kvCache
|
cache kvCache
|
||||||
|
contextLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
|
|
||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
|
r.contextLength = m.MaxContextLength()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
slog.Info("Request terminated", "error", err)
|
slog.Info("Request terminated", "error", err)
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
statusErr = api.StatusError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||||
|
case <-request.Ctx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||||
"status": 0,
|
Status: 0,
|
||||||
"progress": 100,
|
Progress: 100,
|
||||||
|
ContextLength: runner.contextLength,
|
||||||
|
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
slog.Error("Failed to encode response", "error", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -78,7 +80,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan Response)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
@@ -87,9 +89,6 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||||
if request.Options.MaxTokens < 1 {
|
|
||||||
request.Options.MaxTokens = 16 << 10
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(
|
||||||
|
|||||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
// MaxContextLength returns the maximum context length
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
// VocabSize returns the vocabulary size
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user