mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 06:54:09 +02:00
Compare commits
9 Commits
pdevine/sa
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
330b19b73f | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
14
api/types.go
14
api/types.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -569,6 +570,7 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.PeakMemory > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
||||||
|
}
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatPeakMemory(b uint64) string {
|
||||||
|
if b >= format.GibiByte {
|
||||||
|
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
||||||
|
}
|
||||||
|
|
||||||
|
return format.HumanBytes2(b)
|
||||||
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
MemorySize() (total, vram uint64)
|
||||||
TotalSize() uint64
|
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
totalSize, _ := s.MemorySize()
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
@@ -1848,17 +1849,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMSize() uint64 {
|
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var mem uint64
|
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
mem += g.Size()
|
vram += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1869,25 +1870,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
mem += s.mem.InputWeights
|
vram += s.mem.InputWeights
|
||||||
mem += s.mem.CPU.Graph
|
vram += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return mem
|
return total, vram
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) TotalSize() uint64 {
|
|
||||||
if s.mem == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
mem := s.mem.InputWeights
|
|
||||||
mem += s.mem.CPU.Size()
|
|
||||||
for _, g := range s.mem.GPUs {
|
|
||||||
mem += g.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
|
|||||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
}
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||||
|
}
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
|
|||||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Validate() error {
|
||||||
|
if m.Options == nil {
|
||||||
|
return fmt.Errorf("qwen3next: missing model options")
|
||||||
|
}
|
||||||
|
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||||
|
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if !m.Options.isRecurrent[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||||
|
if !ok || gdn == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
if len(m.mropeSections) > 0 {
|
if len(m.mropeSections) > 0 {
|
||||||
@@ -450,6 +490,64 @@ var (
|
|||||||
_ model.MultimodalProcessor = (*Model)(nil)
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func defaultVHeadReordered(arch string) bool {
|
||||||
|
return arch == "qwen35" || arch == "qwen35moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||||
|
isRecurrent := make([]bool, numLayers)
|
||||||
|
|
||||||
|
hasZero := false
|
||||||
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZero && hasFull {
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility path: older imports store a scalar KV head count and omit
|
||||||
|
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||||
|
interval := int(fullAttentionInterval)
|
||||||
|
if interval == 0 {
|
||||||
|
interval = min(4, numLayers)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||||
|
}
|
||||||
|
if interval > numLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasZero = false
|
||||||
|
hasFull = false
|
||||||
|
for i := range numLayers {
|
||||||
|
isRecurrent[i] = (i+1)%interval != 0
|
||||||
|
if isRecurrent[i] {
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
layers := make([]Layer, numLayers)
|
layers := make([]Layer, numLayers)
|
||||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var isRecurrent []bool
|
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent = make([]bool, numLayers)
|
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||||
hasZero := false
|
if err != nil {
|
||||||
hasFull := false
|
return nil, err
|
||||||
for i := range numLayers {
|
|
||||||
// If KV head count is 0, it's a recurrent layer
|
|
||||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
for _, section := range mropeSections {
|
for _, section := range mropeSections {
|
||||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
|
|||||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, false, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, true, false, true, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, false, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||||
|
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultVHeadReordered(t *testing.T) {
|
||||||
|
if !defaultVHeadReordered("qwen35") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||||
|
}
|
||||||
|
if !defaultVHeadReordered("qwen35moe") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||||
|
}
|
||||||
|
if defaultVHeadReordered("qwen3next") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{
|
||||||
|
Operator: &GatedDeltaNet{
|
||||||
|
SSMQKV: &nn.Linear{},
|
||||||
|
SSMQKVGate: &nn.Linear{},
|
||||||
|
SSMBeta: &nn.Linear{},
|
||||||
|
SSMAlpha: &nn.Linear{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Validate() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -35,6 +35,7 @@ type GLM46Parser struct {
|
|||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `plan</think>
|
||||||
|
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||||
|
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||||
|
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
|
p.callIndex = 0
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
|||||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||||
|
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||||
|
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type Qwen3CoderParser struct {
|
|||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||||
|
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||||
|
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) IsMLX() bool {
|
||||||
|
return m.Config.ModelFormat == "safetensors"
|
||||||
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
|
if truncate {
|
||||||
// Start with all messages and remove from the front until it fits in context
|
// Start with all messages and remove from the front until it fits in context
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
// Collect system messages from the portion we're about to skip
|
// Collect system messages from the portion we're about to skip
|
||||||
@@ -57,7 +58,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !truncate || ctxLen <= opts.NumCtx {
|
if ctxLen <= opts.NumCtx {
|
||||||
currMsgIdx = i
|
currMsgIdx = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -68,6 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if currMsgIdx > 0 {
|
if currMsgIdx > 0 {
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||||
|
|||||||
@@ -21,33 +21,76 @@ type quantizer struct {
|
|||||||
progressFn func(n uint64)
|
progressFn func(n uint64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const quantizationChunkElements uint64 = 4 * 1024 * 1024
|
||||||
|
|
||||||
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||||
quantize := q.from.Kind != q.to.Kind
|
quantize := q.from.Kind != q.to.Kind
|
||||||
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
||||||
if !quantize {
|
if !quantize {
|
||||||
n, err := io.Copy(w, sr)
|
n, err := io.Copy(w, sr)
|
||||||
|
if q.progressFn != nil {
|
||||||
q.progressFn(q.from.Size())
|
q.progressFn(q.from.Size())
|
||||||
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(sr)
|
|
||||||
if err != nil {
|
if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 {
|
||||||
|
return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fromType := fsggml.TensorType(q.from.Kind)
|
||||||
|
toType := fsggml.TensorType(q.to.Kind)
|
||||||
|
nPerRow := q.from.Shape[0]
|
||||||
|
totalElements := q.from.Elements()
|
||||||
|
if totalElements%nPerRow != 0 {
|
||||||
|
return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
inRowSize := fromType.RowSize(nPerRow)
|
||||||
|
if inRowSize == 0 {
|
||||||
|
return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalRows := totalElements / nPerRow
|
||||||
|
rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1))
|
||||||
|
chunkBuf := make([]byte, inRowSize*rowsPerChunk)
|
||||||
|
var written int64
|
||||||
|
|
||||||
|
for row := uint64(0); row < totalRows; {
|
||||||
|
chunkRows := min(rowsPerChunk, totalRows-row)
|
||||||
|
chunkBytes := inRowSize * chunkRows
|
||||||
|
data := chunkBuf[:chunkBytes]
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(sr, data); err != nil {
|
||||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
|
||||||
}
|
|
||||||
if uint64(len(data)) < q.from.Size() {
|
|
||||||
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
newType := fsggml.TensorType(q.to.Kind)
|
chunkElements := chunkRows * nPerRow
|
||||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
if fromType == fsggml.TensorTypeF32 {
|
||||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
|
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements)
|
||||||
} else {
|
} else {
|
||||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements)
|
||||||
}
|
}
|
||||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
|
||||||
n, err := w.Write(data)
|
quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows})
|
||||||
q.progressFn(q.from.Size())
|
n, err := w.Write(quantized)
|
||||||
return int64(n), err
|
written += int64(n)
|
||||||
|
if err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
if n != len(quantized) {
|
||||||
|
return written, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.progressFn != nil {
|
||||||
|
q.progressFn(chunkBytes)
|
||||||
|
}
|
||||||
|
row += chunkRows
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type quantizeState struct {
|
type quantizeState struct {
|
||||||
|
|||||||
@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||||
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -557,6 +558,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
|
PeakMemory: cr.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -1951,6 +1953,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
|
total, vram := v.llama.MemorySize()
|
||||||
|
mr.Size = int64(total)
|
||||||
|
mr.SizeVRAM = int64(vram)
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -2213,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
|
if m.IsMLX() {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2309,6 +2317,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
|
PeakMemory: r.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.Config.ModelFormat == "safetensors" {
|
if pending.model.IsMLX() {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -536,6 +536,7 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := llama.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -545,8 +546,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
vramSize: llama.VRAMSize(),
|
totalSize: totalSize,
|
||||||
totalSize: llama.TotalSize(),
|
vramSize: vramSize,
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := server.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: isImagegen,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: server.TotalSize(),
|
totalSize: totalSize,
|
||||||
vramSize: server.VRAMSize(),
|
vramSize: vramSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMSize returns the estimated VRAM usage.
|
// MemorySize returns the total and VRAM memory usage.
|
||||||
func (s *Server) VRAMSize() uint64 {
|
func (s *Server) MemorySize() (total, vram uint64) {
|
||||||
return s.vramSize
|
return s.vramSize, s.vramSize
|
||||||
}
|
|
||||||
|
|
||||||
// TotalSize returns the total memory usage.
|
|
||||||
func (s *Server) TotalSize() uint64 {
|
|
||||||
return s.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|||||||
prefix++
|
prefix++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always keep at least one token to re-evaluate so the
|
||||||
|
// pipeline can seed token generation from it.
|
||||||
|
if prefix == len(tokens) && prefix > 0 {
|
||||||
|
prefix--
|
||||||
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
if prefix < len(c.tokens) {
|
||||||
trim := len(c.tokens) - prefix
|
trim := len(c.tokens) - prefix
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -19,19 +18,21 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
vramSize uint64
|
contextLength atomic.Int64
|
||||||
|
memory atomic.Uint64
|
||||||
done chan error
|
done chan error
|
||||||
client *http.Client
|
client *http.Client
|
||||||
lastErr string
|
lastErr string
|
||||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate VRAM based on tensor size from manifest
|
|
||||||
var vramSize uint64
|
|
||||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
|
||||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
|
||||||
} else {
|
|
||||||
vramSize = 8 * 1024 * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
vramSize: vramSize,
|
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
@@ -201,6 +193,20 @@ type completionOpts struct {
|
|||||||
NumPredict int `json:"num_predict,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string
|
||||||
|
Done bool
|
||||||
|
DoneReason int
|
||||||
|
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
PeakMemory uint64
|
||||||
|
|
||||||
|
Error *api.StatusError
|
||||||
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -260,28 +266,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw struct {
|
var raw CompletionResponse
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration int `json:"eval_duration,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if raw.Error != nil {
|
||||||
|
return *raw.Error
|
||||||
|
}
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: time.Duration(raw.EvalDuration),
|
EvalDuration: raw.EvalDuration,
|
||||||
|
PeakMemory: raw.PeakMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
@@ -294,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
return math.MaxInt
|
return int(c.contextLength.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize implements llm.LlamaServer.
|
// Detokenize implements llm.LlamaServer.
|
||||||
@@ -347,9 +350,16 @@ func (c *Client) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type statusResponse struct {
|
||||||
|
Status int
|
||||||
|
Progress int
|
||||||
|
ContextLength int
|
||||||
|
Memory uint64
|
||||||
|
}
|
||||||
|
|
||||||
// Ping implements llm.LlamaServer.
|
// Ping implements llm.LlamaServer.
|
||||||
func (c *Client) Ping(ctx context.Context) error {
|
func (c *Client) Ping(ctx context.Context) error {
|
||||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -362,6 +372,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var status statusResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.contextLength.Store(int64(status.ContextLength))
|
||||||
|
c.memory.Store(status.Memory)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,19 +407,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotalSize implements llm.LlamaServer.
|
func (c *Client) currentMemory() uint64 {
|
||||||
func (c *Client) TotalSize() uint64 {
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
return c.vramSize
|
defer cancel()
|
||||||
|
if err := c.Ping(ctx); err != nil {
|
||||||
|
slog.Warn("failed to get current memory", "error", err)
|
||||||
|
}
|
||||||
|
return c.memory.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemorySize implements llm.LlamaServer.
|
||||||
|
func (c *Client) MemorySize() (total, vram uint64) {
|
||||||
|
mem := c.currentMemory()
|
||||||
|
return mem, mem
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU implements llm.LlamaServer.
|
// VRAMByGPU implements llm.LlamaServer.
|
||||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
return c.vramSize
|
return c.currentMemory()
|
||||||
}
|
|
||||||
|
|
||||||
// VRAMSize implements llm.LlamaServer.
|
|
||||||
func (c *Client) VRAMSize() uint64 {
|
|
||||||
return c.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitUntilRunning implements llm.LlamaServer.
|
// WaitUntilRunning implements llm.LlamaServer.
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
|||||||
return int(peak)
|
return int(peak)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetPeakMemory() {
|
||||||
|
C.mlx_reset_peak_memory()
|
||||||
|
}
|
||||||
|
|
||||||
type Memory struct{}
|
type Memory struct{}
|
||||||
|
|
||||||
func (Memory) LogValue() slog.Value {
|
func (Memory) LogValue() slog.Value {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Model interface {
|
|||||||
Unembed(x *mlx.Array) *mlx.Array
|
Unembed(x *mlx.Array) *mlx.Array
|
||||||
NumLayers() int
|
NumLayers() int
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
|
MaxContextLength() int
|
||||||
|
|
||||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
@@ -44,16 +47,35 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
} else {
|
} else {
|
||||||
mlx.DisableCompile()
|
mlx.DisableCompile()
|
||||||
}
|
}
|
||||||
|
mlx.ResetPeakMemory()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return errors.New("empty prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputs) >= r.contextLength {
|
||||||
|
return api.StatusError{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap generation to stay within the model's context length
|
||||||
|
maxGenerate := r.contextLength - len(inputs)
|
||||||
|
if request.Options.MaxTokens <= 0 {
|
||||||
|
request.Options.MaxTokens = maxGenerate
|
||||||
|
} else {
|
||||||
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
|
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -93,8 +115,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
now := time.Now()
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,9 +124,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
nextSample, nextLogprobs = step(sample)
|
nextSample, nextLogprobs = step(sample)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptTokensDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,18 +133,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
final.Token = int(output)
|
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.CompletionTokens = i
|
final.EvalCount = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- Response{
|
case request.Responses <- CompletionResponse{
|
||||||
Text: r.Decode(output, &b),
|
Content: r.Decode(output, &b),
|
||||||
Token: int(output),
|
|
||||||
}:
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +155,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.CompletionTokensDuration = time.Since(now)
|
final.EvalDuration = time.Since(now)
|
||||||
|
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -21,7 +22,7 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan Response
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
|||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Text string `json:"content,omitempty"`
|
|
||||||
Token int `json:"token,omitempty"`
|
|
||||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
||||||
Done bool `json:"done,omitempty"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
CompletionTokens int `json:"eval_count,omitempty"`
|
|
||||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
TotalTokens int `json:"total_tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
cache kvCache
|
cache kvCache
|
||||||
|
contextLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
|
|
||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
|
r.contextLength = m.MaxContextLength()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
slog.Info("Request terminated", "error", err)
|
slog.Info("Request terminated", "error", err)
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
statusErr = api.StatusError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||||
|
case <-request.Ctx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||||
"status": 0,
|
Status: 0,
|
||||||
"progress": 100,
|
Progress: 100,
|
||||||
|
ContextLength: runner.contextLength,
|
||||||
|
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
slog.Error("Failed to encode response", "error", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -78,7 +80,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan Response)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
@@ -87,9 +89,6 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||||
if request.Options.MaxTokens < 1 {
|
|
||||||
request.Options.MaxTokens = 16 << 10
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(
|
||||||
|
|||||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
// MaxContextLength returns the maximum context length
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
// VocabSize returns the vocabulary size
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user