Compare commits

...

5 Commits

Author SHA1 Message Date
Patrick Devine
578c32e42e still more linter stuff 2026-03-19 17:29:12 -07:00
Patrick Devine
a10d2625ca linters ftw 2026-03-19 17:20:59 -07:00
Patrick Devine
b960d769ad more linter fixes 2026-03-19 17:11:43 -07:00
Patrick Devine
455a6099d1 gofumpt the linter 2026-03-19 16:52:35 -07:00
Patrick Devine
7e6e8377eb mlx: qwen3.5 vision support 2026-03-19 16:35:08 -07:00
12 changed files with 1949 additions and 170 deletions

View File

@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
spinnerKey = "create" spinnerKey = "create"
capabilities = []string{"completion"} capabilities = []string{"completion"}
// Check if model supports thinking based on architecture configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
if supportsThinking(opts.ModelDir) { mcfg := parseModelConfig(configData)
if mcfg.supportsThinking() {
capabilities = append(capabilities, "thinking") capabilities = append(capabilities, "thinking")
} }
if mcfg.supportsVision() {
capabilities = append(capabilities, "vision")
}
// Set parser and renderer name based on architecture parserName = mcfg.parserName()
parserName = getParserName(opts.ModelDir) rendererName = mcfg.rendererName()
rendererName = getRendererName(opts.ModelDir)
} else { } else {
modelType = "image generation model" modelType = "image generation model"
spinnerKey = "imagegen" spinnerKey = "imagegen"
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil return layers, nil
} }
// supportsThinking checks if the model supports thinking mode based on its architecture. // modelConfig holds the fields from config.json needed during model creation.
// This reads the config.json from the model directory and checks the architectures field. type visionConfig struct {
func supportsThinking(modelDir string) bool { Depth int32 `json:"depth"`
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
} }
var cfg struct { type modelConfig struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
ModelType string `json:"model_type"` ModelType string `json:"model_type"`
} VisionConfig *visionConfig `json:"vision_config"`
if err := json.Unmarshal(data, &cfg); err != nil { ImageTokenID *int32 `json:"image_token_id"`
return false VisionStartTokenID *int32 `json:"vision_start_token_id"`
VisionEndTokenID *int32 `json:"vision_end_token_id"`
} }
// Check architectures that support thinking func parseModelConfig(data []byte) modelConfig {
thinkingArchitectures := []string{ var cfg modelConfig
"glm4moe", // GLM-4 MoE models _ = json.Unmarshal(data, &cfg)
"deepseek", // DeepSeek models return cfg
"qwen3", // Qwen3 models
} }
// Check the architecture list // archOrTypeContains returns true if any architecture or the model_type
for _, arch := range cfg.Architectures { // contains one of the given substrings (case-insensitive).
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
for _, arch := range c.Architectures {
archLower := strings.ToLower(arch) archLower := strings.ToLower(arch)
for _, thinkArch := range thinkingArchitectures { for _, s := range substrs {
if strings.Contains(archLower, thinkArch) { if strings.Contains(archLower, s) {
return true return true
} }
} }
} }
if c.ModelType != "" {
// Also check model_type typeLower := strings.ToLower(c.ModelType)
if cfg.ModelType != "" { for _, s := range substrs {
typeLower := strings.ToLower(cfg.ModelType) if strings.Contains(typeLower, s) {
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
return true return true
} }
} }
} }
return false return false
} }
// getParserName returns the parser name for a model based on its architecture. func (c *modelConfig) supportsThinking() bool {
// This reads the config.json from the model directory and determines the appropriate parser. return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
} }
var cfg struct { func (c *modelConfig) supportsVision() bool {
Architectures []string `json:"architectures"` return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
} }
// Check architectures for known parsers func (c *modelConfig) parserName() string {
for _, arch := range cfg.Architectures { switch {
archLower := strings.ToLower(arch) case c.archOrTypeContains("glm4", "glm-4"):
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7" return "glm-4.7"
} case c.archOrTypeContains("deepseek"):
if strings.Contains(archLower, "deepseek") {
return "deepseek3" return "deepseek3"
} case c.archOrTypeContains("qwen3"):
if strings.Contains(archLower, "qwen3") {
return "qwen3" return "qwen3"
} }
return ""
} }
// Also check model_type func (c *modelConfig) rendererName() string {
if cfg.ModelType != "" { switch {
typeLower := strings.ToLower(cfg.ModelType) case c.archOrTypeContains("glm4", "glm-4"):
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7" return "glm-4.7"
} case c.archOrTypeContains("deepseek"):
if strings.Contains(typeLower, "deepseek") {
return "deepseek3" return "deepseek3"
} case c.archOrTypeContains("qwen3"):
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder" return "qwen3-coder"
} }
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return "" return ""
} }

View File

