diff --git a/x/create/client/create.go b/x/create/client/create.go index 8a3beba64..00193a8c1 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { spinnerKey = "create" capabilities = []string{"completion"} - // Check if model supports thinking based on architecture - if supportsThinking(opts.ModelDir) { + configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json")) + mcfg := parseModelConfig(configData) + + if mcfg.supportsThinking() { capabilities = append(capabilities, "thinking") } + if mcfg.supportsVision() { + capabilities = append(capabilities, "vision") + } - // Set parser and renderer name based on architecture - parserName = getParserName(opts.ModelDir) - rendererName = getRendererName(opts.ModelDir) + parserName = mcfg.parserName() + rendererName = mcfg.rendererName() } else { modelType = "image generation model" spinnerKey = "imagegen" @@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) { return layers, nil } -// supportsThinking checks if the model supports thinking mode based on its architecture. -// This reads the config.json from the model directory and checks the architectures field. -func supportsThinking(modelDir string) bool { - configPath := filepath.Join(modelDir, "config.json") - data, err := os.ReadFile(configPath) - if err != nil { - return false - } +// modelConfig holds the fields from config.json needed during model creation. +type visionConfig struct { + Depth int32 `json:"depth"` +} - var cfg struct { - Architectures []string `json:"architectures"` - ModelType string `json:"model_type"` - } - if err := json.Unmarshal(data, &cfg); err != nil { - return false - } +type modelConfig struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + VisionConfig *visionConfig `json:"vision_config"` + ImageTokenID *int32 `json:"image_token_id"` + VisionStartTokenID *int32 `json:"vision_start_token_id"` + VisionEndTokenID *int32 `json:"vision_end_token_id"` +} - // Check architectures that support thinking - thinkingArchitectures := []string{ - "glm4moe", // GLM-4 MoE models - "deepseek", // DeepSeek models - "qwen3", // Qwen3 models - } +func parseModelConfig(data []byte) modelConfig { + var cfg modelConfig + _ = json.Unmarshal(data, &cfg) + return cfg +} - // Check the architecture list - for _, arch := range cfg.Architectures { +// archOrTypeContains returns true if any architecture or the model_type +// contains one of the given substrings (case-insensitive). +func (c *modelConfig) archOrTypeContains(substrs ...string) bool { + for _, arch := range c.Architectures { archLower := strings.ToLower(arch) - for _, thinkArch := range thinkingArchitectures { - if strings.Contains(archLower, thinkArch) { + for _, s := range substrs { + if strings.Contains(archLower, s) { return true } } } - - // Also check model_type - if cfg.ModelType != "" { - typeLower := strings.ToLower(cfg.ModelType) - for _, thinkArch := range thinkingArchitectures { - if strings.Contains(typeLower, thinkArch) { + if c.ModelType != "" { + typeLower := strings.ToLower(c.ModelType) + for _, s := range substrs { + if strings.Contains(typeLower, s) { return true } } } - return false } -// getParserName returns the parser name for a model based on its architecture. -// This reads the config.json from the model directory and determines the appropriate parser. -func getParserName(modelDir string) string { - configPath := filepath.Join(modelDir, "config.json") - data, err := os.ReadFile(configPath) - if err != nil { - return "" - } +func (c *modelConfig) supportsThinking() bool { + return c.archOrTypeContains("glm4moe", "deepseek", "qwen3") +} - var cfg struct { - Architectures []string `json:"architectures"` - ModelType string `json:"model_type"` - } - if err := json.Unmarshal(data, &cfg); err != nil { - return "" - } +func (c *modelConfig) supportsVision() bool { + return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil +} - // Check architectures for known parsers - for _, arch := range cfg.Architectures { - archLower := strings.ToLower(arch) - if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { - return "glm-4.7" - } - if strings.Contains(archLower, "deepseek") { - return "deepseek3" - } - if strings.Contains(archLower, "qwen3") { - return "qwen3" - } +func (c *modelConfig) parserName() string { + switch { + case c.archOrTypeContains("glm4", "glm-4"): + return "glm-4.7" + case c.archOrTypeContains("deepseek"): + return "deepseek3" + case c.archOrTypeContains("qwen3"): + return "qwen3" } - - // Also check model_type - if cfg.ModelType != "" { - typeLower := strings.ToLower(cfg.ModelType) - if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { - return "glm-4.7" - } - if strings.Contains(typeLower, "deepseek") { - return "deepseek3" - } - if strings.Contains(typeLower, "qwen3") { - return "qwen3" - } - } - return "" } -// getRendererName returns the renderer name for a model based on its architecture. -// This reads the config.json from the model directory and determines the appropriate renderer. -func getRendererName(modelDir string) string { - configPath := filepath.Join(modelDir, "config.json") - data, err := os.ReadFile(configPath) - if err != nil { - return "" +func (c *modelConfig) rendererName() string { + switch { + case c.archOrTypeContains("glm4", "glm-4"): + return "glm-4.7" + case c.archOrTypeContains("deepseek"): + return "deepseek3" + case c.archOrTypeContains("qwen3"): + return "qwen3-coder" } - - var cfg struct { - Architectures []string `json:"architectures"` - ModelType string `json:"model_type"` - } - if err := json.Unmarshal(data, &cfg); err != nil { - return "" - } - - // Check architectures for known renderers - for _, arch := range cfg.Architectures { - archLower := strings.ToLower(arch) - if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { - return "glm-4.7" - } - if strings.Contains(archLower, "deepseek") { - return "deepseek3" - } - if strings.Contains(archLower, "qwen3") { - return "qwen3-coder" - } - } - - // Also check model_type - if cfg.ModelType != "" { - typeLower := strings.ToLower(cfg.ModelType) - if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { - return "glm-4.7" - } - if strings.Contains(typeLower, "deepseek") { - return "deepseek3" - } - if strings.Contains(typeLower, "qwen3") { - return "qwen3-coder" - } - } - return "" } diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index d901e1e19..bf10aa11a 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) { t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7)) } } + +func TestSupportsVision(t *testing.T) { + t.Run("vision_config present", func(t *testing.T) { + cfg := parseModelConfig([]byte(`{ + "vision_config": {"depth": 2}, + "image_token_id": 151655 + }`)) + if !cfg.supportsVision() { + t.Fatal("supportsVision() = false, want true") + } + }) + + t.Run("token ids alone imply vision", func(t *testing.T) { + cfg := parseModelConfig([]byte(`{ + "vision_start_token_id": 10, + "vision_end_token_id": 11 + }`)) + if !cfg.supportsVision() { + t.Fatal("supportsVision() = false, want true") + } + }) + + t.Run("plain text model", func(t *testing.T) { + cfg := parseModelConfig([]byte(`{ + "architectures": ["Qwen3_5ForCausalLM"] + }`)) + if cfg.supportsVision() { + t.Fatal("supportsVision() = true, want false") + } + }) +} diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index a5709101d..4f8fdeb5b 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -366,6 +366,23 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) { c.enforceEvictionPolicy() } +// clear releases live caches and drops the trie so future requests cannot +// reuse prompt state keyed only by token IDs. +func (c *kvCache) clear() { + c.freeAll() + walkNodes(c.root, func(n *trieNode) bool { + for _, s := range n.snapshots { + if s != nil { + s.Close() + } + } + n.snapshots = nil + return true + }) + c.root = nil + c.activePath = nil +} + // freeAll releases all cache layers. func (c *kvCache) freeAll() { for _, kv := range c.caches { diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 5eff1bd02..17ef4490a 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error { // completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. type completionRequest struct { Prompt string `json:"prompt"` + Images []llm.ImageData `json:"images,omitempty"` Options *completionOpts `json:"options,omitempty"` } @@ -155,6 +156,7 @@ func (c *Client) Close() error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { creq := completionRequest{ Prompt: req.Prompt, + Images: req.Images, } if req.Options != nil { creq.Options = &completionOpts{ diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index ff06092e9..84a9fe85d 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -304,6 +304,18 @@ func Exp(a *Array) *Array { return out } +func Sin(a *Array) *Array { + out := New("SIN") + C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func Cos(a *Array) *Array { + out := New("COS") + C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + func Log(a *Array) *Array { out := New("LOG") C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) diff --git a/x/mlxrunner/model/base/multimodal.go b/x/mlxrunner/model/base/multimodal.go new file mode 100644 index 000000000..91d7fd3ba --- /dev/null +++ b/x/mlxrunner/model/base/multimodal.go @@ -0,0 +1,32 @@ +package base + +import ( + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// ImageInput is a single image attached to a prompt. +type ImageInput struct { + ID int + Data []byte +} + +// PromptTokenization contains tokenized prompt IDs plus optional request-scoped +// model metadata needed during forward. +type PromptTokenization struct { + Tokens []int32 + State any +} + +// MultimodalPromptTokenizerWithState is an optional model interface used by +// mlxrunner to expand tagged multimodal prompts into token IDs, returning +// request-scoped state to be attached to the forward pass. +type MultimodalPromptTokenizerWithState interface { + TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error) +} + +// ForwardWithStateModel is an optional model interface for request-scoped +// forward metadata that should not be stored in shared caches. +type ForwardWithStateModel interface { + ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array +} diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index ea7e12a30..32ff1d1db 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -12,12 +12,42 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" ) func prefillChunkSize() int { return 2 << 10 } +func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) { + if len(request.Images) > 0 { + // The shared trie cache keys only on token IDs today, so multimodal + // prompts must not reuse snapshots across distinct image inputs. + r.cache.clear() + } + + if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 { + images := make([]base.ImageInput, len(request.Images)) + for i := range request.Images { + images[i] = base.ImageInput{ + ID: request.Images[i].ID, + Data: request.Images[i].Data, + } + } + + out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images) + if err != nil { + return nil, nil, err + } + if out == nil { + return nil, nil, errors.New("empty multimodal tokenization result") + } + return out.Tokens, out.State, nil + } + + return r.Tokenizer.Encode(request.Prompt, true), nil, nil +} + func (r *Runner) TextGenerationPipeline(request Request) error { if r.Model == nil { return errors.New("model not loaded") @@ -55,7 +85,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error { slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) }() - inputs := r.Tokenizer.Encode(request.Prompt, true) + inputs, promptState, err := r.tokenizeRequest(request) + if err != nil { + return err + } if len(inputs) == 0 { return errors.New("empty prompt") } @@ -83,6 +116,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error { tokens := session.remaining prefillChunk := prefillChunkSize() + modelForward := func(tokens *mlx.Array) *mlx.Array { + if withState, ok := r.Model.(base.ForwardWithStateModel); ok { + return withState.ForwardWithState(tokens, caches, promptState) + } + return r.Model.Forward(tokens, caches) + } + materializeCaches := func() { state := make([]*mlx.Array, 0, 2*len(caches)) for _, c := range caches { @@ -114,7 +154,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } } - r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) + modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0)) mlx.Sweep() materializeCaches() processed += n @@ -132,7 +172,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { - fwd := r.Model.Forward(token.ExpandDims(0), caches) + fwd := modelForward(token.ExpandDims(0)) logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 08a376d43..77a933e89 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -11,6 +11,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" @@ -29,7 +30,8 @@ type Request struct { } type TextCompletionsRequest struct { - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` + Images []llm.ImageData `json:"images,omitempty"` Options struct { Temperature float32 `json:"temperature"` TopP float32 `json:"top_p"` diff --git a/x/models/qwen3_5/multimodal.go b/x/models/qwen3_5/multimodal.go new file mode 100644 index 000000000..b90514a16 --- /dev/null +++ b/x/models/qwen3_5/multimodal.go @@ -0,0 +1,354 @@ +package qwen3_5 + +import ( + "fmt" + "math" + "regexp" + "strconv" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" +) + +var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`) + +type promptVisionSpan struct { + Start int32 + End int32 + + Main *mlx.Array + Grid *VisionGrid +} + +type promptVisionState struct { + Spans []promptVisionSpan + PositionCache []int32 +} + +func promptStartPosFromCaches(caches []cache.Cache) int32 { + offset := -1 + for _, c := range caches { + if c == nil { + continue + } + off := c.Offset() + if offset < 0 || off < offset { + offset = off + } + } + if offset < 0 { + return 0 + } + return int32(offset) +} + +func promptVisionStateFromState(state any) *promptVisionState { + typed, _ := state.(*promptVisionState) + return typed +} + +func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) { + chunkEnd := chunkStart + chunkLen + overlapStart := max(chunkStart, spanStart) + overlapEnd := min(chunkEnd, spanEnd) + if overlapStart >= overlapEnd { + return 0, 0, 0, 0, false + } + + chunkLo := overlapStart - chunkStart + chunkHi := overlapEnd - chunkStart + spanLo := overlapStart - spanStart + spanHi := overlapEnd - spanStart + return chunkLo, chunkHi, spanLo, spanHi, true +} + +func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array { + if m == nil || h == nil || state == nil || len(state.Spans) == 0 { + return h + } + + dims := h.Dims() + if len(dims) != 3 { + return h + } + + L := int32(dims[1]) + for _, span := range state.Spans { + chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End) + if !ok || span.Main == nil || !span.Main.Valid() { + continue + } + + repl := span.Main.Slice( + mlx.Slice(), + mlx.Slice(int(spanLo), int(spanHi)), + mlx.Slice(), + ) + repl = repl.AsType(h.DType()) + h = h.SliceUpdate( + repl, + mlx.Slice(), + mlx.Slice(int(chunkLo), int(chunkHi)), + mlx.Slice(), + ) + } + + return h +} + +func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) { + for i := range images { + if images[i].ID == id { + return images[i], true + } + } + return base.ImageInput{}, false +} + +func mapPromptPosition(state *promptVisionState, id int32) int32 { + if state == nil { + return id + } + if id < int32(len(state.PositionCache)) { + return state.PositionCache[id] + } + if len(state.PositionCache) > 0 { + return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1 + } + return id +} + +func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 { + if fallback <= 0 { + fallback = 1 + } + if grid == nil { + return fallback + } + if merge <= 0 { + merge = 1 + } + return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge)) +} + +func normalizeMRoPESections(sections []int32) [4]int32 { + var out [4]int32 + for i := range min(4, len(sections)) { + if sections[i] > 0 { + out[i] = sections[i] + } + } + return out +} + +func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int { + if interleaved { + if pair%3 == 1 && pair < 1+3*sections[1] { + return 1 + } + if pair%3 == 2 && pair < 2+3*sections[2] { + return 2 + } + if pair%3 == 0 && pair < 3*sections[0] { + return 0 + } + return 3 + } + + secW := sections[0] + sections[1] + secE := secW + sections[2] + switch { + case pair < sections[0]: + return 0 + case pair < secW: + return 1 + case pair < secE: + return 2 + default: + return 3 + } +} + +func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 { + var positions [4][]int32 + for i := range positions { + positions[i] = make([]int32, chunkLen) + } + + // positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections. + for i := range chunkLen { + p := mapPromptPosition(state, startPos+i) + positions[0][i] = p + positions[1][i] = p + positions[2][i] = p + } + + merge := int32(1) + if m != nil && m.Config != nil && m.Config.Vision != nil { + merge = m.Config.Vision.SpatialMergeSize + } + for _, span := range state.Spans { + if span.Grid == nil { + continue + } + + chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End) + if !ok { + continue + } + + w := max(int32(1), span.Grid.Width/merge) + for i := chunkLo; i < chunkHi; i++ { + rel := spanLo + (i - chunkLo) + positions[1][i] += rel / w + positions[2][i] += rel % w + } + } + + return positions +} + +func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) { + if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 { + return nil, nil + } + + ropeDim := m.Config.RopeDim + if ropeDim%2 != 0 { + ropeDim-- + } + if ropeDim <= 0 { + return nil, nil + } + + half := ropeDim / 2 + positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen) + sections := normalizeMRoPESections(m.Config.MRoPESections) + theta := m.Config.RopeTheta + if theta <= 0 { + theta = 100000.0 + } + + freqs := make([]float64, half) + for j := range half { + freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim)) + } + + angles := make([]float32, chunkLen*ropeDim) + for i := range chunkLen { + base := i * ropeDim + for j := range half { + component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved) + angle := float32(float64(positions[component][i]) * freqs[j]) + angles[base+j] = angle + angles[base+half+j] = angle + } + } + + arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim)) + cos := mlx.Cos(arr) + sin := mlx.Sin(arr) + if dtype != 0 { + cos = cos.AsType(dtype) + sin = sin.AsType(dtype) + } + return cos, sin +} + +func (m *Model) tokenizePromptWithResolvedImages( + prompt string, + images []base.ImageInput, + resolve func([]byte) (*VisionEmbeddings, error), +) ([]int32, *promptVisionState, error) { + if m == nil || m.tok == nil { + return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized") + } + + if m.Vision == nil || m.ImageProcessor == nil || resolve == nil { + return m.tok.Encode(prompt, true), nil, nil + } + + parts := imageTagRE.Split(prompt, -1) + matches := imageTagRE.FindAllStringSubmatch(prompt, -1) + + resolved := make(map[int]*VisionEmbeddings, len(images)) + var out []int32 + state := &promptVisionState{} + var p int32 + appendToken := func(tok, pos int32) { + out = append(out, tok) + state.PositionCache = append(state.PositionCache, pos) + } + for i, part := range parts { + for _, tok := range m.tok.Encode(part, i == 0) { + appendToken(tok, p) + p++ + } + + if i >= len(matches) { + continue + } + + imageID, err := strconv.Atoi(matches[i][1]) + if err != nil { + return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err) + } + + img, ok := findImageByID(images, imageID) + if !ok { + return nil, nil, fmt.Errorf("invalid image index: %d", imageID) + } + + embeds := resolved[imageID] + if embeds == nil { + embeds, err = resolve(img.Data) + if err != nil { + return nil, nil, err + } + resolved[imageID] = embeds + } + if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 { + return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings") + } + + tokensPerImage := int32(embeds.Main.Dim(1)) + if tokensPerImage <= 0 { + return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage) + } + + appendToken(m.VisionStartToken, p) + p++ + basePos := p + spanStart := int32(len(out)) + for range tokensPerImage { + appendToken(m.ImageTokenID, basePos) + } + spanEnd := int32(len(out)) + merge := int32(1) + if m.Config != nil && m.Config.Vision != nil { + merge = m.Config.Vision.SpatialMergeSize + } + gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage) + p += gridSpan + appendToken(m.VisionEndToken, p) + p++ + + state.Spans = append(state.Spans, promptVisionSpan{ + Start: spanStart, + End: spanEnd, + Main: embeds.Main, + Grid: embeds.Grid, + }) + } + + return out, state, nil +} + +func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) { + tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage) + if err != nil { + return nil, err + } + return &base.PromptTokenization{Tokens: tokens, State: state}, nil +} diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index b98830300..6bdac5da8 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -2,6 +2,7 @@ package qwen3_5 import ( + "cmp" "encoding/json" "fmt" "math" @@ -22,16 +23,24 @@ func init() { base.Register("Qwen3NextForConditionalGeneration", NewModel) } +var _ base.MultimodalPromptTokenizerWithState = (*Model)(nil) +var _ base.ForwardWithStateModel = (*Model)(nil) + // RopeParameters carries optional rope metadata embedded under rope_parameters. type RopeParameters struct { Type string `json:"type"` RopeType string `json:"rope_type"` RopeTheta float32 `json:"rope_theta"` PartialRotaryFactor float32 `json:"partial_rotary_factor"` + MRoPEInterleaved bool `json:"mrope_interleaved"` + MRoPESection []int32 `json:"mrope_section"` + DimensionSections []int32 `json:"dimension_sections"` } -// Config holds Qwen 3.5 text config (top-level or nested text_config). -type Config struct { +// TextConfig holds the Qwen 3.5 text-model architecture fields. +// In VLM configs these live under the "text_config" key; in text-only +// configs they appear at the top level. +type TextConfig struct { ModelType string `json:"model_type"` HiddenSize int32 `json:"hidden_size"` IntermediateSize int32 `json:"intermediate_size"` @@ -67,6 +76,19 @@ type Config struct { PartialRotaryFactor float32 `json:"partial_rotary_factor"` RopeScaling map[string]any `json:"rope_scaling"` RopeParameters *RopeParameters `json:"rope_parameters"` + MRoPESections []int32 `json:"mrope_sections"` + MRoPEInterleaved bool `json:"mrope_interleaved"` +} + +// Config is the full model config. It embeds TextConfig for the text-model +// fields and adds top-level-only fields (vision, token IDs, quantization). +type Config struct { + TextConfig + + Vision *VisionConfig `json:"vision_config"` + ImageTokenID int32 `json:"image_token_id"` + VisionStartToken int32 `json:"vision_start_token_id"` + VisionEndToken int32 `json:"vision_end_token_id"` // Quantization metadata. QuantGroupSize int `json:"-"` @@ -90,6 +112,9 @@ type Model struct { *Config weightPrefix string + + Vision *VisionModel + ImageProcessor *VisionImageProcessor } // Layer is a transformer decoder layer. @@ -190,17 +215,24 @@ func parseConfig(configData []byte) (Config, error) { var cfg Config activeRaw := rawTop + + // First pass: unmarshal the full config to pick up top-level fields + // (vision_config, image_token_id, etc.) and text fields for text-only models. + if err := json.Unmarshal(configData, &cfg); err != nil { + return Config{}, fmt.Errorf("parse config: %w", err) + } + + // Second pass: if text_config exists, unmarshal it into TextConfig so + // text-model fields from text_config take priority over any top-level + // duplicates. Top-level-only fields (Vision, token IDs) are unaffected + // because they live on Config, not TextConfig. if textRaw, ok := rawTop["text_config"]; ok { - if err := json.Unmarshal(textRaw, &cfg); err != nil { + if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil { return Config{}, fmt.Errorf("parse text_config: %w", err) } if err := json.Unmarshal(textRaw, &activeRaw); err != nil { return Config{}, fmt.Errorf("parse text_config envelope: %w", err) } - } else { - if err := json.Unmarshal(configData, &cfg); err != nil { - return Config{}, fmt.Errorf("parse config: %w", err) - } } if cfg.HiddenSize <= 0 { @@ -225,12 +257,8 @@ func parseConfig(configData []byte) (Config, error) { return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim) } - if cfg.RMSNormEps == 0 { - cfg.RMSNormEps = 1e-6 - } - if cfg.LinearConvKernelDim <= 0 { - cfg.LinearConvKernelDim = 4 - } + cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6) + cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4) if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 { return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)", cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim) @@ -246,14 +274,21 @@ func parseConfig(configData []byte) (Config, error) { if cfg.RopeParameters.PartialRotaryFactor > 0 { cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor } + if len(cfg.MRoPESections) == 0 { + switch { + case len(cfg.RopeParameters.MRoPESection) > 0: + cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...) + case len(cfg.RopeParameters.DimensionSections) > 0: + cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...) + } + } + cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved) } - if cfg.RopeTheta == 0 { - cfg.RopeTheta = 100000.0 + if len(cfg.MRoPESections) > 4 { + cfg.MRoPESections = cfg.MRoPESections[:4] } - if cfg.PartialRotaryFactor == 0 { - cfg.PartialRotaryFactor = 0.25 - } - if cfg.PartialRotaryFactor < 0 { + cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0) + if cfg.PartialRotaryFactor <= 0 { cfg.PartialRotaryFactor = 0.25 } ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor) @@ -281,24 +316,23 @@ func parseConfig(configData []byte) (Config, error) { } if cfg.NumExperts > 0 { - if cfg.NumExpertsPerTok <= 0 { - cfg.NumExpertsPerTok = 1 - } - if cfg.MoeIntermediateSize <= 0 { - cfg.MoeIntermediateSize = cfg.IntermediateSize - } - if cfg.SharedExpertIntermediateSize <= 0 { - cfg.SharedExpertIntermediateSize = cfg.IntermediateSize - } + cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1)) + cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize) + cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize) if _, ok := activeRaw["norm_topk_prob"]; !ok { cfg.NormTopKProb = true } - if cfg.DecoderSparseStep <= 0 { - cfg.DecoderSparseStep = 1 - } + cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1)) } cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + if cfg.Vision != nil { + cfg.Vision.applyDefaults() + } + cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655)) + cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652)) + cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653)) return cfg, nil } @@ -364,6 +398,11 @@ func NewModel(root *model.Root) (base.Model, error) { if err != nil { return nil, err } + if cfg.Vision != nil { + if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil { + cfg.Vision.applyPreprocessorConfig(preprocessorData) + } + } if qt := root.QuantType(); qt != "" { cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) @@ -1060,6 +1099,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Layers[i] = layer } + if cfg.Vision != nil && cfg.Vision.Depth > 0 { + vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix) + if err != nil { + return err + } + m.Vision = vision + m.ImageProcessor = processor + } + return nil } @@ -1117,7 +1165,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k, return q, k, v, z, b, a } -func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func rotateHalf(x *mlx.Array) *mlx.Array { + shape := x.Dims() + last := int32(shape[len(shape)-1]) + half := last / 2 + if half <= 0 { + return x + } + + x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half}) + x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last}) + return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1) +} + +func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array { + if x == nil || cos == nil || sin == nil || ropeDim <= 0 { + return x + } + + shape := x.Dims() + if len(shape) != 4 { + return x + } + + last := int32(shape[len(shape)-1]) + if ropeDim > last { + ropeDim = last + } + if ropeDim%2 != 0 { + ropeDim-- + } + if ropeDim <= 0 { + return x + } + + rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim}) + rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin)) + if ropeDim == last { + return rot + } + + tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last}) + return mlx.Concatenate([]*mlx.Array{rot, tail}, -1) +} + +func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array { qg := a.QProj.Forward(x) qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2) q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim}) @@ -1140,8 +1232,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co if c != nil { offset = c.Offset() } - q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) - k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + if ropeCos != nil && ropeSin != nil { + q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim) + k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim) + } else { + q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + } if c != nil { k, v = c.Update(k, v) @@ -1323,13 +1420,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { return mlx.Reshape(y, B, L, cfg.HiddenSize) } -func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array { var r *mlx.Array normed := l.InputNorm.Forward(x, cfg.RMSNormEps) if l.IsLinear { r = l.Linear.Forward(normed, c, B, L, cfg) } else { - r = l.FullAttn.Forward(normed, c, B, L, cfg) + r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin) } h := mlx.Add(x, r) r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg) @@ -1337,16 +1434,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m } func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + return m.ForwardWithState(tokens, caches, nil) +} + +func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array { dims := tokens.Dims() B, L := int32(dims[0]), int32(dims[1]) + startPos := promptStartPosFromCaches(caches) + promptState := promptVisionStateFromState(state) h := m.EmbedTokens.Forward(tokens) + h = m.applyPromptVisionEmbeddings(h, startPos, promptState) + var ropeCos, ropeSin *mlx.Array + if len(m.MRoPESections) > 0 { + ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType()) + } for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { c = caches[i] } - h = layer.Forward(h, c, B, L, m.Config) + h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin) } out := m.Norm.Forward(h, m.RMSNormEps) return out diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go index f425ee5c5..4c23d6fdc 100644 --- a/x/models/qwen3_5/qwen3_5_test.go +++ b/x/models/qwen3_5/qwen3_5_test.go @@ -1,10 +1,14 @@ package qwen3_5 import ( + "fmt" + "slices" "testing" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/tokenizer" ) func skipIfNoMLX(t *testing.T) { @@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) { } func TestLayerSelectionHelpers(t *testing.T) { - cfg := &Config{ + cfg := &Config{TextConfig: TextConfig{ NumHiddenLayers: 6, FullAttentionInterval: 3, NumExperts: 8, DecoderSparseStep: 2, MLPOnlyLayers: []int32{1}, - } + }} if !layerIsLinear(cfg, 0) { t.Fatalf("layer 0 should be linear") @@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) { func TestNewCachesLayout(t *testing.T) { m := &Model{ - Config: &Config{ + Config: &Config{TextConfig: TextConfig{ LinearConvKernelDim: 4, LinearNumKeyHeads: 2, LinearKeyHeadDim: 8, LinearNumValueHeads: 4, LinearValueHeadDim: 16, - }, + }}, Layers: []*Layer{ {IsLinear: true}, {IsLinear: false}, @@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) { func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) { skipIfNoMLX(t) - cfg := &Config{ + cfg := &Config{TextConfig: TextConfig{ HiddenSize: 4, IntermediateSize: 8, NumHiddenLayers: 2, @@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) { LinearValueHeadDim: 2, LinearConvKernelDim: 4, FullAttentionInterval: 2, - } + }} m := &Model{ Config: cfg, @@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) { t.Fatalf("k norm dtype = %v, want %v", got, f32) } } + +func TestParseConfigVisionFields(t *testing.T) { + data := []byte(`{ + "text_config": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_hidden_layers": 4, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "head_dim": 128, + "linear_num_value_heads": 64, + "linear_num_key_heads": 16, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4, + "rope_parameters": { + "rope_theta": 10000000 + } + }, + "vision_config": { + "depth": 2, + "hidden_size": 256, + "num_heads": 8, + "in_channels": 3, + "patch_size": 14, + "spatial_merge_size": 2, + "layer_norm_epsilon": 0.000001, + "temporal_patch_size": 2, + "num_position_embeddings": 2304 + }, + "image_token_id": 111, + "vision_start_token_id": 112, + "vision_end_token_id": 113 + }`) + + cfg, err := parseConfig(data) + if err != nil { + t.Fatalf("parseConfig failed: %v", err) + } + + if cfg.Vision == nil { + t.Fatal("vision config should be parsed") + } + if cfg.Vision.Depth != 2 { + t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth) + } + if cfg.Vision.GridPerSide != 48 { + t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide) + } + if cfg.Vision.RopeTheta != 10000 { + t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta) + } + if cfg.RopeTheta != 10000000 { + t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta) + } + if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 { + t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken) + } +} + +func TestParseConfigMRoPEFromRopeParameters(t *testing.T) { + data := []byte(`{ + "text_config": { + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layers": 4, + "num_attention_heads": 16, + "num_key_value_heads": 2, + "head_dim": 256, + "linear_num_value_heads": 32, + "linear_num_key_heads": 16, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4, + "rope_parameters": { + "rope_theta": 10000000, + "partial_rotary_factor": 0.25, + "mrope_interleaved": true, + "mrope_section": [11, 11, 10] + } + } + }`) + + cfg, err := parseConfig(data) + if err != nil { + t.Fatalf("parseConfig failed: %v", err) + } + + if !cfg.MRoPEInterleaved { + t.Fatal("mrope_interleaved should be parsed from rope_parameters") + } + if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) { + t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections) + } + if cfg.RopeDim != 64 { + t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim) + } +} + +func TestParseConfigVisionTokenDefaults(t *testing.T) { + data := []byte(`{ + "text_config": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_hidden_layers": 2, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "head_dim": 128, + "linear_num_value_heads": 64, + "linear_num_key_heads": 16, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4 + } + }`) + + cfg, err := parseConfig(data) + if err != nil { + t.Fatalf("parseConfig failed: %v", err) + } + + if cfg.ImageTokenID != 151655 { + t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID) + } + if cfg.VisionStartToken != 151652 { + t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken) + } + if cfg.VisionEndToken != 151653 { + t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken) + } +} + +func TestResolveVisionPrefix(t *testing.T) { + tests := []struct { + name string + tensors map[string]*mlx.Array + want string + }{ + { + name: "legacy visual prefix", + tensors: map[string]*mlx.Array{ + "model.visual.patch_embed.proj.weight": mlx.New("patch"), + }, + want: "model.visual", + }, + { + name: "imported vision tower prefix", + tensors: map[string]*mlx.Array{ + "vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"), + }, + want: "vision_tower", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want { + t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestVisionPreprocessorOverridesDefaults(t *testing.T) { + v := &VisionConfig{} + v.applyDefaults() + + v.applyPreprocessorConfig([]byte(`{ + "patch_size": 16, + "temporal_patch_size": 3, + "merge_size": 4, + "size": { + "shortest_edge": 1024, + "longest_edge": 8192 + }, + "image_mean": [0.1, 0.2, 0.3], + "image_std": [0.9, 0.8, 0.7] + }`)) + + if v.PatchSize != 16 { + t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize) + } + if v.TemporalPatchSize != 3 { + t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize) + } + if v.SpatialMergeSize != 4 { + t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize) + } + if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 { + t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge) + } + if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 { + t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd) + } +} + +func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) { + v := &VisionConfig{} + v.applyDefaults() + + v.applyPreprocessorConfig([]byte(`{ + "size": { + "shortest_edge": 65536, + "longest_edge": 16777216 + }, + "patch_size": 16, + "temporal_patch_size": 2, + "merge_size": 2, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] + }`)) + + p := newVisionImageProcessor(v) + if p == nil { + t.Fatal("newVisionImageProcessor returned nil") + } + + if p.shortestEdge != 65536 || p.longestEdge != 16777216 { + t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge) + } +} + +func testTokenizer(t *testing.T) *tokenizer.Tokenizer { + t.Helper() + + tok, err := tokenizer.LoadFromBytes([]byte(`{ + "model": { + "type": "BPE", + "vocab": {"a": 0}, + "merges": [] + } + }`)) + if err != nil { + t.Fatalf("failed to load test tokenizer: %v", err) + } + + return tok +} + +func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) { + skipIfNoMLX(t) + + m := &Model{ + tok: testTokenizer(t), + Config: &Config{ + ImageTokenID: 101, + VisionStartToken: 102, + VisionEndToken: 103, + Vision: &VisionConfig{SpatialMergeSize: 2}, + }, + Vision: &VisionModel{}, + ImageProcessor: &VisionImageProcessor{}, + } + + main := mlx.FromValues([]float32{ + 10, 11, + 20, 21, + }, 1, 2, 2) + + resolveCalls := 0 + got, state, err := m.tokenizePromptWithResolvedImages( + "a[img-7][img-7]a", + []base.ImageInput{{ID: 7, Data: []byte("img7")}}, + func(data []byte) (*VisionEmbeddings, error) { + if string(data) != "img7" { + return nil, fmt.Errorf("unexpected data: %q", string(data)) + } + resolveCalls++ + return &VisionEmbeddings{ + Main: main, + Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1}, + }, nil + }, + ) + if err != nil { + t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err) + } + if resolveCalls != 1 { + t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls) + } + + want := []int32{ + 0, + 102, 101, 101, 103, + 102, 101, 101, 103, + 0, + } + if !slices.Equal(got, want) { + t.Fatalf("expanded tokens mismatch: got %v want %v", got, want) + } + + if state == nil { + t.Fatal("expected prompt vision state") + } + if len(state.Spans) != 2 { + t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans)) + } + if state.Spans[0].Start != 2 || state.Spans[0].End != 4 { + t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End) + } + if state.Spans[1].Start != 6 || state.Spans[1].End != 8 { + t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End) + } + wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7} + if !slices.Equal(state.PositionCache, wantPos) { + t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos) + } +} + +func TestBuildPromptMRoPEPositions(t *testing.T) { + m := &Model{ + Config: &Config{ + Vision: &VisionConfig{SpatialMergeSize: 2}, + }, + } + state := &promptVisionState{ + PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}, + Spans: []promptVisionSpan{ + { + Start: 2, + End: 8, + Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1}, + }, + }, + } + + pos := m.buildPromptMRoPEPositions(state, 0, 10) + if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) { + t.Fatalf("time positions mismatch: got %v want %v", got, want) + } + if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) { + t.Fatalf("height positions mismatch: got %v want %v", got, want) + } + if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) { + t.Fatalf("width positions mismatch: got %v want %v", got, want) + } +} + +func TestMapPromptPositionContinuesAfterCache(t *testing.T) { + state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}} + + if got := mapPromptPosition(state, 3); got != 2 { + t.Fatalf("mapPromptPosition(3) = %d, want 2", got) + } + if got := mapPromptPosition(state, 5); got != 4 { + t.Fatalf("mapPromptPosition(5) = %d, want 4", got) + } +} + +func TestApplyPromptVisionEmbeddings(t *testing.T) { + skipIfNoMLX(t) + + m := &Model{} + state := &promptVisionState{ + Spans: []promptVisionSpan{ + { + Start: 1, + End: 3, + Main: mlx.FromValues([]float32{ + 10, 11, + 20, 21, + }, 1, 2, 2), + }, + }, + } + + h := mlx.FromValues([]float32{ + 0, 1, + 2, 3, + 4, 5, + 6, 7, + }, 1, 4, 2) + + got := m.applyPromptVisionEmbeddings(h, 0, state) + mlx.Eval(got) + + want := []float32{ + 0, 1, + 10, 11, + 20, 21, + 6, 7, + } + if !slices.Equal(got.Floats(), want) { + t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want) + } +} diff --git a/x/models/qwen3_5/vision.go b/x/models/qwen3_5/vision.go new file mode 100644 index 000000000..0977e6315 --- /dev/null +++ b/x/models/qwen3_5/vision.go @@ -0,0 +1,854 @@ +package qwen3_5 + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "image" + _ "image/jpeg" + _ "image/png" + "math" + + "github.com/ollama/ollama/model/imageproc" + "github.com/ollama/ollama/x/mlxrunner/mlx" + mlxmodel "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/models/nn" +) + +var errNoVisionModel = errors.New("qwen3_5: no vision model") + +// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config. +type VisionConfig struct { + Depth int32 `json:"depth"` + HiddenSize int32 `json:"hidden_size"` + NumHeads int32 `json:"num_heads"` + InChannels int32 `json:"in_channels"` + PatchSize int32 `json:"patch_size"` + SpatialMergeSize int32 `json:"spatial_merge_size"` + LayerNormEpsilon float32 `json:"layer_norm_epsilon"` + RopeTheta float32 `json:"rope_theta"` + TemporalPatchSize int32 `json:"temporal_patch_size"` + NumPositionEmbeddings int32 `json:"num_position_embeddings"` + + Size struct { + ShortestEdge int32 `json:"shortest_edge"` + LongestEdge int32 `json:"longest_edge"` + } `json:"size"` + + ImageMean []float32 `json:"image_mean"` + ImageStd []float32 `json:"image_std"` + + GridPerSide int32 `json:"-"` +} + +func (v *VisionConfig) applyDefaults() { + if v == nil { + return + } + if v.HiddenSize <= 0 { + v.HiddenSize = 1280 + } + if v.NumHeads <= 0 { + v.NumHeads = 16 + } + if v.InChannels <= 0 { + v.InChannels = 3 + } + if v.PatchSize <= 0 { + v.PatchSize = 14 + } + if v.SpatialMergeSize <= 0 { + v.SpatialMergeSize = 2 + } + if v.LayerNormEpsilon == 0 { + v.LayerNormEpsilon = 1e-6 + } + if v.RopeTheta == 0 { + v.RopeTheta = 10000 + } + if v.TemporalPatchSize <= 0 { + v.TemporalPatchSize = 2 + } + if v.NumPositionEmbeddings <= 0 { + v.NumPositionEmbeddings = 2304 + } + if len(v.ImageMean) < 3 { + v.ImageMean = []float32{0.5, 0.5, 0.5} + } + if len(v.ImageStd) < 3 { + v.ImageStd = []float32{0.5, 0.5, 0.5} + } + if v.Size.ShortestEdge <= 0 { + v.Size.ShortestEdge = 64 << 10 + } + if v.Size.LongestEdge <= 0 { + v.Size.LongestEdge = 2 << 20 + } + + grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings))) + if grid <= 0 { + grid = 48 + } + v.GridPerSide = grid +} + +func (v *VisionConfig) applyPreprocessorConfig(data []byte) { + if v == nil || len(data) == 0 { + return + } + + var pre struct { + Size struct { + ShortestEdge int32 `json:"shortest_edge"` + LongestEdge int32 `json:"longest_edge"` + } `json:"size"` + PatchSize int32 `json:"patch_size"` + TemporalPatchSize int32 `json:"temporal_patch_size"` + MergeSize int32 `json:"merge_size"` + ImageMean []float32 `json:"image_mean"` + ImageStd []float32 `json:"image_std"` + } + if err := json.Unmarshal(data, &pre); err != nil { + return + } + + if pre.PatchSize > 0 { + v.PatchSize = pre.PatchSize + } + if pre.TemporalPatchSize > 0 { + v.TemporalPatchSize = pre.TemporalPatchSize + } + if pre.MergeSize > 0 { + v.SpatialMergeSize = pre.MergeSize + } + if pre.Size.ShortestEdge > 0 { + v.Size.ShortestEdge = pre.Size.ShortestEdge + } + if pre.Size.LongestEdge > 0 { + v.Size.LongestEdge = pre.Size.LongestEdge + } + if len(pre.ImageMean) >= 3 { + v.ImageMean = pre.ImageMean + } + if len(pre.ImageStd) >= 3 { + v.ImageStd = pre.ImageStd + } + v.applyDefaults() +} + +// VisionGrid tracks patch-grid dimensions for an image. +type VisionGrid struct { + Height int32 + Width int32 + Temporal int32 +} + +// VisionImageProcessor reproduces qwen3vl image preprocessing. +type VisionImageProcessor struct { + numChannels int32 + patchSize int32 + temporalPatchSize int32 + mergeSize int32 + shortestEdge int32 + longestEdge int32 + factor int32 + imageMean [3]float32 + imageStd [3]float32 +} + +func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor { + if cfg == nil { + return nil + } + + return &VisionImageProcessor{ + numChannels: cfg.InChannels, + patchSize: cfg.PatchSize, + temporalPatchSize: cfg.TemporalPatchSize, + mergeSize: cfg.SpatialMergeSize, + shortestEdge: cfg.Size.ShortestEdge, + longestEdge: cfg.Size.LongestEdge, + factor: cfg.PatchSize * cfg.SpatialMergeSize, + imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]}, + imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]}, + } +} + +func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) { + factor := int(p.factor) + if factor <= 0 { + return 0, 0, fmt.Errorf("invalid factor: %d", factor) + } + + if height < factor || width < factor { + return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor) + } + if min(height, width) == 0 { + return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height) + } + if max(height, width)/min(height, width) > 200 { + return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height) + } + + roundEven := func(x float64) int { return int(math.RoundToEven(x)) } + + hBar := roundEven(float64(height)/float64(factor)) * factor + wBar := roundEven(float64(width)/float64(factor)) * factor + + if hBar*wBar > int(p.longestEdge) { + beta := math.Sqrt(float64(height*width) / float64(p.longestEdge)) + hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor + wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor + } else if hBar*wBar < int(p.shortestEdge) { + beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width)) + hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor + wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor + } + + return hBar, wBar, nil +} + +func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) { + if p == nil { + return nil, nil, errNoVisionModel + } + + img = imageproc.Composite(img) + origW := img.Bounds().Dx() + origH := img.Bounds().Dy() + + resizedH, resizedW, err := p.smartResize(origH, origW) + if err != nil { + return nil, nil, err + } + + resized := imageproc.Resize( + img, + image.Point{X: resizedW, Y: resizedH}, + imageproc.ResizeBilinear, + ) + pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true) + + grid := &VisionGrid{ + Height: int32(resizedH / int(p.patchSize)), + Width: int32(resizedW / int(p.patchSize)), + Temporal: 1, + } + + patches := p.createPatches(pixels, resizedH, resizedW, grid) + + patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize) + numPatches := int(grid.Height * grid.Width) + pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0) + return pixelValues, grid, nil +} + +func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 { + channels := int(p.numChannels) + patchSize := int(p.patchSize) + mergeSize := int(p.mergeSize) + temporalPatchSize := int(p.temporalPatchSize) + + // Temporal is always 1 for static images; only spatial patches are created. + numPatches := int(grid.Height * grid.Width) + patchDim := channels * temporalPatchSize * patchSize * patchSize + result := make([]float32, numPatches*patchDim) + + patchIndex := 0 + for h := 0; h < int(grid.Height); h += mergeSize { + for w := 0; w < int(grid.Width); w += mergeSize { + for mh := 0; mh < mergeSize; mh++ { + for mw := 0; mw < mergeSize; mw++ { + baseOffset := patchIndex * patchDim + + for c := 0; c < channels; c++ { + channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize + for py := 0; py < patchSize; py++ { + for px := 0; px < patchSize; px++ { + y := (h+mh)*patchSize + py + x := (w+mw)*patchSize + px + srcIdx := c*height*width + y*width + x + dstIdx := channelOffset + py*patchSize + px + if srcIdx < len(pixels) && dstIdx < len(result) { + result[dstIdx] = pixels[srcIdx] + } + } + } + } + + if temporalPatchSize > 1 { + for c := 0; c < channels; c++ { + channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize + frameSize := patchSize * patchSize + for tp := 1; tp < temporalPatchSize; tp++ { + cur := channelOffset + tp*frameSize + copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize]) + } + } + } + + patchIndex++ + } + } + } + } + + return result +} + +// VisionAttention runs one self-attention block inside the vision encoder. +type VisionAttention struct { + QKV nn.LinearLayer + Query nn.LinearLayer + Key nn.LinearLayer + Value nn.LinearLayer + Output nn.LinearLayer +} + +func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array { + return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin)) +} + +func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) { + shape := x.Dims() + if len(shape) != 3 { + return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape) + } + B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2]) + headDim := cfg.HiddenSize / cfg.NumHeads + if headDim <= 0 { + return nil, fmt.Errorf("invalid vision head dim: %d", headDim) + } + + var q, k, v *mlx.Array + if a.QKV != nil { + qkv := a.QKV.Forward(x) + qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim) + q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2) + k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2) + v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2) + } else { + if a.Query == nil || a.Key == nil || a.Value == nil { + return nil, errors.New("vision attention is missing q/k/v projections") + } + q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim) + k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim) + v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim) + } + + q = applyVisionRoPE(q, cos, sin) + k = applyVisionRoPE(k, cos, sin) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + scale := float32(1.0 / math.Sqrt(float64(headDim))) + attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false) + attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden) + if a.Output == nil { + return nil, errors.New("vision attention is missing output projection") + } + return a.Output.Forward(attn), nil +} + +// VisionMLP is the vision feed-forward block. +type VisionMLP struct { + FC1 nn.LinearLayer + FC2 nn.LinearLayer +} + +func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) { + if m.FC1 == nil || m.FC2 == nil { + return nil, errors.New("vision mlp is missing fc1/fc2") + } + return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil +} + +// VisionEncoderLayer is one transformer block in the vision encoder. +type VisionEncoderLayer struct { + Norm1 *nn.LayerNorm + Attn *VisionAttention + Norm2 *nn.LayerNorm + MLP *VisionMLP +} + +func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) { + if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil { + return nil, errors.New("vision layer is incomplete") + } + + r := x + a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg) + if err != nil { + return nil, err + } + x = mlx.Add(r, a) + + r = x + m, err := l.MLP.Forward(l.Norm2.Forward(x)) + if err != nil { + return nil, err + } + return mlx.Add(r, m), nil +} + +// VisionPatchMerger projects merged spatial groups into language embedding space. +type VisionPatchMerger struct { + Norm *nn.LayerNorm + FC1 nn.LinearLayer + FC2 nn.LinearLayer +} + +func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) { + shape := x.Dims() + if len(shape) != 3 { + return nil, fmt.Errorf("expected [B,L,D], got %v", shape) + } + if merge <= 0 { + merge = 1 + } + B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2]) + group := merge * merge + if group <= 0 || L%group != 0 { + return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge) + } + + x = mlx.Reshape(x, B, L/group, group, D) + x = mlx.Reshape(x, B, L/group, group*D) + return x, nil +} + +func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) { + if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil { + return nil, errors.New("vision patch merger is incomplete") + } + + x = m.Norm.Forward(x) + + var err error + x, err = groupMergedTokens(x, cfg.SpatialMergeSize) + if err != nil { + return nil, err + } + + x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))) + return x, nil +} + +// VisionModel contains the full Qwen vision tower. +type VisionModel struct { + PatchProjection nn.LinearLayer + PositionEmbed *nn.Embedding + Layers []*VisionEncoderLayer + PatchMerger *VisionPatchMerger + + cfg *VisionConfig +} + +func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 { + if merge <= 0 { + merge = 1 + } + // Temporal is always 1 for static images; only spatial coordinates are generated. + coords := make([][2]int32, 0, grid.Height*grid.Width) + for h := int32(0); h < grid.Height; h += merge { + for w := int32(0); w < grid.Width; w += merge { + for mh := int32(0); mh < merge; mh++ { + for mw := int32(0); mw < merge; mw++ { + coords = append(coords, [2]int32{h + mh, w + mw}) + } + } + } + } + return coords +} + +func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) { + if m.PositionEmbed == nil { + return x, nil + } + shape := x.Dims() + if len(shape) != 3 { + return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape) + } + B, D := int32(shape[0]), int32(shape[2]) + coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize) + L := int32(len(coords)) + if L != int32(shape[1]) { + return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L) + } + + stepH := float32(0) + if grid.Height > 1 { + stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1) + } + stepW := float32(0) + if grid.Width > 1 { + stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1) + } + + indices := make([]int32, 0, L*4) + weights := make([]float32, 0, L*4) + for _, c := range coords { + y := float32(c[0]) * stepH + x0 := float32(c[1]) * stepW + + fy := int32(y) + fx := int32(x0) + cy := min(fy+1, m.cfg.GridPerSide-1) + cx := min(fx+1, m.cfg.GridPerSide-1) + + indices = append(indices, + fy*m.cfg.GridPerSide+fx, + fy*m.cfg.GridPerSide+cx, + cy*m.cfg.GridPerSide+fx, + cy*m.cfg.GridPerSide+cx, + ) + + dy := y - float32(fy) + dx := x0 - float32(fx) + weights = append(weights, + (1-dy)*(1-dx), + (1-dy)*dx, + dy*(1-dx), + dy*dx, + ) + } + + idxArr := mlx.FromValues(indices, int(L), 4) + wArr := mlx.FromValues(weights, int(L), 4, 1) + + pos := m.PositionEmbed.Forward(idxArr) + wArr = wArr.AsType(pos.DType()) + pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false) + if D != int32(pos.Dim(1)) { + return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1)) + } + + pos = mlx.ExpandDims(pos, 0) + if B > 1 { + pos = mlx.Tile(pos, []int32{B, 1, 1}) + } + + return mlx.Add(x, pos), nil +} + +func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) { + headDim := m.cfg.HiddenSize / m.cfg.NumHeads + if headDim <= 0 { + return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim) + } + + coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize) + L := int32(len(coords)) + half := headDim / 2 + quarter := half / 2 + if quarter <= 0 { + return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim) + } + + angles := make([]float32, L*headDim) + for i, c := range coords { + base := int32(i) * headDim + for j := int32(0); j < quarter; j++ { + freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half)) + angles[base+j] = float32(float64(c[0]) * freq) + angles[base+quarter+j] = float32(float64(c[1]) * freq) + } + for j := int32(0); j < half; j++ { + angles[base+half+j] = angles[base+j] + } + } + + arr := mlx.FromValues(angles, int(L), int(headDim)) + cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2) + sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2) + return cos, sin, nil +} + +func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) { + if m == nil || pixelValues == nil || grid == nil { + return nil, errNoVisionModel + } + if m.PatchProjection == nil || m.PatchMerger == nil { + return nil, errors.New("vision model is missing required projections") + } + + x := m.PatchProjection.Forward(pixelValues) + var err error + x, err = m.addPositionEmbedding(x, grid) + if err != nil { + return nil, err + } + + cos, sin, err := m.rotaryEmbeddings(grid) + if err != nil { + return nil, err + } + + for i, layer := range m.Layers { + x, err = layer.Forward(x, cos, sin, m.cfg) + if err != nil { + return nil, fmt.Errorf("vision layer %d: %w", i, err) + } + } + + main, err := m.PatchMerger.Forward(x, m.cfg) + if err != nil { + return nil, fmt.Errorf("vision patch merger: %w", err) + } + return main, nil +} + +type VisionEmbeddings struct { + Main *mlx.Array + Grid *VisionGrid +} + +func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) { + if m == nil || m.Vision == nil || m.ImageProcessor == nil { + return nil, errNoVisionModel + } + + img, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + pixelValues, grid, err := m.ImageProcessor.ProcessImage(img) + if err != nil { + return nil, err + } + + main, err := m.Vision.Forward(pixelValues, grid) + if err != nil { + return nil, err + } + + return &VisionEmbeddings{Main: main, Grid: grid}, nil +} + +func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string { + candidates := []string{ + "vision_tower", + weightPrefix + "vision_tower", + "model.visual", + "visual", + weightPrefix + "model.visual", + weightPrefix + "visual", + } + + hasTensor := func(prefix string) bool { + for _, suffix := range []string{ + ".patch_embed.proj.weight", + ".patch_embed.weight", + ".pos_embed.weight", + ".blocks.0.attn.qkv.weight", + ".merger.linear_fc1.weight", + ".merger.mlp.0.weight", + } { + if tensors[prefix+suffix] != nil { + return true + } + } + return false + } + + for _, prefix := range candidates { + if hasTensor(prefix) { + return prefix + } + } + + return "" +} + +func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer { + for _, p := range paths { + if l := linears.Make(p); l != nil { + return l + } + } + return nil +} + +func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm { + for _, base := range bases { + if w := tensors[base+".weight"]; w != nil { + return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps} + } + if w := tensors[base]; w != nil { + return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps} + } + } + return nil +} + +func loadVisionPatchMerger( + tensors map[string]*mlx.Array, + linears mlxmodel.LinearFactory, + eps float32, + bases ...string, +) *VisionPatchMerger { + for _, base := range bases { + norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q") + fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0") + fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2") + if norm != nil && fc1 != nil && fc2 != nil { + return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2} + } + } + return nil +} + +func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) { + if w == nil || !w.Valid() { + return nil, errors.New("missing patch embedding weight") + } + if w.NumDims() < 2 { + return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims()) + } + if w.NumDims() == 2 { + return w, nil + } + + out := int32(w.Dim(0)) + in := int32(w.Size() / w.Dim(0)) + return mlx.Reshape(w, out, in), nil +} + +func loadVisionComponents( + tensors map[string]*mlx.Array, + linears mlxmodel.LinearFactory, + cfg *Config, + weightPrefix string, +) (*VisionModel, *VisionImageProcessor, error) { + if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 { + return nil, nil, nil + } + cfg.Vision.applyDefaults() + + visionPrefix := resolveVisionPrefix(tensors, weightPrefix) + if visionPrefix == "" { + return nil, nil, errors.New("vision enabled in config but vision tensors were not found") + } + + patchW, _ := tensorAny( + tensors, + visionPrefix+".patch_embed.proj.weight", + visionPrefix+".patch_embed.weight", + ) + if patchW == nil { + return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix) + } + patchW, err := flattenPatchEmbeddingWeight(patchW) + if err != nil { + return nil, nil, err + } + patchB, _ := tensorAny( + tensors, + visionPrefix+".patch_embed.proj.bias", + visionPrefix+".patch_embed.bias", + ) + + patchProj := nn.NewLinear(patchW, patchB) + if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize { + return nil, nil, fmt.Errorf( + "vision patch embedding input dim mismatch: got %d expected %d", + got, + cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize, + ) + } + + posW, _ := tensorAny( + tensors, + visionPrefix+".pos_embed.weight", + visionPrefix+".position_embedding.weight", + ) + if posW == nil { + return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix) + } + cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0)) + cfg.Vision.applyDefaults() + + vm := &VisionModel{ + PatchProjection: patchProj, + PositionEmbed: nn.NewEmbedding(posW), + Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth), + cfg: cfg.Vision, + } + + for i := int32(0); i < cfg.Vision.Depth; i++ { + layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i) + layer := &VisionEncoderLayer{ + Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"), + Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"), + Attn: &VisionAttention{ + QKV: firstLinear( + linears, + layerPrefix+".attn.qkv", + layerPrefix+".attn_qkv", + ), + Query: firstLinear( + linears, + layerPrefix+".attn.q_proj", + layerPrefix+".attn_q", + ), + Key: firstLinear( + linears, + layerPrefix+".attn.k_proj", + layerPrefix+".attn_k", + ), + Value: firstLinear( + linears, + layerPrefix+".attn.v_proj", + layerPrefix+".attn_v", + ), + Output: firstLinear( + linears, + layerPrefix+".attn.proj", + layerPrefix+".attn_out", + layerPrefix+".attn.o_proj", + ), + }, + MLP: &VisionMLP{ + FC1: firstLinear( + linears, + layerPrefix+".mlp.fc1", + layerPrefix+".mlp.linear_fc1", + ), + FC2: firstLinear( + linears, + layerPrefix+".mlp.fc2", + layerPrefix+".mlp.linear_fc2", + ), + }, + } + + if layer.Norm1 == nil || layer.Norm2 == nil { + return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i) + } + if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) { + return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i) + } + if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil { + return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i) + } + + vm.Layers[i] = layer + } + + vm.PatchMerger = loadVisionPatchMerger( + tensors, + linears, + cfg.Vision.LayerNormEpsilon, + visionPrefix+".merger", + ) + if vm.PatchMerger == nil { + return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix) + } + + return vm, newVisionImageProcessor(cfg.Vision), nil +}