mirror of
https://github.com/ollama/ollama.git
synced 2026-04-27 19:25:55 +02:00
Compare commits
21 Commits
v0.17.1-rc
...
pdevine/sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67ce53b9b5 | ||
|
|
dd497534c4 | ||
|
|
560626fb43 | ||
|
|
1a23c1a810 | ||
|
|
a6c1aa4da5 | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 | ||
|
|
d98dda4676 | ||
|
|
d69ddc1edc | ||
|
|
9bf41969f0 | ||
|
|
0f23b7bff5 | ||
|
|
4e57d2094e | ||
|
|
7f9efd53df | ||
|
|
da70c3222e | ||
|
|
9d902d63ce |
14
api/types.go
14
api/types.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -569,6 +570,7 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.PeakMemory > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
||||||
|
}
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatPeakMemory(b uint64) string {
|
||||||
|
if b >= format.GibiByte {
|
||||||
|
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
||||||
|
}
|
||||||
|
|
||||||
|
return format.HumanBytes2(b)
|
||||||
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
wv = &Webview{}
|
wv = &Webview{}
|
||||||
uiServerPort int
|
uiServerPort int
|
||||||
|
appStore *store.Store
|
||||||
)
|
)
|
||||||
|
|
||||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||||
@@ -208,6 +209,7 @@ func main() {
|
|||||||
uiServerPort = port
|
uiServerPort = port
|
||||||
|
|
||||||
st := &store.Store{}
|
st := &store.Store{}
|
||||||
|
appStore = st
|
||||||
|
|
||||||
// Enable CORS in development mode
|
// Enable CORS in development mode
|
||||||
if devMode {
|
if devMode {
|
||||||
@@ -294,8 +296,15 @@ func main() {
|
|||||||
|
|
||||||
// Check for pending updates on startup (show tray notification if update is ready)
|
// Check for pending updates on startup (show tray notification if update is ready)
|
||||||
if updater.IsUpdatePending() {
|
if updater.IsUpdatePending() {
|
||||||
slog.Debug("update pending on startup, showing tray notification")
|
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
|
||||||
UpdateAvailable("")
|
// before that would dereference a nil tray callback.
|
||||||
|
// TODO: refactor so the update check runs after platform init on all platforms.
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
|
||||||
|
} else {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||||
@@ -360,8 +369,7 @@ func startHiddenTasks() {
|
|||||||
slog.Info("deferring pending update for fast startup")
|
slog.Info("deferring pending update for fast startup")
|
||||||
} else {
|
} else {
|
||||||
// Check if auto-update is enabled before automatically upgrading
|
// Check if auto-update is enabled before automatically upgrading
|
||||||
st := &store.Store{}
|
settings, err := appStore.Settings()
|
||||||
settings, err := st.Settings()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||||
} else if !settings.AutoUpdateEnabled {
|
} else if !settings.AutoUpdateEnabled {
|
||||||
|
|||||||
@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAvailable(ver string) error {
|
func UpdateAvailable(ver string) error {
|
||||||
|
if app.t == nil {
|
||||||
|
slog.Debug("tray not yet initialized, skipping update notification")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return app.t.UpdateAvailable(ver)
|
return app.t.UpdateAvailable(ver)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
|||||||
log.Fatalf("Failed to start: %s", err)
|
log.Fatalf("Failed to start: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for pending updates now that the tray is initialized.
|
||||||
|
// The platform-independent check in app.go fires before osRun,
|
||||||
|
// when app.t is still nil, so we must re-check here.
|
||||||
|
if updater.IsUpdatePending() {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
|
|
||||||
signals := make(chan os.Signal, 1)
|
signals := make(chan os.Signal, 1)
|
||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ func (u *Updater) TriggerImmediateCheck() {
|
|||||||
|
|
||||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||||
u.checkNow = make(chan struct{}, 1)
|
u.checkNow = make(chan struct{}, 1)
|
||||||
|
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||||
go func() {
|
go func() {
|
||||||
// Don't blast an update message immediately after startup
|
// Don't blast an update message immediately after startup
|
||||||
time.Sleep(UpdateCheckInitialDelay)
|
time.Sleep(UpdateCheckInitialDelay)
|
||||||
@@ -333,7 +334,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Download successful - show tray notification (regardless of toggle state)
|
// Download successful - show tray notification
|
||||||
err = cb(resp.UpdateVersion)
|
err = cb(resp.UpdateVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to register update available with tray", "error", err)
|
slog.Warn("failed to register update available with tray", "error", err)
|
||||||
|
|||||||
@@ -351,10 +351,13 @@ func TestTriggerImmediateCheck(t *testing.T) {
|
|||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
// Wait for goroutine to start and pass initial delay
|
// Wait for the initial check that fires after the initial delay
|
||||||
time.Sleep(10 * time.Millisecond)
|
select {
|
||||||
|
case <-checkDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("initial check did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
// With 1 hour interval, no check should have happened yet
|
|
||||||
initialCount := checkCount.Load()
|
initialCount := checkCount.Load()
|
||||||
|
|
||||||
// Trigger immediate check
|
// Trigger immediate check
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
|||||||
conv = &lfm2Model{}
|
conv = &lfm2Model{}
|
||||||
case "Lfm2VlForConditionalGeneration":
|
case "Lfm2VlForConditionalGeneration":
|
||||||
conv = &lfm2VLTextModel{}
|
conv = &lfm2VLTextModel{}
|
||||||
case "Qwen3NextForCausalLM":
|
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
|
||||||
conv = &qwen3NextModel{}
|
conv = &qwen3NextModel{}
|
||||||
case "NemotronHForCausalLM":
|
case "NemotronHForCausalLM":
|
||||||
conv = &nemotronHModel{}
|
conv = &nemotronHModel{}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math"
|
"math"
|
||||||
@@ -13,8 +14,21 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type qwen3NextModel struct {
|
type qwen3NextRopeScaling struct {
|
||||||
ModelParameters
|
Type string `json:"type"`
|
||||||
|
Factor ropeFactor `json:"factor"`
|
||||||
|
MropeSection []int32 `json:"mrope_section"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextRopeParams struct {
|
||||||
|
MRopeInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
MropeSection []int32 `json:"mrope_section"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextTextConfig struct {
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
HiddenSize uint32 `json:"hidden_size"`
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
|
|||||||
// MoE config
|
// MoE config
|
||||||
NumExperts uint32 `json:"num_experts"`
|
NumExperts uint32 `json:"num_experts"`
|
||||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
NormTopkProb bool `json:"norm_topk_prob"`
|
NormTopkProb *bool `json:"norm_topk_prob"`
|
||||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||||
|
|
||||||
// Hybrid attention config
|
// Hybrid attention config
|
||||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
|
||||||
// Linear attention (Gated Delta Net) config
|
// Linear attention (Gated Delta Net) config
|
||||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||||
@@ -43,16 +58,102 @@ type qwen3NextModel struct {
|
|||||||
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
|
||||||
|
|
||||||
// RoPE config
|
// RoPE config
|
||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
RopeScaling struct {
|
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
|
||||||
Type string `json:"type"`
|
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
|
||||||
Factor ropeFactor `json:"factor"`
|
}
|
||||||
} `json:"rope_scaling"`
|
|
||||||
|
type qwen3NextVisionConfig struct {
|
||||||
|
Depth uint32 `json:"depth"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHeads uint32 `json:"num_heads"`
|
||||||
|
InChannels uint32 `json:"in_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||||
|
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
|
||||||
|
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge uint32 `json:"shortest_edge"`
|
||||||
|
LongestEdge uint32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextModel struct {
|
||||||
|
ModelParameters
|
||||||
|
qwen3NextTextConfig
|
||||||
|
|
||||||
|
TextConfig *qwen3NextTextConfig `json:"text_config"`
|
||||||
|
VisionModel qwen3NextVisionConfig `json:"vision_config"`
|
||||||
|
|
||||||
|
ImageTokenID uint32 `json:"image_token_id"`
|
||||||
|
VisionStartTokenID uint32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndTokenID uint32 `json:"vision_end_token_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||||
|
|
||||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
|
||||||
|
if q.TextConfig != nil {
|
||||||
|
q.qwen3NextTextConfig = *q.TextConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeTheta == 0 {
|
||||||
|
q.RopeTheta = q.RopeParameters.RopeTheta
|
||||||
|
}
|
||||||
|
if q.PartialRotaryFactor == 0 {
|
||||||
|
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
|
||||||
|
q.RopeScaling.Type = q.RopeParameters.RopeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull vision preprocessing fields when present.
|
||||||
|
if q.VisionModel.Depth > 0 {
|
||||||
|
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
|
||||||
|
var pre struct {
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge uint32 `json:"shortest_edge"`
|
||||||
|
LongestEdge uint32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||||
|
MergeSize uint32 `json:"merge_size"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(bts, &pre) == nil {
|
||||||
|
if q.VisionModel.PatchSize == 0 {
|
||||||
|
q.VisionModel.PatchSize = pre.PatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.TemporalPatchSize == 0 {
|
||||||
|
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.SpatialMergeSize == 0 {
|
||||||
|
q.VisionModel.SpatialMergeSize = pre.MergeSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.ShortestEdge == 0 {
|
||||||
|
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.LongestEdge == 0 {
|
||||||
|
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageMean) == 0 {
|
||||||
|
q.VisionModel.ImageMean = pre.ImageMean
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageStd) == 0 {
|
||||||
|
q.VisionModel.ImageStd = pre.ImageStd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if q.NumHiddenLayers == 0 {
|
if q.NumHiddenLayers == 0 {
|
||||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||||
}
|
}
|
||||||
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
|||||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||||
}
|
}
|
||||||
if q.FullAttentionInterval == 0 {
|
if _, err := q.kvHeadCounts(); err != nil {
|
||||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
return err
|
||||||
}
|
|
||||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
|
||||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
|
||||||
}
|
|
||||||
|
|
||||||
hasFull := false
|
|
||||||
for i := range q.NumHiddenLayers {
|
|
||||||
if (i+1)%q.FullAttentionInterval == 0 {
|
|
||||||
hasFull = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasFull {
|
|
||||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
|
||||||
|
if len(q.LayerTypes) > 0 {
|
||||||
|
kv := make([]uint32, q.NumHiddenLayers)
|
||||||
|
hasFull := false
|
||||||
|
hasRecurrent := false
|
||||||
|
for i := range q.NumHiddenLayers {
|
||||||
|
layerType := ""
|
||||||
|
if i < uint32(len(q.LayerTypes)) {
|
||||||
|
layerType = q.LayerTypes[i]
|
||||||
|
}
|
||||||
|
if layerType == "full_attention" {
|
||||||
|
kv[i] = q.NumKeyValueHeads
|
||||||
|
hasFull = true
|
||||||
|
} else {
|
||||||
|
hasRecurrent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFull || !hasRecurrent {
|
||||||
|
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
|
||||||
|
}
|
||||||
|
return kv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.FullAttentionInterval == 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||||
|
}
|
||||||
|
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := make([]uint32, q.NumHiddenLayers)
|
||||||
|
hasFull := false
|
||||||
|
for i := range q.NumHiddenLayers {
|
||||||
|
if (i+1)%q.FullAttentionInterval == 0 {
|
||||||
|
kv[i] = q.NumKeyValueHeads
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
return kv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) ropeSections() []int32 {
|
||||||
|
if len(q.RopeParameters.MropeSection) > 0 {
|
||||||
|
return q.RopeParameters.MropeSection
|
||||||
|
}
|
||||||
|
return q.RopeScaling.MropeSection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) shouldReorderVHeads() bool {
|
||||||
|
modelType := strings.ToLower(q.ModelType)
|
||||||
|
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, arch := range q.Architectures {
|
||||||
|
arch = strings.ToLower(arch)
|
||||||
|
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to qwen3.5 layout for all other qwen3next-family imports.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen3next"
|
|
||||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
arch := "qwen35"
|
||||||
|
if q.NumExperts > 0 {
|
||||||
|
arch = "qwen35moe"
|
||||||
|
}
|
||||||
|
kv["general.architecture"] = arch
|
||||||
|
kv["tokenizer.ggml.pre"] = "qwen35"
|
||||||
kv["block_count"] = q.NumHiddenLayers
|
kv["block_count"] = q.NumHiddenLayers
|
||||||
kv["context_length"] = q.MaxPositionEmbeddings
|
kv["context_length"] = q.MaxPositionEmbeddings
|
||||||
kv["embedding_length"] = q.HiddenSize
|
kv["embedding_length"] = q.HiddenSize
|
||||||
kv["feed_forward_length"] = q.IntermediateSize
|
kv["feed_forward_length"] = q.IntermediateSize
|
||||||
kv["attention.head_count"] = q.NumAttentionHeads
|
kv["attention.head_count"] = q.NumAttentionHeads
|
||||||
|
|
||||||
headDim := q.HeadDim
|
headDim := q.HeadDim
|
||||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||||
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||||
kv["rope.freq_base"] = q.RopeTheta
|
kv["rope.freq_base"] = q.RopeTheta
|
||||||
|
|
||||||
// RoPE dimension count (partial rotary)
|
|
||||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
|
||||||
partialRotary := q.PartialRotaryFactor
|
partialRotary := q.PartialRotaryFactor
|
||||||
if partialRotary > 0 && partialRotary <= 1 {
|
if partialRotary > 0 && partialRotary <= 1 {
|
||||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoE config
|
if sections := q.ropeSections(); len(sections) > 0 {
|
||||||
|
kv["mrope_sections"] = sections
|
||||||
|
kv["rope.mrope_section"] = sections
|
||||||
|
kv["rope.dimension_sections"] = sections
|
||||||
|
}
|
||||||
|
if q.RopeParameters.MRopeInterleaved {
|
||||||
|
kv["rope.mrope_interleaved"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
|
||||||
|
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||||
|
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||||
|
}
|
||||||
|
|
||||||
if q.NumExperts > 0 {
|
if q.NumExperts > 0 {
|
||||||
kv["expert_count"] = q.NumExperts
|
kv["expert_count"] = q.NumExperts
|
||||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
if q.NormTopkProb != nil {
|
||||||
|
kv["norm_top_k_prob"] = *q.NormTopkProb
|
||||||
|
}
|
||||||
if q.MoEIntermediateSize > 0 {
|
if q.MoEIntermediateSize > 0 {
|
||||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||||
}
|
}
|
||||||
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSM/Linear attention config
|
|
||||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
|
||||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||||
kv["ssm.inner_size"] = dInner
|
kv["ssm.inner_size"] = dInner
|
||||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
kv["ssm.state_size"] = q.LinearKeyHeadDim
|
||||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
kv["ssm.group_count"] = q.LinearNumKeyHeads
|
||||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
|
||||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||||
interval := q.FullAttentionInterval
|
if q.shouldReorderVHeads() {
|
||||||
kv["full_attention_interval"] = interval
|
kv["ssm.v_head_reordered"] = true
|
||||||
|
}
|
||||||
// Build per-layer KV head count array to identify layer types
|
if q.FullAttentionInterval > 0 {
|
||||||
// 0 = recurrent (linear attention), non-zero = full attention
|
kv["full_attention_interval"] = q.FullAttentionInterval
|
||||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
|
||||||
for i := range q.NumHiddenLayers {
|
|
||||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
|
||||||
if interval > 0 && (i+1)%interval == 0 {
|
|
||||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
|
||||||
}
|
|
||||||
// else stays 0 (recurrent layer)
|
|
||||||
}
|
}
|
||||||
kv["attention.head_count_kv"] = kvHeadCounts
|
|
||||||
|
|
||||||
// RoPE scaling
|
if headCounts, err := q.kvHeadCounts(); err == nil {
|
||||||
if q.RopeScaling.Type != "" {
|
kv["attention.head_count_kv"] = headCounts
|
||||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
}
|
||||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
|
||||||
|
if q.VisionModel.Depth > 0 {
|
||||||
|
kv["vision.block_count"] = q.VisionModel.Depth
|
||||||
|
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||||
|
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
|
||||||
|
kv["vision.num_channels"] = q.VisionModel.InChannels
|
||||||
|
if q.VisionModel.PatchSize > 0 {
|
||||||
|
kv["vision.patch_size"] = q.VisionModel.PatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.SpatialMergeSize > 0 {
|
||||||
|
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.RMSNormEps > 0 {
|
||||||
|
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
|
||||||
|
}
|
||||||
|
if q.VisionModel.RopeTheta > 0 {
|
||||||
|
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
|
||||||
|
}
|
||||||
|
if q.VisionModel.TemporalPatchSize > 0 {
|
||||||
|
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
|
||||||
|
}
|
||||||
|
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
|
||||||
|
if q.VisionModel.Size.ShortestEdge > 0 {
|
||||||
|
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.LongestEdge > 0 {
|
||||||
|
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageMean) > 0 {
|
||||||
|
kv["vision.image_mean"] = q.VisionModel.ImageMean
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageStd) > 0 {
|
||||||
|
kv["vision.image_std"] = q.VisionModel.ImageStd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.ImageTokenID > 0 {
|
||||||
|
kv["image_token_id"] = q.ImageTokenID
|
||||||
|
}
|
||||||
|
if q.VisionStartTokenID > 0 {
|
||||||
|
kv["vision_start_token_id"] = q.VisionStartTokenID
|
||||||
|
}
|
||||||
|
if q.VisionEndTokenID > 0 {
|
||||||
|
kv["vision_end_token_id"] = q.VisionEndTokenID
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
|
||||||
merges := make([]merge, q.NumHiddenLayers*3)
|
merges := make([]merge, q.NumHiddenLayers*3)
|
||||||
for i := range q.NumHiddenLayers {
|
for i := range q.NumHiddenLayers {
|
||||||
merges[i*3+0] = merge{
|
merges[i*3+0] = merge{
|
||||||
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge expert tensors
|
|
||||||
merged, remaining := mergeTensors(ts, merges...)
|
merged, remaining := mergeTensors(ts, merges...)
|
||||||
out = append(out, merged...)
|
out = append(out, merged...)
|
||||||
|
|
||||||
// Process remaining tensors
|
|
||||||
for _, t := range remaining {
|
for _, t := range remaining {
|
||||||
name := t.Name()
|
name := t.Name()
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
|
|
||||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
|
||||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||||
out = append(out, qkv, gate)
|
out = append(out, qkv, gate)
|
||||||
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
|
||||||
// This matches the Python converter behavior for qwen3next
|
out = append(out, slices.Collect(splitDim(t, 1,
|
||||||
|
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
|
||||||
|
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
|
||||||
|
))...)
|
||||||
|
|
||||||
|
case strings.Contains(name, ".mlp.experts.down_proj"):
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: slices.Clone(shape),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
|
||||||
|
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
|
||||||
|
out = append(out, slices.Collect(splitDim(t, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||||
|
))...)
|
||||||
|
|
||||||
|
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
|
||||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||||
t.SetRepacker(q.addOne)
|
t.SetRepacker(q.addOne)
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: slices.Clone(shape),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
|
||||||
// Note: name has already been transformed by Replacements at this point
|
|
||||||
case strings.HasSuffix(name, ".ssm_a"):
|
case strings.HasSuffix(name, ".ssm_a"):
|
||||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
t.SetRepacker(q.repackSSMA())
|
||||||
// Compute -exp(A_log)
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
result := make([]float32, len(data))
|
|
||||||
for i, v := range data {
|
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||||
// -exp(v)
|
if q.shouldReorderVHeads() {
|
||||||
result[i] = -float32(math.Exp(float64(v)))
|
t.SetRepacker(q.repackAttnQKV())
|
||||||
}
|
}
|
||||||
return result, nil
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
})
|
|
||||||
out = append(out, &ggml.Tensor{
|
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||||
Name: name,
|
if q.shouldReorderVHeads() {
|
||||||
Kind: t.Kind(),
|
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||||
Shape: slices.Clone(shape),
|
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
|
||||||
WriterTo: t,
|
}
|
||||||
})
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||||
|
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_dt"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
// HF out_proj layout is [out_features, in_features]; reorder columns.
|
||||||
|
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
|
||||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||||
newShape := slices.Clone(shape)
|
newShape := slices.Clone(shape)
|
||||||
if len(shape) == 3 {
|
if len(shape) == 3 {
|
||||||
if shape[0] == 1 {
|
if shape[0] == 1 {
|
||||||
// [1, D, K] -> [D, K]
|
|
||||||
newShape = []uint64{shape[1], shape[2]}
|
newShape = []uint64{shape[1], shape[2]}
|
||||||
} else if shape[1] == 1 {
|
} else if shape[1] == 1 {
|
||||||
// [D, 1, K] -> [D, K]
|
|
||||||
newShape = []uint64{shape[0], shape[2]}
|
newShape = []uint64{shape[0], shape[2]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out = append(out, &ggml.Tensor{
|
if q.shouldReorderVHeads() {
|
||||||
Name: name,
|
t.SetRepacker(q.repackConv1D())
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: newShape,
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
|
||||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
|
||||||
newShape := slices.Clone(shape)
|
|
||||||
if len(shape) == 2 {
|
|
||||||
if shape[0] == 1 && shape[1] > 1 {
|
|
||||||
newShape = []uint64{shape[1]}
|
|
||||||
} else if shape[1] == 1 && shape[0] > 1 {
|
|
||||||
newShape = []uint64{shape[0]}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: newShape,
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: slices.Clone(shape),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||||
|
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackAttnQKV() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() || len(shape) != 2 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := int(shape[0])
|
||||||
|
cols := int(shape[1])
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numV := int(q.LinearNumValueHeads)
|
||||||
|
headK := int(q.LinearKeyHeadDim)
|
||||||
|
headV := int(q.LinearValueHeadDim)
|
||||||
|
qDim := headK * numK
|
||||||
|
kDim := headK * numK
|
||||||
|
vDim := headV * numV
|
||||||
|
qkvDim := qDim + kDim + vDim
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rows == qkvDim:
|
||||||
|
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
|
||||||
|
// reorder only V rows from grouped -> tiled head layout.
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
qkRows := qDim + kDim
|
||||||
|
qkSize := qkRows * cols
|
||||||
|
copy(out[:qkSize], data[:qkSize])
|
||||||
|
|
||||||
|
vStart := qkSize
|
||||||
|
vEnd := vStart + vDim*cols
|
||||||
|
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
copy(out[vEnd:], data[vEnd:])
|
||||||
|
return out, nil
|
||||||
|
|
||||||
|
case cols == qkvDim:
|
||||||
|
// Fallback for already-transposed [in_features, out_features] tensors.
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
copy(out, data)
|
||||||
|
for r := range rows {
|
||||||
|
base := r * cols
|
||||||
|
vStart := base + qDim + kDim
|
||||||
|
vEnd := vStart + vDim
|
||||||
|
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackConv1D() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normShape := slices.Clone(shape)
|
||||||
|
if len(shape) == 3 {
|
||||||
|
if shape[0] == 1 {
|
||||||
|
normShape = []uint64{shape[1], shape[2]}
|
||||||
|
} else if shape[1] == 1 {
|
||||||
|
normShape = []uint64{shape[0], shape[2]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(normShape) != 2 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := int(normShape[0])
|
||||||
|
cols := int(normShape[1])
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numV := int(q.LinearNumValueHeads)
|
||||||
|
headK := int(q.LinearKeyHeadDim)
|
||||||
|
headV := int(q.LinearValueHeadDim)
|
||||||
|
qkChannels := 2 * headK * numK
|
||||||
|
totalChannels := qkChannels + headV*numV
|
||||||
|
if qkChannels <= 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rows == totalChannels:
|
||||||
|
// HF layout after squeeze: [channels, kernel]
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
prefix := qkChannels * cols
|
||||||
|
copy(out[:prefix], data[:prefix])
|
||||||
|
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[prefix:], reorderedV)
|
||||||
|
return out, nil
|
||||||
|
case cols == totalChannels:
|
||||||
|
// Fallback for transposed [kernel, channels]
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
copy(out, data)
|
||||||
|
vChannels := totalChannels - qkChannels
|
||||||
|
for r := range rows {
|
||||||
|
base := r * cols
|
||||||
|
vStart := base + qkChannels
|
||||||
|
vEnd := vStart + vChannels
|
||||||
|
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackSSMA() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
result := make([]float32, len(data))
|
||||||
|
for i, v := range data {
|
||||||
|
result[i] = -float32(math.Exp(float64(v)))
|
||||||
|
}
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||||
|
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
|
||||||
|
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i := range shape {
|
||||||
|
dims[i] = int(shape[i])
|
||||||
|
}
|
||||||
|
if dim < 0 {
|
||||||
|
dim += len(dims)
|
||||||
|
}
|
||||||
|
if dim < 0 || dim >= len(dims) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := numKHeads * numVPerK * headDim
|
||||||
|
if dims[dim] != expected {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newShape := make([]int, 0, len(dims)+2)
|
||||||
|
newShape = append(newShape, dims[:dim]...)
|
||||||
|
newShape = append(newShape, numKHeads, numVPerK, headDim)
|
||||||
|
newShape = append(newShape, dims[dim+1:]...)
|
||||||
|
|
||||||
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := tt.Reshape(newShape...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
perm := make([]int, len(newShape))
|
||||||
|
for i := range perm {
|
||||||
|
perm[i] = i
|
||||||
|
}
|
||||||
|
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
|
||||||
|
|
||||||
|
tt, err := tensor.Transpose(tt, perm...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tt = tensor.Materialize(tt)
|
||||||
|
|
||||||
|
total := 1
|
||||||
|
for _, d := range dims {
|
||||||
|
total *= d
|
||||||
|
}
|
||||||
|
if err := tt.Reshape(total); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return native.VectorF32(tt.(*tensor.Dense))
|
||||||
|
}
|
||||||
|
|
||||||
type qkvzSplitSpec struct {
|
type qkvzSplitSpec struct {
|
||||||
hidden int
|
hidden int
|
||||||
headKDim int
|
headKDim int
|
||||||
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
|||||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Convert to [hidden, out_features] layout for slicing
|
|
||||||
tt, err = tensor.Transpose(tt, 1, 0)
|
tt, err = tensor.Transpose(tt, 1, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
|
||||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||||
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
return []string{
|
return []string{
|
||||||
// Embeddings and output
|
// Embeddings and output
|
||||||
"lm_head", "output",
|
"lm_head", "output",
|
||||||
|
"model.language_model.embed_tokens", "token_embd",
|
||||||
|
"model.language_model.norm", "output_norm",
|
||||||
|
"model.language_model.layers", "blk",
|
||||||
"model.embed_tokens", "token_embd",
|
"model.embed_tokens", "token_embd",
|
||||||
"model.norm", "output_norm",
|
"model.norm", "output_norm",
|
||||||
"model.layers", "blk",
|
"model.layers", "blk",
|
||||||
|
|
||||||
|
// Vision
|
||||||
|
"model.visual", "v",
|
||||||
|
"patch_embed.proj", "patch_embed",
|
||||||
|
"blocks", "blk",
|
||||||
|
"attn.qkv", "attn_qkv",
|
||||||
|
"attn.proj", "attn_out",
|
||||||
|
"deepstack_merger_list", "deepstack_merger",
|
||||||
|
|
||||||
// Layer norms
|
// Layer norms
|
||||||
"input_layernorm", "attn_norm",
|
"input_layernorm", "attn_norm",
|
||||||
"post_attention_layernorm", "post_attention_norm",
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
"self_attn.v_proj", "attn_v",
|
"self_attn.v_proj", "attn_v",
|
||||||
"self_attn.o_proj", "attn_output",
|
"self_attn.o_proj", "attn_output",
|
||||||
|
|
||||||
// Linear attention (Gated Delta Net)
|
// Linear attention (legacy qwen3next)
|
||||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||||
"linear_attn.in_proj_ba", "ssm_ba",
|
"linear_attn.in_proj_ba", "ssm_ba",
|
||||||
|
|
||||||
|
// Linear attention (qwen35)
|
||||||
|
"linear_attn.in_proj_qkv", "attn_qkv",
|
||||||
|
"linear_attn.in_proj_z", "attn_gate",
|
||||||
|
"linear_attn.in_proj_a", "ssm_alpha",
|
||||||
|
"linear_attn.in_proj_b", "ssm_beta",
|
||||||
|
|
||||||
"linear_attn.conv1d", "ssm_conv1d",
|
"linear_attn.conv1d", "ssm_conv1d",
|
||||||
"linear_attn.dt_bias", "ssm_dt",
|
"linear_attn.dt_bias", "ssm_dt",
|
||||||
"linear_attn.dt_proj", "ssm_dt",
|
"linear_attn.dt_proj", "ssm_dt",
|
||||||
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
"linear_attn.norm", "ssm_norm",
|
"linear_attn.norm", "ssm_norm",
|
||||||
"linear_attn.out_proj", "ssm_out",
|
"linear_attn.out_proj", "ssm_out",
|
||||||
|
|
||||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
// MoE
|
||||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||||
|
|
||||||
// Dense FFN (if any layers use it)
|
// Dense FFN
|
||||||
"mlp.down_proj", "ffn_down",
|
"mlp.down_proj", "ffn_down",
|
||||||
"mlp.gate_proj", "ffn_gate",
|
"mlp.gate_proj", "ffn_gate",
|
||||||
"mlp.up_proj", "ffn_up",
|
"mlp.up_proj", "ffn_up",
|
||||||
|
|||||||
563
convert/convert_qwen3next_test.go
Normal file
563
convert/convert_qwen3next_test.go
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
func boolPtr(v bool) *bool {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tensor.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
numel := 1
|
||||||
|
for _, d := range tensor.Shape {
|
||||||
|
numel *= int(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]float32, numel)
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shouldReorderVHeads() {
|
||||||
|
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
Architectures: []string{"Qwen3NextForCausalLM"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shouldReorderVHeads() {
|
||||||
|
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextKVLegacyConfig(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 8192,
|
||||||
|
HiddenSize: 512,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 2048,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 2,
|
||||||
|
HeadDim: 64,
|
||||||
|
RopeTheta: 1_000_000,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
|
||||||
|
NumExperts: 8,
|
||||||
|
NumExpertsPerToken: 2,
|
||||||
|
NormTopkProb: boolPtr(true),
|
||||||
|
MoEIntermediateSize: 256,
|
||||||
|
SharedExpertIntermSize: 512,
|
||||||
|
|
||||||
|
FullAttentionInterval: 2,
|
||||||
|
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 64,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 64,
|
||||||
|
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
|
||||||
|
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
|
||||||
|
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := kv["ssm.v_head_reordered"]; ok {
|
||||||
|
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
|
||||||
|
}
|
||||||
|
if got, want := kv["norm_top_k_prob"], true; got != want {
|
||||||
|
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 4096,
|
||||||
|
HiddenSize: 512,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 2048,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 2,
|
||||||
|
HeadDim: 64,
|
||||||
|
RopeTheta: 1_000_000,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
NumExperts: 8,
|
||||||
|
NumExpertsPerToken: 2,
|
||||||
|
FullAttentionInterval: 2,
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 64,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 64,
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if _, ok := kv["norm_top_k_prob"]; ok {
|
||||||
|
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35KVFromTextConfig(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
TextConfig: &qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 16384,
|
||||||
|
HiddenSize: 1024,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 4096,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 4,
|
||||||
|
HeadDim: 128,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
|
||||||
|
LayerTypes: []string{
|
||||||
|
"linear_attention",
|
||||||
|
"full_attention",
|
||||||
|
"linear_attention",
|
||||||
|
"full_attention",
|
||||||
|
},
|
||||||
|
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 128,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 128,
|
||||||
|
|
||||||
|
RopeParameters: qwen3NextRopeParams{
|
||||||
|
MRopeInterleaved: true,
|
||||||
|
MropeSection: []int32{11, 11, 10},
|
||||||
|
RopeType: "default",
|
||||||
|
RopeTheta: 10_000_000,
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
VisionModel: qwen3NextVisionConfig{
|
||||||
|
Depth: 2,
|
||||||
|
HiddenSize: 128,
|
||||||
|
NumHeads: 4,
|
||||||
|
InChannels: 3,
|
||||||
|
PatchSize: 16,
|
||||||
|
SpatialMergeSize: 2,
|
||||||
|
RMSNormEps: 1e-6,
|
||||||
|
RopeTheta: 10_000,
|
||||||
|
TemporalPatchSize: 2,
|
||||||
|
DeepstackVisualIndexes: []int32{1},
|
||||||
|
},
|
||||||
|
ImageTokenID: 1001,
|
||||||
|
VisionStartTokenID: 1002,
|
||||||
|
VisionEndTokenID: 1003,
|
||||||
|
}
|
||||||
|
m.VisionModel.Size.ShortestEdge = 224
|
||||||
|
m.VisionModel.Size.LongestEdge = 4096
|
||||||
|
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if got, want := kv["general.architecture"], "qwen35"; got != want {
|
||||||
|
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
|
||||||
|
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
|
||||||
|
}
|
||||||
|
|
||||||
|
mrope, ok := kv["mrope_sections"].([]int32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
|
||||||
|
}
|
||||||
|
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
|
||||||
|
}
|
||||||
|
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
|
||||||
|
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.block_count"], uint32(2); got != want {
|
||||||
|
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextReplacements(t *testing.T) {
|
||||||
|
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
|
||||||
|
|
||||||
|
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersVHeads(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_gate.weight",
|
||||||
|
shape: []uint64{4, 2},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearKeyHeadDim: 1,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_qkv.weight",
|
||||||
|
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
|
||||||
|
data: []float32{
|
||||||
|
0, 1, // q0
|
||||||
|
2, 3, // q1
|
||||||
|
4, 5, // k0
|
||||||
|
6, 7, // k1
|
||||||
|
10, 11, // v(k0,v0)
|
||||||
|
12, 13, // v(k0,v1)
|
||||||
|
20, 21, // v(k1,v0)
|
||||||
|
22, 23, // v(k1,v1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
10, 11, 20, 21, 12, 13, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_out.weight",
|
||||||
|
shape: []uint64{2, 4},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_beta.weight",
|
||||||
|
shape: []uint64{4, 2},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearKeyHeadDim: 1,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_conv1d.weight",
|
||||||
|
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
|
||||||
|
data: []float32{
|
||||||
|
0, 1, // q0
|
||||||
|
2, 3, // q1
|
||||||
|
4, 5, // k0
|
||||||
|
6, 7, // k1
|
||||||
|
10, 11, // v(k0,v0)
|
||||||
|
12, 13, // v(k0,v1)
|
||||||
|
20, 21, // v(k1,v0)
|
||||||
|
22, 23, // v(k1,v1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
10, 11, 20, 21, 12, 13, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_gate.weight",
|
||||||
|
shape: []uint64{4, 1},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35MoePackedExperts(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
NumHiddenLayers: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.mlp.experts.gate_up_proj",
|
||||||
|
shape: []uint64{2, 4, 3},
|
||||||
|
data: []float32{
|
||||||
|
0, 1, 2,
|
||||||
|
3, 4, 5,
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
12, 13, 14,
|
||||||
|
15, 16, 17,
|
||||||
|
18, 19, 20,
|
||||||
|
21, 22, 23,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.mlp.experts.down_proj",
|
||||||
|
shape: []uint64{2, 5, 3},
|
||||||
|
data: make([]float32, 2*5*3),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
get := func(name string) *ggml.Tensor {
|
||||||
|
for _, tensor := range out {
|
||||||
|
if tensor.Name == name {
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gate := get("blk.0.ffn_gate_exps.weight")
|
||||||
|
if gate == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := readTensorData(t, gate), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5,
|
||||||
|
12, 13, 14, 15, 16, 17,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected gate values: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
up := get("blk.0.ffn_up_exps.weight")
|
||||||
|
if up == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected up shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := readTensorData(t, up), []float32{
|
||||||
|
6, 7, 8, 9, 10, 11,
|
||||||
|
18, 19, 20, 21, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected up values: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
down := get("blk.0.ffn_down_exps.weight")
|
||||||
|
if down == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected down shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ffn_gate_inp_shexp.weight",
|
||||||
|
shape: []uint64{1, 4},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
t.Pre = "deepseek-coder"
|
t.Pre = "deepseek-coder"
|
||||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||||
t.Pre = "qwen2"
|
t.Pre = "qwen2"
|
||||||
|
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
|
||||||
|
t.Pre = "qwen35"
|
||||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||||
// noop, empty pretokenizer
|
// noop, empty pretokenizer
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -386,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "qwen35 pretokenizer",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Sequence",
|
||||||
|
"pretokenizers": [
|
||||||
|
{
|
||||||
|
"type": "Split",
|
||||||
|
"pattern": {
|
||||||
|
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||||
|
Pre: "qwen35",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
|||||||
@@ -290,6 +290,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||||||
"olmo3",
|
"olmo3",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
|
"qwen35", "qwen35moe",
|
||||||
"qwen3next",
|
"qwen3next",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
"glm4moelite",
|
"glm4moelite",
|
||||||
@@ -868,7 +869,12 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
arch := f.KV().Architecture()
|
||||||
|
if slices.Contains([]string{"qwen35", "qwen35moe", "qwen3next"}, arch) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains([]string{"gemma2"}, arch) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -892,6 +898,7 @@ func (f GGML) FlashAttention() bool {
|
|||||||
"nemotron_h", "nemotron_h_moe",
|
"nemotron_h", "nemotron_h_moe",
|
||||||
"olmo3",
|
"olmo3",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
|
"qwen35", "qwen35moe",
|
||||||
"qwen3next",
|
"qwen3next",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
}, f.KV().String("general.architecture"))
|
}, f.KV().String("general.architecture"))
|
||||||
|
|||||||
@@ -245,7 +245,22 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
padding := ggufPadding(offset, int64(alignment))
|
padding := ggufPadding(offset, int64(alignment))
|
||||||
llm.tensorOffset = uint64(offset + padding)
|
llm.tensorOffset = uint64(offset + padding)
|
||||||
|
|
||||||
|
// get file size to validate tensor bounds
|
||||||
|
fileSize, err := rs.Seek(0, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine file size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := rs.Seek(offset, io.SeekStart); err != nil {
|
||||||
|
return fmt.Errorf("failed to seek back after size check: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, tensor := range llm.tensors {
|
for _, tensor := range llm.tensors {
|
||||||
|
tensorEnd := llm.tensorOffset + tensor.Offset + tensor.Size()
|
||||||
|
if tensorEnd > uint64(fileSize) {
|
||||||
|
return fmt.Errorf("tensor %q offset+size (%d) exceeds file size (%d)", tensor.Name, tensorEnd, fileSize)
|
||||||
|
}
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get current offset: %w", err)
|
return fmt.Errorf("failed to get current offset: %w", err)
|
||||||
|
|||||||
@@ -11,21 +11,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestWriteGGUF(t *testing.T) {
|
func TestWriteGGUF(t *testing.T) {
|
||||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
tensorData := make([]byte, 2*3*4) // 6 F32 elements = 24 bytes
|
||||||
for range 8 {
|
for range 8 {
|
||||||
t.Run("shuffle", func(t *testing.T) {
|
t.Run("shuffle", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ts := []*Tensor{
|
ts := []*Tensor{
|
||||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Shuffle(len(ts), func(i, j int) {
|
rand.Shuffle(len(ts), func(i, j int) {
|
||||||
@@ -98,4 +98,32 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("truncated_tensor_data", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ts := []*Tensor{
|
||||||
|
{Name: "blk.0.attn.weight", Kind: 0, Shape: []uint64{512, 2}, WriterTo: bytes.NewBuffer(make([]byte, 32))},
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := os.CreateTemp(t.TempDir(), "truncated_*.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
if err := WriteGGUF(w, KV{"general.architecture": "test"}, ts); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := os.Open(w.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
if _, err := Decode(r, -1); err == nil {
|
||||||
|
t.Error("Decode should reject GGUF files where tensor data extends beyond file size")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultCheckpointCount = 32
|
DefaultCheckpointCount = 24
|
||||||
DefaultCheckpointMinPos = int32(16)
|
DefaultCheckpointMinPos = int32(16)
|
||||||
DefaultCheckpointInterval = int32(1280)
|
DefaultCheckpointInterval = int32(1664)
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
MemorySize() (total, vram uint64)
|
||||||
TotalSize() uint64
|
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
totalSize, _ := s.MemorySize()
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1453,10 +1453,12 @@ type ImageData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
Think *api.ThinkValue
|
||||||
|
ExplicitOptions map[string]struct{}
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
Shift bool
|
Shift bool
|
||||||
@@ -1518,6 +1520,7 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
@@ -1848,17 +1851,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMSize() uint64 {
|
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var mem uint64
|
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
mem += g.Size()
|
vram += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1869,25 +1872,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
mem += s.mem.InputWeights
|
vram += s.mem.InputWeights
|
||||||
mem += s.mem.CPU.Graph
|
vram += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return mem
|
return total, vram
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) TotalSize() uint64 {
|
|
||||||
if s.mem == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
mem := s.mem.InputWeights
|
|
||||||
mem += s.mem.CPU.Size()
|
|
||||||
for _, g := range s.mem.GPUs {
|
|
||||||
mem += g.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -195,6 +195,7 @@ type Tensor interface {
|
|||||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||||
Rows(ctx Context, t2 Tensor) Tensor
|
Rows(ctx Context, t2 Tensor) Tensor
|
||||||
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
|
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
|
||||||
|
SetInplace(ctx Context, src Tensor, nb1, nb2, nb3, offset int) Tensor
|
||||||
Copy(ctx Context, t2 Tensor) Tensor
|
Copy(ctx Context, t2 Tensor) Tensor
|
||||||
Duplicate(ctx Context) Tensor
|
Duplicate(ctx Context) Tensor
|
||||||
|
|
||||||
|
|||||||
@@ -1345,6 +1345,21 @@ func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tenso
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_set_inplace(
|
||||||
|
ctx.(*Context).ctx,
|
||||||
|
t.t,
|
||||||
|
src.(*Tensor).t,
|
||||||
|
C.size_t(nb1),
|
||||||
|
C.size_t(nb2),
|
||||||
|
C.size_t(nb3),
|
||||||
|
C.size_t(offset),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
|||||||
@@ -2,595 +2,58 @@ package qwen3next
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
var (
|
||||||
|
_ kvcache.Cache = (*HybridCache)(nil)
|
||||||
|
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// HybridCache stores:
|
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
|
||||||
// - a standard causal KV cache for full attention layers
|
|
||||||
// - per-sequence conv state for linear attention layers
|
|
||||||
// - per-sequence delta state for linear attention layers
|
|
||||||
//
|
|
||||||
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
|
||||||
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
|
||||||
type HybridCache struct {
|
type HybridCache struct {
|
||||||
kv *kvcache.Causal
|
*kvcache.Recurrent
|
||||||
|
|
||||||
backend ml.Backend
|
|
||||||
dtype ml.DType
|
|
||||||
maxSequences int
|
|
||||||
|
|
||||||
// Conv state dimensions
|
|
||||||
convDim int // convKernelSize - 1
|
|
||||||
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
|
||||||
|
|
||||||
// Delta state dimensions
|
|
||||||
deltaStateSize int // headVDim * headVDim * numVHeads
|
|
||||||
|
|
||||||
// slot mapping for recurrent state (copy-on-write)
|
|
||||||
slotForSeq map[int]int
|
|
||||||
refCount []int
|
|
||||||
freeSlots []int
|
|
||||||
|
|
||||||
// per-layer conv state buffers (allocated lazily)
|
|
||||||
convCtxs map[int]ml.Context
|
|
||||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
|
||||||
|
|
||||||
// per-layer delta state buffers (allocated lazily)
|
|
||||||
deltaCtxs map[int]ml.Context
|
|
||||||
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
|
||||||
|
|
||||||
// recurrent checkpoints (per slot)
|
|
||||||
checkpointCount int
|
|
||||||
checkpointMinPos int32
|
|
||||||
checkpointInterval int32
|
|
||||||
checkpointCtxSize int
|
|
||||||
checkpoints map[int]*slotCheckpointStore
|
|
||||||
pendingRestore map[int]checkpointRestore
|
|
||||||
curCheckpointPos []int32
|
|
||||||
curCheckpointSlots map[int]int
|
|
||||||
reserveCheckpoints bool
|
|
||||||
checkpointConvCtxs map[int]ml.Context
|
|
||||||
checkpointDeltaCtxs map[int]ml.Context
|
|
||||||
checkpointReserved map[int]struct{}
|
|
||||||
|
|
||||||
// current forward batch (derived in StartForward)
|
|
||||||
curSeqs []int
|
|
||||||
curSlots []int
|
|
||||||
curSlotsInput ml.Tensor
|
|
||||||
curSeqTokens int
|
|
||||||
|
|
||||||
// track if EnsureWritable has been called for this forward pass
|
|
||||||
writableEnsured bool
|
|
||||||
writableError error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHybridCache(
|
func NewHybridCache(
|
||||||
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
||||||
convDim, convChannels, deltaStateSize int,
|
convDim, convChannels, deltaStateSize int,
|
||||||
) *HybridCache {
|
) *HybridCache {
|
||||||
return &HybridCache{
|
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||||
kv: kvcache.NewCausalCache(shift),
|
Shift: shift,
|
||||||
convDim: convDim,
|
ConvDim: convDim,
|
||||||
convChannels: convChannels,
|
ConvChannels: convChannels,
|
||||||
deltaStateSize: deltaStateSize,
|
RecurrentStateSize: deltaStateSize,
|
||||||
slotForSeq: make(map[int]int),
|
CheckpointLogPrefix: "qwen3next",
|
||||||
convCtxs: make(map[int]ml.Context),
|
})
|
||||||
convStates: make(map[int]ml.Tensor),
|
return &HybridCache{Recurrent: base}
|
||||||
deltaCtxs: make(map[int]ml.Context),
|
|
||||||
deltaStates: make(map[int]ml.Tensor),
|
|
||||||
checkpointCount: checkpointCountDefault,
|
|
||||||
checkpointMinPos: checkpointMinPosDefault,
|
|
||||||
checkpointInterval: checkpointIntervalDefault,
|
|
||||||
checkpoints: make(map[int]*slotCheckpointStore),
|
|
||||||
pendingRestore: make(map[int]checkpointRestore),
|
|
||||||
curCheckpointSlots: make(map[int]int),
|
|
||||||
checkpointConvCtxs: make(map[int]ml.Context),
|
|
||||||
checkpointDeltaCtxs: make(map[int]ml.Context),
|
|
||||||
checkpointReserved: make(map[int]struct{}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
// DeltaState returns the delta state for current batch sequences as
|
||||||
c.backend = backend
|
// [headVDim, headVDim*numVHeads, nSeqs].
|
||||||
c.dtype = dtype
|
|
||||||
c.maxSequences = maxSequences
|
|
||||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
|
||||||
c.pendingRestore = make(map[int]checkpointRestore)
|
|
||||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
|
||||||
c.curCheckpointSlots = make(map[int]int)
|
|
||||||
c.checkpointReserved = make(map[int]struct{})
|
|
||||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
|
||||||
if c.checkpointCtxSize < 8 {
|
|
||||||
c.checkpointCtxSize = 8
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize slot allocator
|
|
||||||
c.refCount = make([]int, maxSequences)
|
|
||||||
c.freeSlots = c.freeSlots[:0]
|
|
||||||
for i := maxSequences - 1; i >= 0; i-- {
|
|
||||||
c.freeSlots = append(c.freeSlots, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) Close() {
|
|
||||||
for _, ctx := range c.convCtxs {
|
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
for _, ctx := range c.deltaCtxs {
|
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
for _, ctx := range c.checkpointConvCtxs {
|
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
for _, ctx := range c.checkpointDeltaCtxs {
|
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
c.kv.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
|
||||||
c.kv.SetConfig(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) SetLayer(layer int) {
|
|
||||||
c.kv.SetLayer(layer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
||||||
return c.kv.Get(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
||||||
c.kv.Put(ctx, key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
|
||||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive equal-length sequence layout for recurrent layers
|
|
||||||
seqCounts := make(map[int]int)
|
|
||||||
c.curSeqs = c.curSeqs[:0]
|
|
||||||
for _, s := range batch.Sequences {
|
|
||||||
if _, ok := seqCounts[s]; !ok {
|
|
||||||
c.curSeqs = append(c.curSeqs, s)
|
|
||||||
}
|
|
||||||
seqCounts[s]++
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(c.curSeqs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nTokens := len(batch.Sequences)
|
|
||||||
nSeqs := len(c.curSeqs)
|
|
||||||
want := nTokens / nSeqs
|
|
||||||
for _, s := range c.curSeqs {
|
|
||||||
if seqCounts[s] != want {
|
|
||||||
return kvcache.ErrNotSupported
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.curSeqTokens = want
|
|
||||||
|
|
||||||
// When reserving memory for estimation, use fake slot assignments
|
|
||||||
if reserve {
|
|
||||||
c.curSlots = c.curSlots[:0]
|
|
||||||
slots := make([]int32, nSeqs)
|
|
||||||
for i := range nSeqs {
|
|
||||||
c.curSlots = append(c.curSlots, i)
|
|
||||||
slots[i] = int32(i)
|
|
||||||
}
|
|
||||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
||||||
c.reserveCheckpoints = true
|
|
||||||
c.planCheckpoints(batch)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure slots exist for sequences in this batch
|
|
||||||
c.curSlots = c.curSlots[:0]
|
|
||||||
var newSlots []int
|
|
||||||
for _, s := range c.curSeqs {
|
|
||||||
slot, ok := c.slotForSeq[s]
|
|
||||||
if !ok {
|
|
||||||
var err error
|
|
||||||
slot, err = c.allocSlot()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.slotForSeq[s] = slot
|
|
||||||
c.refCount[slot] = 1
|
|
||||||
newSlots = append(newSlots, slot)
|
|
||||||
}
|
|
||||||
c.curSlots = append(c.curSlots, slot)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero state for newly allocated slots
|
|
||||||
if len(newSlots) > 0 {
|
|
||||||
c.zeroSlots(ctx, newSlots)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a tensor for the current slots
|
|
||||||
slots := make([]int32, len(c.curSlots))
|
|
||||||
for i, v := range c.curSlots {
|
|
||||||
slots[i] = int32(v)
|
|
||||||
}
|
|
||||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
||||||
|
|
||||||
// Reset writable state for new forward pass
|
|
||||||
c.writableEnsured = false
|
|
||||||
c.writableError = nil
|
|
||||||
c.reserveCheckpoints = false
|
|
||||||
c.planCheckpoints(batch)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) allocSlot() (int, error) {
|
|
||||||
if len(c.freeSlots) == 0 {
|
|
||||||
return 0, kvcache.ErrKvCacheFull
|
|
||||||
}
|
|
||||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
|
||||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
|
||||||
return slot, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) freeSlot(slot int) {
|
|
||||||
if slot >= 0 && slot < c.maxSequences {
|
|
||||||
c.freeSlots = append(c.freeSlots, slot)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
|
||||||
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
|
||||||
if len(slots) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
inputCtx := ctx.Input()
|
|
||||||
|
|
||||||
slotIndices := make([]int32, len(slots))
|
|
||||||
for i, s := range slots {
|
|
||||||
slotIndices[i] = int32(s)
|
|
||||||
}
|
|
||||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
|
||||||
|
|
||||||
// Zero conv states
|
|
||||||
if len(c.convStates) > 0 {
|
|
||||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
|
||||||
for _, buf := range c.convStates {
|
|
||||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero delta states
|
|
||||||
if len(c.deltaStates) > 0 {
|
|
||||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
|
||||||
for _, buf := range c.deltaStates {
|
|
||||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
|
||||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
|
||||||
for i, seq := range c.curSeqs {
|
|
||||||
slot, ok := c.slotForSeq[seq]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if slot < 0 || slot >= len(c.refCount) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.refCount[slot] <= 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
newSlot, err := c.allocSlot()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.refCount[slot]--
|
|
||||||
c.refCount[newSlot] = 1
|
|
||||||
c.slotForSeq[seq] = newSlot
|
|
||||||
c.curSlots[i] = newSlot
|
|
||||||
|
|
||||||
c.copyRecurrentState(ctx, slot, newSlot)
|
|
||||||
c.copyCheckpoints(ctx, slot, newSlot)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rebuild current slots tensor
|
|
||||||
slots := make([]int32, len(c.curSlots))
|
|
||||||
for i, v := range c.curSlots {
|
|
||||||
slots[i] = int32(v)
|
|
||||||
}
|
|
||||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
|
||||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
|
||||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
|
||||||
|
|
||||||
for _, buf := range c.convStates {
|
|
||||||
rows := buf.Rows(ctx, src)
|
|
||||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
|
||||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, buf := range c.deltaStates {
|
|
||||||
rows := buf.Rows(ctx, src)
|
|
||||||
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
|
||||||
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
|
||||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
|
||||||
|
|
||||||
// Copy-on-write for recurrent state
|
|
||||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
|
||||||
if c.validSlot(dstSlot) {
|
|
||||||
c.refCount[dstSlot]--
|
|
||||||
if c.refCount[dstSlot] <= 0 {
|
|
||||||
c.refCount[dstSlot] = 0
|
|
||||||
c.freeSlot(dstSlot)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(c.slotForSeq, dstSeq)
|
|
||||||
}
|
|
||||||
|
|
||||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.validSlot(srcSlot) {
|
|
||||||
c.slotForSeq[dstSeq] = srcSlot
|
|
||||||
c.refCount[srcSlot]++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
|
||||||
if !c.kv.CanResume(seq, pos) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if pos == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return c.hasCheckpoint(seq, pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
||||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
|
||||||
return kvcache.ErrNotSupported
|
|
||||||
}
|
|
||||||
|
|
||||||
if beginIndex > 0 {
|
|
||||||
restore, ok := c.pendingRestore[seq]
|
|
||||||
if !ok || restore.pos+1 != beginIndex {
|
|
||||||
return kvcache.ErrNotSupported
|
|
||||||
}
|
|
||||||
if !c.restoreComplete(restore) {
|
|
||||||
return kvcache.ErrNotSupported
|
|
||||||
}
|
|
||||||
// If the recurrent slot is shared, detach it before applying a restore.
|
|
||||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
|
||||||
newSlot, err := c.allocSlot()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ctx := c.backend.NewContext()
|
|
||||||
c.copyRecurrentState(ctx, slot, newSlot)
|
|
||||||
c.copyCheckpoints(ctx, slot, newSlot)
|
|
||||||
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
|
||||||
ctx.Compute()
|
|
||||||
}
|
|
||||||
ctx.Close()
|
|
||||||
|
|
||||||
c.refCount[slot]--
|
|
||||||
c.refCount[newSlot] = 1
|
|
||||||
c.slotForSeq[seq] = newSlot
|
|
||||||
|
|
||||||
restore.slot = newSlot
|
|
||||||
c.pendingRestore[seq] = restore
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if beginIndex > 0 {
|
|
||||||
restore := c.pendingRestore[seq]
|
|
||||||
delete(c.pendingRestore, seq)
|
|
||||||
return c.applyCheckpointRestore(restore)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Removal invalidates recurrent state
|
|
||||||
slot, ok := c.slotForSeq[seq]
|
|
||||||
delete(c.pendingRestore, seq)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.validSlot(slot) {
|
|
||||||
delete(c.slotForSeq, seq)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.refCount[slot]--
|
|
||||||
if c.refCount[slot] <= 0 {
|
|
||||||
c.refCount[slot] = 0
|
|
||||||
c.clearCheckpoints(slot)
|
|
||||||
c.freeSlot(slot)
|
|
||||||
}
|
|
||||||
delete(c.slotForSeq, seq)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) validSlot(slot int) bool {
|
|
||||||
return slot >= 0 && slot < len(c.refCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
|
||||||
return c.curSlotsInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
|
||||||
func (c *HybridCache) contiguousSlots() (int, bool) {
|
|
||||||
if len(c.curSlots) == 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
start := c.curSlots[0]
|
|
||||||
for i, s := range c.curSlots {
|
|
||||||
if s != start+i {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return start, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) seqTokens() int {
|
|
||||||
return c.curSeqTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) numSeqs() int {
|
|
||||||
return len(c.curSeqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
|
||||||
if buf, ok := c.convStates[layer]; ok {
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := c.convCtxs[layer]; !ok {
|
|
||||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
|
||||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
|
||||||
c.convStates[layer] = buf
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
|
||||||
if buf, ok := c.deltaStates[layer]; ok {
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := c.deltaCtxs[layer]; !ok {
|
|
||||||
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurrent delta state must stay in F32.
|
|
||||||
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
|
||||||
c.deltaStates[layer] = buf
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
|
||||||
if !c.writableEnsured {
|
|
||||||
needsWritable := false
|
|
||||||
for _, seq := range c.curSeqs {
|
|
||||||
slot, ok := c.slotForSeq[seq]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
|
||||||
needsWritable = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needsWritable {
|
|
||||||
if err := c.EnsureWritable(ctx); err != nil {
|
|
||||||
c.writableError = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.writableEnsured = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
|
||||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
|
||||||
c.ensureWritableOnce(ctx)
|
|
||||||
|
|
||||||
if c.writableError != nil {
|
|
||||||
return nil, c.writableError
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := c.convBuffer(ctx, layer)
|
|
||||||
cur := buf.Rows(ctx, c.slotsTensor())
|
|
||||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConvState writes a new conv state for current batch sequences.
|
|
||||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
||||||
buf := c.convBuffer(ctx, layer)
|
|
||||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
|
||||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
||||||
if start, ok := c.contiguousSlots(); ok {
|
|
||||||
// Fast path: contiguous slots allow a single view + copy
|
|
||||||
offset := start * buf.Stride(1)
|
|
||||||
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
|
||||||
ctx.Forward(srcF32.Copy(ctx, view))
|
|
||||||
} else {
|
|
||||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
|
||||||
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||||
c.ensureWritableOnce(ctx)
|
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
|
||||||
|
|
||||||
if c.writableError != nil {
|
|
||||||
return nil, c.writableError
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := c.deltaBuffer(ctx, layer)
|
|
||||||
cur := buf.Rows(ctx, c.slotsTensor())
|
|
||||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDeltaState writes a new delta state for current batch sequences.
|
// UpdateDeltaState writes a new delta state for current batch sequences.
|
||||||
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||||
buf := c.deltaBuffer(ctx, layer)
|
c.UpdateRecurrentState(ctx, layer, newState)
|
||||||
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
}
|
||||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
||||||
if start, ok := c.contiguousSlots(); ok {
|
func (c *HybridCache) seqTokens() int {
|
||||||
// Fast path: contiguous slots allow a single view + copy
|
return c.SeqTokens()
|
||||||
offset := start * buf.Stride(1)
|
}
|
||||||
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
|
||||||
ctx.Forward(srcF32.Copy(ctx, view))
|
func (c *HybridCache) numSeqs() int {
|
||||||
} else {
|
return c.NumSeqs()
|
||||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
}
|
||||||
|
|
||||||
|
// Keep qwen3next behavior for partial mid-sequence removals.
|
||||||
|
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||||
|
return kvcache.ErrNotSupported
|
||||||
}
|
}
|
||||||
|
return c.Recurrent.Remove(seq, beginIndex, endIndex)
|
||||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
|
||||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
|
||||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
|
||||||
func (c *HybridCache) Seqs() []int {
|
|
||||||
return slices.Clone(c.curSeqs)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,498 +0,0 @@
|
|||||||
package qwen3next
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
checkpointCountDefault = 32
|
|
||||||
checkpointMinPosDefault = int32(16)
|
|
||||||
checkpointIntervalDefault = int32(1280)
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
|
||||||
// memory usage while preserving prefix reuse for recurrent state.
|
|
||||||
|
|
||||||
type checkpointEntry struct {
|
|
||||||
pos int32
|
|
||||||
conv map[int]ml.Tensor
|
|
||||||
delta map[int]ml.Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
type slotCheckpointStore struct {
|
|
||||||
entries []checkpointEntry
|
|
||||||
size int
|
|
||||||
next int
|
|
||||||
lastPos int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type checkpointRestore struct {
|
|
||||||
slot int
|
|
||||||
idx int
|
|
||||||
pos int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
|
||||||
entries := make([]checkpointEntry, n)
|
|
||||||
for i := range entries {
|
|
||||||
entries[i].pos = -1
|
|
||||||
}
|
|
||||||
return &slotCheckpointStore{
|
|
||||||
entries: entries,
|
|
||||||
lastPos: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slotCheckpointStore) reset() {
|
|
||||||
s.size = 0
|
|
||||||
s.next = 0
|
|
||||||
s.lastPos = -1
|
|
||||||
for i := range s.entries {
|
|
||||||
s.entries[i].pos = -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slotCheckpointStore) record(pos int32) int {
|
|
||||||
if len(s.entries) == 0 {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
idx := s.next
|
|
||||||
s.next = (s.next + 1) % len(s.entries)
|
|
||||||
if s.size < len(s.entries) {
|
|
||||||
s.size++
|
|
||||||
}
|
|
||||||
s.entries[idx].pos = pos
|
|
||||||
s.lastPos = pos
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
|
||||||
bestIdx := -1
|
|
||||||
bestPos := int32(-1)
|
|
||||||
for i := range s.entries {
|
|
||||||
pos := s.entries[i].pos
|
|
||||||
if pos < 0 || pos >= targetPos {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if pos > bestPos {
|
|
||||||
bestPos = pos
|
|
||||||
bestIdx = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if bestIdx < 0 {
|
|
||||||
return -1, -1, false
|
|
||||||
}
|
|
||||||
return bestIdx, bestPos, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
|
||||||
if len(s.entries) == 0 {
|
|
||||||
s.size = 0
|
|
||||||
s.next = 0
|
|
||||||
s.lastPos = -1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
size := 0
|
|
||||||
next := -1
|
|
||||||
minPos := int32(math.MaxInt32)
|
|
||||||
minIdx := 0
|
|
||||||
for i := range s.entries {
|
|
||||||
if s.entries[i].pos > pos {
|
|
||||||
s.entries[i].pos = -1
|
|
||||||
}
|
|
||||||
if s.entries[i].pos >= 0 {
|
|
||||||
size++
|
|
||||||
if s.entries[i].pos < minPos {
|
|
||||||
minPos = s.entries[i].pos
|
|
||||||
minIdx = i
|
|
||||||
}
|
|
||||||
} else if next == -1 {
|
|
||||||
next = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.size = size
|
|
||||||
if size == 0 {
|
|
||||||
s.next = 0
|
|
||||||
s.lastPos = -1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if next != -1 {
|
|
||||||
s.next = next
|
|
||||||
} else {
|
|
||||||
// Full ring: overwrite the oldest checkpoint next.
|
|
||||||
s.next = minIdx
|
|
||||||
}
|
|
||||||
s.lastPos = pos
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
|
||||||
minPos = int32(math.MaxInt32)
|
|
||||||
maxPos = int32(-1)
|
|
||||||
for i := range s.entries {
|
|
||||||
pos := s.entries[i].pos
|
|
||||||
if pos < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
size++
|
|
||||||
if pos < minPos {
|
|
||||||
minPos = pos
|
|
||||||
}
|
|
||||||
if pos > maxPos {
|
|
||||||
maxPos = pos
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if size == 0 {
|
|
||||||
minPos = -1
|
|
||||||
maxPos = -1
|
|
||||||
}
|
|
||||||
return size, minPos, maxPos, s.lastPos
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) planCheckpoints(batch input.Batch) {
|
|
||||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
|
||||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
|
||||||
for k := range c.curCheckpointSlots {
|
|
||||||
delete(c.curCheckpointSlots, k)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
|
||||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
|
||||||
} else {
|
|
||||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
|
||||||
}
|
|
||||||
for i := range c.curCheckpointPos {
|
|
||||||
c.curCheckpointPos[i] = -1
|
|
||||||
}
|
|
||||||
for k := range c.curCheckpointSlots {
|
|
||||||
delete(c.curCheckpointSlots, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
posMax := make(map[int]int32, len(c.curSeqs))
|
|
||||||
for i, seq := range batch.Sequences {
|
|
||||||
pos := batch.Positions[i]
|
|
||||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
|
||||||
posMax[seq] = pos
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, seq := range c.curSeqs {
|
|
||||||
pos, ok := posMax[seq]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if pos < c.checkpointMinPos {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
slot := c.curSlots[i]
|
|
||||||
store := c.checkpointStore(slot)
|
|
||||||
lastPos := store.lastPos
|
|
||||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
|
||||||
c.curCheckpointPos[i] = pos
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
|
|
||||||
store, ok := c.checkpoints[slot]
|
|
||||||
if ok {
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
store = newSlotCheckpointStore(c.checkpointCount)
|
|
||||||
c.checkpoints[slot] = store
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
|
|
||||||
if c.checkpointCount == 0 {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
store := c.checkpointStore(slot)
|
|
||||||
idx := store.record(pos)
|
|
||||||
if idx >= 0 {
|
|
||||||
c.curCheckpointSlots[slot] = idx
|
|
||||||
}
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
|
|
||||||
if pos <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
slot, ok := c.slotForSeq[seq]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
store, ok := c.checkpoints[slot]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, _, ok = store.bestIndex(pos)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
|
||||||
if targetPos <= 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
slot, ok := c.slotForSeq[seq]
|
|
||||||
if !ok {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
store, ok := c.checkpoints[slot]
|
|
||||||
if !ok {
|
|
||||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
idx, pos, ok := store.bestIndex(targetPos)
|
|
||||||
if !ok {
|
|
||||||
size, minPos, maxPos, lastPos := store.window()
|
|
||||||
slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
|
||||||
"min", minPos, "max", maxPos, "last", lastPos)
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
c.pendingRestore[seq] = checkpointRestore{
|
|
||||||
slot: slot,
|
|
||||||
idx: idx,
|
|
||||||
pos: pos,
|
|
||||||
}
|
|
||||||
return pos + 1, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
|
|
||||||
entry, ok := c.restoreEntry(restore)
|
|
||||||
if !ok {
|
|
||||||
return kvcache.ErrNotSupported
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := c.backend.NewContext()
|
|
||||||
defer ctx.Close()
|
|
||||||
|
|
||||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
|
||||||
for layer, src := range entry.conv {
|
|
||||||
buf := c.convBuffer(ctx, layer)
|
|
||||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
|
||||||
}
|
|
||||||
for layer, src := range entry.delta {
|
|
||||||
buf := c.deltaBuffer(ctx, layer)
|
|
||||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(entry.conv) > 0 || len(entry.delta) > 0 {
|
|
||||||
ctx.Compute()
|
|
||||||
}
|
|
||||||
store := c.checkpoints[restore.slot]
|
|
||||||
store.pruneAfter(restore.pos)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
|
|
||||||
_, ok := c.restoreEntry(restore)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
|
||||||
store, ok := c.checkpoints[restore.slot]
|
|
||||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
entry := &store.entries[restore.idx]
|
|
||||||
if entry.pos < 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
if !c.entryComplete(entry) {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
return entry, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
|
|
||||||
for layer := range c.convStates {
|
|
||||||
if entry.conv == nil || entry.conv[layer] == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for layer := range c.deltaStates {
|
|
||||||
if entry.delta == nil || entry.delta[layer] == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) clearCheckpoints(slot int) {
|
|
||||||
if store, ok := c.checkpoints[slot]; ok {
|
|
||||||
store.reset()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
|
||||||
if c.checkpointCount == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
srcStore, ok := c.checkpoints[srcSlot]
|
|
||||||
if !ok || srcStore.size == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dstStore := c.checkpointStore(dstSlot)
|
|
||||||
dstStore.size = srcStore.size
|
|
||||||
dstStore.next = srcStore.next
|
|
||||||
dstStore.lastPos = srcStore.lastPos
|
|
||||||
|
|
||||||
for i := range srcStore.entries {
|
|
||||||
srcEntry := &srcStore.entries[i]
|
|
||||||
dstEntry := &dstStore.entries[i]
|
|
||||||
dstEntry.pos = srcEntry.pos
|
|
||||||
if srcEntry.conv != nil {
|
|
||||||
if dstEntry.conv == nil {
|
|
||||||
dstEntry.conv = make(map[int]ml.Tensor)
|
|
||||||
}
|
|
||||||
for layer, src := range srcEntry.conv {
|
|
||||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
|
||||||
ctx.Forward(src.Copy(ctx, dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if srcEntry.delta != nil {
|
|
||||||
if dstEntry.delta == nil {
|
|
||||||
dstEntry.delta = make(map[int]ml.Tensor)
|
|
||||||
}
|
|
||||||
for layer, src := range srcEntry.delta {
|
|
||||||
dst := c.ensureCheckpointDelta(layer, dstEntry)
|
|
||||||
ctx.Forward(src.Copy(ctx, dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
|
||||||
if c.checkpointCount == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.reserveCheckpoints {
|
|
||||||
c.reserveCheckpointConv(layer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(c.curCheckpointPos) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i, pos := range c.curCheckpointPos {
|
|
||||||
if pos < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
slot := c.curSlots[i]
|
|
||||||
idx := c.checkpointIndexForSlot(slot, pos)
|
|
||||||
if idx < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := &c.checkpoints[slot].entries[idx]
|
|
||||||
dst := c.ensureCheckpointConv(layer, entry)
|
|
||||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
|
||||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
|
||||||
if c.checkpointCount == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.reserveCheckpoints {
|
|
||||||
c.reserveCheckpointDelta(layer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(c.curCheckpointPos) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i, pos := range c.curCheckpointPos {
|
|
||||||
if pos < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
slot := c.curSlots[i]
|
|
||||||
idx := c.checkpointIndexForSlot(slot, pos)
|
|
||||||
if idx < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := &c.checkpoints[slot].entries[idx]
|
|
||||||
dst := c.ensureCheckpointDelta(layer, entry)
|
|
||||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
|
||||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
|
||||||
if entry.conv == nil {
|
|
||||||
entry.conv = make(map[int]ml.Tensor)
|
|
||||||
}
|
|
||||||
if t, ok := entry.conv[layer]; ok {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
ctx, ok := c.checkpointConvCtxs[layer]
|
|
||||||
if !ok {
|
|
||||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
|
||||||
c.checkpointConvCtxs[layer] = ctx
|
|
||||||
}
|
|
||||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
|
||||||
entry.conv[layer] = t
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
|
|
||||||
if entry.delta == nil {
|
|
||||||
entry.delta = make(map[int]ml.Tensor)
|
|
||||||
}
|
|
||||||
if t, ok := entry.delta[layer]; ok {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
ctx, ok := c.checkpointDeltaCtxs[layer]
|
|
||||||
if !ok {
|
|
||||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
|
||||||
c.checkpointDeltaCtxs[layer] = ctx
|
|
||||||
}
|
|
||||||
t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
|
|
||||||
entry.delta[layer] = t
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) reserveCheckpointConv(layer int) {
|
|
||||||
key := checkpointReserveKey(layer, 0)
|
|
||||||
if _, ok := c.checkpointReserved[key]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for slot := range c.maxSequences {
|
|
||||||
store := c.checkpointStore(slot)
|
|
||||||
for i := range store.entries {
|
|
||||||
entry := &store.entries[i]
|
|
||||||
_ = c.ensureCheckpointConv(layer, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.checkpointReserved[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HybridCache) reserveCheckpointDelta(layer int) {
|
|
||||||
key := checkpointReserveKey(layer, 1)
|
|
||||||
if _, ok := c.checkpointReserved[key]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for slot := range c.maxSequences {
|
|
||||||
store := c.checkpointStore(slot)
|
|
||||||
for i := range store.entries {
|
|
||||||
entry := &store.entries[i]
|
|
||||||
_ = c.ensureCheckpointDelta(layer, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.checkpointReserved[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkpointReserveKey(layer int, kind int) int {
|
|
||||||
return layer*2 + kind
|
|
||||||
}
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
package qwen3next
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"math"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestBackend(tb testing.TB) ml.Backend {
|
|
||||||
tb.Helper()
|
|
||||||
|
|
||||||
f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
|
|
||||||
if err != nil {
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
|
||||||
_ = f.Close()
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := f.Close(); err != nil {
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
|
||||||
if err != nil {
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
tb.Cleanup(func() {
|
|
||||||
b.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
|
||||||
store := newSlotCheckpointStore(2)
|
|
||||||
store.record(10)
|
|
||||||
store.record(20)
|
|
||||||
|
|
||||||
_, pos, ok := store.bestIndex(15)
|
|
||||||
if !ok || pos != 10 {
|
|
||||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
store.record(30) // overwrite oldest (10)
|
|
||||||
|
|
||||||
if _, _, ok := store.bestIndex(15); ok {
|
|
||||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, pos, ok = store.bestIndex(40)
|
|
||||||
if !ok || pos != 30 {
|
|
||||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHybridCachePrepareRestore(t *testing.T) {
|
|
||||||
cache := NewHybridCache(nil, 1, 1, 1)
|
|
||||||
cache.checkpointCount = 3
|
|
||||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
||||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
||||||
|
|
||||||
cache.slotForSeq[1] = 0
|
|
||||||
store := cache.checkpointStore(0)
|
|
||||||
store.record(5)
|
|
||||||
store.record(9)
|
|
||||||
store.record(15)
|
|
||||||
|
|
||||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected restore ok")
|
|
||||||
}
|
|
||||||
if restorePos != 10 {
|
|
||||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
|
||||||
}
|
|
||||||
rest, ok := cache.pendingRestore[1]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected pending restore entry")
|
|
||||||
}
|
|
||||||
if rest.pos != 9 {
|
|
||||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
|
||||||
store := newSlotCheckpointStore(3)
|
|
||||||
store.record(10)
|
|
||||||
store.record(20)
|
|
||||||
store.record(30)
|
|
||||||
|
|
||||||
store.pruneAfter(20)
|
|
||||||
|
|
||||||
if store.lastPos != 20 {
|
|
||||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, pos, ok := store.bestIndex(25)
|
|
||||||
if !ok || pos != 20 {
|
|
||||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, pos, ok = store.bestIndex(35)
|
|
||||||
if !ok || pos != 20 {
|
|
||||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
|
|
||||||
backend := newTestBackend(t)
|
|
||||||
|
|
||||||
cache := NewHybridCache(nil, 1, 2, 2)
|
|
||||||
cache.Init(backend, ml.DTypeF16, 2, 8, 2)
|
|
||||||
|
|
||||||
cache.slotForSeq[1] = 0
|
|
||||||
cache.slotForSeq[2] = 0
|
|
||||||
cache.refCount[0] = 2
|
|
||||||
cache.refCount[1] = 0
|
|
||||||
cache.freeSlots = []int{1}
|
|
||||||
|
|
||||||
store := cache.checkpointStore(0)
|
|
||||||
idx := store.record(9)
|
|
||||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
|
||||||
|
|
||||||
if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
|
|
||||||
t.Fatalf("Remove failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cache.slotForSeq[1] == cache.slotForSeq[2] {
|
|
||||||
t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
|
|
||||||
}
|
|
||||||
if cache.slotForSeq[1] != 1 {
|
|
||||||
t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
|
|
||||||
}
|
|
||||||
if cache.slotForSeq[2] != 0 {
|
|
||||||
t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
|
|
||||||
}
|
|
||||||
if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
|
|
||||||
t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
|
|
||||||
}
|
|
||||||
if _, ok := cache.pendingRestore[1]; ok {
|
|
||||||
t.Fatalf("expected pending restore to be cleared")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
|
||||||
cache := NewHybridCache(nil, 1, 2, 2)
|
|
||||||
cache.checkpointCount = 3
|
|
||||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
||||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
||||||
|
|
||||||
cache.slotForSeq[1] = 0
|
|
||||||
cache.refCount = []int{1}
|
|
||||||
cache.freeSlots = nil
|
|
||||||
|
|
||||||
// Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
|
|
||||||
cache.convStates[0] = nil // placeholder to indicate layer 0 exists
|
|
||||||
cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
|
|
||||||
|
|
||||||
store := cache.checkpointStore(0)
|
|
||||||
idx := store.record(9)
|
|
||||||
entry := &store.entries[idx]
|
|
||||||
// Only set conv checkpoint, not delta - making it incomplete
|
|
||||||
entry.conv = map[int]ml.Tensor{0: nil}
|
|
||||||
// entry.delta is not set, so checkpoint is incomplete
|
|
||||||
|
|
||||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
|
||||||
|
|
||||||
err := cache.Remove(1, 10, math.MaxInt32)
|
|
||||||
if !errors.Is(err, kvcache.ErrNotSupported) {
|
|
||||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
|
||||||
cache := NewHybridCache(nil, 1, 2, 2)
|
|
||||||
cache.checkpointCount = 3
|
|
||||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
||||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
||||||
|
|
||||||
cache.slotForSeq[1] = 0
|
|
||||||
cache.refCount = []int{1}
|
|
||||||
cache.freeSlots = nil
|
|
||||||
|
|
||||||
// Don't set convStates/deltaStates - with no layers to check,
|
|
||||||
// entryComplete will return true as long as entry.pos >= 0
|
|
||||||
|
|
||||||
store := cache.checkpointStore(0)
|
|
||||||
idx := store.record(9)
|
|
||||||
|
|
||||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
|
||||||
|
|
||||||
// Test that restoreComplete returns true when no layers need checkpoints
|
|
||||||
restore := cache.pendingRestore[1]
|
|
||||||
if !cache.restoreComplete(restore) {
|
|
||||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
|
||||||
// Test that ring buffer wrap-around reuses entries without clearing maps.
|
|
||||||
store := newSlotCheckpointStore(3)
|
|
||||||
|
|
||||||
// Fill the buffer
|
|
||||||
store.record(10)
|
|
||||||
store.record(20)
|
|
||||||
store.record(30)
|
|
||||||
|
|
||||||
// Create fake tensor data in the first entry's maps
|
|
||||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
|
||||||
store.entries[0].conv[0] = nil // Simulated tensor reference
|
|
||||||
store.entries[0].delta = make(map[int]ml.Tensor)
|
|
||||||
store.entries[0].delta[0] = nil // Simulated tensor reference
|
|
||||||
|
|
||||||
// Record another entry, which should wrap around and overwrite entry 0
|
|
||||||
store.record(40)
|
|
||||||
|
|
||||||
// Verify the maps are still present (we reuse tensors)
|
|
||||||
if store.entries[0].conv == nil {
|
|
||||||
t.Fatalf("expected conv map to be preserved on reuse")
|
|
||||||
}
|
|
||||||
if store.entries[0].delta == nil {
|
|
||||||
t.Fatalf("expected delta map to be preserved on reuse")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the new position was recorded
|
|
||||||
if store.entries[0].pos != 40 {
|
|
||||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
|
||||||
// Test behavior when buffer is exactly at capacity
|
|
||||||
store := newSlotCheckpointStore(2)
|
|
||||||
|
|
||||||
idx1 := store.record(10)
|
|
||||||
idx2 := store.record(20)
|
|
||||||
|
|
||||||
if idx1 != 0 || idx2 != 1 {
|
|
||||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.size != 2 {
|
|
||||||
t.Fatalf("expected size 2, got %d", store.size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify both checkpoints are accessible
|
|
||||||
_, pos1, ok1 := store.bestIndex(15)
|
|
||||||
_, pos2, ok2 := store.bestIndex(25)
|
|
||||||
|
|
||||||
if !ok1 || pos1 != 10 {
|
|
||||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
|
||||||
}
|
|
||||||
if !ok2 || pos2 != 20 {
|
|
||||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
|
||||||
// Test behavior with zero-size buffer
|
|
||||||
store := newSlotCheckpointStore(0)
|
|
||||||
|
|
||||||
idx := store.record(10)
|
|
||||||
if idx != -1 {
|
|
||||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, ok := store.bestIndex(15)
|
|
||||||
if ok {
|
|
||||||
t.Fatalf("expected no checkpoint for empty buffer")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
|
||||||
// Test pruning that removes all checkpoints
|
|
||||||
store := newSlotCheckpointStore(3)
|
|
||||||
store.record(10)
|
|
||||||
store.record(20)
|
|
||||||
store.record(30)
|
|
||||||
|
|
||||||
// Prune everything by setting threshold below all positions
|
|
||||||
store.pruneAfter(5)
|
|
||||||
|
|
||||||
if store.size != 0 {
|
|
||||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
|
||||||
}
|
|
||||||
// When all checkpoints are pruned, lastPos is reset to -1
|
|
||||||
if store.lastPos != -1 {
|
|
||||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, ok := store.bestIndex(100)
|
|
||||||
if ok {
|
|
||||||
t.Fatalf("expected no checkpoint after pruning all")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -37,10 +37,12 @@ type GatedDeltaNet struct {
|
|||||||
// Optimized path: pre-split QKV and gate
|
// Optimized path: pre-split QKV and gate
|
||||||
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
|
||||||
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
|
||||||
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
|
SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha (legacy qwen3next)
|
||||||
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
|
|
||||||
@@ -96,7 +98,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
headVDim := opts.ssmDInner / numVHeads
|
headVDim := opts.ssmDInner / numVHeads
|
||||||
convKernelSize := opts.convKernelSize
|
convKernelSize := opts.convKernelSize
|
||||||
|
|
||||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
|
||||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||||
|
|
||||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
@@ -106,24 +107,52 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
baNewDim := 2 * numVHeads / numKHeads
|
var beta ml.Tensor
|
||||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
var alpha ml.Tensor
|
||||||
|
switch {
|
||||||
|
case gdn.SSMBetaAlpha != nil:
|
||||||
|
// Legacy qwen3next path: in_proj_ba packs beta/alpha grouped by K-head.
|
||||||
|
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||||
|
baNewDim := 2 * numVHeads / numKHeads
|
||||||
|
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||||
|
|
||||||
// Split beta and alpha
|
betaSize := numVHeads / numKHeads
|
||||||
betaSize := numVHeads / numKHeads
|
alphaSize := numVHeads / numKHeads
|
||||||
alphaSize := numVHeads / numKHeads
|
|
||||||
|
|
||||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||||
|
|
||||||
// Reshape to merge head dimensions
|
// Keep beta layout consistent with qwen35.
|
||||||
beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
|
// [1, numVHeads, nSeqTokens, nSeqs]
|
||||||
alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
|
||||||
|
case gdn.SSMBeta != nil && gdn.SSMAlpha != nil:
|
||||||
|
// qwen35 path: beta/alpha are separate projections.
|
||||||
|
beta = gdn.SSMBeta.Forward(ctx, hiddenStates).Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
alpha = gdn.SSMAlpha.Forward(ctx, hiddenStates).Reshape(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||||
|
}
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
alphaSoftplus := alphaBiased.Softplus(ctx)
|
alphaSoftplus := alphaBiased.Softplus(ctx)
|
||||||
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
|
||||||
|
gate = gate.Reshape(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||||
|
|
||||||
// Get conv state from cache
|
// Get conv state from cache
|
||||||
@@ -172,16 +201,20 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
|
|
||||||
// Repeat interleave Q and K if numKHeads != numVHeads
|
// Repeat interleave Q and K if numKHeads != numVHeads
|
||||||
if numKHeads != numVHeads {
|
if numKHeads != numVHeads {
|
||||||
repeatFactor := numVHeads / numKHeads
|
if opts.vHeadReordered {
|
||||||
|
qConv = qConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
kConv = kConv.Repeat4D(ctx, headKDim, numVHeads, nSeqTokens, nSeqs)
|
||||||
|
} else {
|
||||||
|
repeatFactor := numVHeads / numKHeads
|
||||||
|
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||||
|
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
||||||
|
|
||||||
qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||||
kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
|
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
||||||
|
|
||||||
qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||||
kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
|
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
||||||
|
}
|
||||||
qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
|
||||||
kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Choose computation mode based on sequence length
|
// Choose computation mode based on sequence length
|
||||||
@@ -189,7 +222,9 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
if nSeqTokens == 1 {
|
if nSeqTokens == 1 {
|
||||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||||
} else {
|
} else {
|
||||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
if opts.masks == nil {
|
||||||
|
opts.masks = createMasks(ctx)
|
||||||
|
}
|
||||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,9 +345,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||||
gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
|
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
|
||||||
|
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||||
beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||||
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||||
|
|
||||||
// Compute padding
|
// Compute padding
|
||||||
@@ -324,7 +359,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
q = q.Pad(ctx, 0, pad, 0, 0)
|
q = q.Pad(ctx, 0, pad, 0, 0)
|
||||||
k = k.Pad(ctx, 0, pad, 0, 0)
|
k = k.Pad(ctx, 0, pad, 0, 0)
|
||||||
v = v.Pad(ctx, 0, pad, 0, 0)
|
v = v.Pad(ctx, 0, pad, 0, 0)
|
||||||
gate = gate.Pad(ctx, pad, 0, 0, 0)
|
gate = gate.Pad(ctx, 0, pad, 0, 0)
|
||||||
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
beta = beta.Pad(ctx, 0, pad, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,10 +379,12 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||||
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
|
||||||
|
|
||||||
gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
// Reshape gate and cumsum over chunk axis.
|
||||||
|
// [1, chunkSize, nChunks, H*nSeqs] -> transpose -> [chunkSize, 1, nChunks, H*nSeqs]
|
||||||
|
gate = gate.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||||
|
|
||||||
// g_cumsum = cumsum(gate)
|
// g_cumsum = cumsum(gate)
|
||||||
gCumsum := gate.CumSum(ctx)
|
gCumsum := gate.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs).CumSum(ctx)
|
||||||
|
|
||||||
// Compute decay mask
|
// Compute decay mask
|
||||||
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
|
||||||
@@ -411,60 +448,64 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||||
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
|
||||||
// Process chunks and update state
|
// Process chunks and update state.
|
||||||
var coreAttnOut ml.Tensor
|
// Keep a transposed view of v and recurrent state across chunks so the
|
||||||
newState := state
|
// chunk loop does not need extra transpose+contiguous nodes.
|
||||||
|
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||||
|
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||||
|
|
||||||
for chunk := range nChunks {
|
for chunk := range nChunks {
|
||||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
||||||
|
|
||||||
// state^T - permute is needed but Contiguous creates a copy
|
// v'_t = k_cumdecay @ state_t
|
||||||
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
vTPrime := kCumdecayChunk.Mulmat(ctx, stateT)
|
||||||
|
|
||||||
// v_prime = k_cumdecay @ state
|
// v_t_new = v_t - v'_t
|
||||||
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
vTNewChunk := vTChunk.Sub(ctx, vTPrime)
|
||||||
|
|
||||||
// v_new = v - v_prime
|
|
||||||
vNew := vChunk.Sub(ctx, vPrime)
|
|
||||||
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
// attn_inter = (q * g_exp) @ state
|
// attn_inter = (q * g_exp) @ state
|
||||||
qGExp := qChunk.Mul(ctx, gExpChunk)
|
qGExp := qChunk.Mul(ctx, gExpChunk)
|
||||||
attnInter := stateT.Mulmat(ctx, qGExp)
|
attnInter := stateT.Mulmat(ctx, qGExp)
|
||||||
|
|
||||||
// core_attn_out = attn_inter + attn @ v_new
|
// core_attn_out = attn_inter + attn @ v_new
|
||||||
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||||
|
|
||||||
if coreAttnOut == nil {
|
v = v.SetInplace(
|
||||||
coreAttnOut = coreAttnOutChunk
|
ctx,
|
||||||
} else {
|
coreAttnOutChunk,
|
||||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
v.Stride(1),
|
||||||
}
|
v.Stride(2),
|
||||||
|
v.Stride(3),
|
||||||
|
chunk*v.Stride(2),
|
||||||
|
)
|
||||||
|
|
||||||
// Update state for next chunk
|
// Update state for next chunk
|
||||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
// kgdmulvnew = key_gdiff_t @ v_new_t
|
||||||
|
kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk)
|
||||||
|
|
||||||
// state = state * g_last + kgdmulvnew
|
// stateT = stateT * g_last + kgdmulvnew
|
||||||
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
stateT = stateT.Mul(ctx, gExpLastChunk)
|
||||||
newState = newState.Mul(ctx, gExpLastReshaped)
|
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||||
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final reshape
|
// Final reshape
|
||||||
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||||
|
|
||||||
// Slice to remove padding
|
// Slice to remove padding
|
||||||
if pad > 0 {
|
if pad > 0 {
|
||||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs]
|
||||||
|
newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||||
|
|
||||||
// Update delta state in cache
|
// Update delta state in cache
|
||||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package qwen3next
|
package qwen3next
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"image"
|
||||||
"math"
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@@ -11,6 +14,7 @@ import (
|
|||||||
"github.com/ollama/ollama/ml/nn/rope"
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
|
"github.com/ollama/ollama/model/models/qwen3vl"
|
||||||
"github.com/ollama/ollama/tokenizer"
|
"github.com/ollama/ollama/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,10 +45,15 @@ type Options struct {
|
|||||||
ssmNGroup int // num_k_heads
|
ssmNGroup int // num_k_heads
|
||||||
ssmDtRank int // num_v_heads
|
ssmDtRank int // num_v_heads
|
||||||
convKernelSize int // SSM conv kernel size
|
convKernelSize int // SSM conv kernel size
|
||||||
|
vHeadReordered bool
|
||||||
|
|
||||||
// Per-layer type from GGUF metadata
|
// Per-layer type from GGUF metadata
|
||||||
isRecurrent []bool
|
isRecurrent []bool
|
||||||
|
|
||||||
|
// RoPE mode config (used by qwen35/qwen35moe)
|
||||||
|
mropeSections []int
|
||||||
|
mropeInterleaved bool
|
||||||
|
|
||||||
// Pre-computed masks for chunked attention (created once per forward pass)
|
// Pre-computed masks for chunked attention (created once per forward pass)
|
||||||
masks *Masks
|
masks *Masks
|
||||||
}
|
}
|
||||||
@@ -54,7 +63,17 @@ func (o Options) headDim() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
var opts []func(*rope.Options)
|
||||||
|
if len(o.mropeSections) > 0 {
|
||||||
|
if o.mropeInterleaved {
|
||||||
|
opts = append(opts, rope.WithInterleaveMRoPE(o.mropeSections))
|
||||||
|
} else {
|
||||||
|
opts = append(opts, rope.WithMRoPE(o.mropeSections))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
opts = append(opts, rope.WithTypeNeoX())
|
||||||
|
}
|
||||||
|
|
||||||
if o.ropeType == "yarn" {
|
if o.ropeType == "yarn" {
|
||||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
@@ -214,20 +233,190 @@ type Model struct {
|
|||||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
|
Vision *qwen3vl.VisionModel `gguf:"v"`
|
||||||
|
|
||||||
|
ImageProcessor *qwen3vl.ImageProcessor
|
||||||
|
|
||||||
*Options
|
*Options
|
||||||
|
|
||||||
|
positionCache []int32
|
||||||
|
imageToken int32
|
||||||
|
visionStart int32
|
||||||
|
visionEnd int32
|
||||||
|
spatialMergeSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) mapPosition(id int32) int32 {
|
||||||
|
if id < int32(len(m.positionCache)) {
|
||||||
|
return m.positionCache[id]
|
||||||
|
}
|
||||||
|
if len(m.positionCache) > 0 {
|
||||||
|
return id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildPositions(ctx ml.Context, batch input.Batch) ml.Tensor {
|
||||||
|
if len(m.mropeSections) == 0 {
|
||||||
|
return ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml MRoPE expects [time, height, width, extra] for each token.
|
||||||
|
positionSlice := [][]int32{
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
make([]int32, len(batch.Positions)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, id := range batch.Positions {
|
||||||
|
p := m.mapPosition(id)
|
||||||
|
positionSlice[0][i] = p
|
||||||
|
positionSlice[1][i] = p
|
||||||
|
positionSlice[2][i] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Vision != nil {
|
||||||
|
for _, mi := range batch.Multimodal {
|
||||||
|
grid, ok := mi.Multimodal[0].Data.(*qwen3vl.Grid)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
w := max(1, grid.Width/int(m.spatialMergeSize))
|
||||||
|
for i := range mi.Multimodal[0].Tensor.Dim(1) {
|
||||||
|
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||||
|
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||||
|
if m.Vision == nil || m.ImageProcessor == nil || len(m.Vision.Layers) == 0 {
|
||||||
|
return nil, model.ErrNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelValues, grid, err := m.ImageProcessor.ProcessImage(ctx, img)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
visionOutputs, deepstackVisualEmbeds := m.Vision.Forward(ctx, pixelValues, grid)
|
||||||
|
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
|
||||||
|
for i := range deepstackVisualEmbeds {
|
||||||
|
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
|
||||||
|
}
|
||||||
|
|
||||||
|
return mm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
|
m.positionCache = m.positionCache[:0]
|
||||||
|
var result []*input.Input
|
||||||
|
appendInput := func(inp *input.Input, position int32) {
|
||||||
|
result = append(result, inp)
|
||||||
|
m.positionCache = append(m.positionCache, position)
|
||||||
|
}
|
||||||
|
|
||||||
|
var p int32
|
||||||
|
for _, inp := range inputs {
|
||||||
|
if inp.Multimodal == nil {
|
||||||
|
appendInput(inp, p)
|
||||||
|
p++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
grid := inp.Multimodal[0].Data.(*qwen3vl.Grid)
|
||||||
|
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
|
appendInput(&input.Input{
|
||||||
|
Token: m.visionStart,
|
||||||
|
SameBatch: tokensPerGrid + 1,
|
||||||
|
}, p)
|
||||||
|
p++
|
||||||
|
|
||||||
|
appendInput(&input.Input{
|
||||||
|
Token: m.imageToken,
|
||||||
|
Multimodal: inp.Multimodal,
|
||||||
|
MultimodalHash: inp.MultimodalHash,
|
||||||
|
}, p)
|
||||||
|
|
||||||
|
for range tokensPerGrid - 1 {
|
||||||
|
appendInput(&input.Input{
|
||||||
|
Token: m.imageToken,
|
||||||
|
}, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
gridSpan := max(grid.Width/int(m.spatialMergeSize), grid.Height/int(m.spatialMergeSize))
|
||||||
|
p = p + int32(gridSpan)
|
||||||
|
appendInput(&input.Input{
|
||||||
|
Token: m.visionEnd,
|
||||||
|
}, p)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
positions := m.buildPositions(ctx, batch)
|
||||||
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
if len(batch.Multimodal) > 0 {
|
||||||
|
hiddenStates = hiddenStates.Duplicate(ctx)
|
||||||
|
|
||||||
|
var deepstackVisualEmbeds []ml.Tensor
|
||||||
|
for _, mi := range batch.Multimodal {
|
||||||
|
visionOutputs := mi.Multimodal[0].Tensor
|
||||||
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
|
||||||
|
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
|
||||||
|
}
|
||||||
|
for i, mm := range mi.Multimodal[1:] {
|
||||||
|
if deepstackVisualEmbeds[i] == nil {
|
||||||
|
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||||
|
}
|
||||||
|
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := m.Cache.(*HybridCache)
|
||||||
|
m.Options.masks = nil
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
cache.SetLayer(i)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if i < len(deepstackVisualEmbeds) {
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
cache := m.Cache.(*HybridCache)
|
cache := m.Cache.(*HybridCache)
|
||||||
|
|
||||||
// Create masks once per forward pass
|
// Masks are allocated lazily only for chunked recurrent prefill.
|
||||||
m.Options.masks = createMasks(ctx)
|
m.Options.masks = nil
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
cache.SetLayer(i)
|
cache.SetLayer(i)
|
||||||
@@ -248,11 +437,116 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Validate() error {
|
||||||
|
if m.Options == nil {
|
||||||
|
return fmt.Errorf("qwen3next: missing model options")
|
||||||
|
}
|
||||||
|
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||||
|
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if !m.Options.isRecurrent[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||||
|
if !ok || gdn == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
m.positionCache = nil
|
||||||
|
if len(m.mropeSections) > 0 {
|
||||||
|
shift = shift.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
||||||
|
}
|
||||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Model = (*Model)(nil)
|
var (
|
||||||
|
_ model.Model = (*Model)(nil)
|
||||||
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultVHeadReordered(arch string) bool {
|
||||||
|
return arch == "qwen35" || arch == "qwen35moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||||
|
isRecurrent := make([]bool, numLayers)
|
||||||
|
|
||||||
|
hasZero := false
|
||||||
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZero && hasFull {
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility path: older imports store a scalar KV head count and omit
|
||||||
|
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||||
|
interval := int(fullAttentionInterval)
|
||||||
|
if interval == 0 {
|
||||||
|
interval = min(4, numLayers)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||||
|
}
|
||||||
|
if interval > numLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasZero = false
|
||||||
|
hasFull = false
|
||||||
|
for i := range numLayers {
|
||||||
|
isRecurrent[i] = (i+1)%interval != 0
|
||||||
|
if isRecurrent[i] {
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
@@ -264,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var isRecurrent []bool
|
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent = make([]bool, numLayers)
|
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||||
hasZero := false
|
if err != nil {
|
||||||
hasFull := false
|
return nil, err
|
||||||
for i := range numLayers {
|
|
||||||
// If KV head count is 0, it's a recurrent layer
|
|
||||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -303,6 +585,22 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mropeSections := c.Ints("mrope_sections", nil)
|
||||||
|
if len(mropeSections) == 0 {
|
||||||
|
mropeSections = c.Ints("rope.mrope_section", nil)
|
||||||
|
}
|
||||||
|
if len(mropeSections) == 0 {
|
||||||
|
mropeSections = c.Ints("rope.dimension_sections", nil)
|
||||||
|
}
|
||||||
|
if len(mropeSections) > 4 {
|
||||||
|
mropeSections = mropeSections[:4]
|
||||||
|
}
|
||||||
|
|
||||||
|
ropeType := c.String("rope.scaling.type")
|
||||||
|
if ropeType == "" {
|
||||||
|
ropeType = c.String("rope.type")
|
||||||
|
}
|
||||||
|
|
||||||
opts := &Options{
|
opts := &Options{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
@@ -318,7 +616,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
valueLength: int(c.Uint("attention.value_length")),
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeType: c.String("rope.scaling.type"),
|
ropeType: ropeType,
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
@@ -331,10 +629,19 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
|
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
|
for _, section := range mropeSections {
|
||||||
|
if !yield(int(section)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
@@ -353,6 +660,19 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vision *qwen3vl.VisionModel
|
||||||
|
var imageProcessor *qwen3vl.ImageProcessor
|
||||||
|
if c.Uint("vision.block_count", 0) > 0 {
|
||||||
|
vision = qwen3vl.NewVisionModel(c)
|
||||||
|
processor := qwen3vl.NewImageProcessor(c)
|
||||||
|
imageProcessor = &processor
|
||||||
|
}
|
||||||
|
|
||||||
|
spatialMergeSize := c.Uint("vision.spatial_merge_size", 2)
|
||||||
|
if spatialMergeSize == 0 {
|
||||||
|
spatialMergeSize = 2
|
||||||
|
}
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||||
&tokenizer.Vocabulary{
|
&tokenizer.Vocabulary{
|
||||||
@@ -371,8 +691,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
},
|
},
|
||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
Options: opts,
|
Vision: vision,
|
||||||
|
ImageProcessor: imageProcessor,
|
||||||
|
Options: opts,
|
||||||
|
imageToken: int32(c.Uint("image_token_id", 151655)),
|
||||||
|
visionStart: int32(c.Uint("vision_start_token_id", 151652)),
|
||||||
|
visionEnd: int32(c.Uint("vision_end_token_id", 151653)),
|
||||||
|
spatialMergeSize: spatialMergeSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
|
||||||
@@ -380,5 +706,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
model.Register("qwen35", New)
|
||||||
|
model.Register("qwen35moe", New)
|
||||||
model.Register("qwen3next", New)
|
model.Register("qwen3next", New)
|
||||||
}
|
}
|
||||||
|
|||||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, false, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, true, false, true, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, false, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||||
|
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultVHeadReordered(t *testing.T) {
|
||||||
|
if !defaultVHeadReordered("qwen35") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||||
|
}
|
||||||
|
if !defaultVHeadReordered("qwen35moe") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||||
|
}
|
||||||
|
if defaultVHeadReordered("qwen3next") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
101
model/models/qwen3next/model_posttokenize_test.go
Normal file
101
model/models/qwen3next/model_posttokenize_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
"github.com/ollama/ollama/model/models/qwen3vl"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTensor struct {
|
||||||
|
*ggml.Tensor
|
||||||
|
dims []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTensor) Dim(i int) int {
|
||||||
|
return t.dims[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeImageInput(hash uint64, width, height, tokens int) *input.Input {
|
||||||
|
return &input.Input{
|
||||||
|
Multimodal: []input.Multimodal{{
|
||||||
|
Tensor: &fakeTensor{dims: []int{1, tokens, 1, 1}},
|
||||||
|
Data: &qwen3vl.Grid{Width: width, Height: height},
|
||||||
|
}},
|
||||||
|
MultimodalHash: hash,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostTokenizeMultiImageSpans(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
imageToken: 10,
|
||||||
|
visionStart: 11,
|
||||||
|
visionEnd: 12,
|
||||||
|
spatialMergeSize: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := []*input.Input{
|
||||||
|
{Token: 100},
|
||||||
|
makeImageInput(1, 8, 4, 4),
|
||||||
|
makeImageInput(2, 4, 8, 4),
|
||||||
|
{Token: 200},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := m.PostTokenize(inputs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PostTokenize() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []struct {
|
||||||
|
token int32
|
||||||
|
hash uint64
|
||||||
|
sameBatch int
|
||||||
|
hasMM bool
|
||||||
|
}{
|
||||||
|
{token: 100},
|
||||||
|
{token: 11, sameBatch: 5},
|
||||||
|
{token: 10, hash: 1, hasMM: true},
|
||||||
|
{token: 10},
|
||||||
|
{token: 10},
|
||||||
|
{token: 10},
|
||||||
|
{token: 12},
|
||||||
|
{token: 11, sameBatch: 5},
|
||||||
|
{token: 10, hash: 2, hasMM: true},
|
||||||
|
{token: 10},
|
||||||
|
{token: 10},
|
||||||
|
{token: 10},
|
||||||
|
{token: 12},
|
||||||
|
{token: 200},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("len(got) = %d, want %d", len(got), len(want))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range want {
|
||||||
|
if got[i].Token != want[i].token {
|
||||||
|
t.Fatalf("got[%d].Token = %d, want %d", i, got[i].Token, want[i].token)
|
||||||
|
}
|
||||||
|
if got[i].MultimodalHash != want[i].hash {
|
||||||
|
t.Fatalf("got[%d].MultimodalHash = %d, want %d", i, got[i].MultimodalHash, want[i].hash)
|
||||||
|
}
|
||||||
|
if got[i].SameBatch != want[i].sameBatch {
|
||||||
|
t.Fatalf("got[%d].SameBatch = %d, want %d", i, got[i].SameBatch, want[i].sameBatch)
|
||||||
|
}
|
||||||
|
hasMM := len(got[i].Multimodal) > 0
|
||||||
|
if hasMM != want[i].hasMM {
|
||||||
|
t.Fatalf("got[%d].hasMM = %v, want %v", i, hasMM, want[i].hasMM)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wantPositions := []int32{0, 1, 2, 2, 2, 2, 6, 7, 8, 8, 8, 8, 12, 13}
|
||||||
|
if len(m.positionCache) != len(wantPositions) {
|
||||||
|
t.Fatalf("len(positionCache) = %d, want %d", len(m.positionCache), len(wantPositions))
|
||||||
|
}
|
||||||
|
for i := range wantPositions {
|
||||||
|
if m.positionCache[i] != wantPositions[i] {
|
||||||
|
t.Fatalf("positionCache[%d] = %d, want %d", i, m.positionCache[i], wantPositions[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{
|
||||||
|
Operator: &GatedDeltaNet{
|
||||||
|
SSMQKV: &nn.Linear{},
|
||||||
|
SSMQKVGate: &nn.Linear{},
|
||||||
|
SSMBeta: &nn.Linear{},
|
||||||
|
SSMAlpha: &nn.Linear{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Validate() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,8 +24,8 @@ type ImageProcessor struct {
|
|||||||
imageStd []float32
|
imageStd []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
// newImageProcessor creates a new image processor with default values
|
// NewImageProcessor creates a new image processor with default values.
|
||||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
func NewImageProcessor(c fs.Config) ImageProcessor {
|
||||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||||
|
|
||||||
|
|||||||
@@ -56,60 +56,46 @@ var (
|
|||||||
tokenVisionEnd int32 = 151653
|
tokenVisionEnd int32 = 151653
|
||||||
)
|
)
|
||||||
|
|
||||||
type modelInput struct {
|
|
||||||
*input.Input
|
|
||||||
position int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
|
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
m.positionCache = m.positionCache[:0]
|
m.positionCache = m.positionCache[:0]
|
||||||
return slices.Collect(func(yield func(*input.Input) bool) {
|
var result []*input.Input
|
||||||
for i := range inputs {
|
appendInput := func(inp *input.Input, position int32) {
|
||||||
s := []modelInput{{Input: inputs[i]}}
|
result = append(result, inp)
|
||||||
if mm := inputs[i].Multimodal; mm != nil {
|
m.positionCache = append(m.positionCache, position)
|
||||||
t := mm[0].Tensor
|
}
|
||||||
s = slices.Repeat([]modelInput{
|
|
||||||
{
|
|
||||||
position: int32(i + 1),
|
|
||||||
Input: &input.Input{Token: tokenVision},
|
|
||||||
},
|
|
||||||
}, t.Dim(1)+1+1)
|
|
||||||
|
|
||||||
s[0] = modelInput{
|
var p int32
|
||||||
Input: &input.Input{Token: tokenVisionStart},
|
for _, inp := range inputs {
|
||||||
position: int32(i),
|
if inp.Multimodal == nil {
|
||||||
}
|
appendInput(inp, p)
|
||||||
|
p++
|
||||||
s[len(s)-1] = modelInput{
|
continue
|
||||||
Input: &input.Input{Token: tokenVisionEnd},
|
|
||||||
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
s[1] = modelInput{
|
|
||||||
Input: &input.Input{
|
|
||||||
Token: tokenVision,
|
|
||||||
Multimodal: inputs[i].Multimodal,
|
|
||||||
MultimodalHash: inputs[i].MultimodalHash,
|
|
||||||
SameBatch: t.Dim(1),
|
|
||||||
},
|
|
||||||
position: int32(i + 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range s {
|
|
||||||
position := e.position
|
|
||||||
if position == 0 && len(m.positionCache) > 0 {
|
|
||||||
position = m.positionCache[len(m.positionCache)-1] + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
m.positionCache = append(m.positionCache, position)
|
|
||||||
if !yield(e.Input) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}), nil
|
|
||||||
|
grid := inp.Multimodal[0].Data.(*Grid)
|
||||||
|
tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
|
appendInput(&input.Input{Token: tokenVisionStart}, p)
|
||||||
|
p++
|
||||||
|
|
||||||
|
appendInput(&input.Input{
|
||||||
|
Token: tokenVision,
|
||||||
|
Multimodal: inp.Multimodal,
|
||||||
|
MultimodalHash: inp.MultimodalHash,
|
||||||
|
SameBatch: tokensPerGrid,
|
||||||
|
}, p)
|
||||||
|
|
||||||
|
for range tokensPerGrid - 1 {
|
||||||
|
appendInput(&input.Input{Token: tokenVision}, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
p = p + int32(grid.Width/m.spatialMergeSize)
|
||||||
|
appendInput(&input.Input{Token: tokenVisionEnd}, p)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
@@ -143,9 +129,13 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
|
if len(mi.Multimodal[1:]) > len(deepstackVisualEmbeds) {
|
||||||
|
deepstackVisualEmbeds = append(deepstackVisualEmbeds, make([]ml.Tensor, len(mi.Multimodal[1:])-len(deepstackVisualEmbeds))...)
|
||||||
|
}
|
||||||
for i, mm := range mi.Multimodal[1:] {
|
for i, mm := range mi.Multimodal[1:] {
|
||||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
if deepstackVisualEmbeds[i] == nil {
|
||||||
|
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||||
|
}
|
||||||
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: NewVisionModel(c),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: NewImageProcessor(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
||||||
|
|||||||
@@ -238,8 +238,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
return hiddenStates, deepstackStates
|
return hiddenStates, deepstackStates
|
||||||
}
|
}
|
||||||
|
|
||||||
// newVisionModel creates a new instance of the Qwen vision model
|
// NewVisionModel creates a new instance of the Qwen vision model.
|
||||||
func newVisionModel(c fs.Config) *VisionModel {
|
func NewVisionModel(c fs.Config) *VisionModel {
|
||||||
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
|
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
|
||||||
model := &VisionModel{
|
model := &VisionModel{
|
||||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type GLM46Parser struct {
|
type GLM46Parser struct {
|
||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `plan</think>
|
||||||
|
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||||
|
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||||
|
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ func ParserForName(name string) Parser {
|
|||||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
case "qwen3-thinking":
|
case "qwen3-thinking":
|
||||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
case "qwen3.5":
|
||||||
|
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
case "qwen3-coder":
|
case "qwen3-coder":
|
||||||
p = &Qwen3CoderParser{}
|
p = &Qwen3CoderParser{}
|
||||||
case "qwen3-vl-instruct":
|
case "qwen3-vl-instruct":
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
|||||||
{"qwen3-coder"},
|
{"qwen3-coder"},
|
||||||
{"lfm2"},
|
{"lfm2"},
|
||||||
{"lfm2-thinking"},
|
{"lfm2-thinking"},
|
||||||
|
{"qwen3.5"},
|
||||||
{"harmony"},
|
{"harmony"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
|
p.callIndex = 0
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
@@ -204,6 +208,24 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
|||||||
p.maybeThinkingOpenAtBOL = false
|
p.maybeThinkingOpenAtBOL = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
|
||||||
|
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
|
||||||
|
|
||||||
|
// If a tool call starts before </think>, treat that as the end of thinking
|
||||||
|
// for parsing purposes and continue in tool-call mode.
|
||||||
|
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||||
|
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, qwen3EventThinkingContent{content: before})
|
||||||
|
}
|
||||||
|
if after == "" {
|
||||||
|
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen3ParserStateCollectingToolContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
|
||||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||||
if len(thinking) > 0 {
|
if len(thinking) > 0 {
|
||||||
@@ -215,7 +237,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
|||||||
p.state = qwen3ParserStateCollectingContent
|
p.state = qwen3ParserStateCollectingContent
|
||||||
}
|
}
|
||||||
return events, true
|
return events, true
|
||||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
|
||||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||||
|
|||||||
@@ -145,3 +145,174 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
|||||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "Let me think" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if content != "" || thinking != "Let me think" || len(calls) != 0 {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
|
||||||
|
"",
|
||||||
|
"Let me think",
|
||||||
|
0,
|
||||||
|
content,
|
||||||
|
thinking,
|
||||||
|
len(calls),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on second chunk: %v", err)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
content, thinking, calls, err := parser.Add("Hello! How can I help you today?", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "Hello! How can I help you today?" {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Hello! How can I help you today?", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||||
|
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||||
|
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,9 +29,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Qwen3CoderParser struct {
|
type Qwen3CoderParser struct {
|
||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||||
|
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||||
|
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -180,7 +180,22 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
|||||||
return events, false
|
return events, false
|
||||||
}
|
}
|
||||||
case CollectingThinkingContent:
|
case CollectingThinkingContent:
|
||||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
acc := p.buffer.String()
|
||||||
|
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
|
||||||
|
toolOpenIdx := strings.Index(acc, toolOpenTag)
|
||||||
|
|
||||||
|
// If a tool call starts before </think>, treat that as the end of thinking
|
||||||
|
// for parsing purposes and continue in tool-call mode.
|
||||||
|
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||||
|
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, qwenEventThinkingContent{content: before})
|
||||||
|
}
|
||||||
|
p.state = CollectingToolContent
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acc, thinkingCloseTag) {
|
||||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||||
if len(thinking) > 0 {
|
if len(thinking) > 0 {
|
||||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||||
@@ -191,13 +206,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
|||||||
p.state = CollectingContent
|
p.state = CollectingContent
|
||||||
}
|
}
|
||||||
return events, true
|
return events, true
|
||||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
|
||||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||||
|
|
||||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
unambiguous := acc[:ambiguousStart]
|
||||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
ambiguous := acc[ambiguousStart:]
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
p.buffer.WriteString(ambiguous)
|
p.buffer.WriteString(ambiguous)
|
||||||
if len(unambiguous) > 0 {
|
if len(unambiguous) > 0 {
|
||||||
@@ -205,11 +220,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
|||||||
}
|
}
|
||||||
return events, false
|
return events, false
|
||||||
} else {
|
} else {
|
||||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
whitespaceLen := trailingWhitespaceLen(acc)
|
||||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
ambiguousStart := len(acc) - whitespaceLen
|
||||||
|
|
||||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
unambiguous := acc[:ambiguousStart]
|
||||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
ambiguous := acc[ambiguousStart:]
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
p.buffer.WriteString(ambiguous)
|
p.buffer.WriteString(ambiguous)
|
||||||
if len(unambiguous) > 0 {
|
if len(unambiguous) > 0 {
|
||||||
|
|||||||
@@ -98,8 +98,12 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
|||||||
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
||||||
steps: []step{
|
steps: []step{
|
||||||
{
|
{
|
||||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||||
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventThinkingContent{content: "I'm thinking"},
|
||||||
|
qwenEventRawToolCall{raw: "I'm nested tool call"},
|
||||||
|
qwenEventContent{content: "</think>"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -109,8 +113,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
|||||||
{
|
{
|
||||||
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
||||||
wantEvents: []qwenEvent{
|
wantEvents: []qwenEvent{
|
||||||
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
|
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
|
||||||
qwenEventContent{content: "</tool_call>"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -121,8 +124,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
|||||||
{
|
{
|
||||||
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
||||||
wantEvents: []qwenEvent{
|
wantEvents: []qwenEvent{
|
||||||
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
|
qwenEventThinkingContent{content: "I'm thinking"},
|
||||||
qwenEventContent{content: "</tool_call>"},
|
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
|
||||||
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
||||||
qwenEventContent{content: "</think>"},
|
qwenEventContent{content: "</think>"},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package renderers
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -192,21 +193,25 @@ func lfm2RenderToolCalls(calls []api.ToolCall) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
|
func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int) string {
|
||||||
content := lfm2RenderContent(message.Content, r.useImgTags)
|
content := lfm2RenderContent(message.Content, r.useImgTags)
|
||||||
if len(message.Images) == 0 {
|
if len(message.Images) == 0 {
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
// chatPrompt may already have inserted [img] / [img-n] placeholders.
|
|
||||||
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
placeholder := lfm2ImagePlaceholder(r.useImgTags)
|
if r.useImgTags {
|
||||||
for range message.Images {
|
for i := range message.Images {
|
||||||
sb.WriteString(placeholder)
|
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
placeholder := lfm2ImagePlaceholder(false)
|
||||||
|
if strings.Contains(content, placeholder) {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
for range message.Images {
|
||||||
|
sb.WriteString(placeholder)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sb.WriteString(content)
|
sb.WriteString(content)
|
||||||
return sb.String()
|
return sb.String()
|
||||||
@@ -262,6 +267,11 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
|
for i := range startIdx {
|
||||||
|
imageOffset += len(messages[i].Images)
|
||||||
|
}
|
||||||
|
|
||||||
for i := startIdx; i < len(messages); i++ {
|
for i := startIdx; i < len(messages); i++ {
|
||||||
message := messages[i]
|
message := messages[i]
|
||||||
lastMessage := i == len(messages)-1
|
lastMessage := i == len(messages)-1
|
||||||
@@ -271,7 +281,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
|||||||
sb.WriteString(message.Role)
|
sb.WriteString(message.Role)
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
|
|
||||||
content := r.renderMessageContent(message)
|
content := r.renderMessageContent(message, imageOffset)
|
||||||
|
imageOffset += len(message.Images)
|
||||||
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
|
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
|
||||||
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
|
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
|
||||||
content = strings.TrimSpace(content[idx+len("</think>"):])
|
content = strings.TrimSpace(content[idx+len("</think>"):])
|
||||||
|
|||||||
@@ -236,16 +236,6 @@ func TestLFM2Renderer_Images(t *testing.T) {
|
|||||||
Content: "Describe this image.",
|
Content: "Describe this image.",
|
||||||
Images: []api.ImageData{api.ImageData("img1")},
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
},
|
},
|
||||||
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "existing_indexed_img_placeholder_not_duplicated",
|
|
||||||
renderer: &LFM2Renderer{useImgTags: true},
|
|
||||||
message: api.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: "[img-0]Describe this image.",
|
|
||||||
Images: []api.ImageData{api.ImageData("img1")},
|
|
||||||
},
|
|
||||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
@@ -9,10 +10,11 @@ import (
|
|||||||
type Qwen3VLRenderer struct {
|
type Qwen3VLRenderer struct {
|
||||||
isThinking bool
|
isThinking bool
|
||||||
|
|
||||||
useImgTags bool
|
emitEmptyThinkOnNoThink bool
|
||||||
|
useImgTags bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||||
var subSb strings.Builder
|
var subSb strings.Builder
|
||||||
for range content.Images {
|
for range content.Images {
|
||||||
@@ -20,7 +22,8 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
|||||||
// model backends, and so we should eventually parameterize this or
|
// model backends, and so we should eventually parameterize this or
|
||||||
// only output a placeholder such as [img]
|
// only output a placeholder such as [img]
|
||||||
if r.useImgTags {
|
if r.useImgTags {
|
||||||
subSb.WriteString("[img]")
|
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||||
|
imageOffset++
|
||||||
} else {
|
} else {
|
||||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||||
}
|
}
|
||||||
@@ -28,12 +31,17 @@ func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
|||||||
// TODO: support videos
|
// TODO: support videos
|
||||||
|
|
||||||
subSb.WriteString(content.Content)
|
subSb.WriteString(content.Content)
|
||||||
return subSb.String()
|
return subSb.String(), imageOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
|
isThinking := r.isThinking
|
||||||
|
if think != nil {
|
||||||
|
isThinking = think.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
sb.WriteString(imStartTag + "system\n")
|
sb.WriteString(imStartTag + "system\n")
|
||||||
if len(messages) > 0 && messages[0].Role == "system" {
|
if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
@@ -57,7 +65,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
|||||||
message := messages[i]
|
message := messages[i]
|
||||||
if multiStepTool && message.Role == "user" {
|
if multiStepTool && message.Role == "user" {
|
||||||
// Check if content starts with <tool_response> and ends with </tool_response>
|
// Check if content starts with <tool_response> and ends with </tool_response>
|
||||||
content := r.renderContent(message)
|
content, _ := r.renderContent(message, 0)
|
||||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||||
multiStepTool = false
|
multiStepTool = false
|
||||||
lastQueryIndex = i
|
lastQueryIndex = i
|
||||||
@@ -65,8 +73,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
content := r.renderContent(message)
|
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextImageOffset
|
||||||
|
|
||||||
lastMessage := i == len(messages)-1
|
lastMessage := i == len(messages)-1
|
||||||
prefill := lastMessage && message.Role == "assistant"
|
prefill := lastMessage && message.Role == "assistant"
|
||||||
@@ -76,13 +86,13 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
|||||||
} else if message.Role == "assistant" {
|
} else if message.Role == "assistant" {
|
||||||
contentReasoning := ""
|
contentReasoning := ""
|
||||||
|
|
||||||
if r.isThinking {
|
if isThinking {
|
||||||
if message.Thinking != "" {
|
if message.Thinking != "" {
|
||||||
contentReasoning = message.Thinking
|
contentReasoning = message.Thinking
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.isThinking && i > lastQueryIndex {
|
if isThinking && i > lastQueryIndex {
|
||||||
if i == len(messages)-1 || contentReasoning != "" {
|
if i == len(messages)-1 || contentReasoning != "" {
|
||||||
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
|
sb.WriteString("<|im_start|>" + message.Role + "\n<think>\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here?
|
||||||
if content != "" {
|
if content != "" {
|
||||||
@@ -125,8 +135,10 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
|||||||
// prefill at the end
|
// prefill at the end
|
||||||
if lastMessage && !prefill {
|
if lastMessage && !prefill {
|
||||||
sb.WriteString("<|im_start|>assistant\n")
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
if r.isThinking {
|
if isThinking {
|
||||||
sb.WriteString("<think>\n")
|
sb.WriteString("<think>\n")
|
||||||
|
} else if r.emitEmptyThinkOnNoThink {
|
||||||
|
sb.WriteString("<think>\n\n</think>\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ Let me analyze this image.`,
|
|||||||
},
|
},
|
||||||
useImgTags: true,
|
useImgTags: true,
|
||||||
expected: `<|im_start|>user
|
expected: `<|im_start|>user
|
||||||
[img]Describe this image.<|im_end|>
|
[img-0]Describe this image.<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Let me analyze this image.`,
|
Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
@@ -123,7 +123,7 @@ Let me analyze this image.`,
|
|||||||
},
|
},
|
||||||
useImgTags: true,
|
useImgTags: true,
|
||||||
expected: `<|im_start|>user
|
expected: `<|im_start|>user
|
||||||
[img][img]Describe these images.<|im_end|>
|
[img-0][img-1]Describe these images.<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Let me analyze this image.`,
|
Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -370,3 +371,74 @@ func TestFormatToolCallArgumentThinkingVL(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3VLRendererThinkOverride(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
renderThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(renderThinking, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected default thinking renderer to emit <think>, got:\n%s", renderThinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
renderNonThinking, err := (&Qwen3VLRenderer{isThinking: true}).Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(renderNonThinking, "<think>") {
|
||||||
|
t.Fatalf("expected think=false override to suppress <think>, got:\n%s", renderNonThinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
renderForcedThinking, err := (&Qwen3VLRenderer{isThinking: false}).Render(msgs, nil, &api.ThinkValue{Value: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(renderForcedThinking, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected think=true override to emit <think>, got:\n%s", renderForcedThinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3VLRendererThinkOverrideWithExplicitNoThinkPrefill(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
renderNonThinking, err := (&Qwen3VLRenderer{
|
||||||
|
isThinking: true,
|
||||||
|
emitEmptyThinkOnNoThink: true,
|
||||||
|
}).Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(renderNonThinking, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||||
|
t.Fatalf("expected explicit think=false prefill block, got:\n%s", renderNonThinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenRendererNameNoThinkBehaviorSplit(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
thinkFalse := &api.ThinkValue{Value: false}
|
||||||
|
|
||||||
|
qwen35Rendered, err := RenderWithRenderer("qwen3.5", msgs, nil, thinkFalse)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(qwen35Rendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||||
|
t.Fatalf("expected qwen3.5 renderer to emit explicit no-think prefill, got:\n%s", qwen35Rendered)
|
||||||
|
}
|
||||||
|
|
||||||
|
qwen3VLRendered, err := RenderWithRenderer("qwen3-vl-thinking", msgs, nil, thinkFalse)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(qwen3VLRendered, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||||
|
t.Fatalf("expected qwen3-vl-thinking renderer to keep legacy non-empty no-think behavior, got:\n%s", qwen3VLRendered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
|
|||||||
case "qwen3-vl-thinking":
|
case "qwen3-vl-thinking":
|
||||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||||
return renderer
|
return renderer
|
||||||
|
case "qwen3.5":
|
||||||
|
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||||
|
return renderer
|
||||||
case "cogito":
|
case "cogito":
|
||||||
renderer := &CogitoRenderer{isThinking: true}
|
renderer := &CogitoRenderer{isThinking: true}
|
||||||
return renderer
|
return renderer
|
||||||
|
|||||||
@@ -29,17 +29,27 @@ func TestRegisterCustomRenderer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBuiltInRendererStillWorks(t *testing.T) {
|
func TestBuiltInRendererStillWorks(t *testing.T) {
|
||||||
// Test that qwen3-coder still works
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{name: "qwen3-coder"},
|
||||||
|
{name: "qwen3.5"},
|
||||||
|
}
|
||||||
|
|
||||||
messages := []api.Message{
|
messages := []api.Message{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: "Hello"},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil)
|
for _, tt := range tests {
|
||||||
if err != nil {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
result, err := RenderWithRenderer(tt.name, messages, nil, nil)
|
||||||
}
|
if err != nil {
|
||||||
if result == "" {
|
t.Fatalf("unexpected error: %v", err)
|
||||||
t.Error("expected non-empty result from qwen3-coder renderer")
|
}
|
||||||
|
if result == "" {
|
||||||
|
t.Fatalf("expected non-empty result from %s renderer", tt.name)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) IsMLX() bool {
|
||||||
|
return m.Config.ModelFormat == "safetensors"
|
||||||
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
// Start with all messages and remove from the front until it fits in context
|
if truncate {
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
// Start with all messages and remove from the front until it fits in context
|
||||||
// Collect system messages from the portion we're about to skip
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
system = make([]api.Message, 0)
|
// Collect system messages from the portion we're about to skip
|
||||||
for j := range i {
|
system = make([]api.Message, 0)
|
||||||
if msgs[j].Role == "system" {
|
for j := range i {
|
||||||
system = append(system, msgs[j])
|
if msgs[j].Role == "system" {
|
||||||
|
system = append(system, msgs[j])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
|
||||||
|
|
||||||
s, err := tokenize(ctx, p)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxLen := len(s)
|
|
||||||
if m.ProjectorPaths != nil {
|
|
||||||
for _, msg := range msgs[i:] {
|
|
||||||
ctxLen += imageNumTokens * len(msg.Images)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !truncate || ctxLen <= opts.NumCtx {
|
s, err := tokenize(ctx, p)
|
||||||
currMsgIdx = i
|
if err != nil {
|
||||||
break
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must always include at least the last message
|
ctxLen := len(s)
|
||||||
if i == lastMsgIdx {
|
if m.ProjectorPaths != nil {
|
||||||
currMsgIdx = lastMsgIdx
|
for _, msg := range msgs[i:] {
|
||||||
break
|
ctxLen += imageNumTokens * len(msg.Images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctxLen <= opts.NumCtx {
|
||||||
|
currMsgIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must always include at least the last message
|
||||||
|
if i == lastMsgIdx {
|
||||||
|
currMsgIdx = lastMsgIdx
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,6 +88,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
ID: len(images),
|
ID: len(images),
|
||||||
Data: i,
|
Data: i,
|
||||||
}
|
}
|
||||||
|
images = append(images, imgData)
|
||||||
|
|
||||||
|
if m.Config.Renderer != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||||
if !strings.Contains(prompt, "[img]") {
|
if !strings.Contains(prompt, "[img]") {
|
||||||
@@ -93,8 +100,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
} else {
|
} else {
|
||||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
images = append(images, imgData)
|
|
||||||
}
|
}
|
||||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
@@ -330,3 +331,38 @@ func TestChatPromptTokenizeCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "what do these photos have in common?",
|
||||||
|
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
originalContent := msgs[0].Content
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
|
||||||
|
ProjectorPaths: []string{"vision"},
|
||||||
|
}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||||
|
think := false
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgs[0].Content != originalContent {
|
||||||
|
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(images), 3; got != want {
|
||||||
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt == "" {
|
||||||
|
t.Fatal("prompt is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -33,6 +34,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
|||||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||||
}
|
}
|
||||||
|
if uint64(len(data)) < q.from.Size() {
|
||||||
|
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||||
|
}
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
newType := fsggml.TensorType(q.to.Kind)
|
newType := fsggml.TensorType(q.to.Kind)
|
||||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||||
@@ -58,7 +62,7 @@ func useMoreBits(iLayer, nLayers int) bool {
|
|||||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||||
}
|
}
|
||||||
|
|
||||||
func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
|
||||||
switch {
|
switch {
|
||||||
// Full attention
|
// Full attention
|
||||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||||
@@ -79,6 +83,10 @@ func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
|
|||||||
// SSM
|
// SSM
|
||||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||||
return fsggml.TensorTypeQ4_K, true
|
return fsggml.TensorTypeQ4_K, true
|
||||||
|
case strings.HasSuffix(name, ".ssm_beta.weight"):
|
||||||
|
return fsggml.TensorTypeQ4_K, true
|
||||||
|
case strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||||
|
return fsggml.TensorTypeQ4_K, true
|
||||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||||
return fsggml.TensorTypeQ4_K, true
|
return fsggml.TensorTypeQ4_K, true
|
||||||
|
|
||||||
@@ -287,8 +295,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
|||||||
|
|
||||||
newType := fsggml.TensorType(t.Kind)
|
newType := fsggml.TensorType(t.Kind)
|
||||||
if quantize {
|
if quantize {
|
||||||
if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||||
if qt, ok := qwen3nextQuantType(name); ok {
|
if qt, ok := qwen3LinearAttnQuantType(name); ok {
|
||||||
return qt
|
return qt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -166,6 +166,60 @@ func TestGetTensorNewType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
arch string
|
||||||
|
tensor string
|
||||||
|
fileType fsggml.FileType
|
||||||
|
expected fsggml.TensorType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "qwen35_beta",
|
||||||
|
arch: "qwen35",
|
||||||
|
tensor: "blk.0.ssm_beta.weight",
|
||||||
|
fileType: fsggml.FileTypeQ4_K_M,
|
||||||
|
expected: fsggml.TensorTypeQ4_K,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen35_alpha",
|
||||||
|
arch: "qwen35",
|
||||||
|
tensor: "blk.0.ssm_alpha.weight",
|
||||||
|
fileType: fsggml.FileTypeQ4_K_M,
|
||||||
|
expected: fsggml.TensorTypeQ4_K,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen35moe_attn_qkv",
|
||||||
|
arch: "qwen35moe",
|
||||||
|
tensor: "blk.0.attn_qkv.weight",
|
||||||
|
fileType: fsggml.FileTypeQ4_K_M,
|
||||||
|
expected: fsggml.TensorTypeQ4_K,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_qwen35_falls_back",
|
||||||
|
arch: "foo",
|
||||||
|
tensor: "blk.0.attn_qkv.weight",
|
||||||
|
fileType: fsggml.FileTypeQ4_K_M,
|
||||||
|
expected: fsggml.TensorTypeQ5_K,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
kv := fsggml.KV{"general.architecture": tt.arch}
|
||||||
|
got := newType(&fsggml.Tensor{
|
||||||
|
Name: tt.tensor,
|
||||||
|
Shape: []uint64{256, 256},
|
||||||
|
Kind: uint32(fsggml.TensorTypeF16),
|
||||||
|
}, kv, &quantizeState{}, tt.fileType)
|
||||||
|
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQuantizeModel(t *testing.T) {
|
func TestQuantizeModel(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -173,6 +227,7 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
tensors []*fsggml.Tensor
|
tensors []*fsggml.Tensor
|
||||||
newType string
|
newType string
|
||||||
expectedTensorTypes map[string]fsggml.TensorType
|
expectedTensorTypes map[string]fsggml.TensorType
|
||||||
|
expectErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "f16_q4_k",
|
name: "f16_q4_k",
|
||||||
@@ -253,6 +308,36 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
"output.weight": fsggml.TensorTypeQ8_0,
|
"output.weight": fsggml.TensorTypeQ8_0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "f32_short_data",
|
||||||
|
kv: map[string]any{
|
||||||
|
"general.architecture": "foo",
|
||||||
|
},
|
||||||
|
tensors: []*fsggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||||
|
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||||
|
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newType: "Q4_K",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "f16_short_data",
|
||||||
|
kv: map[string]any{
|
||||||
|
"general.architecture": "foo",
|
||||||
|
},
|
||||||
|
tensors: []*fsggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||||
|
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||||
|
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newType: "Q4_K",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
@@ -264,6 +349,9 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
meta, err := fsggml.Decode(fp, -1)
|
meta, err := fsggml.Decode(fp, -1)
|
||||||
|
if tt.expectErr && err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
@@ -283,6 +371,12 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = quantize(fp, tmp, meta, ftype, progress)
|
err = quantize(fp, tmp, meta, ftype, progress)
|
||||||
|
if tt.expectErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected quantize to return an error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error during quantize: %s", err)
|
t.Fatalf("error during quantize: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,6 +130,35 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
|
|||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
|
||||||
|
keys := []string{
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"min_p",
|
||||||
|
"top_k",
|
||||||
|
"repeat_last_n",
|
||||||
|
"repeat_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit := make(map[string]struct{}, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if optionSpecified(modelOpts, requestOpts, key) {
|
||||||
|
explicit[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return explicit
|
||||||
|
}
|
||||||
|
|
||||||
|
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
|
||||||
|
if _, ok := requestOpts[key]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := modelOpts[key]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
@@ -484,7 +513,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||||
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -538,14 +568,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
Shift: req.Shift == nil || *req.Shift,
|
Think: req.Think,
|
||||||
Truncate: req.Truncate == nil || *req.Truncate,
|
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
||||||
Logprobs: req.Logprobs,
|
Shift: req.Shift == nil || *req.Shift,
|
||||||
TopLogprobs: req.TopLogprobs,
|
Truncate: req.Truncate == nil || *req.Truncate,
|
||||||
|
Logprobs: req.Logprobs,
|
||||||
|
TopLogprobs: req.TopLogprobs,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -557,6 +589,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
|
PeakMemory: cr.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -1951,6 +1984,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
|
total, vram := v.llama.MemorySize()
|
||||||
|
mr.Size = int64(total)
|
||||||
|
mr.SizeVRAM = int64(vram)
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -2213,6 +2249,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
|
if m.IsMLX() {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2290,14 +2329,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// sets up new context given parent context per request
|
// sets up new context given parent context per request
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
err := r.Completion(ctx, llm.CompletionRequest{
|
err := r.Completion(ctx, llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: currentFormat,
|
Format: currentFormat,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
Shift: req.Shift == nil || *req.Shift,
|
Think: req.Think,
|
||||||
Truncate: truncate,
|
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
||||||
Logprobs: req.Logprobs,
|
Shift: req.Shift == nil || *req.Shift,
|
||||||
TopLogprobs: req.TopLogprobs,
|
Truncate: truncate,
|
||||||
|
Logprobs: req.Logprobs,
|
||||||
|
TopLogprobs: req.TopLogprobs,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -2309,6 +2350,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
|
PeakMemory: r.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.Config.ModelFormat == "safetensors" {
|
if pending.model.IsMLX() {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
|||||||
|
|
||||||
// Some architectures are not safe with num_parallel > 1.
|
// Some architectures are not safe with num_parallel > 1.
|
||||||
// ref: https://github.com/ollama/ollama/issues/4165
|
// ref: https://github.com/ollama/ollama/issues/4165
|
||||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||||
}
|
}
|
||||||
@@ -536,6 +536,7 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := llama.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -545,8 +546,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
vramSize: llama.VRAMSize(),
|
totalSize: totalSize,
|
||||||
totalSize: llama.TotalSize(),
|
vramSize: vramSize,
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := server.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: isImagegen,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: server.TotalSize(),
|
totalSize: totalSize,
|
||||||
vramSize: server.VRAMSize(),
|
vramSize: vramSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStackedExpertWeight(name string) bool {
|
||||||
|
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||||
|
// or "...proj" (pre-stacked packed tensor).
|
||||||
|
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||||
|
strings.Contains(name, ".mlp.experts.") ||
|
||||||
|
strings.Contains(name, ".mlp.shared_experts.")
|
||||||
|
}
|
||||||
|
|
||||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||||
// Returns "" if the tensor should not be quantized.
|
// Returns "" if the tensor should not be quantized.
|
||||||
// This implements mixed-precision quantization:
|
// This implements mixed-precision quantization:
|
||||||
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
|
|||||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||||
// - Norms, embeddings, biases, routing gates: no quantization
|
// - Norms, embeddings, biases, routing gates: no quantization
|
||||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||||
|
stackedExpert := isStackedExpertWeight(name)
|
||||||
|
|
||||||
// Use basic name-based check first
|
// Use basic name-based check first
|
||||||
if !ShouldQuantize(name, "") {
|
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||||
if len(shape) != 2 {
|
// e.g. qwen switch_mlp / experts combined tensors.
|
||||||
|
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
var elems int64 = 1
|
||||||
|
for _, d := range shape {
|
||||||
|
elems *= int64(d)
|
||||||
|
}
|
||||||
|
if elems < 1024 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
// 3D+ tensors should not be quantized
|
// 3D+ tensors should not be quantized
|
||||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||||
|
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||||
|
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||||
|
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||||
|
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||||
|
|
||||||
// Embeddings should not be quantized regardless of shape
|
// Embeddings should not be quantized regardless of shape
|
||||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||||
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||||
|
gateUp := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||||
|
[]int32{64, 22016, 4096},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if gateUp != "int4" {
|
||||||
|
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||||
|
}
|
||||||
|
|
||||||
|
down := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.experts.down_proj.weight",
|
||||||
|
[]int32{64, 4096, 14336},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if down != "int8" {
|
||||||
|
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedGateUp := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||||
|
[]int32{256, 1024, 2048},
|
||||||
|
"int8",
|
||||||
|
)
|
||||||
|
if combinedGateUp != "int8" {
|
||||||
|
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedDown := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||||
|
[]int32{256, 2048, 512},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if combinedDown != "int8" {
|
||||||
|
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMSize returns the estimated VRAM usage.
|
// MemorySize returns the total and VRAM memory usage.
|
||||||
func (s *Server) VRAMSize() uint64 {
|
func (s *Server) MemorySize() (total, vram uint64) {
|
||||||
return s.vramSize
|
return s.vramSize, s.vramSize
|
||||||
}
|
|
||||||
|
|
||||||
// TotalSize returns the total memory usage.
|
|
||||||
func (s *Server) TotalSize() uint64 {
|
|
||||||
return s.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||||
|
|||||||
@@ -9,59 +9,177 @@ import (
|
|||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheEntry stores a single sequence
|
type kvCache struct {
|
||||||
type CacheEntry struct {
|
// For now we only support a single entry, so this is just one sequence
|
||||||
Tokens []int32
|
tokens []int32
|
||||||
Caches []cache.Cache
|
caches []cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
// cacheSession manages caches for a single pipeline run.
|
||||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
// Callers should append generated tokens to outputs and
|
||||||
if r.cache == nil {
|
// defer close to save the cache state.
|
||||||
slog.Info("Cache miss", "left", len(tokens))
|
type cacheSession struct {
|
||||||
return nil, tokens
|
cache *kvCache
|
||||||
|
inputs []int32
|
||||||
|
outputs []int32
|
||||||
|
|
||||||
|
caches []cache.Cache
|
||||||
|
remaining []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) free() {
|
||||||
|
for i, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kv.Free()
|
||||||
|
c.caches[i] = nil
|
||||||
|
}
|
||||||
|
c.caches = nil
|
||||||
|
c.tokens = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) cachesCanTrim() bool {
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !kv.CanTrim() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) trimToPrefix(prefix int) {
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil || !kv.CanTrim() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if trim := kv.Offset() - prefix; trim > 0 {
|
||||||
|
kv.Trim(trim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if prefix < len(c.tokens) {
|
||||||
|
c.tokens = c.tokens[:prefix]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin prepares caches for a new request. It finds the nearest
|
||||||
|
// matching cache or creates new caches if none match.
|
||||||
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||||
|
ensureCaches := func() {
|
||||||
|
if len(c.caches) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||||
|
c.caches = cacheFactory.NewCaches()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.caches = make([]cache.Cache, m.NumLayers())
|
||||||
|
for i := range c.caches {
|
||||||
|
c.caches[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ensureCaches()
|
||||||
|
|
||||||
|
remaining := c.findRemaining(inputs)
|
||||||
|
ensureCaches()
|
||||||
|
|
||||||
|
return &cacheSession{
|
||||||
|
cache: c,
|
||||||
|
inputs: inputs,
|
||||||
|
caches: c.caches,
|
||||||
|
remaining: remaining,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// close saves the token state if the forward pass ran.
|
||||||
|
func (s *cacheSession) close() {
|
||||||
|
if len(s.caches) == 0 {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find longest common prefix
|
offset := -1
|
||||||
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||||
|
for _, kv := range s.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if off := kv.Offset(); offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
arrays = append(arrays, kv.Materialize()...)
|
||||||
|
}
|
||||||
|
if offset <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that if we have run the forward pass and set the metadata
|
||||||
|
// that we also actually have the data.
|
||||||
|
mlx.AsyncEval(arrays...)
|
||||||
|
|
||||||
|
stored := append(s.inputs, s.outputs...)
|
||||||
|
if offset > len(stored) {
|
||||||
|
offset = len(stored)
|
||||||
|
}
|
||||||
|
s.cache.tokens = stored[:offset]
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRemaining finds the longest common prefix between tokens and the cached
|
||||||
|
// sequence, trims stale cache entries, and returns the remaining tokens.
|
||||||
|
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||||
prefix := 0
|
prefix := 0
|
||||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
||||||
prefix++
|
prefix++
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
// Always keep at least one token to re-evaluate so the
|
||||||
case prefix == 0:
|
// pipeline can seed token generation from it.
|
||||||
for _, c := range r.cache.Caches {
|
if prefix == len(tokens) && prefix > 0 {
|
||||||
c.Free()
|
prefix--
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix < len(c.tokens) {
|
||||||
|
if c.cachesCanTrim() {
|
||||||
|
c.trimToPrefix(prefix)
|
||||||
|
} else {
|
||||||
|
c.free()
|
||||||
|
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||||
|
return tokens
|
||||||
}
|
}
|
||||||
r.cache = nil
|
}
|
||||||
|
|
||||||
|
if prefix == 0 {
|
||||||
slog.Info("Cache miss", "left", len(tokens))
|
slog.Info("Cache miss", "left", len(tokens))
|
||||||
return nil, tokens
|
} else {
|
||||||
case prefix < len(r.cache.Tokens):
|
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||||
trim := len(r.cache.Tokens) - prefix
|
|
||||||
for _, c := range r.cache.Caches {
|
|
||||||
c.Trim(trim)
|
|
||||||
}
|
|
||||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
|
||||||
}
|
}
|
||||||
|
return tokens[prefix:]
|
||||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
||||||
return r.cache.Caches, tokens[prefix:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
func (c *kvCache) log() {
|
||||||
r.cache = &CacheEntry{
|
if len(c.caches) == 0 {
|
||||||
Tokens: tokens,
|
return
|
||||||
Caches: caches,
|
|
||||||
}
|
}
|
||||||
}
|
offset := -1
|
||||||
|
|
||||||
func (c *CacheEntry) LogCache() {
|
|
||||||
var totalBytes int
|
var totalBytes int
|
||||||
for _, kv := range c.Caches {
|
for _, kv := range c.caches {
|
||||||
k, v := kv.State()
|
if kv == nil {
|
||||||
totalBytes += k.NumBytes() + v.NumBytes()
|
continue
|
||||||
|
}
|
||||||
|
if off := kv.Offset(); offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
for _, a := range kv.Materialize() {
|
||||||
|
totalBytes += a.NumBytes()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
if offset < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
||||||
}
|
}
|
||||||
|
|||||||
18
x/mlxrunner/cache/cache.go
vendored
18
x/mlxrunner/cache/cache.go
vendored
@@ -10,6 +10,8 @@ import (
|
|||||||
type Cache interface {
|
type Cache interface {
|
||||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||||
State() (keys, values *mlx.Array)
|
State() (keys, values *mlx.Array)
|
||||||
|
Materialize() []*mlx.Array
|
||||||
|
CanTrim() bool
|
||||||
Trim(int) int
|
Trim(int) int
|
||||||
Clone() Cache
|
Clone() Cache
|
||||||
Free()
|
Free()
|
||||||
@@ -67,6 +69,20 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
|||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Materialize returns the backing key/value buffers currently held by the cache.
|
||||||
|
func (c *KVCache) Materialize() []*mlx.Array {
|
||||||
|
out := make([]*mlx.Array, 0, 2)
|
||||||
|
if c.keys != nil && c.keys.Valid() {
|
||||||
|
out = append(out, c.keys)
|
||||||
|
}
|
||||||
|
if c.values != nil && c.values.Valid() {
|
||||||
|
out = append(out, c.values)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) CanTrim() bool { return true }
|
||||||
|
|
||||||
func (c *KVCache) Trim(n int) int {
|
func (c *KVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
@@ -190,6 +206,8 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
|||||||
return c.keys, c.values
|
return c.keys, c.values
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||||
|
|
||||||
func (c *RotatingKVCache) Trim(n int) int {
|
func (c *RotatingKVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
|
|||||||
220
x/mlxrunner/cache/recurrent.go
vendored
Normal file
220
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
|
||||||
|
// RecurrentCache stores state for linear-recurrent layers.
|
||||||
|
//
|
||||||
|
// Conv state shape: [B, convTail, convDim]
|
||||||
|
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||||
|
type RecurrentCache struct {
|
||||||
|
convState *mlx.Array
|
||||||
|
deltaState *mlx.Array
|
||||||
|
offset int
|
||||||
|
|
||||||
|
convTail int
|
||||||
|
convDim int
|
||||||
|
numVHeads int
|
||||||
|
headVDim int
|
||||||
|
headKDim int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||||
|
if v == nil || !v.Valid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if *dst == v {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Break dependency chains so recurrent state does not retain the full
|
||||||
|
// per-token compute graph over time.
|
||||||
|
snap := mlx.Snapshot(v)
|
||||||
|
mlx.Eval(snap)
|
||||||
|
|
||||||
|
old := *dst
|
||||||
|
*dst = snap
|
||||||
|
mlx.Pin(snap)
|
||||||
|
|
||||||
|
// Drop references to the previous cached state root and transient incoming
|
||||||
|
// graph root now that a detached snapshot is retained in cache. Actual
|
||||||
|
// cleanup happens at the runner's normal sweep points.
|
||||||
|
if old != nil && old != snap {
|
||||||
|
mlx.Unpin(old)
|
||||||
|
}
|
||||||
|
if v != snap && v != old {
|
||||||
|
mlx.Unpin(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||||
|
if v == nil || !v.Valid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if *dst == v {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
old := *dst
|
||||||
|
*dst = v
|
||||||
|
mlx.Pin(v)
|
||||||
|
if old != nil && old != v {
|
||||||
|
mlx.Unpin(old)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
|
||||||
|
if v == nil || !v.Valid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if *dst == v {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
root := v
|
||||||
|
if ensureContiguous {
|
||||||
|
root = mlx.Contiguous(v, false)
|
||||||
|
}
|
||||||
|
detached := mlx.Detach(root)
|
||||||
|
|
||||||
|
old := *dst
|
||||||
|
*dst = detached
|
||||||
|
mlx.Pin(detached)
|
||||||
|
if old != nil && old != detached {
|
||||||
|
mlx.Unpin(old)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intentionally do not force-release root/v here. In the fast path, the detached
|
||||||
|
// handle aliases the same MLX value and may still be lazily computed. Releasing the
|
||||||
|
// source handles can invalidate the cached state before the next eval/sweep point.
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||||
|
if a == nil || !a.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snap := mlx.Snapshot(a)
|
||||||
|
mlx.Eval(snap)
|
||||||
|
mlx.Pin(snap)
|
||||||
|
return snap
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||||
|
return &RecurrentCache{
|
||||||
|
convTail: int(convTail),
|
||||||
|
convDim: int(convDim),
|
||||||
|
numVHeads: int(numVHeads),
|
||||||
|
headVDim: int(headVDim),
|
||||||
|
headKDim: int(headKDim),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||||
|
if batch <= 0 {
|
||||||
|
batch = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||||
|
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||||
|
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
||||||
|
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||||
|
if !needConv && !needDelta {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if needConv {
|
||||||
|
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||||
|
}
|
||||||
|
if needDelta {
|
||||||
|
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||||
|
c.ensure(batch, dtype)
|
||||||
|
return c.convState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||||
|
c.setStateMaterialized(&c.convState, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
|
||||||
|
// Use only for decode hot paths that accept higher transient memory until the next
|
||||||
|
// sync/sweep point. The conv-state input is usually a slice view, so request a
|
||||||
|
// compact contiguous copy to avoid pinning the whole source buffer.
|
||||||
|
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
|
||||||
|
c.setStateDetached(&c.convState, v, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||||
|
c.ensure(batch, dtype)
|
||||||
|
return c.deltaState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||||
|
c.setStateMaterialized(&c.deltaState, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
|
||||||
|
// Use only for decode hot paths that accept higher transient memory until the next
|
||||||
|
// sync/sweep point.
|
||||||
|
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
|
||||||
|
c.setStateDetached(&c.deltaState, v, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Advance(n int) {
|
||||||
|
c.offset += n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
return keys, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||||
|
return c.convState, c.deltaState
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
|
||||||
|
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||||
|
out := make([]*mlx.Array, 0, 2)
|
||||||
|
if c.convState != nil && c.convState.Valid() {
|
||||||
|
out = append(out, c.convState)
|
||||||
|
}
|
||||||
|
if c.deltaState != nil && c.deltaState.Valid() {
|
||||||
|
out = append(out, c.deltaState)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Trim(n int) int {
|
||||||
|
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
||||||
|
_ = n
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Clone() Cache {
|
||||||
|
clone := &RecurrentCache{
|
||||||
|
offset: c.offset,
|
||||||
|
convTail: c.convTail,
|
||||||
|
convDim: c.convDim,
|
||||||
|
numVHeads: c.numVHeads,
|
||||||
|
headVDim: c.headVDim,
|
||||||
|
headKDim: c.headKDim,
|
||||||
|
convState: snapshotPinned(c.convState),
|
||||||
|
deltaState: snapshotPinned(c.deltaState),
|
||||||
|
}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Free() {
|
||||||
|
mlx.Unpin(c.convState, c.deltaState)
|
||||||
|
c.convState, c.deltaState = nil, nil
|
||||||
|
c.offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||||
|
func (c *RecurrentCache) Len() int { return c.offset }
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -19,25 +18,27 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
vramSize uint64
|
contextLength atomic.Int64
|
||||||
done chan error
|
memory atomic.Uint64
|
||||||
client *http.Client
|
done chan error
|
||||||
lastErr string
|
client *http.Client
|
||||||
lastErrLock sync.Mutex
|
lastErr string
|
||||||
mu sync.Mutex
|
lastErrLock sync.Mutex
|
||||||
cmd *exec.Cmd
|
mu sync.Mutex
|
||||||
|
cmd *exec.Cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate VRAM based on tensor size from manifest
|
|
||||||
var vramSize uint64
|
|
||||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
|
||||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
|
||||||
} else {
|
|
||||||
vramSize = 8 * 1024 * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
vramSize: vramSize,
|
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
@@ -190,15 +182,34 @@ func (c *Client) waitUntilRunning() error {
|
|||||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||||
type completionRequest struct {
|
type completionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Think *bool `json:"think,omitempty"`
|
||||||
Options *completionOpts `json:"options,omitempty"`
|
Options *completionOpts `json:"options,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type completionOpts struct {
|
type completionOpts struct {
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
TopP *float32 `json:"top_p,omitempty"`
|
||||||
MinP float32 `json:"min_p,omitempty"`
|
MinP *float32 `json:"min_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK *int `json:"top_k,omitempty"`
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
RepeatLastN *int `json:"repeat_last_n,omitempty"`
|
||||||
|
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
|
||||||
|
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
|
||||||
|
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
|
||||||
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string
|
||||||
|
Done bool
|
||||||
|
DoneReason int
|
||||||
|
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
PeakMemory uint64
|
||||||
|
|
||||||
|
Error *api.StatusError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
@@ -222,16 +233,27 @@ func (c *Client) Close() error {
|
|||||||
|
|
||||||
// Completion implements llm.LlamaServer.
|
// Completion implements llm.LlamaServer.
|
||||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
|
var think *bool
|
||||||
|
if req.Think != nil {
|
||||||
|
enabled := req.Think.Bool()
|
||||||
|
think = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
creq := completionRequest{
|
creq := completionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
|
Think: think,
|
||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = &completionOpts{
|
||||||
Temperature: req.Options.Temperature,
|
Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")),
|
||||||
TopP: req.Options.TopP,
|
TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")),
|
||||||
MinP: req.Options.MinP,
|
MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")),
|
||||||
TopK: req.Options.TopK,
|
TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")),
|
||||||
NumPredict: req.Options.NumPredict,
|
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")),
|
||||||
|
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
|
||||||
|
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
|
||||||
|
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
|
||||||
|
NumPredict: req.Options.NumPredict,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +282,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw struct {
|
var raw CompletionResponse
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration int `json:"eval_duration,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if raw.Error != nil {
|
||||||
|
return *raw.Error
|
||||||
|
}
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: time.Duration(raw.EvalDuration),
|
EvalDuration: raw.EvalDuration,
|
||||||
|
PeakMemory: raw.PeakMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
@@ -293,8 +312,27 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
return scanner.Err()
|
return scanner.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
|
||||||
|
_, ok := explicit[key]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func float32Ptr(v float32, ok bool) *float32 {
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func intPtr(v int, ok bool) *int {
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
return math.MaxInt
|
return int(c.contextLength.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize implements llm.LlamaServer.
|
// Detokenize implements llm.LlamaServer.
|
||||||
@@ -347,9 +385,16 @@ func (c *Client) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type statusResponse struct {
|
||||||
|
Status int
|
||||||
|
Progress int
|
||||||
|
ContextLength int
|
||||||
|
Memory uint64
|
||||||
|
}
|
||||||
|
|
||||||
// Ping implements llm.LlamaServer.
|
// Ping implements llm.LlamaServer.
|
||||||
func (c *Client) Ping(ctx context.Context) error {
|
func (c *Client) Ping(ctx context.Context) error {
|
||||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -362,6 +407,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var status statusResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.contextLength.Store(int64(status.ContextLength))
|
||||||
|
c.memory.Store(status.Memory)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,19 +442,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotalSize implements llm.LlamaServer.
|
func (c *Client) currentMemory() uint64 {
|
||||||
func (c *Client) TotalSize() uint64 {
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
return c.vramSize
|
defer cancel()
|
||||||
|
if err := c.Ping(ctx); err != nil {
|
||||||
|
slog.Warn("failed to get current memory", "error", err)
|
||||||
|
}
|
||||||
|
return c.memory.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemorySize implements llm.LlamaServer.
|
||||||
|
func (c *Client) MemorySize() (total, vram uint64) {
|
||||||
|
mem := c.currentMemory()
|
||||||
|
return mem, mem
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU implements llm.LlamaServer.
|
// VRAMByGPU implements llm.LlamaServer.
|
||||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
return c.vramSize
|
return c.currentMemory()
|
||||||
}
|
|
||||||
|
|
||||||
// VRAMSize implements llm.LlamaServer.
|
|
||||||
func (c *Client) VRAMSize() uint64 {
|
|
||||||
return c.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitUntilRunning implements llm.LlamaServer.
|
// WaitUntilRunning implements llm.LlamaServer.
|
||||||
|
|||||||
167
x/mlxrunner/client_test.go
Normal file
167
x/mlxrunner/client_test.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package mlxrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompletionForwardsThink(t *testing.T) {
|
||||||
|
boolPtr := func(v bool) *bool { return &v }
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
think *api.ThinkValue
|
||||||
|
want *bool
|
||||||
|
}{
|
||||||
|
{name: "unset", think: nil, want: nil},
|
||||||
|
{name: "enabled", think: &api.ThinkValue{Value: true}, want: boolPtr(true)},
|
||||||
|
{name: "disabled", think: &api.ThinkValue{Value: false}, want: boolPtr(false)},
|
||||||
|
{name: "level maps to enabled", think: &api.ThinkValue{Value: "high"}, want: boolPtr(true)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var got completionRequest
|
||||||
|
|
||||||
|
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if r.URL.Path != "/completion" {
|
||||||
|
t.Fatalf("request path = %q, want %q", r.URL.Path, "/completion")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
c := &Client{
|
||||||
|
port: 11434,
|
||||||
|
client: &http.Client{
|
||||||
|
Transport: rt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.Completion(context.Background(), llm.CompletionRequest{
|
||||||
|
Prompt: "hello",
|
||||||
|
Think: tc.think,
|
||||||
|
}, func(llm.CompletionResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("completion request failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Prompt != "hello" {
|
||||||
|
t.Fatalf("prompt = %q, want %q", got.Prompt, "hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case tc.want == nil && got.Think != nil:
|
||||||
|
t.Fatalf("think = %v, want nil", *got.Think)
|
||||||
|
case tc.want != nil && got.Think == nil:
|
||||||
|
t.Fatalf("think = nil, want %v", *tc.want)
|
||||||
|
case tc.want != nil && got.Think != nil && *tc.want != *got.Think:
|
||||||
|
t.Fatalf("think = %v, want %v", *got.Think, *tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
|
||||||
|
var got completionRequest
|
||||||
|
|
||||||
|
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
c := &Client{
|
||||||
|
port: 11434,
|
||||||
|
client: &http.Client{
|
||||||
|
Transport: rt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &api.Options{
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopP: 0.95,
|
||||||
|
MinP: 0.1,
|
||||||
|
TopK: 20,
|
||||||
|
RepeatLastN: 128,
|
||||||
|
RepeatPenalty: 1.2,
|
||||||
|
PresencePenalty: 1.5,
|
||||||
|
FrequencyPenalty: 0.25,
|
||||||
|
NumPredict: 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.Completion(context.Background(), llm.CompletionRequest{
|
||||||
|
Prompt: "hello",
|
||||||
|
Options: opts,
|
||||||
|
ExplicitOptions: map[string]struct{}{
|
||||||
|
"temperature": {},
|
||||||
|
"top_k": {},
|
||||||
|
"repeat_penalty": {},
|
||||||
|
"presence_penalty": {},
|
||||||
|
},
|
||||||
|
}, func(llm.CompletionResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("completion request failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Options == nil {
|
||||||
|
t.Fatal("options = nil, want serialized options")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
|
||||||
|
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
|
||||||
|
}
|
||||||
|
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
|
||||||
|
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
|
||||||
|
}
|
||||||
|
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
|
||||||
|
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
|
||||||
|
}
|
||||||
|
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
|
||||||
|
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
|
||||||
|
}
|
||||||
|
if got.Options.TopP != nil {
|
||||||
|
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
|
||||||
|
}
|
||||||
|
if got.Options.MinP != nil {
|
||||||
|
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
|
||||||
|
}
|
||||||
|
if got.Options.RepeatLastN != nil {
|
||||||
|
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
|
||||||
|
}
|
||||||
|
if got.Options.FrequencyPenalty != nil {
|
||||||
|
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
|
||||||
|
}
|
||||||
|
if got.Options.NumPredict != opts.NumPredict {
|
||||||
|
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
return f(r)
|
||||||
|
}
|
||||||
@@ -7,4 +7,6 @@ import (
|
|||||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||||
_ "github.com/ollama/ollama/x/models/llama"
|
_ "github.com/ollama/ollama/x/models/llama"
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||||
|
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
|
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||||
)
|
)
|
||||||
|
|||||||
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "generated.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
gatedDeltaMetalKernelOnce sync.Once
|
||||||
|
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||||
|
gatedDeltaMetalDisabled atomic.Bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const gatedDeltaMetalKernelSource = `
|
||||||
|
auto n = thread_position_in_grid.z;
|
||||||
|
auto b_idx = n / Hv;
|
||||||
|
auto hv_idx = n % Hv;
|
||||||
|
auto hk_idx = hv_idx / (Hv / Hk);
|
||||||
|
constexpr int n_per_t = Dk / 32;
|
||||||
|
|
||||||
|
// q, k: [B, T, Hk, Dk]
|
||||||
|
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||||
|
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||||
|
|
||||||
|
// v, y: [B, T, Hv, Dv]
|
||||||
|
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||||
|
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||||
|
|
||||||
|
auto dk_idx = thread_position_in_threadgroup.x;
|
||||||
|
auto dv_idx = thread_position_in_grid.y;
|
||||||
|
|
||||||
|
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||||
|
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||||
|
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||||
|
|
||||||
|
float state[n_per_t];
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = static_cast<float>(i_state[s_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// g: [B, T, Hv]
|
||||||
|
auto g_ = g + b_idx * T * Hv;
|
||||||
|
auto beta_ = beta + b_idx * T * Hv;
|
||||||
|
|
||||||
|
for (int t = 0; t < T; ++t) {
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] * g_[hv_idx];
|
||||||
|
kv_mem += state[i] * k_[s_idx];
|
||||||
|
}
|
||||||
|
kv_mem = simd_sum(kv_mem);
|
||||||
|
|
||||||
|
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||||
|
|
||||||
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] + k_[s_idx] * delta;
|
||||||
|
out += state[i] * q_[s_idx];
|
||||||
|
}
|
||||||
|
out = simd_sum(out);
|
||||||
|
if (thread_index_in_simdgroup == 0) {
|
||||||
|
y[dv_idx] = static_cast<InT>(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
q_ += Hk * Dk;
|
||||||
|
k_ += Hk * Dk;
|
||||||
|
v_ += Hv * Dv;
|
||||||
|
y += Hv * Dv;
|
||||||
|
g_ += Hv;
|
||||||
|
beta_ += Hv;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||||
|
vec := C.mlx_vector_string_new()
|
||||||
|
ok := true
|
||||||
|
for _, s := range values {
|
||||||
|
cs := C.CString(s)
|
||||||
|
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
C.free(unsafe.Pointer(cs))
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cleanup := func() {
|
||||||
|
C.mlx_vector_string_free(vec)
|
||||||
|
}
|
||||||
|
return vec, cleanup, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func initGatedDeltaMetalKernel() {
|
||||||
|
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
freeInputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeInputs()
|
||||||
|
|
||||||
|
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
freeOutputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeOutputs()
|
||||||
|
|
||||||
|
cName := C.CString("gated_delta_step")
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||||
|
defer C.free(unsafe.Pointer(cSource))
|
||||||
|
cHeader := C.CString("")
|
||||||
|
defer C.free(unsafe.Pointer(cHeader))
|
||||||
|
|
||||||
|
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||||
|
cName,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
cSource,
|
||||||
|
cHeader,
|
||||||
|
C.bool(true),
|
||||||
|
C.bool(false),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||||
|
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||||
|
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||||
|
if gatedDeltaMetalDisabled.Load() {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
Hv, Dv := vd[2], vd[3]
|
||||||
|
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype := q.DType()
|
||||||
|
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||||
|
if gatedDeltaMetalDisabled.Load() {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||||
|
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||||
|
|
||||||
|
cInT := C.CString("InT")
|
||||||
|
defer C.free(unsafe.Pointer(cInT))
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
for _, tpl := range []struct {
|
||||||
|
name string
|
||||||
|
value int
|
||||||
|
}{
|
||||||
|
{name: "Dk", value: Dk},
|
||||||
|
{name: "Dv", value: Dv},
|
||||||
|
{name: "Hk", value: Hk},
|
||||||
|
{name: "Hv", value: Hv},
|
||||||
|
} {
|
||||||
|
cn := C.CString(tpl.name)
|
||||||
|
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||||
|
C.free(unsafe.Pointer(cn))
|
||||||
|
if rc != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||||
|
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
threadY := Dv
|
||||||
|
if threadY > 4 {
|
||||||
|
threadY = 4
|
||||||
|
}
|
||||||
|
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
tScalar := FromValue(T)
|
||||||
|
inputs := []C.mlx_array{
|
||||||
|
q.ctx,
|
||||||
|
k.ctx,
|
||||||
|
v.ctx,
|
||||||
|
g.ctx,
|
||||||
|
beta.ctx,
|
||||||
|
state.ctx,
|
||||||
|
tScalar.ctx,
|
||||||
|
}
|
||||||
|
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||||
|
gatedDeltaMetalDisabled.Store(true)
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
y = New("GATED_DELTA_METAL_Y")
|
||||||
|
nextState = New("GATED_DELTA_METAL_STATE")
|
||||||
|
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||||
|
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||||
|
return y, nextState, true
|
||||||
|
}
|
||||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
|||||||
return int(peak)
|
return int(peak)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetPeakMemory() {
|
||||||
|
C.mlx_reset_peak_memory()
|
||||||
|
}
|
||||||
|
|
||||||
type Memory struct{}
|
type Memory struct{}
|
||||||
|
|
||||||
func (Memory) LogValue() slog.Value {
|
func (Memory) LogValue() slog.Value {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
defer C.mlx_vector_array_free(vector)
|
defer C.mlx_vector_array_free(vector)
|
||||||
|
|
||||||
for _, output := range outputs {
|
for _, output := range outputs {
|
||||||
if output.Valid() {
|
if output != nil && output.Valid() {
|
||||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,6 +93,12 @@ func (t *Array) Divide(other *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
||||||
|
out := New("CUMSUM")
|
||||||
|
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) ExpandDims(axis int) *Array {
|
func (t *Array) ExpandDims(axis int) *Array {
|
||||||
out := New("EXPAND_DIMS")
|
out := New("EXPAND_DIMS")
|
||||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||||
@@ -123,12 +129,30 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) GreaterEqual(other *Array) *Array {
|
||||||
|
out := New("GREATER_EQUAL")
|
||||||
|
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||||
out := New("LOGSUMEXP")
|
out := New("LOGSUMEXP")
|
||||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) Less(other *Array) *Array {
|
||||||
|
out := New("LESS")
|
||||||
|
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Array) LogicalOr(other *Array) *Array {
|
||||||
|
out := New("LOGICAL_OR")
|
||||||
|
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Matmul(other *Array) *Array {
|
func (t *Array) Matmul(other *Array) *Array {
|
||||||
out := New("MATMUL")
|
out := New("MATMUL")
|
||||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
|||||||
@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||||
|
out := New("CONV1D")
|
||||||
|
C.mlx_conv1d(
|
||||||
|
&out.ctx,
|
||||||
|
x.ctx,
|
||||||
|
weight.ctx,
|
||||||
|
C.int(stride),
|
||||||
|
C.int(padding),
|
||||||
|
C.int(dilation),
|
||||||
|
C.int(groups),
|
||||||
|
DefaultStream().ctx,
|
||||||
|
)
|
||||||
|
if bias != nil && bias.Valid() {
|
||||||
|
out = Add(out, bias)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||||
|
out := New("CONTIGUOUS")
|
||||||
|
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||||
|
groups := int32(x.Dim(x.NumDims() - 1))
|
||||||
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||||
|
}
|
||||||
|
|
||||||
// Convenience wrappers (function-style for the model code)
|
// Convenience wrappers (function-style for the model code)
|
||||||
|
|
||||||
func Stack(arrays []*Array, axis int) *Array {
|
func Stack(arrays []*Array, axis int) *Array {
|
||||||
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
|
|||||||
return a.Sigmoid()
|
return a.Sigmoid()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Exp(a *Array) *Array {
|
||||||
|
out := New("EXP")
|
||||||
|
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Log(a *Array) *Array {
|
||||||
|
out := New("LOG")
|
||||||
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||||
|
out := New("SOFTMAX_AXIS")
|
||||||
|
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||||
mask := New("")
|
mask := New("")
|
||||||
sinks := New("")
|
sinks := New("")
|
||||||
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
|||||||
|
|
||||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||||
out := New("FAST_RMSNORM")
|
out := New("FAST_RMSNORM")
|
||||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
var w C.mlx_array
|
||||||
|
if weight != nil {
|
||||||
|
w = weight.ctx
|
||||||
|
}
|
||||||
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,6 +429,27 @@ func Collect(v any) []*Array {
|
|||||||
return arrays
|
return arrays
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||||
|
func Snapshot(a *Array) *Array {
|
||||||
|
if a == nil || !a.Valid() {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
out := New("SNAPSHOT")
|
||||||
|
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detach returns a new Array handle that shares the same MLX value but does
|
||||||
|
// not retain Go-side graph input references.
|
||||||
|
func Detach(a *Array) *Array {
|
||||||
|
if a == nil || !a.Valid() {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
out := New("DETACH")
|
||||||
|
C.mlx_array_set(&out.ctx, a.ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Model interface {
|
|||||||
Unembed(x *mlx.Array) *mlx.Array
|
Unembed(x *mlx.Array) *mlx.Array
|
||||||
NumLayers() int
|
NumLayers() int
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
|
MaxContextLength() int
|
||||||
|
|
||||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||||
|
|||||||
@@ -6,19 +6,47 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func prefillChunkSize() int {
|
||||||
|
return 2 << 10
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||||
if r.Model == nil {
|
if r.Model == nil {
|
||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := request.Ctx
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
sample, logprobs *mlx.Array
|
||||||
|
nextSample, nextLogprobs *mlx.Array
|
||||||
|
)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
mlx.Unpin(sample, logprobs)
|
||||||
|
mlx.Unpin(nextSample, nextLogprobs)
|
||||||
|
mlx.Sweep()
|
||||||
|
mlx.ClearCache()
|
||||||
|
|
||||||
|
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||||
|
mlx.LogArrays()
|
||||||
|
r.cache.log()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
enableCompile := true
|
enableCompile := true
|
||||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||||
enableCompile = modelCompile.EnableCompile()
|
enableCompile = modelCompile.EnableCompile()
|
||||||
@@ -28,46 +56,72 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
} else {
|
} else {
|
||||||
mlx.DisableCompile()
|
mlx.DisableCompile()
|
||||||
}
|
}
|
||||||
|
mlx.ResetPeakMemory()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return errors.New("empty prompt")
|
||||||
|
}
|
||||||
|
|
||||||
caches, tokens := r.FindNearestCache(inputs)
|
if len(inputs) >= r.contextLength {
|
||||||
if len(caches) == 0 {
|
return api.StatusError{
|
||||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
StatusCode: http.StatusBadRequest,
|
||||||
caches = cacheFactory.NewCaches()
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||||
} else {
|
|
||||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
|
||||||
for i := range caches {
|
|
||||||
caches[i] = cache.NewKVCache()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cap generation to stay within the model's context length
|
||||||
|
maxGenerate := r.contextLength - len(inputs)
|
||||||
|
if request.Options.MaxTokens <= 0 {
|
||||||
|
request.Options.MaxTokens = maxGenerate
|
||||||
|
} else {
|
||||||
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
|
session := r.cache.begin(r.Model, inputs)
|
||||||
|
defer session.close()
|
||||||
|
caches := session.caches
|
||||||
|
tokens := session.remaining
|
||||||
|
history := append([]int32(nil), session.inputs...)
|
||||||
|
prefillChunk := prefillChunkSize()
|
||||||
|
|
||||||
|
materializeCaches := func() {
|
||||||
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||||
|
for _, c := range caches {
|
||||||
|
if c == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = append(state, c.Materialize()...)
|
||||||
|
}
|
||||||
|
if len(state) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mlx.Eval(state...)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
n := min(2<<10, total-processed-1)
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := min(prefillChunk, total-processed-1)
|
||||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
mlx.Eval(func() []*mlx.Array {
|
materializeCaches()
|
||||||
s := make([]*mlx.Array, 2*len(caches))
|
|
||||||
for i, c := range caches {
|
|
||||||
s[2*i], s[2*i+1] = c.State()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}()...)
|
|
||||||
processed += n
|
processed += n
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) {
|
||||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||||
logits := r.Model.Unembed(fwd)
|
logits := r.Model.Unembed(fwd)
|
||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||||
sample := request.Sample(logprobs)
|
sample := request.Sample(logprobs, history)
|
||||||
|
|
||||||
mlx.Pin(sample, logprobs)
|
mlx.Pin(sample, logprobs)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
@@ -76,61 +130,59 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return sample, logprobs
|
return sample, logprobs
|
||||||
}
|
}
|
||||||
|
|
||||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
now := time.Now()
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
|
||||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
nextSample, nextLogprobs := step(sample)
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptTokensDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
output := int32(sample.Int())
|
output := int32(sample.Int())
|
||||||
outputs = append(outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
history = append(history, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
mlx.Unpin(nextSample, nextLogprobs)
|
|
||||||
final.Token = int(output)
|
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.CompletionTokens = i
|
final.EvalCount = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Responses <- Response{
|
select {
|
||||||
Text: r.Decode(output, &b),
|
case <-request.Ctx.Done():
|
||||||
Token: int(output),
|
return request.Ctx.Err()
|
||||||
|
case request.Responses <- CompletionResponse{
|
||||||
|
Content: r.Decode(output, &b),
|
||||||
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nextSample, nextLogprobs = step(sample, history)
|
||||||
|
|
||||||
mlx.Unpin(sample, logprobs)
|
mlx.Unpin(sample, logprobs)
|
||||||
|
sample, logprobs = nextSample, nextLogprobs
|
||||||
|
nextSample, nextLogprobs = nil, nil
|
||||||
|
|
||||||
if i%256 == 0 {
|
if i%256 == 0 {
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
sample, logprobs = nextSample, nextLogprobs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx.Unpin(sample, logprobs)
|
final.EvalDuration = time.Since(now)
|
||||||
final.CompletionTokensDuration = time.Since(now)
|
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||||
request.Responses <- final
|
select {
|
||||||
r.InsertCache(append(inputs, outputs...), caches)
|
case <-ctx.Done():
|
||||||
mlx.Sweep()
|
return ctx.Err()
|
||||||
|
case request.Responses <- final:
|
||||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
return nil
|
||||||
mlx.LogArrays()
|
|
||||||
if r.cache != nil {
|
|
||||||
r.cache.LogCache()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||||
|
|||||||
@@ -4,15 +4,15 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -22,46 +22,39 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan Response
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
|
Ctx context.Context
|
||||||
|
|
||||||
sample.Sampler
|
sample.Sampler
|
||||||
caches []cache.Cache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextCompletionsRequest struct {
|
type TextCompletionsRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Think *bool `json:"think,omitempty"`
|
||||||
Options struct {
|
Options struct {
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP *float32 `json:"top_p"`
|
||||||
MinP float32 `json:"min_p"`
|
MinP *float32 `json:"min_p"`
|
||||||
TopK int `json:"top_k"`
|
TopK *int `json:"top_k"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
RepeatLastN *int `json:"repeat_last_n"`
|
||||||
|
RepeatPenalty *float32 `json:"repeat_penalty"`
|
||||||
|
PresencePenalty *float32 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
|
||||||
// Deprecated: use MaxTokens instead
|
// Deprecated: use MaxTokens instead
|
||||||
NumPredict int `json:"num_predict"`
|
NumPredict int `json:"num_predict"`
|
||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Text string `json:"content,omitempty"`
|
|
||||||
Token int `json:"token,omitempty"`
|
|
||||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
||||||
Done bool `json:"done,omitempty"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
CompletionTokens int `json:"eval_count,omitempty"`
|
|
||||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
TotalTokens int `json:"total_tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
cache *CacheEntry
|
cache kvCache
|
||||||
|
contextLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
@@ -90,6 +83,7 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
|
|
||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
|
r.contextLength = m.MaxContextLength()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +151,18 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
return nil
|
return nil
|
||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
break
|
slog.Info("Request terminated", "error", err)
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
statusErr = api.StatusError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||||
|
case <-request.Ctx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -9,69 +9,204 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Sampler interface {
|
type Sampler interface {
|
||||||
Sample(*mlx.Array) *mlx.Array
|
Sample(*mlx.Array, []int32) *mlx.Array
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler {
|
||||||
if temp == 0 {
|
|
||||||
return greedy{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var samplers []Sampler
|
var samplers []Sampler
|
||||||
if top_p > 0 && top_p < 1 {
|
if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) {
|
||||||
samplers = append(samplers, TopP(top_p))
|
samplers = append(samplers, Penalty{
|
||||||
|
RepeatLastN: repeatLastN,
|
||||||
|
RepeatPenalty: repeatPenalty,
|
||||||
|
PresencePenalty: presencePenalty,
|
||||||
|
FrequencyPenalty: frequencyPenalty,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if min_p != 0 {
|
if temp == 0 {
|
||||||
samplers = append(samplers, MinP(min_p))
|
samplers = append(samplers, greedy{})
|
||||||
|
} else {
|
||||||
|
samplers = append(samplers, Distribution{
|
||||||
|
Temperature: temp,
|
||||||
|
TopK: top_k,
|
||||||
|
TopP: top_p,
|
||||||
|
MinP: min_p,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if top_k > 0 {
|
|
||||||
samplers = append(samplers, TopK(top_k))
|
|
||||||
}
|
|
||||||
|
|
||||||
samplers = append(samplers, Temperature(temp))
|
|
||||||
return chain(samplers)
|
return chain(samplers)
|
||||||
}
|
}
|
||||||
|
|
||||||
type greedy struct{}
|
type greedy struct{}
|
||||||
|
|
||||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
||||||
return logits.Argmax(-1, false)
|
return logits.Argmax(-1, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
type chain []Sampler
|
type chain []Sampler
|
||||||
|
|
||||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array {
|
||||||
for _, sampler := range c {
|
for _, sampler := range c {
|
||||||
logits = sampler.Sample(logits)
|
logits = sampler.Sample(logits, history)
|
||||||
}
|
}
|
||||||
return logits
|
return logits
|
||||||
}
|
}
|
||||||
|
|
||||||
type Temperature float32
|
type Distribution struct {
|
||||||
|
Temperature float32
|
||||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
TopK int
|
||||||
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
|
TopP float32
|
||||||
|
MinP float32
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopP float32
|
func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
||||||
|
filtered, indices := d.filter(logits)
|
||||||
|
sample := filtered.Categorical(-1)
|
||||||
|
if indices == nil {
|
||||||
|
return sample
|
||||||
|
}
|
||||||
|
|
||||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
positions := sample.ExpandDims(1)
|
||||||
// TODO: implement
|
return indices.TakeAlongAxis(positions, -1).Squeeze(1)
|
||||||
return logprobs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MinP float32
|
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
candidates := logits
|
||||||
|
var candidateIndices *mlx.Array
|
||||||
|
|
||||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) {
|
||||||
// TODO: implement
|
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1)
|
||||||
return logprobs
|
switch logits.NumDims() {
|
||||||
|
case 1:
|
||||||
|
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
|
||||||
|
default:
|
||||||
|
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
|
||||||
|
}
|
||||||
|
candidates = logits.TakeAlongAxis(candidateIndices, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.Temperature != 1 {
|
||||||
|
candidates = mlx.DivScalar(candidates, d.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.needsProbabilityFilters() {
|
||||||
|
return candidates, candidateIndices
|
||||||
|
}
|
||||||
|
|
||||||
|
order := candidates.Negative().ArgsortAxis(-1)
|
||||||
|
sortedLogits := candidates.TakeAlongAxis(order, -1)
|
||||||
|
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
|
||||||
|
|
||||||
|
remove := d.topPRemovalMask(sortedProbs)
|
||||||
|
if d.MinP > 0 {
|
||||||
|
minPRemove := d.minPRemovalMask(sortedProbs)
|
||||||
|
if remove == nil {
|
||||||
|
remove = minPRemove
|
||||||
|
} else {
|
||||||
|
remove = remove.LogicalOr(minPRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if remove == nil {
|
||||||
|
return candidates, candidateIndices
|
||||||
|
}
|
||||||
|
|
||||||
|
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||||
|
filtered := mlx.Where(remove, negInf, sortedLogits)
|
||||||
|
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopK int
|
func (d Distribution) needsProbabilityFilters() bool {
|
||||||
|
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
|
||||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
}
|
||||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
|
||||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
||||||
|
if d.TopP <= 0 || d.TopP >= 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold := mlx.NewScalarArray(d.TopP)
|
||||||
|
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
||||||
|
return prevCum.GreaterEqual(threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
||||||
|
if d.MinP <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var maxProb *mlx.Array
|
||||||
|
switch sortedProbs.NumDims() {
|
||||||
|
case 1:
|
||||||
|
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
|
||||||
|
default:
|
||||||
|
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold := mlx.MulScalar(maxProb, d.MinP)
|
||||||
|
return sortedProbs.Less(threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Penalty struct {
|
||||||
|
RepeatLastN int
|
||||||
|
RepeatPenalty float32
|
||||||
|
PresencePenalty float32
|
||||||
|
FrequencyPenalty float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
|
||||||
|
if len(history) == 0 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
window := p.RepeatLastN
|
||||||
|
if window <= 0 || window > len(history) {
|
||||||
|
window = len(history)
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make(map[int32]int, window)
|
||||||
|
order := make([]int32, 0, window)
|
||||||
|
for _, token := range history[len(history)-window:] {
|
||||||
|
if token < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if counts[token] == 0 {
|
||||||
|
order = append(order, token)
|
||||||
|
}
|
||||||
|
counts[token]++
|
||||||
|
}
|
||||||
|
if len(order) == 0 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
indexShape := []int32{int32(len(order))}
|
||||||
|
valueShape := []int{len(order)}
|
||||||
|
if logprobs.NumDims() > 1 {
|
||||||
|
indexShape = []int32{1, int32(len(order))}
|
||||||
|
valueShape = []int{1, len(order)}
|
||||||
|
}
|
||||||
|
|
||||||
|
indices := mlx.NewArrayInt32(order, indexShape)
|
||||||
|
selected := logprobs.TakeAlongAxis(indices, -1)
|
||||||
|
mlx.Eval(selected)
|
||||||
|
|
||||||
|
values := selected.Floats()
|
||||||
|
for i, token := range order {
|
||||||
|
v := values[i]
|
||||||
|
if p.RepeatPenalty != 1 {
|
||||||
|
if v < 0 {
|
||||||
|
v *= p.RepeatPenalty
|
||||||
|
} else {
|
||||||
|
v /= p.RepeatPenalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.PresencePenalty != 0 {
|
||||||
|
v -= p.PresencePenalty
|
||||||
|
}
|
||||||
|
if p.FrequencyPenalty != 0 {
|
||||||
|
v -= p.FrequencyPenalty * float32(counts[token])
|
||||||
|
}
|
||||||
|
values[i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
|
||||||
}
|
}
|
||||||
|
|||||||
104
x/mlxrunner/sample/sample_test.go
Normal file
104
x/mlxrunner/sample/sample_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPenaltySample(t *testing.T) {
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logprobs := mlx.FromValues([]float32{
|
||||||
|
1.0, -2.0, 3.0, 4.0,
|
||||||
|
}, 1, 4)
|
||||||
|
|
||||||
|
got := Penalty{
|
||||||
|
RepeatLastN: 3,
|
||||||
|
RepeatPenalty: 2.0,
|
||||||
|
PresencePenalty: 1.5,
|
||||||
|
FrequencyPenalty: 0.25,
|
||||||
|
}.Sample(logprobs, []int32{2, 1, 2})
|
||||||
|
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
want := []float32{1.0, -5.75, -0.5, 4.0}
|
||||||
|
values := got.Floats()
|
||||||
|
if len(values) != len(want) {
|
||||||
|
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range want {
|
||||||
|
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
||||||
|
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logprobs := mlx.FromValues([]float32{
|
||||||
|
1.0, 2.0, 3.0,
|
||||||
|
}, 1, 3)
|
||||||
|
|
||||||
|
got := Penalty{
|
||||||
|
RepeatLastN: 1,
|
||||||
|
PresencePenalty: 1.0,
|
||||||
|
}.Sample(logprobs, []int32{0, 1})
|
||||||
|
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
want := []float32{1.0, 1.0, 3.0}
|
||||||
|
values := got.Floats()
|
||||||
|
for i := range want {
|
||||||
|
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
||||||
|
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDistributionFilterTopP(t *testing.T) {
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits := mlx.FromValues([]float32{
|
||||||
|
10.0, 9.0, 1.0, 0.0,
|
||||||
|
}, 1, 4)
|
||||||
|
|
||||||
|
filtered, indices := Distribution{
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopK: 2,
|
||||||
|
TopP: 0.55,
|
||||||
|
}.filter(logits)
|
||||||
|
|
||||||
|
got := materializeFilteredLogits(filtered, indices, 4)
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
values := got.Floats()
|
||||||
|
if values[0] != 10.0 {
|
||||||
|
t.Fatalf("values[0] = %v, want 10", values[0])
|
||||||
|
}
|
||||||
|
for i := 1; i < len(values); i++ {
|
||||||
|
if !math.IsInf(float64(values[i]), -1) {
|
||||||
|
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
|
||||||
|
if indices == nil {
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
|
||||||
|
return base.PutAlongAxis(indices, filtered, -1)
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ package mlxrunner
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,12 +16,89 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||||
|
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type samplingConfig struct {
|
||||||
|
temperature float32
|
||||||
|
topP float32
|
||||||
|
minP float32
|
||||||
|
topK int
|
||||||
|
repeatLastN int
|
||||||
|
repeatPenalty float32
|
||||||
|
presencePenalty float32
|
||||||
|
frequencyPenalty float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
|
||||||
|
if _, ok := m.(*qwen3_5.Model); ok {
|
||||||
|
cfg := samplingConfig{
|
||||||
|
temperature: 1.0,
|
||||||
|
topP: 0.95,
|
||||||
|
minP: 0.0,
|
||||||
|
topK: 20,
|
||||||
|
repeatLastN: 64,
|
||||||
|
repeatPenalty: 1.0,
|
||||||
|
presencePenalty: 1.5,
|
||||||
|
frequencyPenalty: 0.0,
|
||||||
|
}
|
||||||
|
if think != nil && !*think {
|
||||||
|
cfg.temperature = 0.7
|
||||||
|
cfg.topP = 0.8
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
return samplingConfig{
|
||||||
|
temperature: opts.Temperature,
|
||||||
|
topP: opts.TopP,
|
||||||
|
minP: opts.MinP,
|
||||||
|
topK: opts.TopK,
|
||||||
|
repeatLastN: opts.RepeatLastN,
|
||||||
|
repeatPenalty: opts.RepeatPenalty,
|
||||||
|
presencePenalty: opts.PresencePenalty,
|
||||||
|
frequencyPenalty: opts.FrequencyPenalty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
|
||||||
|
cfg := defaultSamplingConfig(m, req.Think)
|
||||||
|
|
||||||
|
if req.Options.Temperature != nil {
|
||||||
|
cfg.temperature = *req.Options.Temperature
|
||||||
|
}
|
||||||
|
if req.Options.TopP != nil {
|
||||||
|
cfg.topP = *req.Options.TopP
|
||||||
|
}
|
||||||
|
if req.Options.MinP != nil {
|
||||||
|
cfg.minP = *req.Options.MinP
|
||||||
|
}
|
||||||
|
if req.Options.TopK != nil {
|
||||||
|
cfg.topK = *req.Options.TopK
|
||||||
|
}
|
||||||
|
if req.Options.RepeatLastN != nil {
|
||||||
|
cfg.repeatLastN = *req.Options.RepeatLastN
|
||||||
|
}
|
||||||
|
if req.Options.RepeatPenalty != nil {
|
||||||
|
cfg.repeatPenalty = *req.Options.RepeatPenalty
|
||||||
|
}
|
||||||
|
if req.Options.PresencePenalty != nil {
|
||||||
|
cfg.presencePenalty = *req.Options.PresencePenalty
|
||||||
|
}
|
||||||
|
if req.Options.FrequencyPenalty != nil {
|
||||||
|
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
func Execute(args []string) error {
|
func Execute(args []string) error {
|
||||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||||
|
|
||||||
@@ -49,9 +127,11 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||||
"status": 0,
|
Status: 0,
|
||||||
"progress": 100,
|
Progress: 100,
|
||||||
|
ContextLength: runner.contextLength,
|
||||||
|
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
slog.Error("Failed to encode response", "error", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -77,7 +157,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan Response)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
@@ -86,31 +166,51 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||||
if request.Options.MaxTokens < 1 {
|
|
||||||
request.Options.MaxTokens = 16 << 10
|
sampling := resolveSamplingConfig(runner.Model, request)
|
||||||
}
|
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(
|
||||||
request.Options.Temperature,
|
sampling.temperature,
|
||||||
request.Options.TopP,
|
sampling.topP,
|
||||||
request.Options.MinP,
|
sampling.minP,
|
||||||
request.Options.TopK,
|
sampling.topK,
|
||||||
|
sampling.repeatLastN,
|
||||||
|
sampling.repeatPenalty,
|
||||||
|
sampling.presencePenalty,
|
||||||
|
sampling.frequencyPenalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
runner.Requests <- request
|
var cancel context.CancelFunc
|
||||||
|
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case runner.Requests <- request:
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/jsonl")
|
w.Header().Set("Content-Type", "application/jsonl")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
enc := json.NewEncoder(w)
|
enc := json.NewEncoder(w)
|
||||||
for response := range request.Responses {
|
for {
|
||||||
if err := enc.Encode(response); err != nil {
|
select {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
case <-r.Context().Done():
|
||||||
return
|
return
|
||||||
}
|
case response, ok := <-request.Responses:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if f, ok := w.(http.Flusher); ok {
|
if err := enc.Encode(response); err != nil {
|
||||||
f.Flush()
|
slog.Error("Failed to encode response", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
172
x/mlxrunner/server_test.go
Normal file
172
x/mlxrunner/server_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package mlxrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
|
"github.com/ollama/ollama/x/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubModel struct{}
|
||||||
|
|
||||||
|
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
|
||||||
|
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
|
||||||
|
func (stubModel) NumLayers() int { return 0 }
|
||||||
|
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
|
||||||
|
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
|
||||||
|
|
||||||
|
func TestResolveSamplingConfigDefaults(t *testing.T) {
|
||||||
|
trueValue := true
|
||||||
|
falseValue := false
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model base.Model
|
||||||
|
req Request
|
||||||
|
want samplingConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "generic model uses api defaults",
|
||||||
|
model: stubModel{},
|
||||||
|
req: Request{},
|
||||||
|
want: samplingConfig{
|
||||||
|
temperature: 0.8,
|
||||||
|
topP: 0.9,
|
||||||
|
minP: 0.0,
|
||||||
|
topK: 40,
|
||||||
|
repeatLastN: 64,
|
||||||
|
repeatPenalty: 1.1,
|
||||||
|
presencePenalty: 0.0,
|
||||||
|
frequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3.5 defaults to thinking profile when think unset",
|
||||||
|
model: &qwen3_5.Model{},
|
||||||
|
req: Request{},
|
||||||
|
want: samplingConfig{
|
||||||
|
temperature: 1.0,
|
||||||
|
topP: 0.95,
|
||||||
|
minP: 0.0,
|
||||||
|
topK: 20,
|
||||||
|
repeatLastN: 64,
|
||||||
|
repeatPenalty: 1.0,
|
||||||
|
presencePenalty: 1.5,
|
||||||
|
frequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3.5 thinking disabled defaults",
|
||||||
|
model: &qwen3_5.Model{},
|
||||||
|
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
|
||||||
|
want: samplingConfig{
|
||||||
|
temperature: 0.7,
|
||||||
|
topP: 0.8,
|
||||||
|
minP: 0.0,
|
||||||
|
topK: 20,
|
||||||
|
repeatLastN: 64,
|
||||||
|
repeatPenalty: 1.0,
|
||||||
|
presencePenalty: 1.5,
|
||||||
|
frequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3.5 thinking enabled defaults",
|
||||||
|
model: &qwen3_5.Model{},
|
||||||
|
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
|
||||||
|
want: samplingConfig{
|
||||||
|
temperature: 1.0,
|
||||||
|
topP: 0.95,
|
||||||
|
minP: 0.0,
|
||||||
|
topK: 20,
|
||||||
|
repeatLastN: 64,
|
||||||
|
repeatPenalty: 1.0,
|
||||||
|
presencePenalty: 1.5,
|
||||||
|
frequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
|
||||||
|
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
|
||||||
|
trueValue := true
|
||||||
|
temperature := float32(0.4)
|
||||||
|
topP := float32(0.6)
|
||||||
|
minP := float32(0.05)
|
||||||
|
topK := 12
|
||||||
|
repeatLastN := 32
|
||||||
|
repeatPenalty := float32(1.1)
|
||||||
|
presencePenalty := float32(0.7)
|
||||||
|
frequencyPenalty := float32(0.2)
|
||||||
|
|
||||||
|
got := resolveSamplingConfig(stubModel{}, Request{
|
||||||
|
TextCompletionsRequest: TextCompletionsRequest{
|
||||||
|
Think: &trueValue,
|
||||||
|
Options: struct {
|
||||||
|
Temperature *float32 `json:"temperature"`
|
||||||
|
TopP *float32 `json:"top_p"`
|
||||||
|
MinP *float32 `json:"min_p"`
|
||||||
|
TopK *int `json:"top_k"`
|
||||||
|
RepeatLastN *int `json:"repeat_last_n"`
|
||||||
|
RepeatPenalty *float32 `json:"repeat_penalty"`
|
||||||
|
PresencePenalty *float32 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
NumPredict int `json:"num_predict"`
|
||||||
|
}{
|
||||||
|
Temperature: &temperature,
|
||||||
|
TopP: &topP,
|
||||||
|
MinP: &minP,
|
||||||
|
TopK: &topK,
|
||||||
|
RepeatLastN: &repeatLastN,
|
||||||
|
RepeatPenalty: &repeatPenalty,
|
||||||
|
PresencePenalty: &presencePenalty,
|
||||||
|
FrequencyPenalty: &frequencyPenalty,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
want := samplingConfig{
|
||||||
|
temperature: temperature,
|
||||||
|
topP: topP,
|
||||||
|
minP: minP,
|
||||||
|
topK: topK,
|
||||||
|
repeatLastN: repeatLastN,
|
||||||
|
repeatPenalty: repeatPenalty,
|
||||||
|
presencePenalty: presencePenalty,
|
||||||
|
frequencyPenalty: frequencyPenalty,
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
|
||||||
|
want := api.DefaultOptions()
|
||||||
|
got := defaultSamplingConfig(stubModel{}, nil)
|
||||||
|
|
||||||
|
if got.temperature != want.Temperature ||
|
||||||
|
got.topP != want.TopP ||
|
||||||
|
got.minP != want.MinP ||
|
||||||
|
got.topK != want.TopK ||
|
||||||
|
got.repeatLastN != want.RepeatLastN ||
|
||||||
|
got.repeatPenalty != want.RepeatPenalty ||
|
||||||
|
got.presencePenalty != want.PresencePenalty ||
|
||||||
|
got.frequencyPenalty != want.FrequencyPenalty {
|
||||||
|
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
// MaxContextLength returns the maximum context length
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
// VocabSize returns the vocabulary size
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,40 @@ type LinearLayer interface {
|
|||||||
OutputDim() int32
|
OutputDim() int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conv1d applies 1D convolution over NLC input.
|
||||||
|
type Conv1d struct {
|
||||||
|
Weight *mlx.Array
|
||||||
|
Bias *mlx.Array
|
||||||
|
Stride int32
|
||||||
|
Padding int32
|
||||||
|
Dilation int32
|
||||||
|
Groups int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||||
|
if stride <= 0 {
|
||||||
|
stride = 1
|
||||||
|
}
|
||||||
|
if dilation <= 0 {
|
||||||
|
dilation = 1
|
||||||
|
}
|
||||||
|
if groups <= 0 {
|
||||||
|
groups = 1
|
||||||
|
}
|
||||||
|
return &Conv1d{
|
||||||
|
Weight: weight,
|
||||||
|
Bias: bias,
|
||||||
|
Stride: stride,
|
||||||
|
Padding: padding,
|
||||||
|
Dilation: dilation,
|
||||||
|
Groups: groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
// Linear applies an affine transformation: y = x @ W.T + b
|
// Linear applies an affine transformation: y = x @ W.T + b
|
||||||
type Linear struct {
|
type Linear struct {
|
||||||
Weight *mlx.Array
|
Weight *mlx.Array
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
1457
x/models/qwen3_5/qwen3_5.go
Normal file
1457
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
166
x/models/qwen3_5/qwen3_5_test.go
Normal file
166
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 8,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"num_experts": 16,
|
||||||
|
"num_experts_per_tok": 4,
|
||||||
|
"moe_intermediate_size": 2048,
|
||||||
|
"shared_expert_intermediate_size": 4096,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 500000,
|
||||||
|
"partial_rotary_factor": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.RopeTheta != 500000 {
|
||||||
|
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.RopeDim != 64 {
|
||||||
|
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||||
|
}
|
||||||
|
if cfg.FullAttentionInterval != 4 {
|
||||||
|
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||||
|
}
|
||||||
|
if !cfg.NormTopKProb {
|
||||||
|
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLayerSelectionHelpers(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
NumHiddenLayers: 6,
|
||||||
|
FullAttentionInterval: 3,
|
||||||
|
NumExperts: 8,
|
||||||
|
DecoderSparseStep: 2,
|
||||||
|
MLPOnlyLayers: []int32{1},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !layerIsLinear(cfg, 0) {
|
||||||
|
t.Fatalf("layer 0 should be linear")
|
||||||
|
}
|
||||||
|
if layerIsLinear(cfg, 2) {
|
||||||
|
t.Fatalf("layer 2 should be full attention")
|
||||||
|
}
|
||||||
|
|
||||||
|
if layerUsesMoE(cfg, 1) {
|
||||||
|
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||||
|
}
|
||||||
|
if !layerUsesMoE(cfg, 3) {
|
||||||
|
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTensorPathLayout(t *testing.T) {
|
||||||
|
dummy := mlx.New("dummy")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
wantContainer string
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "standard",
|
||||||
|
key: "model.embed_tokens.weight",
|
||||||
|
wantContainer: "",
|
||||||
|
wantModel: "model.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested language model with inner model",
|
||||||
|
key: "model.language_model.model.embed_tokens.weight",
|
||||||
|
wantContainer: "model.language_model.",
|
||||||
|
wantModel: "model.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested language model without inner model",
|
||||||
|
key: "model.language_model.embed_tokens.weight",
|
||||||
|
wantContainer: "model.language_model.",
|
||||||
|
wantModel: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||||
|
tt.key: dummy,
|
||||||
|
})
|
||||||
|
|
||||||
|
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||||
|
t.Fatalf(
|
||||||
|
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||||
|
layout.containerPrefix,
|
||||||
|
layout.modelPrefix,
|
||||||
|
tt.wantContainer,
|
||||||
|
tt.wantModel,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRuntimeDefaults(t *testing.T) {
|
||||||
|
m := &Model{}
|
||||||
|
if m.DisablePromptCache() {
|
||||||
|
t.Fatal("DisablePromptCache() = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCachesLayout(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Config: &Config{
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearKeyHeadDim: 8,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 16,
|
||||||
|
},
|
||||||
|
Layers: []*Layer{
|
||||||
|
{IsLinear: true},
|
||||||
|
{IsLinear: false},
|
||||||
|
{IsLinear: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
caches := m.NewCaches()
|
||||||
|
if len(caches) != len(m.Layers) {
|
||||||
|
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||||
|
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||||
|
}
|
||||||
|
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||||
|
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||||
|
}
|
||||||
|
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||||
|
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||||
|
package qwen3_5_moe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||||
|
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user