@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7)) t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
} }
} }
func TestSupportsVision(t *testing.T) {
t.Run("vision_config present", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"vision_config": {"depth": 2},
"image_token_id": 151655
}`))
if !cfg.supportsVision() {
t.Fatal("supportsVision() = false, want true")
}
})
t.Run("token ids alone imply vision", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"vision_start_token_id": 10,
"vision_end_token_id": 11
}`))
if !cfg.supportsVision() {
t.Fatal("supportsVision() = false, want true")
}
})
t.Run("plain text model", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"architectures": ["Qwen3_5ForCausalLM"]
}`))
if cfg.supportsVision() {
t.Fatal("supportsVision() = true, want false")
}
})
}

View File

@@ -366,6 +366,23 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
c.enforceEvictionPolicy() c.enforceEvictionPolicy()
} }
// clear releases live caches and drops the trie so future requests cannot
// reuse prompt state keyed only by token IDs.
func (c *kvCache) clear() {
c.freeAll()
walkNodes(c.root, func(n *trieNode) bool {
for _, s := range n.snapshots {
if s != nil {
s.Close()
}
}
n.snapshots = nil
return true
})
c.root = nil
c.activePath = nil
}
// freeAll releases all cache layers. // freeAll releases all cache layers.
func (c *kvCache) freeAll() { func (c *kvCache) freeAll() {
for _, kv := range c.caches { for _, kv := range c.caches {

View File

@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. // completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct { type completionRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Images []llm.ImageData `json:"images,omitempty"`
Options *completionOpts `json:"options,omitempty"` Options *completionOpts `json:"options,omitempty"`
} }
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{ creq := completionRequest{
Prompt: req.Prompt, Prompt: req.Prompt,
Images: req.Images,
} }
if req.Options != nil { if req.Options != nil {
creq.Options = &completionOpts{ creq.Options = &completionOpts{

View File

@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
return out return out
} }
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array { func Log(a *Array) *Array {
out := New("LOG") out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)

View File

@@ -0,0 +1,32 @@
package base
import (
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// ImageInput is a single image attached to a prompt.
type ImageInput struct {
ID int
Data []byte
}
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
// model metadata needed during forward.
type PromptTokenization struct {
Tokens []int32
State any
}
// MultimodalPromptTokenizerWithState is an optional model interface used by
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
// request-scoped state to be attached to the forward pass.
type MultimodalPromptTokenizerWithState interface {
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
}
// ForwardWithStateModel is an optional model interface for request-scoped
// forward metadata that should not be stored in shared caches.
type ForwardWithStateModel interface {
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
}

View File

@@ -12,12 +12,42 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
) )
func prefillChunkSize() int { func prefillChunkSize() int {
return 2 << 10 return 2 << 10
} }
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
if len(request.Images) > 0 {
// The shared trie cache keys only on token IDs today, so multimodal
// prompts must not reuse snapshots across distinct image inputs.
r.cache.clear()
}
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
images := make([]base.ImageInput, len(request.Images))
for i := range request.Images {
images[i] = base.ImageInput{
ID: request.Images[i].ID,
Data: request.Images[i].Data,
}
}
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
if err != nil {
return nil, nil, err
}
if out == nil {
return nil, nil, errors.New("empty multimodal tokenization result")
}
return out.Tokens, out.State, nil
}
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
}
func (r *Runner) TextGenerationPipeline(request Request) error { func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil { if r.Model == nil {
return errors.New("model not loaded") return errors.New("model not loaded")
@@ -55,7 +85,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}() }()
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs, promptState, err := r.tokenizeRequest(request)
if err != nil {
return err
}
if len(inputs) == 0 { if len(inputs) == 0 {
return errors.New("empty prompt") return errors.New("empty prompt")
} }
@@ -83,6 +116,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
tokens := session.remaining tokens := session.remaining
prefillChunk := prefillChunkSize() prefillChunk := prefillChunkSize()
modelForward := func(tokens *mlx.Array) *mlx.Array {
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
return withState.ForwardWithState(tokens, caches, promptState)
}
return r.Model.Forward(tokens, caches)
}
materializeCaches := func() { materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches)) state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches { for _, c := range caches {
@@ -114,7 +154,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
mlx.Sweep() mlx.Sweep()
materializeCaches() materializeCaches()
processed += n processed += n
@@ -132,7 +172,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := modelForward(token.ExpandDims(0))
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -30,6 +31,7 @@ type Request struct {
type TextCompletionsRequest struct { type TextCompletionsRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Images []llm.ImageData `json:"images,omitempty"`
Options struct { Options struct {
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`

View File

