mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 16:54:13 +02:00
Compare commits
5 Commits
hoyyeva/op
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
578c32e42e | ||
|
|
a10d2625ca | ||
|
|
b960d769ad | ||
|
|
455a6099d1 | ||
|
|
7e6e8377eb |
@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
|
||||
// Check if model supports thinking based on architecture
|
||||
if supportsThinking(opts.ModelDir) {
|
||||
configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
|
||||
mcfg := parseModelConfig(configData)
|
||||
|
||||
if mcfg.supportsThinking() {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
if mcfg.supportsVision() {
|
||||
capabilities = append(capabilities, "vision")
|
||||
}
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
parserName = mcfg.parserName()
|
||||
rendererName = mcfg.rendererName()
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// modelConfig holds the fields from config.json needed during model creation.
|
||||
type visionConfig struct {
|
||||
Depth int32 `json:"depth"`
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
type modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
VisionConfig *visionConfig `json:"vision_config"`
|
||||
ImageTokenID *int32 `json:"image_token_id"`
|
||||
VisionStartTokenID *int32 `json:"vision_start_token_id"`
|
||||
VisionEndTokenID *int32 `json:"vision_end_token_id"`
|
||||
}
|
||||
|
||||
// Check architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
func parseModelConfig(data []byte) modelConfig {
|
||||
var cfg modelConfig
|
||||
_ = json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Check the architecture list
|
||||
for _, arch := range cfg.Architectures {
|
||||
// archOrTypeContains returns true if any architecture or the model_type
|
||||
// contains one of the given substrings (case-insensitive).
|
||||
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
|
||||
for _, arch := range c.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
for _, s := range substrs {
|
||||
if strings.Contains(archLower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(typeLower, thinkArch) {
|
||||
if c.ModelType != "" {
|
||||
typeLower := strings.ToLower(c.ModelType)
|
||||
for _, s := range substrs {
|
||||
if strings.Contains(typeLower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
||||
func getParserName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
func (c *modelConfig) supportsThinking() bool {
|
||||
return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
func (c *modelConfig) supportsVision() bool {
|
||||
return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
|
||||
}
|
||||
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
func (c *modelConfig) parserName() string {
|
||||
switch {
|
||||
case c.archOrTypeContains("glm4", "glm-4"):
|
||||
return "glm-4.7"
|
||||
case c.archOrTypeContains("deepseek"):
|
||||
return "deepseek3"
|
||||
case c.archOrTypeContains("qwen3"):
|
||||
return "qwen3"
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRendererName returns the renderer name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
||||
func getRendererName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
func (c *modelConfig) rendererName() string {
|
||||
switch {
|
||||
case c.archOrTypeContains("glm4", "glm-4"):
|
||||
return "glm-4.7"
|
||||
case c.archOrTypeContains("deepseek"):
|
||||
return "deepseek3"
|
||||
case c.archOrTypeContains("qwen3"):
|
||||
return "qwen3-coder"
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
||||
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsVision(t *testing.T) {
|
||||
t.Run("vision_config present", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"vision_config": {"depth": 2},
|
||||
"image_token_id": 151655
|
||||
}`))
|
||||
if !cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token ids alone imply vision", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"vision_start_token_id": 10,
|
||||
"vision_end_token_id": 11
|
||||
}`))
|
||||
if !cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain text model", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"architectures": ["Qwen3_5ForCausalLM"]
|
||||
}`))
|
||||
if cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = true, want false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -366,6 +366,23 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
c.enforceEvictionPolicy()
|
||||
}
|
||||
|
||||
// clear releases live caches and drops the trie so future requests cannot
|
||||
// reuse prompt state keyed only by token IDs.
|
||||
func (c *kvCache) clear() {
|
||||
c.freeAll()
|
||||
walkNodes(c.root, func(n *trieNode) bool {
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
n.snapshots = nil
|
||||
return true
|
||||
})
|
||||
c.root = nil
|
||||
c.activePath = nil
|
||||
}
|
||||
|
||||
// freeAll releases all cache layers.
|
||||
func (c *kvCache) freeAll() {
|
||||
for _, kv := range c.caches {
|
||||
|
||||
@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []llm.ImageData `json:"images,omitempty"`
|
||||
Options *completionOpts `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Images: req.Images,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
|
||||
@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
|
||||
32
x/mlxrunner/model/base/multimodal.go
Normal file
32
x/mlxrunner/model/base/multimodal.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// ImageInput is a single image attached to a prompt.
|
||||
type ImageInput struct {
|
||||
ID int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
|
||||
// model metadata needed during forward.
|
||||
type PromptTokenization struct {
|
||||
Tokens []int32
|
||||
State any
|
||||
}
|
||||
|
||||
// MultimodalPromptTokenizerWithState is an optional model interface used by
|
||||
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
|
||||
// request-scoped state to be attached to the forward pass.
|
||||
type MultimodalPromptTokenizerWithState interface {
|
||||
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
|
||||
}
|
||||
|
||||
// ForwardWithStateModel is an optional model interface for request-scoped
|
||||
// forward metadata that should not be stored in shared caches.
|
||||
type ForwardWithStateModel interface {
|
||||
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
|
||||
}
|
||||
@@ -12,12 +12,42 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
|
||||
if len(request.Images) > 0 {
|
||||
// The shared trie cache keys only on token IDs today, so multimodal
|
||||
// prompts must not reuse snapshots across distinct image inputs.
|
||||
r.cache.clear()
|
||||
}
|
||||
|
||||
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
|
||||
images := make([]base.ImageInput, len(request.Images))
|
||||
for i := range request.Images {
|
||||
images[i] = base.ImageInput{
|
||||
ID: request.Images[i].ID,
|
||||
Data: request.Images[i].Data,
|
||||
}
|
||||
}
|
||||
|
||||
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if out == nil {
|
||||
return nil, nil, errors.New("empty multimodal tokenization result")
|
||||
}
|
||||
return out.Tokens, out.State, nil
|
||||
}
|
||||
|
||||
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
@@ -55,7 +85,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
inputs, promptState, err := r.tokenizeRequest(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
@@ -83,6 +116,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
modelForward := func(tokens *mlx.Array) *mlx.Array {
|
||||
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
|
||||
return withState.ForwardWithState(tokens, caches, promptState)
|
||||
}
|
||||
return r.Model.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
@@ -114,7 +154,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
processed += n
|
||||
@@ -132,7 +172,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
fwd := modelForward(token.ExpandDims(0))
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -29,7 +30,8 @@ type Request struct {
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []llm.ImageData `json:"images,omitempty"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
|
||||
354
x/models/qwen3_5/multimodal.go
Normal file
354
x/models/qwen3_5/multimodal.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
|
||||
type promptVisionSpan struct {
|
||||
Start int32
|
||||
End int32
|
||||
|
||||
Main *mlx.Array
|
||||
Grid *VisionGrid
|
||||
}
|
||||
|
||||
type promptVisionState struct {
|
||||
Spans []promptVisionSpan
|
||||
PositionCache []int32
|
||||
}
|
||||
|
||||
func promptStartPosFromCaches(caches []cache.Cache) int32 {
|
||||
offset := -1
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
off := c.Offset()
|
||||
if offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
}
|
||||
if offset < 0 {
|
||||
return 0
|
||||
}
|
||||
return int32(offset)
|
||||
}
|
||||
|
||||
func promptVisionStateFromState(state any) *promptVisionState {
|
||||
typed, _ := state.(*promptVisionState)
|
||||
return typed
|
||||
}
|
||||
|
||||
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
|
||||
chunkEnd := chunkStart + chunkLen
|
||||
overlapStart := max(chunkStart, spanStart)
|
||||
overlapEnd := min(chunkEnd, spanEnd)
|
||||
if overlapStart >= overlapEnd {
|
||||
return 0, 0, 0, 0, false
|
||||
}
|
||||
|
||||
chunkLo := overlapStart - chunkStart
|
||||
chunkHi := overlapEnd - chunkStart
|
||||
spanLo := overlapStart - spanStart
|
||||
spanHi := overlapEnd - spanStart
|
||||
return chunkLo, chunkHi, spanLo, spanHi, true
|
||||
}
|
||||
|
||||
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
|
||||
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
dims := h.Dims()
|
||||
if len(dims) != 3 {
|
||||
return h
|
||||
}
|
||||
|
||||
L := int32(dims[1])
|
||||
for _, span := range state.Spans {
|
||||
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
|
||||
if !ok || span.Main == nil || !span.Main.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
repl := span.Main.Slice(
|
||||
mlx.Slice(),
|
||||
mlx.Slice(int(spanLo), int(spanHi)),
|
||||
mlx.Slice(),
|
||||
)
|
||||
repl = repl.AsType(h.DType())
|
||||
h = h.SliceUpdate(
|
||||
repl,
|
||||
mlx.Slice(),
|
||||
mlx.Slice(int(chunkLo), int(chunkHi)),
|
||||
mlx.Slice(),
|
||||
)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
|
||||
for i := range images {
|
||||
if images[i].ID == id {
|
||||
return images[i], true
|
||||
}
|
||||
}
|
||||
return base.ImageInput{}, false
|
||||
}
|
||||
|
||||
func mapPromptPosition(state *promptVisionState, id int32) int32 {
|
||||
if state == nil {
|
||||
return id
|
||||
}
|
||||
if id < int32(len(state.PositionCache)) {
|
||||
return state.PositionCache[id]
|
||||
}
|
||||
if len(state.PositionCache) > 0 {
|
||||
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
|
||||
if fallback <= 0 {
|
||||
fallback = 1
|
||||
}
|
||||
if grid == nil {
|
||||
return fallback
|
||||
}
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
|
||||
}
|
||||
|
||||
func normalizeMRoPESections(sections []int32) [4]int32 {
|
||||
var out [4]int32
|
||||
for i := range min(4, len(sections)) {
|
||||
if sections[i] > 0 {
|
||||
out[i] = sections[i]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
|
||||
if interleaved {
|
||||
if pair%3 == 1 && pair < 1+3*sections[1] {
|
||||
return 1
|
||||
}
|
||||
if pair%3 == 2 && pair < 2+3*sections[2] {
|
||||
return 2
|
||||
}
|
||||
if pair%3 == 0 && pair < 3*sections[0] {
|
||||
return 0
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
secW := sections[0] + sections[1]
|
||||
secE := secW + sections[2]
|
||||
switch {
|
||||
case pair < sections[0]:
|
||||
return 0
|
||||
case pair < secW:
|
||||
return 1
|
||||
case pair < secE:
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
|
||||
var positions [4][]int32
|
||||
for i := range positions {
|
||||
positions[i] = make([]int32, chunkLen)
|
||||
}
|
||||
|
||||
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
|
||||
for i := range chunkLen {
|
||||
p := mapPromptPosition(state, startPos+i)
|
||||
positions[0][i] = p
|
||||
positions[1][i] = p
|
||||
positions[2][i] = p
|
||||
}
|
||||
|
||||
merge := int32(1)
|
||||
if m != nil && m.Config != nil && m.Config.Vision != nil {
|
||||
merge = m.Config.Vision.SpatialMergeSize
|
||||
}
|
||||
for _, span := range state.Spans {
|
||||
if span.Grid == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
w := max(int32(1), span.Grid.Width/merge)
|
||||
for i := chunkLo; i < chunkHi; i++ {
|
||||
rel := spanLo + (i - chunkLo)
|
||||
positions[1][i] += rel / w
|
||||
positions[2][i] += rel % w
|
||||
}
|
||||
}
|
||||
|
||||
return positions
|
||||
}
|
||||
|
||||
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
|
||||
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ropeDim := m.Config.RopeDim
|
||||
if ropeDim%2 != 0 {
|
||||
ropeDim--
|
||||
}
|
||||
if ropeDim <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
half := ropeDim / 2
|
||||
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
|
||||
sections := normalizeMRoPESections(m.Config.MRoPESections)
|
||||
theta := m.Config.RopeTheta
|
||||
if theta <= 0 {
|
||||
theta = 100000.0
|
||||
}
|
||||
|
||||
freqs := make([]float64, half)
|
||||
for j := range half {
|
||||
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
|
||||
}
|
||||
|
||||
angles := make([]float32, chunkLen*ropeDim)
|
||||
for i := range chunkLen {
|
||||
base := i * ropeDim
|
||||
for j := range half {
|
||||
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
|
||||
angle := float32(float64(positions[component][i]) * freqs[j])
|
||||
angles[base+j] = angle
|
||||
angles[base+half+j] = angle
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
|
||||
cos := mlx.Cos(arr)
|
||||
sin := mlx.Sin(arr)
|
||||
if dtype != 0 {
|
||||
cos = cos.AsType(dtype)
|
||||
sin = sin.AsType(dtype)
|
||||
}
|
||||
return cos, sin
|
||||
}
|
||||
|
||||
func (m *Model) tokenizePromptWithResolvedImages(
|
||||
prompt string,
|
||||
images []base.ImageInput,
|
||||
resolve func([]byte) (*VisionEmbeddings, error),
|
||||
) ([]int32, *promptVisionState, error) {
|
||||
if m == nil || m.tok == nil {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
|
||||
}
|
||||
|
||||
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
|
||||
return m.tok.Encode(prompt, true), nil, nil
|
||||
}
|
||||
|
||||
parts := imageTagRE.Split(prompt, -1)
|
||||
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
|
||||
|
||||
resolved := make(map[int]*VisionEmbeddings, len(images))
|
||||
var out []int32
|
||||
state := &promptVisionState{}
|
||||
var p int32
|
||||
appendToken := func(tok, pos int32) {
|
||||
out = append(out, tok)
|
||||
state.PositionCache = append(state.PositionCache, pos)
|
||||
}
|
||||
for i, part := range parts {
|
||||
for _, tok := range m.tok.Encode(part, i == 0) {
|
||||
appendToken(tok, p)
|
||||
p++
|
||||
}
|
||||
|
||||
if i >= len(matches) {
|
||||
continue
|
||||
}
|
||||
|
||||
imageID, err := strconv.Atoi(matches[i][1])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
|
||||
}
|
||||
|
||||
img, ok := findImageByID(images, imageID)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
|
||||
}
|
||||
|
||||
embeds := resolved[imageID]
|
||||
if embeds == nil {
|
||||
embeds, err = resolve(img.Data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
resolved[imageID] = embeds
|
||||
}
|
||||
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
|
||||
}
|
||||
|
||||
tokensPerImage := int32(embeds.Main.Dim(1))
|
||||
if tokensPerImage <= 0 {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
|
||||
}
|
||||
|
||||
appendToken(m.VisionStartToken, p)
|
||||
p++
|
||||
basePos := p
|
||||
spanStart := int32(len(out))
|
||||
for range tokensPerImage {
|
||||
appendToken(m.ImageTokenID, basePos)
|
||||
}
|
||||
spanEnd := int32(len(out))
|
||||
merge := int32(1)
|
||||
if m.Config != nil && m.Config.Vision != nil {
|
||||
merge = m.Config.Vision.SpatialMergeSize
|
||||
}
|
||||
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
|
||||
p += gridSpan
|
||||
appendToken(m.VisionEndToken, p)
|
||||
p++
|
||||
|
||||
state.Spans = append(state.Spans, promptVisionSpan{
|
||||
Start: spanStart,
|
||||
End: spanEnd,
|
||||
Main: embeds.Main,
|
||||
Grid: embeds.Grid,
|
||||
})
|
||||
}
|
||||
|
||||
return out, state, nil
|
||||
}
|
||||
|
||||
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
|
||||
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
@@ -22,16 +23,26 @@ func init() {
|
||||
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
||||
}
|
||||
|
||||
var (
|
||||
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
|
||||
_ base.ForwardWithStateModel = (*Model)(nil)
|
||||
)
|
||||
|
||||
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
||||
type RopeParameters struct {
|
||||
Type string `json:"type"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
DimensionSections []int32 `json:"dimension_sections"`
|
||||
}
|
||||
|
||||
// Config holds Qwen 3.5 text config (top-level or nested text_config).
|
||||
type Config struct {
|
||||
// TextConfig holds the Qwen 3.5 text-model architecture fields.
|
||||
// In VLM configs these live under the "text_config" key; in text-only
|
||||
// configs they appear at the top level.
|
||||
type TextConfig struct {
|
||||
ModelType string `json:"model_type"`
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
@@ -67,6 +78,19 @@ type Config struct {
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling map[string]any `json:"rope_scaling"`
|
||||
RopeParameters *RopeParameters `json:"rope_parameters"`
|
||||
MRoPESections []int32 `json:"mrope_sections"`
|
||||
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||
}
|
||||
|
||||
// Config is the full model config. It embeds TextConfig for the text-model
|
||||
// fields and adds top-level-only fields (vision, token IDs, quantization).
|
||||
type Config struct {
|
||||
TextConfig
|
||||
|
||||
Vision *VisionConfig `json:"vision_config"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
VisionStartToken int32 `json:"vision_start_token_id"`
|
||||
VisionEndToken int32 `json:"vision_end_token_id"`
|
||||
|
||||
// Quantization metadata.
|
||||
QuantGroupSize int `json:"-"`
|
||||
@@ -90,6 +114,9 @@ type Model struct {
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
|
||||
Vision *VisionModel
|
||||
ImageProcessor *VisionImageProcessor
|
||||
}
|
||||
|
||||
// Layer is a transformer decoder layer.
|
||||
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
|
||||
var cfg Config
|
||||
activeRaw := rawTop
|
||||
|
||||
// First pass: unmarshal the full config to pick up top-level fields
|
||||
// (vision_config, image_token_id, etc.) and text fields for text-only models.
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Second pass: if text_config exists, unmarshal it into TextConfig so
|
||||
// text-model fields from text_config take priority over any top-level
|
||||
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
|
||||
// because they live on Config, not TextConfig.
|
||||
if textRaw, ok := rawTop["text_config"]; ok {
|
||||
if err := json.Unmarshal(textRaw, &cfg); err != nil {
|
||||
if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
|
||||
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
||||
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.LinearConvKernelDim <= 0 {
|
||||
cfg.LinearConvKernelDim = 4
|
||||
}
|
||||
cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
|
||||
cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
|
||||
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
||||
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
||||
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
||||
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
||||
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
||||
}
|
||||
if len(cfg.MRoPESections) == 0 {
|
||||
switch {
|
||||
case len(cfg.RopeParameters.MRoPESection) > 0:
|
||||
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
|
||||
case len(cfg.RopeParameters.DimensionSections) > 0:
|
||||
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
|
||||
}
|
||||
}
|
||||
cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 100000.0
|
||||
if len(cfg.MRoPESections) > 4 {
|
||||
cfg.MRoPESections = cfg.MRoPESections[:4]
|
||||
}
|
||||
if cfg.PartialRotaryFactor == 0 {
|
||||
cfg.PartialRotaryFactor = 0.25
|
||||
}
|
||||
if cfg.PartialRotaryFactor < 0 {
|
||||
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
|
||||
if cfg.PartialRotaryFactor <= 0 {
|
||||
cfg.PartialRotaryFactor = 0.25
|
||||
}
|
||||
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
||||
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
}
|
||||
|
||||
if cfg.NumExperts > 0 {
|
||||
if cfg.NumExpertsPerTok <= 0 {
|
||||
cfg.NumExpertsPerTok = 1
|
||||
}
|
||||
if cfg.MoeIntermediateSize <= 0 {
|
||||
cfg.MoeIntermediateSize = cfg.IntermediateSize
|
||||
}
|
||||
if cfg.SharedExpertIntermediateSize <= 0 {
|
||||
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
|
||||
}
|
||||
cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
|
||||
cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
|
||||
cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
|
||||
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
||||
cfg.NormTopKProb = true
|
||||
}
|
||||
if cfg.DecoderSparseStep <= 0 {
|
||||
cfg.DecoderSparseStep = 1
|
||||
}
|
||||
cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
|
||||
}
|
||||
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
if cfg.Vision != nil {
|
||||
cfg.Vision.applyDefaults()
|
||||
}
|
||||
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
|
||||
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
|
||||
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.Vision != nil {
|
||||
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
|
||||
cfg.Vision.applyPreprocessorConfig(preprocessorData)
|
||||
}
|
||||
}
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
|
||||
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vision = vision
|
||||
m.ImageProcessor = processor
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
||||
return q, k, v, z, b, a
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func rotateHalf(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Dims()
|
||||
last := int32(shape[len(shape)-1])
|
||||
half := last / 2
|
||||
if half <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
|
||||
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
|
||||
}
|
||||
|
||||
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
|
||||
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
shape := x.Dims()
|
||||
if len(shape) != 4 {
|
||||
return x
|
||||
}
|
||||
|
||||
last := int32(shape[len(shape)-1])
|
||||
if ropeDim > last {
|
||||
ropeDim = last
|
||||
}
|
||||
if ropeDim%2 != 0 {
|
||||
ropeDim--
|
||||
}
|
||||
if ropeDim <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
|
||||
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
|
||||
if ropeDim == last {
|
||||
return rot
|
||||
}
|
||||
|
||||
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||
qg := a.QProj.Forward(x)
|
||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
if ropeCos != nil && ropeSin != nil {
|
||||
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
|
||||
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
|
||||
} else {
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||
var r *mlx.Array
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
if l.IsLinear {
|
||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||
} else {
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
|
||||
}
|
||||
h := mlx.Add(x, r)
|
||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.ForwardWithState(tokens, caches, nil)
|
||||
}
|
||||
|
||||
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
startPos := promptStartPosFromCaches(caches)
|
||||
promptState := promptVisionStateFromState(state)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
|
||||
var ropeCos, ropeSin *mlx.Array
|
||||
if len(m.MRoPESections) > 0 {
|
||||
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
|
||||
}
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
|
||||
}
|
||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||
return out
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
cfg := &Config{TextConfig: TextConfig{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
}}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
Config: &Config{TextConfig: TextConfig{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
}},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
|
||||
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &Config{
|
||||
cfg := &Config{TextConfig: TextConfig{
|
||||
HiddenSize: 4,
|
||||
IntermediateSize: 8,
|
||||
NumHiddenLayers: 2,
|
||||
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
LinearValueHeadDim: 2,
|
||||
LinearConvKernelDim: 4,
|
||||
FullAttentionInterval: 2,
|
||||
}
|
||||
}}
|
||||
|
||||
m := &Model{
|
||||
Config: cfg,
|
||||
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigVisionFields(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000
|
||||
}
|
||||
},
|
||||
"vision_config": {
|
||||
"depth": 2,
|
||||
"hidden_size": 256,
|
||||
"num_heads": 8,
|
||||
"in_channels": 3,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"layer_norm_epsilon": 0.000001,
|
||||
"temporal_patch_size": 2,
|
||||
"num_position_embeddings": 2304
|
||||
},
|
||||
"image_token_id": 111,
|
||||
"vision_start_token_id": 112,
|
||||
"vision_end_token_id": 113
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Vision == nil {
|
||||
t.Fatal("vision config should be parsed")
|
||||
}
|
||||
if cfg.Vision.Depth != 2 {
|
||||
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
|
||||
}
|
||||
if cfg.Vision.GridPerSide != 48 {
|
||||
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
|
||||
}
|
||||
if cfg.Vision.RopeTheta != 10000 {
|
||||
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
|
||||
}
|
||||
if cfg.RopeTheta != 10000000 {
|
||||
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
|
||||
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 8192,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"linear_num_value_heads": 32,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000,
|
||||
"partial_rotary_factor": 0.25,
|
||||
"mrope_interleaved": true,
|
||||
"mrope_section": [11, 11, 10]
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.MRoPEInterleaved {
|
||||
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
|
||||
}
|
||||
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
|
||||
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigVisionTokenDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ImageTokenID != 151655 {
|
||||
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
|
||||
}
|
||||
if cfg.VisionStartToken != 151652 {
|
||||
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
|
||||
}
|
||||
if cfg.VisionEndToken != 151653 {
|
||||
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVisionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensors map[string]*mlx.Array
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "legacy visual prefix",
|
||||
tensors: map[string]*mlx.Array{
|
||||
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
|
||||
},
|
||||
want: "model.visual",
|
||||
},
|
||||
{
|
||||
name: "imported vision tower prefix",
|
||||
tensors: map[string]*mlx.Array{
|
||||
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
|
||||
},
|
||||
want: "vision_tower",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
|
||||
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
|
||||
v := &VisionConfig{}
|
||||
v.applyDefaults()
|
||||
|
||||
v.applyPreprocessorConfig([]byte(`{
|
||||
"patch_size": 16,
|
||||
"temporal_patch_size": 3,
|
||||
"merge_size": 4,
|
||||
"size": {
|
||||
"shortest_edge": 1024,
|
||||
"longest_edge": 8192
|
||||
},
|
||||
"image_mean": [0.1, 0.2, 0.3],
|
||||
"image_std": [0.9, 0.8, 0.7]
|
||||
}`))
|
||||
|
||||
if v.PatchSize != 16 {
|
||||
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
|
||||
}
|
||||
if v.TemporalPatchSize != 3 {
|
||||
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
|
||||
}
|
||||
if v.SpatialMergeSize != 4 {
|
||||
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
|
||||
}
|
||||
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
|
||||
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
|
||||
}
|
||||
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
|
||||
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
|
||||
v := &VisionConfig{}
|
||||
v.applyDefaults()
|
||||
|
||||
v.applyPreprocessorConfig([]byte(`{
|
||||
"size": {
|
||||
"shortest_edge": 65536,
|
||||
"longest_edge": 16777216
|
||||
},
|
||||
"patch_size": 16,
|
||||
"temporal_patch_size": 2,
|
||||
"merge_size": 2,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}`))
|
||||
|
||||
p := newVisionImageProcessor(v)
|
||||
if p == nil {
|
||||
t.Fatal("newVisionImageProcessor returned nil")
|
||||
}
|
||||
|
||||
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
|
||||
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
|
||||
}
|
||||
}
|
||||
|
||||
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
tok, err := tokenizer.LoadFromBytes([]byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0},
|
||||
"merges": []
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load test tokenizer: %v", err)
|
||||
}
|
||||
|
||||
return tok
|
||||
}
|
||||
|
||||
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
m := &Model{
|
||||
tok: testTokenizer(t),
|
||||
Config: &Config{
|
||||
ImageTokenID: 101,
|
||||
VisionStartToken: 102,
|
||||
VisionEndToken: 103,
|
||||
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||
},
|
||||
Vision: &VisionModel{},
|
||||
ImageProcessor: &VisionImageProcessor{},
|
||||
}
|
||||
|
||||
main := mlx.FromValues([]float32{
|
||||
10, 11,
|
||||
20, 21,
|
||||
}, 1, 2, 2)
|
||||
|
||||
resolveCalls := 0
|
||||
got, state, err := m.tokenizePromptWithResolvedImages(
|
||||
"a[img-7][img-7]a",
|
||||
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
|
||||
func(data []byte) (*VisionEmbeddings, error) {
|
||||
if string(data) != "img7" {
|
||||
return nil, fmt.Errorf("unexpected data: %q", string(data))
|
||||
}
|
||||
resolveCalls++
|
||||
return &VisionEmbeddings{
|
||||
Main: main,
|
||||
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
|
||||
}
|
||||
if resolveCalls != 1 {
|
||||
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
|
||||
}
|
||||
|
||||
want := []int32{
|
||||
0,
|
||||
102, 101, 101, 103,
|
||||
102, 101, 101, 103,
|
||||
0,
|
||||
}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
t.Fatal("expected prompt vision state")
|
||||
}
|
||||
if len(state.Spans) != 2 {
|
||||
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
|
||||
}
|
||||
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
|
||||
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
|
||||
}
|
||||
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
|
||||
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
|
||||
}
|
||||
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
|
||||
if !slices.Equal(state.PositionCache, wantPos) {
|
||||
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptMRoPEPositions(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||
},
|
||||
}
|
||||
state := &promptVisionState{
|
||||
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
|
||||
Spans: []promptVisionSpan{
|
||||
{
|
||||
Start: 2,
|
||||
End: 8,
|
||||
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pos := m.buildPromptMRoPEPositions(state, 0, 10)
|
||||
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("time positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("height positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("width positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
|
||||
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
|
||||
|
||||
if got := mapPromptPosition(state, 3); got != 2 {
|
||||
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
|
||||
}
|
||||
if got := mapPromptPosition(state, 5); got != 4 {
|
||||
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPromptVisionEmbeddings(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
m := &Model{}
|
||||
state := &promptVisionState{
|
||||
Spans: []promptVisionSpan{
|
||||
{
|
||||
Start: 1,
|
||||
End: 3,
|
||||
Main: mlx.FromValues([]float32{
|
||||
10, 11,
|
||||
20, 21,
|
||||
}, 1, 2, 2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := mlx.FromValues([]float32{
|
||||
0, 1,
|
||||
2, 3,
|
||||
4, 5,
|
||||
6, 7,
|
||||
}, 1, 4, 2)
|
||||
|
||||
got := m.applyPromptVisionEmbeddings(h, 0, state)
|
||||
mlx.Eval(got)
|
||||
|
||||
want := []float32{
|
||||
0, 1,
|
||||
10, 11,
|
||||
20, 21,
|
||||
6, 7,
|
||||
}
|
||||
if !slices.Equal(got.Floats(), want) {
|
||||
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
|
||||
}
|
||||
}
|
||||
|
||||
854
x/models/qwen3_5/vision.go
Normal file
854
x/models/qwen3_5/vision.go
Normal file
@@ -0,0 +1,854 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
var errNoVisionModel = errors.New("qwen3_5: no vision model")
|
||||
|
||||
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
|
||||
type VisionConfig struct {
|
||||
Depth int32 `json:"depth"`
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHeads int32 `json:"num_heads"`
|
||||
InChannels int32 `json:"in_channels"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
SpatialMergeSize int32 `json:"spatial_merge_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
|
||||
|
||||
Size struct {
|
||||
ShortestEdge int32 `json:"shortest_edge"`
|
||||
LongestEdge int32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
|
||||
GridPerSide int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (v *VisionConfig) applyDefaults() {
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
if v.HiddenSize <= 0 {
|
||||
v.HiddenSize = 1280
|
||||
}
|
||||
if v.NumHeads <= 0 {
|
||||
v.NumHeads = 16
|
||||
}
|
||||
if v.InChannels <= 0 {
|
||||
v.InChannels = 3
|
||||
}
|
||||
if v.PatchSize <= 0 {
|
||||
v.PatchSize = 14
|
||||
}
|
||||
if v.SpatialMergeSize <= 0 {
|
||||
v.SpatialMergeSize = 2
|
||||
}
|
||||
if v.LayerNormEpsilon == 0 {
|
||||
v.LayerNormEpsilon = 1e-6
|
||||
}
|
||||
if v.RopeTheta == 0 {
|
||||
v.RopeTheta = 10000
|
||||
}
|
||||
if v.TemporalPatchSize <= 0 {
|
||||
v.TemporalPatchSize = 2
|
||||
}
|
||||
if v.NumPositionEmbeddings <= 0 {
|
||||
v.NumPositionEmbeddings = 2304
|
||||
}
|
||||
if len(v.ImageMean) < 3 {
|
||||
v.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
}
|
||||
if len(v.ImageStd) < 3 {
|
||||
v.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
}
|
||||
if v.Size.ShortestEdge <= 0 {
|
||||
v.Size.ShortestEdge = 64 << 10
|
||||
}
|
||||
if v.Size.LongestEdge <= 0 {
|
||||
v.Size.LongestEdge = 2 << 20
|
||||
}
|
||||
|
||||
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
|
||||
if grid <= 0 {
|
||||
grid = 48
|
||||
}
|
||||
v.GridPerSide = grid
|
||||
}
|
||||
|
||||
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
|
||||
if v == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var pre struct {
|
||||
Size struct {
|
||||
ShortestEdge int32 `json:"shortest_edge"`
|
||||
LongestEdge int32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||
MergeSize int32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &pre); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if pre.PatchSize > 0 {
|
||||
v.PatchSize = pre.PatchSize
|
||||
}
|
||||
if pre.TemporalPatchSize > 0 {
|
||||
v.TemporalPatchSize = pre.TemporalPatchSize
|
||||
}
|
||||
if pre.MergeSize > 0 {
|
||||
v.SpatialMergeSize = pre.MergeSize
|
||||
}
|
||||
if pre.Size.ShortestEdge > 0 {
|
||||
v.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||
}
|
||||
if pre.Size.LongestEdge > 0 {
|
||||
v.Size.LongestEdge = pre.Size.LongestEdge
|
||||
}
|
||||
if len(pre.ImageMean) >= 3 {
|
||||
v.ImageMean = pre.ImageMean
|
||||
}
|
||||
if len(pre.ImageStd) >= 3 {
|
||||
v.ImageStd = pre.ImageStd
|
||||
}
|
||||
v.applyDefaults()
|
||||
}
|
||||
|
||||
// VisionGrid tracks patch-grid dimensions for an image.
|
||||
type VisionGrid struct {
|
||||
Height int32
|
||||
Width int32
|
||||
Temporal int32
|
||||
}
|
||||
|
||||
// VisionImageProcessor reproduces qwen3vl image preprocessing.
|
||||
type VisionImageProcessor struct {
|
||||
numChannels int32
|
||||
patchSize int32
|
||||
temporalPatchSize int32
|
||||
mergeSize int32
|
||||
shortestEdge int32
|
||||
longestEdge int32
|
||||
factor int32
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &VisionImageProcessor{
|
||||
numChannels: cfg.InChannels,
|
||||
patchSize: cfg.PatchSize,
|
||||
temporalPatchSize: cfg.TemporalPatchSize,
|
||||
mergeSize: cfg.SpatialMergeSize,
|
||||
shortestEdge: cfg.Size.ShortestEdge,
|
||||
longestEdge: cfg.Size.LongestEdge,
|
||||
factor: cfg.PatchSize * cfg.SpatialMergeSize,
|
||||
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
|
||||
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
|
||||
factor := int(p.factor)
|
||||
if factor <= 0 {
|
||||
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
|
||||
}
|
||||
|
||||
if height < factor || width < factor {
|
||||
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
|
||||
}
|
||||
if min(height, width) == 0 {
|
||||
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
|
||||
}
|
||||
if max(height, width)/min(height, width) > 200 {
|
||||
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
|
||||
}
|
||||
|
||||
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := roundEven(float64(height)/float64(factor)) * factor
|
||||
wBar := roundEven(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > int(p.longestEdge) {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < int(p.shortestEdge) {
|
||||
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar, nil
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
|
||||
if p == nil {
|
||||
return nil, nil, errNoVisionModel
|
||||
}
|
||||
|
||||
img = imageproc.Composite(img)
|
||||
origW := img.Bounds().Dx()
|
||||
origH := img.Bounds().Dy()
|
||||
|
||||
resizedH, resizedW, err := p.smartResize(origH, origW)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resized := imageproc.Resize(
|
||||
img,
|
||||
image.Point{X: resizedW, Y: resizedH},
|
||||
imageproc.ResizeBilinear,
|
||||
)
|
||||
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
|
||||
|
||||
grid := &VisionGrid{
|
||||
Height: int32(resizedH / int(p.patchSize)),
|
||||
Width: int32(resizedW / int(p.patchSize)),
|
||||
Temporal: 1,
|
||||
}
|
||||
|
||||
patches := p.createPatches(pixels, resizedH, resizedW, grid)
|
||||
|
||||
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
|
||||
numPatches := int(grid.Height * grid.Width)
|
||||
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
|
||||
channels := int(p.numChannels)
|
||||
patchSize := int(p.patchSize)
|
||||
mergeSize := int(p.mergeSize)
|
||||
temporalPatchSize := int(p.temporalPatchSize)
|
||||
|
||||
// Temporal is always 1 for static images; only spatial patches are created.
|
||||
numPatches := int(grid.Height * grid.Width)
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
|
||||
patchIndex := 0
|
||||
for h := 0; h < int(grid.Height); h += mergeSize {
|
||||
for w := 0; w < int(grid.Width); w += mergeSize {
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
srcIdx := c*height*width + y*width + x
|
||||
dstIdx := channelOffset + py*patchSize + px
|
||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if temporalPatchSize > 1 {
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||
frameSize := patchSize * patchSize
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
cur := channelOffset + tp*frameSize
|
||||
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// VisionAttention runs one self-attention block inside the vision encoder.
|
||||
type VisionAttention struct {
|
||||
QKV nn.LinearLayer
|
||||
Query nn.LinearLayer
|
||||
Key nn.LinearLayer
|
||||
Value nn.LinearLayer
|
||||
Output nn.LinearLayer
|
||||
}
|
||||
|
||||
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
|
||||
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
|
||||
}
|
||||
|
||||
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
|
||||
}
|
||||
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||
headDim := cfg.HiddenSize / cfg.NumHeads
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||
}
|
||||
|
||||
var q, k, v *mlx.Array
|
||||
if a.QKV != nil {
|
||||
qkv := a.QKV.Forward(x)
|
||||
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
|
||||
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
|
||||
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
|
||||
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
|
||||
} else {
|
||||
if a.Query == nil || a.Key == nil || a.Value == nil {
|
||||
return nil, errors.New("vision attention is missing q/k/v projections")
|
||||
}
|
||||
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
}
|
||||
|
||||
q = applyVisionRoPE(q, cos, sin)
|
||||
k = applyVisionRoPE(k, cos, sin)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
|
||||
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
|
||||
if a.Output == nil {
|
||||
return nil, errors.New("vision attention is missing output projection")
|
||||
}
|
||||
return a.Output.Forward(attn), nil
|
||||
}
|
||||
|
||||
// VisionMLP is the vision feed-forward block.
|
||||
type VisionMLP struct {
|
||||
FC1 nn.LinearLayer
|
||||
FC2 nn.LinearLayer
|
||||
}
|
||||
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
|
||||
if m.FC1 == nil || m.FC2 == nil {
|
||||
return nil, errors.New("vision mlp is missing fc1/fc2")
|
||||
}
|
||||
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is one transformer block in the vision encoder.
|
||||
type VisionEncoderLayer struct {
|
||||
Norm1 *nn.LayerNorm
|
||||
Attn *VisionAttention
|
||||
Norm2 *nn.LayerNorm
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
|
||||
return nil, errors.New("vision layer is incomplete")
|
||||
}
|
||||
|
||||
r := x
|
||||
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x = mlx.Add(r, a)
|
||||
|
||||
r = x
|
||||
m, err := l.MLP.Forward(l.Norm2.Forward(x))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mlx.Add(r, m), nil
|
||||
}
|
||||
|
||||
// VisionPatchMerger projects merged spatial groups into language embedding space.
|
||||
type VisionPatchMerger struct {
|
||||
Norm *nn.LayerNorm
|
||||
FC1 nn.LinearLayer
|
||||
FC2 nn.LinearLayer
|
||||
}
|
||||
|
||||
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
|
||||
}
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||
group := merge * merge
|
||||
if group <= 0 || L%group != 0 {
|
||||
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, L/group, group, D)
|
||||
x = mlx.Reshape(x, B, L/group, group*D)
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
|
||||
return nil, errors.New("vision patch merger is incomplete")
|
||||
}
|
||||
|
||||
x = m.Norm.Forward(x)
|
||||
|
||||
var err error
|
||||
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// VisionModel contains the full Qwen vision tower.
|
||||
type VisionModel struct {
|
||||
PatchProjection nn.LinearLayer
|
||||
PositionEmbed *nn.Embedding
|
||||
Layers []*VisionEncoderLayer
|
||||
PatchMerger *VisionPatchMerger
|
||||
|
||||
cfg *VisionConfig
|
||||
}
|
||||
|
||||
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
// Temporal is always 1 for static images; only spatial coordinates are generated.
|
||||
coords := make([][2]int32, 0, grid.Height*grid.Width)
|
||||
for h := int32(0); h < grid.Height; h += merge {
|
||||
for w := int32(0); w < grid.Width; w += merge {
|
||||
for mh := range merge {
|
||||
for mw := range merge {
|
||||
coords = append(coords, [2]int32{h + mh, w + mw})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return coords
|
||||
}
|
||||
|
||||
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||
if m.PositionEmbed == nil {
|
||||
return x, nil
|
||||
}
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
|
||||
}
|
||||
B, D := int32(shape[0]), int32(shape[2])
|
||||
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||
L := int32(len(coords))
|
||||
if L != int32(shape[1]) {
|
||||
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
|
||||
}
|
||||
|
||||
stepH := float32(0)
|
||||
if grid.Height > 1 {
|
||||
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
|
||||
}
|
||||
stepW := float32(0)
|
||||
if grid.Width > 1 {
|
||||
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
|
||||
}
|
||||
|
||||
indices := make([]int32, 0, L*4)
|
||||
weights := make([]float32, 0, L*4)
|
||||
for _, c := range coords {
|
||||
y := float32(c[0]) * stepH
|
||||
x0 := float32(c[1]) * stepW
|
||||
|
||||
fy := int32(y)
|
||||
fx := int32(x0)
|
||||
cy := min(fy+1, m.cfg.GridPerSide-1)
|
||||
cx := min(fx+1, m.cfg.GridPerSide-1)
|
||||
|
||||
indices = append(indices,
|
||||
fy*m.cfg.GridPerSide+fx,
|
||||
fy*m.cfg.GridPerSide+cx,
|
||||
cy*m.cfg.GridPerSide+fx,
|
||||
cy*m.cfg.GridPerSide+cx,
|
||||
)
|
||||
|
||||
dy := y - float32(fy)
|
||||
dx := x0 - float32(fx)
|
||||
weights = append(weights,
|
||||
(1-dy)*(1-dx),
|
||||
(1-dy)*dx,
|
||||
dy*(1-dx),
|
||||
dy*dx,
|
||||
)
|
||||
}
|
||||
|
||||
idxArr := mlx.FromValues(indices, int(L), 4)
|
||||
wArr := mlx.FromValues(weights, int(L), 4, 1)
|
||||
|
||||
pos := m.PositionEmbed.Forward(idxArr)
|
||||
wArr = wArr.AsType(pos.DType())
|
||||
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
|
||||
if D != int32(pos.Dim(1)) {
|
||||
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
|
||||
}
|
||||
|
||||
pos = mlx.ExpandDims(pos, 0)
|
||||
if B > 1 {
|
||||
pos = mlx.Tile(pos, []int32{B, 1, 1})
|
||||
}
|
||||
|
||||
return mlx.Add(x, pos), nil
|
||||
}
|
||||
|
||||
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
|
||||
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
|
||||
if headDim <= 0 {
|
||||
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||
}
|
||||
|
||||
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||
L := int32(len(coords))
|
||||
half := headDim / 2
|
||||
quarter := half / 2
|
||||
if quarter <= 0 {
|
||||
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
|
||||
}
|
||||
|
||||
angles := make([]float32, L*headDim)
|
||||
for i, c := range coords {
|
||||
base := int32(i) * headDim
|
||||
for j := range quarter {
|
||||
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
|
||||
angles[base+j] = float32(float64(c[0]) * freq)
|
||||
angles[base+quarter+j] = float32(float64(c[1]) * freq)
|
||||
}
|
||||
for j := range half {
|
||||
angles[base+half+j] = angles[base+j]
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.FromValues(angles, int(L), int(headDim))
|
||||
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
|
||||
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
|
||||
return cos, sin, nil
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||
if m == nil || pixelValues == nil || grid == nil {
|
||||
return nil, errNoVisionModel
|
||||
}
|
||||
if m.PatchProjection == nil || m.PatchMerger == nil {
|
||||
return nil, errors.New("vision model is missing required projections")
|
||||
}
|
||||
|
||||
x := m.PatchProjection.Forward(pixelValues)
|
||||
var err error
|
||||
x, err = m.addPositionEmbedding(x, grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cos, sin, err := m.rotaryEmbeddings(grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
x, err = layer.Forward(x, cos, sin, m.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vision layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
main, err := m.PatchMerger.Forward(x, m.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vision patch merger: %w", err)
|
||||
}
|
||||
return main, nil
|
||||
}
|
||||
|
||||
type VisionEmbeddings struct {
|
||||
Main *mlx.Array
|
||||
Grid *VisionGrid
|
||||
}
|
||||
|
||||
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
|
||||
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
|
||||
return nil, errNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
main, err := m.Vision.Forward(pixelValues, grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VisionEmbeddings{Main: main, Grid: grid}, nil
|
||||
}
|
||||
|
||||
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
|
||||
candidates := []string{
|
||||
"vision_tower",
|
||||
weightPrefix + "vision_tower",
|
||||
"model.visual",
|
||||
"visual",
|
||||
weightPrefix + "model.visual",
|
||||
weightPrefix + "visual",
|
||||
}
|
||||
|
||||
hasTensor := func(prefix string) bool {
|
||||
for _, suffix := range []string{
|
||||
".patch_embed.proj.weight",
|
||||
".patch_embed.weight",
|
||||
".pos_embed.weight",
|
||||
".blocks.0.attn.qkv.weight",
|
||||
".merger.linear_fc1.weight",
|
||||
".merger.mlp.0.weight",
|
||||
} {
|
||||
if tensors[prefix+suffix] != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range candidates {
|
||||
if hasTensor(prefix) {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
|
||||
for _, p := range paths {
|
||||
if l := linears.Make(p); l != nil {
|
||||
return l
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
|
||||
for _, base := range bases {
|
||||
if w := tensors[base+".weight"]; w != nil {
|
||||
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
|
||||
}
|
||||
if w := tensors[base]; w != nil {
|
||||
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadVisionPatchMerger(
|
||||
tensors map[string]*mlx.Array,
|
||||
linears mlxmodel.LinearFactory,
|
||||
eps float32,
|
||||
bases ...string,
|
||||
) *VisionPatchMerger {
|
||||
for _, base := range bases {
|
||||
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
|
||||
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
|
||||
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
|
||||
if norm != nil && fc1 != nil && fc2 != nil {
|
||||
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
|
||||
if w == nil || !w.Valid() {
|
||||
return nil, errors.New("missing patch embedding weight")
|
||||
}
|
||||
if w.NumDims() < 2 {
|
||||
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
|
||||
}
|
||||
if w.NumDims() == 2 {
|
||||
return w, nil
|
||||
}
|
||||
|
||||
out := int32(w.Dim(0))
|
||||
in := int32(w.Size() / w.Dim(0))
|
||||
return mlx.Reshape(w, out, in), nil
|
||||
}
|
||||
|
||||
func loadVisionComponents(
|
||||
tensors map[string]*mlx.Array,
|
||||
linears mlxmodel.LinearFactory,
|
||||
cfg *Config,
|
||||
weightPrefix string,
|
||||
) (*VisionModel, *VisionImageProcessor, error) {
|
||||
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
cfg.Vision.applyDefaults()
|
||||
|
||||
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
|
||||
if visionPrefix == "" {
|
||||
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
|
||||
}
|
||||
|
||||
patchW, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".patch_embed.proj.weight",
|
||||
visionPrefix+".patch_embed.weight",
|
||||
)
|
||||
if patchW == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
|
||||
}
|
||||
patchW, err := flattenPatchEmbeddingWeight(patchW)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
patchB, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".patch_embed.proj.bias",
|
||||
visionPrefix+".patch_embed.bias",
|
||||
)
|
||||
|
||||
patchProj := nn.NewLinear(patchW, patchB)
|
||||
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
|
||||
return nil, nil, fmt.Errorf(
|
||||
"vision patch embedding input dim mismatch: got %d expected %d",
|
||||
got,
|
||||
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
|
||||
)
|
||||
}
|
||||
|
||||
posW, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".pos_embed.weight",
|
||||
visionPrefix+".position_embedding.weight",
|
||||
)
|
||||
if posW == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
|
||||
}
|
||||
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
|
||||
cfg.Vision.applyDefaults()
|
||||
|
||||
vm := &VisionModel{
|
||||
PatchProjection: patchProj,
|
||||
PositionEmbed: nn.NewEmbedding(posW),
|
||||
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
|
||||
cfg: cfg.Vision,
|
||||
}
|
||||
|
||||
for i := range cfg.Vision.Depth {
|
||||
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
|
||||
layer := &VisionEncoderLayer{
|
||||
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
|
||||
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
|
||||
Attn: &VisionAttention{
|
||||
QKV: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.qkv",
|
||||
layerPrefix+".attn_qkv",
|
||||
),
|
||||
Query: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.q_proj",
|
||||
layerPrefix+".attn_q",
|
||||
),
|
||||
Key: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.k_proj",
|
||||
layerPrefix+".attn_k",
|
||||
),
|
||||
Value: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.v_proj",
|
||||
layerPrefix+".attn_v",
|
||||
),
|
||||
Output: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.proj",
|
||||
layerPrefix+".attn_out",
|
||||
layerPrefix+".attn.o_proj",
|
||||
),
|
||||
},
|
||||
MLP: &VisionMLP{
|
||||
FC1: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".mlp.fc1",
|
||||
layerPrefix+".mlp.linear_fc1",
|
||||
),
|
||||
FC2: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".mlp.fc2",
|
||||
layerPrefix+".mlp.linear_fc2",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
if layer.Norm1 == nil || layer.Norm2 == nil {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
|
||||
}
|
||||
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
vm.Layers[i] = layer
|
||||
}
|
||||
|
||||
vm.PatchMerger = loadVisionPatchMerger(
|
||||
tensors,
|
||||
linears,
|
||||
cfg.Vision.LayerNormEpsilon,
|
||||
visionPrefix+".merger",
|
||||
)
|
||||
if vm.PatchMerger == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
|
||||
}
|
||||
|
||||
return vm, newVisionImageProcessor(cfg.Vision), nil
|
||||
}
|
||||
Reference in New Issue
Block a user