@@ -0,0 +1,354 @@
package qwen3_5
import (
"fmt"
"math"
"regexp"
"strconv"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
type promptVisionSpan struct {
Start int32
End int32
Main *mlx.Array
Grid *VisionGrid
}
type promptVisionState struct {
Spans []promptVisionSpan
PositionCache []int32
}
func promptStartPosFromCaches(caches []cache.Cache) int32 {
offset := -1
for _, c := range caches {
if c == nil {
continue
}
off := c.Offset()
if offset < 0 || off < offset {
offset = off
}
}
if offset < 0 {
return 0
}
return int32(offset)
}
func promptVisionStateFromState(state any) *promptVisionState {
typed, _ := state.(*promptVisionState)
return typed
}
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
chunkEnd := chunkStart + chunkLen
overlapStart := max(chunkStart, spanStart)
overlapEnd := min(chunkEnd, spanEnd)
if overlapStart >= overlapEnd {
return 0, 0, 0, 0, false
}
chunkLo := overlapStart - chunkStart
chunkHi := overlapEnd - chunkStart
spanLo := overlapStart - spanStart
spanHi := overlapEnd - spanStart
return chunkLo, chunkHi, spanLo, spanHi, true
}
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
return h
}
dims := h.Dims()
if len(dims) != 3 {
return h
}
L := int32(dims[1])
for _, span := range state.Spans {
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
if !ok || span.Main == nil || !span.Main.Valid() {
continue
}
repl := span.Main.Slice(
mlx.Slice(),
mlx.Slice(int(spanLo), int(spanHi)),
mlx.Slice(),
)
repl = repl.AsType(h.DType())
h = h.SliceUpdate(
repl,
mlx.Slice(),
mlx.Slice(int(chunkLo), int(chunkHi)),
mlx.Slice(),
)
}
return h
}
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
for i := range images {
if images[i].ID == id {
return images[i], true
}
}
return base.ImageInput{}, false
}
func mapPromptPosition(state *promptVisionState, id int32) int32 {
if state == nil {
return id
}
if id < int32(len(state.PositionCache)) {
return state.PositionCache[id]
}
if len(state.PositionCache) > 0 {
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
}
return id
}
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
if fallback <= 0 {
fallback = 1
}
if grid == nil {
return fallback
}
if merge <= 0 {
merge = 1
}
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
}
func normalizeMRoPESections(sections []int32) [4]int32 {
var out [4]int32
for i := range min(4, len(sections)) {
if sections[i] > 0 {
out[i] = sections[i]
}
}
return out
}
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
if interleaved {
if pair%3 == 1 && pair < 1+3*sections[1] {
return 1
}
if pair%3 == 2 && pair < 2+3*sections[2] {
return 2
}
if pair%3 == 0 && pair < 3*sections[0] {
return 0
}
return 3
}
secW := sections[0] + sections[1]
secE := secW + sections[2]
switch {
case pair < sections[0]:
return 0
case pair < secW:
return 1
case pair < secE:
return 2
default:
return 3
}
}
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
var positions [4][]int32
for i := range positions {
positions[i] = make([]int32, chunkLen)
}
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
for i := range chunkLen {
p := mapPromptPosition(state, startPos+i)
positions[0][i] = p
positions[1][i] = p
positions[2][i] = p
}
merge := int32(1)
if m != nil && m.Config != nil && m.Config.Vision != nil {
merge = m.Config.Vision.SpatialMergeSize
}
for _, span := range state.Spans {
if span.Grid == nil {
continue
}
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
if !ok {
continue
}
w := max(int32(1), span.Grid.Width/merge)
for i := chunkLo; i < chunkHi; i++ {
rel := spanLo + (i - chunkLo)
positions[1][i] += rel / w
positions[2][i] += rel % w
}
}
return positions
}
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
return nil, nil
}
ropeDim := m.Config.RopeDim
if ropeDim%2 != 0 {
ropeDim--
}
if ropeDim <= 0 {
return nil, nil
}
half := ropeDim / 2
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
sections := normalizeMRoPESections(m.Config.MRoPESections)
theta := m.Config.RopeTheta
if theta <= 0 {
theta = 100000.0
}
freqs := make([]float64, half)
for j := range half {
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
}
angles := make([]float32, chunkLen*ropeDim)
for i := range chunkLen {
base := i * ropeDim
for j := range half {
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
angle := float32(float64(positions[component][i]) * freqs[j])
angles[base+j] = angle
angles[base+half+j] = angle
}
}
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
cos := mlx.Cos(arr)
sin := mlx.Sin(arr)
if dtype != 0 {
cos = cos.AsType(dtype)
sin = sin.AsType(dtype)
}
return cos, sin
}
func (m *Model) tokenizePromptWithResolvedImages(
prompt string,
images []base.ImageInput,
resolve func([]byte) (*VisionEmbeddings, error),
) ([]int32, *promptVisionState, error) {
if m == nil || m.tok == nil {
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
}
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
return m.tok.Encode(prompt, true), nil, nil
}
parts := imageTagRE.Split(prompt, -1)
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
resolved := make(map[int]*VisionEmbeddings, len(images))
var out []int32
state := &promptVisionState{}
var p int32
appendToken := func(tok, pos int32) {
out = append(out, tok)
state.PositionCache = append(state.PositionCache, pos)
}
for i, part := range parts {
for _, tok := range m.tok.Encode(part, i == 0) {
appendToken(tok, p)
p++
}
if i >= len(matches) {
continue
}
imageID, err := strconv.Atoi(matches[i][1])
if err != nil {
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
}
img, ok := findImageByID(images, imageID)
if !ok {
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
}
embeds := resolved[imageID]
if embeds == nil {
embeds, err = resolve(img.Data)
if err != nil {
return nil, nil, err
}
resolved[imageID] = embeds
}
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
}
tokensPerImage := int32(embeds.Main.Dim(1))
if tokensPerImage <= 0 {
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
}
appendToken(m.VisionStartToken, p)
p++
basePos := p
spanStart := int32(len(out))
for range tokensPerImage {
appendToken(m.ImageTokenID, basePos)
}
spanEnd := int32(len(out))
merge := int32(1)
if m.Config != nil && m.Config.Vision != nil {
merge = m.Config.Vision.SpatialMergeSize
}
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
p += gridSpan
appendToken(m.VisionEndToken, p)
p++
state.Spans = append(state.Spans, promptVisionSpan{
Start: spanStart,
End: spanEnd,
Main: embeds.Main,
Grid: embeds.Grid,
})
}
return out, state, nil
}
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
if err != nil {
return nil, err
}
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
}

View File

@@ -2,6 +2,7 @@
package qwen3_5 package qwen3_5
import ( import (
"cmp"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
@@ -22,16 +23,26 @@ func init() {
base.Register("Qwen3NextForConditionalGeneration", NewModel) base.Register("Qwen3NextForConditionalGeneration", NewModel)
} }
var (
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
_ base.ForwardWithStateModel = (*Model)(nil)
)
// RopeParameters carries optional rope metadata embedded under rope_parameters. // RopeParameters carries optional rope metadata embedded under rope_parameters.
type RopeParameters struct { type RopeParameters struct {
Type string `json:"type"` Type string `json:"type"`
RopeType string `json:"rope_type"` RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"` PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPEInterleaved bool `json:"mrope_interleaved"`
MRoPESection []int32 `json:"mrope_section"`
DimensionSections []int32 `json:"dimension_sections"`
} }
// Config holds Qwen 3.5 text config (top-level or nested text_config). // TextConfig holds the Qwen 3.5 text-model architecture fields.
type Config struct { // In VLM configs these live under the "text_config" key; in text-only
// configs they appear at the top level.
type TextConfig struct {
ModelType string `json:"model_type"` ModelType string `json:"model_type"`
HiddenSize int32 `json:"hidden_size"` HiddenSize int32 `json:"hidden_size"`
IntermediateSize int32 `json:"intermediate_size"` IntermediateSize int32 `json:"intermediate_size"`
@@ -67,6 +78,19 @@ type Config struct {
PartialRotaryFactor float32 `json:"partial_rotary_factor"` PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling map[string]any `json:"rope_scaling"` RopeScaling map[string]any `json:"rope_scaling"`
RopeParameters *RopeParameters `json:"rope_parameters"` RopeParameters *RopeParameters `json:"rope_parameters"`
MRoPESections []int32 `json:"mrope_sections"`
MRoPEInterleaved bool `json:"mrope_interleaved"`
}
// Config is the full model config. It embeds TextConfig for the text-model
// fields and adds top-level-only fields (vision, token IDs, quantization).
type Config struct {
TextConfig
Vision *VisionConfig `json:"vision_config"`
ImageTokenID int32 `json:"image_token_id"`
VisionStartToken int32 `json:"vision_start_token_id"`
VisionEndToken int32 `json:"vision_end_token_id"`
// Quantization metadata. // Quantization metadata.
QuantGroupSize int `json:"-"` QuantGroupSize int `json:"-"`
@@ -90,6 +114,9 @@ type Model struct {
*Config *Config
weightPrefix string weightPrefix string
Vision *VisionModel
ImageProcessor *VisionImageProcessor
} }
// Layer is a transformer decoder layer. // Layer is a transformer decoder layer.
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
var cfg Config var cfg Config
activeRaw := rawTop activeRaw := rawTop
// First pass: unmarshal the full config to pick up top-level fields
// (vision_config, image_token_id, etc.) and text fields for text-only models.
if err := json.Unmarshal(configData, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config: %w", err)
}
// Second pass: if text_config exists, unmarshal it into TextConfig so
// text-model fields from text_config take priority over any top-level
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
// because they live on Config, not TextConfig.
if textRaw, ok := rawTop["text_config"]; ok { if textRaw, ok := rawTop["text_config"]; ok {
if err := json.Unmarshal(textRaw, &cfg); err != nil { if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
return Config{}, fmt.Errorf("parse text_config: %w", err) return Config{}, fmt.Errorf("parse text_config: %w", err)
} }
if err := json.Unmarshal(textRaw, &activeRaw); err != nil { if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
return Config{}, fmt.Errorf("parse text_config envelope: %w", err) return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
} }
} else {
if err := json.Unmarshal(configData, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config: %w", err)
}
} }
if cfg.HiddenSize <= 0 { if cfg.HiddenSize <= 0 {
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim) return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
} }
if cfg.RMSNormEps == 0 { cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
cfg.RMSNormEps = 1e-6 cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
}
if cfg.LinearConvKernelDim <= 0 {
cfg.LinearConvKernelDim = 4
}
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 { if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)", return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim) cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
if cfg.RopeParameters.PartialRotaryFactor > 0 { if cfg.RopeParameters.PartialRotaryFactor > 0 {
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
} }
if len(cfg.MRoPESections) == 0 {
switch {
case len(cfg.RopeParameters.MRoPESection) > 0:
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
case len(cfg.RopeParameters.DimensionSections) > 0:
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
} }
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 100000.0
} }
if cfg.PartialRotaryFactor == 0 { cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
cfg.PartialRotaryFactor = 0.25
} }
if cfg.PartialRotaryFactor < 0 { if len(cfg.MRoPESections) > 4 {
cfg.MRoPESections = cfg.MRoPESections[:4]
}
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
if cfg.PartialRotaryFactor <= 0 {
cfg.PartialRotaryFactor = 0.25 cfg.PartialRotaryFactor = 0.25
} }
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor) ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
} }
if cfg.NumExperts > 0 { if cfg.NumExperts > 0 {
if cfg.NumExpertsPerTok <= 0 { cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
cfg.NumExpertsPerTok = 1 cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
} cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
if cfg.MoeIntermediateSize <= 0 {
cfg.MoeIntermediateSize = cfg.IntermediateSize
}
if cfg.SharedExpertIntermediateSize <= 0 {
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
}
if _, ok := activeRaw["norm_topk_prob"]; !ok { if _, ok := activeRaw["norm_topk_prob"]; !ok {
cfg.NormTopKProb = true cfg.NormTopKProb = true
} }
if cfg.DecoderSparseStep <= 0 { cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
cfg.DecoderSparseStep = 1
}
} }
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.Vision != nil {
cfg.Vision.applyDefaults()
}
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
return cfg, nil return cfg, nil
} }
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cfg.Vision != nil {
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
cfg.Vision.applyPreprocessorConfig(preprocessorData)
}
}
if qt := root.QuantType(); qt != "" { if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer m.Layers[i] = layer
} }
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
if err != nil {
return err
}
m.Vision = vision
m.ImageProcessor = processor
}
return nil return nil
} }
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
return q, k, v, z, b, a return q, k, v, z, b, a
} }
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { func rotateHalf(x *mlx.Array) *mlx.Array {
shape := x.Dims()
last := int32(shape[len(shape)-1])
half := last / 2
if half <= 0 {
return x
}
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
}
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
return x
}
shape := x.Dims()
if len(shape) != 4 {
return x
}
last := int32(shape[len(shape)-1])
if ropeDim > last {
ropeDim = last
}
if ropeDim%2 != 0 {
ropeDim--
}
if ropeDim <= 0 {
return x
}
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
if ropeDim == last {
return rot
}
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
}
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
qg := a.QProj.Forward(x) qg := a.QProj.Forward(x)
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2) qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim}) q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
if c != nil { if c != nil {
offset = c.Offset() offset = c.Offset()
} }
if ropeCos != nil && ropeSin != nil {
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
} else {
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil { if c != nil {
k, v = c.Update(k, v) k, v = c.Update(k, v)
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
return mlx.Reshape(y, B, L, cfg.HiddenSize) return mlx.Reshape(y, B, L, cfg.HiddenSize)
} }
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
var r *mlx.Array var r *mlx.Array
normed := l.InputNorm.Forward(x, cfg.RMSNormEps) normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
if l.IsLinear { if l.IsLinear {
r = l.Linear.Forward(normed, c, B, L, cfg) r = l.Linear.Forward(normed, c, B, L, cfg)
} else { } else {
r = l.FullAttn.Forward(normed, c, B, L, cfg) r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
} }
h := mlx.Add(x, r) h := mlx.Add(x, r)
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg) r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
} }
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.ForwardWithState(tokens, caches, nil)
}
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
dims := tokens.Dims() dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1]) B, L := int32(dims[0]), int32(dims[1])
startPos := promptStartPosFromCaches(caches)
promptState := promptVisionStateFromState(state)
h := m.EmbedTokens.Forward(tokens) h := m.EmbedTokens.Forward(tokens)
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
var ropeCos, ropeSin *mlx.Array
if len(m.MRoPESections) > 0 {
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
}
for i, layer := range m.Layers { for i, layer := range m.Layers {
var c cache.Cache var c cache.Cache
if caches != nil && i < len(caches) { if caches != nil && i < len(caches) {
c = caches[i] c = caches[i]
} }
h = layer.Forward(h, c, B, L, m.Config) h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
} }
out := m.Norm.Forward(h, m.RMSNormEps) out := m.Norm.Forward(h, m.RMSNormEps)
return out return out

View File

@@ -1,10 +1,14 @@
package qwen3_5 package qwen3_5
import ( import (
"fmt"
"slices"
"testing" "testing"
"github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/tokenizer"
) )
func skipIfNoMLX(t *testing.T) { func skipIfNoMLX(t *testing.T) {
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
} }
func TestLayerSelectionHelpers(t *testing.T) { func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{ cfg := &Config{TextConfig: TextConfig{
NumHiddenLayers: 6, NumHiddenLayers: 6,
FullAttentionInterval: 3, FullAttentionInterval: 3,
NumExperts: 8, NumExperts: 8,
DecoderSparseStep: 2, DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1}, MLPOnlyLayers: []int32{1},
} }}
if !layerIsLinear(cfg, 0) { if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear") t.Fatalf("layer 0 should be linear")
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
func TestNewCachesLayout(t *testing.T) { func TestNewCachesLayout(t *testing.T) {
m := &Model{ m := &Model{
Config: &Config{ Config: &Config{TextConfig: TextConfig{
LinearConvKernelDim: 4, LinearConvKernelDim: 4,
LinearNumKeyHeads: 2, LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8, LinearKeyHeadDim: 8,
LinearNumValueHeads: 4, LinearNumValueHeads: 4,
LinearValueHeadDim: 16, LinearValueHeadDim: 16,
}, }},
Layers: []*Layer{ Layers: []*Layer{
{IsLinear: true}, {IsLinear: true},
{IsLinear: false}, {IsLinear: false},
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) { func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
skipIfNoMLX(t) skipIfNoMLX(t)
cfg := &Config{ cfg := &Config{TextConfig: TextConfig{
HiddenSize: 4, HiddenSize: 4,
IntermediateSize: 8, IntermediateSize: 8,
NumHiddenLayers: 2, NumHiddenLayers: 2,
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
LinearValueHeadDim: 2, LinearValueHeadDim: 2,
LinearConvKernelDim: 4, LinearConvKernelDim: 4,
FullAttentionInterval: 2, FullAttentionInterval: 2,
} }}
m := &Model{ m := &Model{
Config: cfg, Config: cfg,
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
t.Fatalf("k norm dtype = %v, want %v", got, f32) t.Fatalf("k norm dtype = %v, want %v", got, f32)
} }
} }
func TestParseConfigVisionFields(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 4,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"rope_parameters": {
"rope_theta": 10000000
}
},
"vision_config": {
"depth": 2,
"hidden_size": 256,
"num_heads": 8,
"in_channels": 3,
"patch_size": 14,
"spatial_merge_size": 2,
"layer_norm_epsilon": 0.000001,
"temporal_patch_size": 2,
"num_position_embeddings": 2304
},
"image_token_id": 111,
"vision_start_token_id": 112,
"vision_end_token_id": 113
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.Vision == nil {
t.Fatal("vision config should be parsed")
}
if cfg.Vision.Depth != 2 {
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
}
if cfg.Vision.GridPerSide != 48 {
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
}
if cfg.Vision.RopeTheta != 10000 {
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
}
if cfg.RopeTheta != 10000000 {
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
}
}
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 2048,
"intermediate_size": 8192,
"num_hidden_layers": 4,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"head_dim": 256,
"linear_num_value_heads": 32,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"rope_parameters": {
"rope_theta": 10000000,
"partial_rotary_factor": 0.25,
"mrope_interleaved": true,
"mrope_section": [11, 11, 10]
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if !cfg.MRoPEInterleaved {
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
}
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
}
func TestParseConfigVisionTokenDefaults(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 2,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.ImageTokenID != 151655 {
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
}
if cfg.VisionStartToken != 151652 {
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
}
if cfg.VisionEndToken != 151653 {
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
}
}
func TestResolveVisionPrefix(t *testing.T) {
tests := []struct {
name string
tensors map[string]*mlx.Array
want string
}{
{
name: "legacy visual prefix",
tensors: map[string]*mlx.Array{
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
},
want: "model.visual",
},
{
name: "imported vision tower prefix",
tensors: map[string]*mlx.Array{
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
},
want: "vision_tower",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
}
})
}
}
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
v := &VisionConfig{}
v.applyDefaults()
v.applyPreprocessorConfig([]byte(`{
"patch_size": 16,
"temporal_patch_size": 3,
"merge_size": 4,
"size": {
"shortest_edge": 1024,
"longest_edge": 8192
},
"image_mean": [0.1, 0.2, 0.3],
"image_std": [0.9, 0.8, 0.7]
}`))
if v.PatchSize != 16 {
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
}
if v.TemporalPatchSize != 3 {
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
}
if v.SpatialMergeSize != 4 {
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
}
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
}
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
}
}
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
v := &VisionConfig{}
v.applyDefaults()
v.applyPreprocessorConfig([]byte(`{
"size": {
"shortest_edge": 65536,
"longest_edge": 16777216
},
"patch_size": 16,
"temporal_patch_size": 2,
"merge_size": 2,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}`))
p := newVisionImageProcessor(v)
if p == nil {
t.Fatal("newVisionImageProcessor returned nil")
}
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
}
}
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
t.Helper()
tok, err := tokenizer.LoadFromBytes([]byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0},
"merges": []
}
}`))
if err != nil {
t.Fatalf("failed to load test tokenizer: %v", err)
}
return tok
}
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
skipIfNoMLX(t)
m := &Model{
tok: testTokenizer(t),
Config: &Config{
ImageTokenID: 101,
VisionStartToken: 102,
VisionEndToken: 103,
Vision: &VisionConfig{SpatialMergeSize: 2},
},
Vision: &VisionModel{},
ImageProcessor: &VisionImageProcessor{},
}
main := mlx.FromValues([]float32{
10, 11,
20, 21,
}, 1, 2, 2)
resolveCalls := 0
got, state, err := m.tokenizePromptWithResolvedImages(
"a[img-7][img-7]a",
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
func(data []byte) (*VisionEmbeddings, error) {
if string(data) != "img7" {
return nil, fmt.Errorf("unexpected data: %q", string(data))
}
resolveCalls++
return &VisionEmbeddings{
Main: main,
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
}, nil
},
)
if err != nil {
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
}
if resolveCalls != 1 {
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
}
want := []int32{
0,
102, 101, 101, 103,
102, 101, 101, 103,
0,
}
if !slices.Equal(got, want) {
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
}
if state == nil {
t.Fatal("expected prompt vision state")
}
if len(state.Spans) != 2 {
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
}
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
}
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
}
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
if !slices.Equal(state.PositionCache, wantPos) {
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
}
}
func TestBuildPromptMRoPEPositions(t *testing.T) {
m := &Model{
Config: &Config{
Vision: &VisionConfig{SpatialMergeSize: 2},
},
}
state := &promptVisionState{
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
Spans: []promptVisionSpan{
{
Start: 2,
End: 8,
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
},
},
}
pos := m.buildPromptMRoPEPositions(state, 0, 10)
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("time positions mismatch: got %v want %v", got, want)
}
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("height positions mismatch: got %v want %v", got, want)
}
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("width positions mismatch: got %v want %v", got, want)
}
}
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
if got := mapPromptPosition(state, 3); got != 2 {
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
}
if got := mapPromptPosition(state, 5); got != 4 {
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
}
}
func TestApplyPromptVisionEmbeddings(t *testing.T) {
skipIfNoMLX(t)
m := &Model{}
state := &promptVisionState{
Spans: []promptVisionSpan{
{
Start: 1,
End: 3,
Main: mlx.FromValues([]float32{
10, 11,
20, 21,
}, 1, 2, 2),
},
},
}
h := mlx.FromValues([]float32{
0, 1,
2, 3,
4, 5,
6, 7,
}, 1, 4, 2)
got := m.applyPromptVisionEmbeddings(h, 0, state)
mlx.Eval(got)
want := []float32{
0, 1,
10, 11,
20, 21,
6, 7,
}
if !slices.Equal(got.Floats(), want) {
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
}
}

854
x/models/qwen3_5/vision.go Normal file
View File

@@ -0,0 +1,854 @@
package qwen3_5
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"math"
"github.com/ollama/ollama/model/imageproc"
"github.com/ollama/ollama/x/mlxrunner/mlx"
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/models/nn"
)
var errNoVisionModel = errors.New("qwen3_5: no vision model")
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
type VisionConfig struct {
Depth int32 `json:"depth"`
HiddenSize int32 `json:"hidden_size"`
NumHeads int32 `json:"num_heads"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"patch_size"`
SpatialMergeSize int32 `json:"spatial_merge_size"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize int32 `json:"temporal_patch_size"`
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
Size struct {
ShortestEdge int32 `json:"shortest_edge"`
LongestEdge int32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
GridPerSide int32 `json:"-"`
}
func (v *VisionConfig) applyDefaults() {
if v == nil {
return
}
if v.HiddenSize <= 0 {
v.HiddenSize = 1280
}
if v.NumHeads <= 0 {
v.NumHeads = 16
}
if v.InChannels <= 0 {
v.InChannels = 3
}
if v.PatchSize <= 0 {
v.PatchSize = 14
}
if v.SpatialMergeSize <= 0 {
v.SpatialMergeSize = 2
}
if v.LayerNormEpsilon == 0 {
v.LayerNormEpsilon = 1e-6
}
if v.RopeTheta == 0 {
v.RopeTheta = 10000
}
if v.TemporalPatchSize <= 0 {
v.TemporalPatchSize = 2
}
if v.NumPositionEmbeddings <= 0 {
v.NumPositionEmbeddings = 2304
}
if len(v.ImageMean) < 3 {
v.ImageMean = []float32{0.5, 0.5, 0.5}
}
if len(v.ImageStd) < 3 {
v.ImageStd = []float32{0.5, 0.5, 0.5}
}
if v.Size.ShortestEdge <= 0 {
v.Size.ShortestEdge = 64 << 10
}
if v.Size.LongestEdge <= 0 {
v.Size.LongestEdge = 2 << 20
}
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
if grid <= 0 {
grid = 48
}
v.GridPerSide = grid
}
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
if v == nil || len(data) == 0 {
return
}
var pre struct {
Size struct {
ShortestEdge int32 `json:"shortest_edge"`
LongestEdge int32 `json:"longest_edge"`
} `json:"size"`
PatchSize int32 `json:"patch_size"`
TemporalPatchSize int32 `json:"temporal_patch_size"`
MergeSize int32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
if err := json.Unmarshal(data, &pre); err != nil {
return
}
if pre.PatchSize > 0 {
v.PatchSize = pre.PatchSize
}
if pre.TemporalPatchSize > 0 {
v.TemporalPatchSize = pre.TemporalPatchSize
}
if pre.MergeSize > 0 {
v.SpatialMergeSize = pre.MergeSize
}
if pre.Size.ShortestEdge > 0 {
v.Size.ShortestEdge = pre.Size.ShortestEdge
}
if pre.Size.LongestEdge > 0 {
v.Size.LongestEdge = pre.Size.LongestEdge
}
if len(pre.ImageMean) >= 3 {
v.ImageMean = pre.ImageMean
}
if len(pre.ImageStd) >= 3 {
v.ImageStd = pre.ImageStd
}
v.applyDefaults()
}
// VisionGrid tracks patch-grid dimensions for an image.
type VisionGrid struct {
Height int32
Width int32
Temporal int32
}
// VisionImageProcessor reproduces qwen3vl image preprocessing.
type VisionImageProcessor struct {
numChannels int32
patchSize int32
temporalPatchSize int32
mergeSize int32
shortestEdge int32
longestEdge int32
factor int32
imageMean [3]float32
imageStd [3]float32
}
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
if cfg == nil {
return nil
}
return &VisionImageProcessor{
numChannels: cfg.InChannels,
patchSize: cfg.PatchSize,
temporalPatchSize: cfg.TemporalPatchSize,
mergeSize: cfg.SpatialMergeSize,
shortestEdge: cfg.Size.ShortestEdge,
longestEdge: cfg.Size.LongestEdge,
factor: cfg.PatchSize * cfg.SpatialMergeSize,
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
}
}
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
factor := int(p.factor)
if factor <= 0 {
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
}
if height < factor || width < factor {
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
}
if min(height, width) == 0 {
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
}
if max(height, width)/min(height, width) > 200 {
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
}
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := roundEven(float64(height)/float64(factor)) * factor
wBar := roundEven(float64(width)/float64(factor)) * factor
if hBar*wBar > int(p.longestEdge) {
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < int(p.shortestEdge) {
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar, nil
}
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
if p == nil {
return nil, nil, errNoVisionModel
}
img = imageproc.Composite(img)
origW := img.Bounds().Dx()
origH := img.Bounds().Dy()
resizedH, resizedW, err := p.smartResize(origH, origW)
if err != nil {
return nil, nil, err
}
resized := imageproc.Resize(
img,
image.Point{X: resizedW, Y: resizedH},
imageproc.ResizeBilinear,
)
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
grid := &VisionGrid{
Height: int32(resizedH / int(p.patchSize)),
Width: int32(resizedW / int(p.patchSize)),
Temporal: 1,
}
patches := p.createPatches(pixels, resizedH, resizedW, grid)
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
numPatches := int(grid.Height * grid.Width)
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
return pixelValues, grid, nil
}
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
channels := int(p.numChannels)
patchSize := int(p.patchSize)
mergeSize := int(p.mergeSize)
temporalPatchSize := int(p.temporalPatchSize)
// Temporal is always 1 for static images; only spatial patches are created.
numPatches := int(grid.Height * grid.Width)
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
for h := 0; h < int(grid.Height); h += mergeSize {
for w := 0; w < int(grid.Width); w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + py*patchSize + px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
if temporalPatchSize > 1 {
for c := range channels {
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
cur := channelOffset + tp*frameSize
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
return result
}
// VisionAttention runs one self-attention block inside the vision encoder.
type VisionAttention struct {
QKV nn.LinearLayer
Query nn.LinearLayer
Key nn.LinearLayer
Value nn.LinearLayer
Output nn.LinearLayer
}
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
}
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
}
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
headDim := cfg.HiddenSize / cfg.NumHeads
if headDim <= 0 {
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
}
var q, k, v *mlx.Array
if a.QKV != nil {
qkv := a.QKV.Forward(x)
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
} else {
if a.Query == nil || a.Key == nil || a.Value == nil {
return nil, errors.New("vision attention is missing q/k/v projections")
}
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
}
q = applyVisionRoPE(q, cos, sin)
k = applyVisionRoPE(k, cos, sin)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
if a.Output == nil {
return nil, errors.New("vision attention is missing output projection")
}
return a.Output.Forward(attn), nil
}
// VisionMLP is the vision feed-forward block.
type VisionMLP struct {
FC1 nn.LinearLayer
FC2 nn.LinearLayer
}
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
if m.FC1 == nil || m.FC2 == nil {
return nil, errors.New("vision mlp is missing fc1/fc2")
}
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
}
// VisionEncoderLayer is one transformer block in the vision encoder.
type VisionEncoderLayer struct {
Norm1 *nn.LayerNorm
Attn *VisionAttention
Norm2 *nn.LayerNorm
MLP *VisionMLP
}
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
return nil, errors.New("vision layer is incomplete")
}
r := x
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
if err != nil {
return nil, err
}
x = mlx.Add(r, a)
r = x
m, err := l.MLP.Forward(l.Norm2.Forward(x))
if err != nil {
return nil, err
}
return mlx.Add(r, m), nil
}
// VisionPatchMerger projects merged spatial groups into language embedding space.
type VisionPatchMerger struct {
Norm *nn.LayerNorm
FC1 nn.LinearLayer
FC2 nn.LinearLayer
}
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
}
if merge <= 0 {
merge = 1
}
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
group := merge * merge
if group <= 0 || L%group != 0 {
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
}
x = mlx.Reshape(x, B, L/group, group, D)
x = mlx.Reshape(x, B, L/group, group*D)
return x, nil
}
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
return nil, errors.New("vision patch merger is incomplete")
}
x = m.Norm.Forward(x)
var err error
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
if err != nil {
return nil, err
}
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
return x, nil
}
// VisionModel contains the full Qwen vision tower.
type VisionModel struct {
PatchProjection nn.LinearLayer
PositionEmbed *nn.Embedding
Layers []*VisionEncoderLayer
PatchMerger *VisionPatchMerger
cfg *VisionConfig
}
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
if merge <= 0 {
merge = 1
}
// Temporal is always 1 for static images; only spatial coordinates are generated.
coords := make([][2]int32, 0, grid.Height*grid.Width)
for h := int32(0); h < grid.Height; h += merge {
for w := int32(0); w < grid.Width; w += merge {
for mh := range merge {
for mw := range merge {
coords = append(coords, [2]int32{h + mh, w + mw})
}
}
}
}
return coords
}
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
if m.PositionEmbed == nil {
return x, nil
}
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
}
B, D := int32(shape[0]), int32(shape[2])
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
L := int32(len(coords))
if L != int32(shape[1]) {
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
}
stepH := float32(0)
if grid.Height > 1 {
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
}
stepW := float32(0)
if grid.Width > 1 {
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
}
indices := make([]int32, 0, L*4)
weights := make([]float32, 0, L*4)
for _, c := range coords {
y := float32(c[0]) * stepH
x0 := float32(c[1]) * stepW
fy := int32(y)
fx := int32(x0)
cy := min(fy+1, m.cfg.GridPerSide-1)
cx := min(fx+1, m.cfg.GridPerSide-1)
indices = append(indices,
fy*m.cfg.GridPerSide+fx,
fy*m.cfg.GridPerSide+cx,
cy*m.cfg.GridPerSide+fx,
cy*m.cfg.GridPerSide+cx,
)
dy := y - float32(fy)
dx := x0 - float32(fx)
weights = append(weights,
(1-dy)*(1-dx),
(1-dy)*dx,
dy*(1-dx),
dy*dx,
)
}
idxArr := mlx.FromValues(indices, int(L), 4)
wArr := mlx.FromValues(weights, int(L), 4, 1)
pos := m.PositionEmbed.Forward(idxArr)
wArr = wArr.AsType(pos.DType())
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
if D != int32(pos.Dim(1)) {
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
}
pos = mlx.ExpandDims(pos, 0)
if B > 1 {
pos = mlx.Tile(pos, []int32{B, 1, 1})
}
return mlx.Add(x, pos), nil
}
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
if headDim <= 0 {
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
}
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
L := int32(len(coords))
half := headDim / 2
quarter := half / 2
if quarter <= 0 {
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
}
angles := make([]float32, L*headDim)
for i, c := range coords {
base := int32(i) * headDim
for j := range quarter {
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
angles[base+j] = float32(float64(c[0]) * freq)
angles[base+quarter+j] = float32(float64(c[1]) * freq)
}
for j := range half {
angles[base+half+j] = angles[base+j]
}
}
arr := mlx.FromValues(angles, int(L), int(headDim))
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
return cos, sin, nil
}
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
if m == nil || pixelValues == nil || grid == nil {
return nil, errNoVisionModel
}
if m.PatchProjection == nil || m.PatchMerger == nil {
return nil, errors.New("vision model is missing required projections")
}
x := m.PatchProjection.Forward(pixelValues)
var err error
x, err = m.addPositionEmbedding(x, grid)
if err != nil {
return nil, err
}
cos, sin, err := m.rotaryEmbeddings(grid)
if err != nil {
return nil, err
}
for i, layer := range m.Layers {
x, err = layer.Forward(x, cos, sin, m.cfg)
if err != nil {
return nil, fmt.Errorf("vision layer %d: %w", i, err)
}
}
main, err := m.PatchMerger.Forward(x, m.cfg)
if err != nil {
return nil, fmt.Errorf("vision patch merger: %w", err)
}
return main, nil
}
type VisionEmbeddings struct {
Main *mlx.Array
Grid *VisionGrid
}
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
return nil, errNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
main, err := m.Vision.Forward(pixelValues, grid)
if err != nil {
return nil, err
}
return &VisionEmbeddings{Main: main, Grid: grid}, nil
}
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
candidates := []string{
"vision_tower",
weightPrefix + "vision_tower",
"model.visual",
"visual",
weightPrefix + "model.visual",
weightPrefix + "visual",
}
hasTensor := func(prefix string) bool {
for _, suffix := range []string{
".patch_embed.proj.weight",
".patch_embed.weight",
".pos_embed.weight",
".blocks.0.attn.qkv.weight",
".merger.linear_fc1.weight",
".merger.mlp.0.weight",
} {
if tensors[prefix+suffix] != nil {
return true
}
}
return false
}
for _, prefix := range candidates {
if hasTensor(prefix) {
return prefix
}
}
return ""
}
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
for _, p := range paths {
if l := linears.Make(p); l != nil {
return l
}
}
return nil
}
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
for _, base := range bases {
if w := tensors[base+".weight"]; w != nil {
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
}
if w := tensors[base]; w != nil {
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
}
}
return nil
}
func loadVisionPatchMerger(
tensors map[string]*mlx.Array,
linears mlxmodel.LinearFactory,
eps float32,
bases ...string,
) *VisionPatchMerger {
for _, base := range bases {
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
if norm != nil && fc1 != nil && fc2 != nil {
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
}
}
return nil
}
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
if w == nil || !w.Valid() {
return nil, errors.New("missing patch embedding weight")
}
if w.NumDims() < 2 {
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
}
if w.NumDims() == 2 {
return w, nil
}
out := int32(w.Dim(0))
in := int32(w.Size() / w.Dim(0))
return mlx.Reshape(w, out, in), nil
}
func loadVisionComponents(
tensors map[string]*mlx.Array,
linears mlxmodel.LinearFactory,
cfg *Config,
weightPrefix string,
) (*VisionModel, *VisionImageProcessor, error) {
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
return nil, nil, nil
}
cfg.Vision.applyDefaults()
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
if visionPrefix == "" {
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
}
patchW, _ := tensorAny(
tensors,
visionPrefix+".patch_embed.proj.weight",
visionPrefix+".patch_embed.weight",
)
if patchW == nil {
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
}
patchW, err := flattenPatchEmbeddingWeight(patchW)
if err != nil {
return nil, nil, err
}
patchB, _ := tensorAny(
tensors,
visionPrefix+".patch_embed.proj.bias",
visionPrefix+".patch_embed.bias",
)
patchProj := nn.NewLinear(patchW, patchB)
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
return nil, nil, fmt.Errorf(
"vision patch embedding input dim mismatch: got %d expected %d",
got,
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
)
}
posW, _ := tensorAny(
tensors,
visionPrefix+".pos_embed.weight",
visionPrefix+".position_embedding.weight",
)
if posW == nil {
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
}
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
cfg.Vision.applyDefaults()
vm := &VisionModel{
PatchProjection: patchProj,
PositionEmbed: nn.NewEmbedding(posW),
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
cfg: cfg.Vision,
}
for i := range cfg.Vision.Depth {
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
layer := &VisionEncoderLayer{
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
Attn: &VisionAttention{
QKV: firstLinear(
linears,
layerPrefix+".attn.qkv",
layerPrefix+".attn_qkv",
),
Query: firstLinear(
linears,
layerPrefix+".attn.q_proj",
layerPrefix+".attn_q",
),
Key: firstLinear(
linears,
layerPrefix+".attn.k_proj",
layerPrefix+".attn_k",
),
Value: firstLinear(
linears,
layerPrefix+".attn.v_proj",
layerPrefix+".attn_v",
),
Output: firstLinear(
linears,
layerPrefix+".attn.proj",
layerPrefix+".attn_out",
layerPrefix+".attn.o_proj",
),
},
MLP: &VisionMLP{
FC1: firstLinear(
linears,
layerPrefix+".mlp.fc1",
layerPrefix+".mlp.linear_fc1",
),
FC2: firstLinear(
linears,
layerPrefix+".mlp.fc2",
layerPrefix+".mlp.linear_fc2",
),
},
}
if layer.Norm1 == nil || layer.Norm2 == nil {
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
}
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
}
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
}
vm.Layers[i] = layer
}
vm.PatchMerger = loadVisionPatchMerger(
tensors,
linears,
cfg.Vision.LayerNormEpsilon,
visionPrefix+".merger",
)
if vm.PatchMerger == nil {
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
}
return vm, newVisionImageProcessor(cfg.Vision), nil
}