mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 11:54:36 +02:00
Compare commits
5 Commits
pdevine/ml
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
578c32e42e | ||
|
|
a10d2625ca | ||
|
|
b960d769ad | ||
|
|
455a6099d1 | ||
|
|
7e6e8377eb |
@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||||||
spinnerKey = "create"
|
spinnerKey = "create"
|
||||||
capabilities = []string{"completion"}
|
capabilities = []string{"completion"}
|
||||||
|
|
||||||
// Check if model supports thinking based on architecture
|
configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
|
||||||
if supportsThinking(opts.ModelDir) {
|
mcfg := parseModelConfig(configData)
|
||||||
|
|
||||||
|
if mcfg.supportsThinking() {
|
||||||
capabilities = append(capabilities, "thinking")
|
capabilities = append(capabilities, "thinking")
|
||||||
}
|
}
|
||||||
|
if mcfg.supportsVision() {
|
||||||
|
capabilities = append(capabilities, "vision")
|
||||||
|
}
|
||||||
|
|
||||||
// Set parser and renderer name based on architecture
|
parserName = mcfg.parserName()
|
||||||
parserName = getParserName(opts.ModelDir)
|
rendererName = mcfg.rendererName()
|
||||||
rendererName = getRendererName(opts.ModelDir)
|
|
||||||
} else {
|
} else {
|
||||||
modelType = "image generation model"
|
modelType = "image generation model"
|
||||||
spinnerKey = "imagegen"
|
spinnerKey = "imagegen"
|
||||||
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
// modelConfig holds the fields from config.json needed during model creation.
|
||||||
// This reads the config.json from the model directory and checks the architectures field.
|
type visionConfig struct {
|
||||||
func supportsThinking(modelDir string) bool {
|
Depth int32 `json:"depth"`
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
type modelConfig struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"`
|
||||||
}
|
VisionConfig *visionConfig `json:"vision_config"`
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
ImageTokenID *int32 `json:"image_token_id"`
|
||||||
return false
|
VisionStartTokenID *int32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndTokenID *int32 `json:"vision_end_token_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check architectures that support thinking
|
func parseModelConfig(data []byte) modelConfig {
|
||||||
thinkingArchitectures := []string{
|
var cfg modelConfig
|
||||||
"glm4moe", // GLM-4 MoE models
|
_ = json.Unmarshal(data, &cfg)
|
||||||
"deepseek", // DeepSeek models
|
return cfg
|
||||||
"qwen3", // Qwen3 models
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the architecture list
|
// archOrTypeContains returns true if any architecture or the model_type
|
||||||
for _, arch := range cfg.Architectures {
|
// contains one of the given substrings (case-insensitive).
|
||||||
|
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
|
||||||
|
for _, arch := range c.Architectures {
|
||||||
archLower := strings.ToLower(arch)
|
archLower := strings.ToLower(arch)
|
||||||
for _, thinkArch := range thinkingArchitectures {
|
for _, s := range substrs {
|
||||||
if strings.Contains(archLower, thinkArch) {
|
if strings.Contains(archLower, s) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.ModelType != "" {
|
||||||
// Also check model_type
|
typeLower := strings.ToLower(c.ModelType)
|
||||||
if cfg.ModelType != "" {
|
for _, s := range substrs {
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
if strings.Contains(typeLower, s) {
|
||||||
for _, thinkArch := range thinkingArchitectures {
|
|
||||||
if strings.Contains(typeLower, thinkArch) {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// getParserName returns the parser name for a model based on its architecture.
|
func (c *modelConfig) supportsThinking() bool {
|
||||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
|
||||||
func getParserName(modelDir string) string {
|
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
func (c *modelConfig) supportsVision() bool {
|
||||||
Architectures []string `json:"architectures"`
|
return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check architectures for known parsers
|
func (c *modelConfig) parserName() string {
|
||||||
for _, arch := range cfg.Architectures {
|
switch {
|
||||||
archLower := strings.ToLower(arch)
|
case c.archOrTypeContains("glm4", "glm-4"):
|
||||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
case c.archOrTypeContains("deepseek"):
|
||||||
if strings.Contains(archLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
case c.archOrTypeContains("qwen3"):
|
||||||
if strings.Contains(archLower, "qwen3") {
|
|
||||||
return "qwen3"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check model_type
|
func (c *modelConfig) rendererName() string {
|
||||||
if cfg.ModelType != "" {
|
switch {
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
case c.archOrTypeContains("glm4", "glm-4"):
|
||||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
case c.archOrTypeContains("deepseek"):
|
||||||
if strings.Contains(typeLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
case c.archOrTypeContains("qwen3"):
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
|
||||||
return "qwen3"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRendererName returns the renderer name for a model based on its architecture.
|
|
||||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
|
||||||
func getRendererName(modelDir string) string {
|
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg struct {
|
|
||||||
Architectures []string `json:"architectures"`
|
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check architectures for known renderers
|
|
||||||
for _, arch := range cfg.Architectures {
|
|
||||||
archLower := strings.ToLower(arch)
|
|
||||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
|
||||||
}
|
|
||||||
if strings.Contains(archLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
|
||||||
}
|
|
||||||
if strings.Contains(archLower, "qwen3") {
|
|
||||||
return "qwen3-coder"
|
return "qwen3-coder"
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Also check model_type
|
|
||||||
if cfg.ModelType != "" {
|
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
|
||||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
|
||||||
}
|
|
||||||
if strings.Contains(typeLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
|
||||||
}
|
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
|
||||||
return "qwen3-coder"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
|||||||
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSupportsVision(t *testing.T) {
|
||||||
|
t.Run("vision_config present", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"vision_config": {"depth": 2},
|
||||||
|
"image_token_id": 151655
|
||||||
|
}`))
|
||||||
|
if !cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = false, want true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("token ids alone imply vision", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"vision_start_token_id": 10,
|
||||||
|
"vision_end_token_id": 11
|
||||||
|
}`))
|
||||||
|
if !cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = false, want true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plain text model", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"architectures": ["Qwen3_5ForCausalLM"]
|
||||||
|
}`))
|
||||||
|
if cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = true, want false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -366,6 +366,23 @@ func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
|||||||
c.enforceEvictionPolicy()
|
c.enforceEvictionPolicy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clear releases live caches and drops the trie so future requests cannot
|
||||||
|
// reuse prompt state keyed only by token IDs.
|
||||||
|
func (c *kvCache) clear() {
|
||||||
|
c.freeAll()
|
||||||
|
walkNodes(c.root, func(n *trieNode) bool {
|
||||||
|
for _, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.snapshots = nil
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
c.root = nil
|
||||||
|
c.activePath = nil
|
||||||
|
}
|
||||||
|
|
||||||
// freeAll releases all cache layers.
|
// freeAll releases all cache layers.
|
||||||
func (c *kvCache) freeAll() {
|
func (c *kvCache) freeAll() {
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
|||||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||||
type completionRequest struct {
|
type completionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Images []llm.ImageData `json:"images,omitempty"`
|
||||||
Options *completionOpts `json:"options,omitempty"`
|
Options *completionOpts `json:"options,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
|
|||||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
creq := completionRequest{
|
creq := completionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
|
Images: req.Images,
|
||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = &completionOpts{
|
||||||
|
|||||||
@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Sin(a *Array) *Array {
|
||||||
|
out := New("SIN")
|
||||||
|
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Cos(a *Array) *Array {
|
||||||
|
out := New("COS")
|
||||||
|
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func Log(a *Array) *Array {
|
func Log(a *Array) *Array {
|
||||||
out := New("LOG")
|
out := New("LOG")
|
||||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
|||||||
32
x/mlxrunner/model/base/multimodal.go
Normal file
32
x/mlxrunner/model/base/multimodal.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package base
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImageInput is a single image attached to a prompt.
|
||||||
|
type ImageInput struct {
|
||||||
|
ID int
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
|
||||||
|
// model metadata needed during forward.
|
||||||
|
type PromptTokenization struct {
|
||||||
|
Tokens []int32
|
||||||
|
State any
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultimodalPromptTokenizerWithState is an optional model interface used by
|
||||||
|
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
|
||||||
|
// request-scoped state to be attached to the forward pass.
|
||||||
|
type MultimodalPromptTokenizerWithState interface {
|
||||||
|
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardWithStateModel is an optional model interface for request-scoped
|
||||||
|
// forward metadata that should not be stored in shared caches.
|
||||||
|
type ForwardWithStateModel interface {
|
||||||
|
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
|
||||||
|
}
|
||||||
@@ -12,12 +12,42 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
func prefillChunkSize() int {
|
func prefillChunkSize() int {
|
||||||
return 2 << 10
|
return 2 << 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
|
||||||
|
if len(request.Images) > 0 {
|
||||||
|
// The shared trie cache keys only on token IDs today, so multimodal
|
||||||
|
// prompts must not reuse snapshots across distinct image inputs.
|
||||||
|
r.cache.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
|
||||||
|
images := make([]base.ImageInput, len(request.Images))
|
||||||
|
for i := range request.Images {
|
||||||
|
images[i] = base.ImageInput{
|
||||||
|
ID: request.Images[i].ID,
|
||||||
|
Data: request.Images[i].Data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
return nil, nil, errors.New("empty multimodal tokenization result")
|
||||||
|
}
|
||||||
|
return out.Tokens, out.State, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||||
if r.Model == nil {
|
if r.Model == nil {
|
||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
@@ -55,7 +85,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs, promptState, err := r.tokenizeRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if len(inputs) == 0 {
|
if len(inputs) == 0 {
|
||||||
return errors.New("empty prompt")
|
return errors.New("empty prompt")
|
||||||
}
|
}
|
||||||
@@ -83,6 +116,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
prefillChunk := prefillChunkSize()
|
prefillChunk := prefillChunkSize()
|
||||||
|
|
||||||
|
modelForward := func(tokens *mlx.Array) *mlx.Array {
|
||||||
|
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
|
||||||
|
return withState.ForwardWithState(tokens, caches, promptState)
|
||||||
|
}
|
||||||
|
return r.Model.Forward(tokens, caches)
|
||||||
|
}
|
||||||
|
|
||||||
materializeCaches := func() {
|
materializeCaches := func() {
|
||||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||||
for _, c := range caches {
|
for _, c := range caches {
|
||||||
@@ -114,7 +154,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
materializeCaches()
|
materializeCaches()
|
||||||
processed += n
|
processed += n
|
||||||
@@ -132,7 +172,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
fwd := modelForward(token.ExpandDims(0))
|
||||||
logits := r.Model.Unembed(fwd)
|
logits := r.Model.Unembed(fwd)
|
||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -30,6 +31,7 @@ type Request struct {
|
|||||||
|
|
||||||
type TextCompletionsRequest struct {
|
type TextCompletionsRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Images []llm.ImageData `json:"images,omitempty"`
|
||||||
Options struct {
|
Options struct {
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
|
|||||||
354
x/models/qwen3_5/multimodal.go
Normal file
354
x/models/qwen3_5/multimodal.go
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
|
||||||
|
|
||||||
|
type promptVisionSpan struct {
|
||||||
|
Start int32
|
||||||
|
End int32
|
||||||
|
|
||||||
|
Main *mlx.Array
|
||||||
|
Grid *VisionGrid
|
||||||
|
}
|
||||||
|
|
||||||
|
type promptVisionState struct {
|
||||||
|
Spans []promptVisionSpan
|
||||||
|
PositionCache []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptStartPosFromCaches(caches []cache.Cache) int32 {
|
||||||
|
offset := -1
|
||||||
|
for _, c := range caches {
|
||||||
|
if c == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
off := c.Offset()
|
||||||
|
if offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if offset < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int32(offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptVisionStateFromState(state any) *promptVisionState {
|
||||||
|
typed, _ := state.(*promptVisionState)
|
||||||
|
return typed
|
||||||
|
}
|
||||||
|
|
||||||
|
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
|
||||||
|
chunkEnd := chunkStart + chunkLen
|
||||||
|
overlapStart := max(chunkStart, spanStart)
|
||||||
|
overlapEnd := min(chunkEnd, spanEnd)
|
||||||
|
if overlapStart >= overlapEnd {
|
||||||
|
return 0, 0, 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkLo := overlapStart - chunkStart
|
||||||
|
chunkHi := overlapEnd - chunkStart
|
||||||
|
spanLo := overlapStart - spanStart
|
||||||
|
spanHi := overlapEnd - spanStart
|
||||||
|
return chunkLo, chunkHi, spanLo, spanHi, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
|
||||||
|
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
dims := h.Dims()
|
||||||
|
if len(dims) != 3 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
L := int32(dims[1])
|
||||||
|
for _, span := range state.Spans {
|
||||||
|
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
|
||||||
|
if !ok || span.Main == nil || !span.Main.Valid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
repl := span.Main.Slice(
|
||||||
|
mlx.Slice(),
|
||||||
|
mlx.Slice(int(spanLo), int(spanHi)),
|
||||||
|
mlx.Slice(),
|
||||||
|
)
|
||||||
|
repl = repl.AsType(h.DType())
|
||||||
|
h = h.SliceUpdate(
|
||||||
|
repl,
|
||||||
|
mlx.Slice(),
|
||||||
|
mlx.Slice(int(chunkLo), int(chunkHi)),
|
||||||
|
mlx.Slice(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
|
||||||
|
for i := range images {
|
||||||
|
if images[i].ID == id {
|
||||||
|
return images[i], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return base.ImageInput{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromptPosition(state *promptVisionState, id int32) int32 {
|
||||||
|
if state == nil {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if id < int32(len(state.PositionCache)) {
|
||||||
|
return state.PositionCache[id]
|
||||||
|
}
|
||||||
|
if len(state.PositionCache) > 0 {
|
||||||
|
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
|
||||||
|
if fallback <= 0 {
|
||||||
|
fallback = 1
|
||||||
|
}
|
||||||
|
if grid == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMRoPESections(sections []int32) [4]int32 {
|
||||||
|
var out [4]int32
|
||||||
|
for i := range min(4, len(sections)) {
|
||||||
|
if sections[i] > 0 {
|
||||||
|
out[i] = sections[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
|
||||||
|
if interleaved {
|
||||||
|
if pair%3 == 1 && pair < 1+3*sections[1] {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if pair%3 == 2 && pair < 2+3*sections[2] {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
if pair%3 == 0 && pair < 3*sections[0] {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
|
||||||
|
secW := sections[0] + sections[1]
|
||||||
|
secE := secW + sections[2]
|
||||||
|
switch {
|
||||||
|
case pair < sections[0]:
|
||||||
|
return 0
|
||||||
|
case pair < secW:
|
||||||
|
return 1
|
||||||
|
case pair < secE:
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
|
||||||
|
var positions [4][]int32
|
||||||
|
for i := range positions {
|
||||||
|
positions[i] = make([]int32, chunkLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
|
||||||
|
for i := range chunkLen {
|
||||||
|
p := mapPromptPosition(state, startPos+i)
|
||||||
|
positions[0][i] = p
|
||||||
|
positions[1][i] = p
|
||||||
|
positions[2][i] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
merge := int32(1)
|
||||||
|
if m != nil && m.Config != nil && m.Config.Vision != nil {
|
||||||
|
merge = m.Config.Vision.SpatialMergeSize
|
||||||
|
}
|
||||||
|
for _, span := range state.Spans {
|
||||||
|
if span.Grid == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
w := max(int32(1), span.Grid.Width/merge)
|
||||||
|
for i := chunkLo; i < chunkHi; i++ {
|
||||||
|
rel := spanLo + (i - chunkLo)
|
||||||
|
positions[1][i] += rel / w
|
||||||
|
positions[2][i] += rel % w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return positions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
|
||||||
|
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ropeDim := m.Config.RopeDim
|
||||||
|
if ropeDim%2 != 0 {
|
||||||
|
ropeDim--
|
||||||
|
}
|
||||||
|
if ropeDim <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
half := ropeDim / 2
|
||||||
|
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
|
||||||
|
sections := normalizeMRoPESections(m.Config.MRoPESections)
|
||||||
|
theta := m.Config.RopeTheta
|
||||||
|
if theta <= 0 {
|
||||||
|
theta = 100000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
freqs := make([]float64, half)
|
||||||
|
for j := range half {
|
||||||
|
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
|
||||||
|
}
|
||||||
|
|
||||||
|
angles := make([]float32, chunkLen*ropeDim)
|
||||||
|
for i := range chunkLen {
|
||||||
|
base := i * ropeDim
|
||||||
|
for j := range half {
|
||||||
|
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
|
||||||
|
angle := float32(float64(positions[component][i]) * freqs[j])
|
||||||
|
angles[base+j] = angle
|
||||||
|
angles[base+half+j] = angle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
|
||||||
|
cos := mlx.Cos(arr)
|
||||||
|
sin := mlx.Sin(arr)
|
||||||
|
if dtype != 0 {
|
||||||
|
cos = cos.AsType(dtype)
|
||||||
|
sin = sin.AsType(dtype)
|
||||||
|
}
|
||||||
|
return cos, sin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) tokenizePromptWithResolvedImages(
|
||||||
|
prompt string,
|
||||||
|
images []base.ImageInput,
|
||||||
|
resolve func([]byte) (*VisionEmbeddings, error),
|
||||||
|
) ([]int32, *promptVisionState, error) {
|
||||||
|
if m == nil || m.tok == nil {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
|
||||||
|
return m.tok.Encode(prompt, true), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := imageTagRE.Split(prompt, -1)
|
||||||
|
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
|
||||||
|
|
||||||
|
resolved := make(map[int]*VisionEmbeddings, len(images))
|
||||||
|
var out []int32
|
||||||
|
state := &promptVisionState{}
|
||||||
|
var p int32
|
||||||
|
appendToken := func(tok, pos int32) {
|
||||||
|
out = append(out, tok)
|
||||||
|
state.PositionCache = append(state.PositionCache, pos)
|
||||||
|
}
|
||||||
|
for i, part := range parts {
|
||||||
|
for _, tok := range m.tok.Encode(part, i == 0) {
|
||||||
|
appendToken(tok, p)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i >= len(matches) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
imageID, err := strconv.Atoi(matches[i][1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, ok := findImageByID(images, imageID)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeds := resolved[imageID]
|
||||||
|
if embeds == nil {
|
||||||
|
embeds, err = resolve(img.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
resolved[imageID] = embeds
|
||||||
|
}
|
||||||
|
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensPerImage := int32(embeds.Main.Dim(1))
|
||||||
|
if tokensPerImage <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToken(m.VisionStartToken, p)
|
||||||
|
p++
|
||||||
|
basePos := p
|
||||||
|
spanStart := int32(len(out))
|
||||||
|
for range tokensPerImage {
|
||||||
|
appendToken(m.ImageTokenID, basePos)
|
||||||
|
}
|
||||||
|
spanEnd := int32(len(out))
|
||||||
|
merge := int32(1)
|
||||||
|
if m.Config != nil && m.Config.Vision != nil {
|
||||||
|
merge = m.Config.Vision.SpatialMergeSize
|
||||||
|
}
|
||||||
|
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
|
||||||
|
p += gridSpan
|
||||||
|
appendToken(m.VisionEndToken, p)
|
||||||
|
p++
|
||||||
|
|
||||||
|
state.Spans = append(state.Spans, promptVisionSpan{
|
||||||
|
Start: spanStart,
|
||||||
|
End: spanEnd,
|
||||||
|
Main: embeds.Main,
|
||||||
|
Grid: embeds.Grid,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
|
||||||
|
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
package qwen3_5
|
package qwen3_5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
@@ -22,16 +23,26 @@ func init() {
|
|||||||
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
|
||||||
|
_ base.ForwardWithStateModel = (*Model)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
||||||
type RopeParameters struct {
|
type RopeParameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
RopeType string `json:"rope_type"`
|
RopeType string `json:"rope_type"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
MRoPESection []int32 `json:"mrope_section"`
|
||||||
|
DimensionSections []int32 `json:"dimension_sections"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config holds Qwen 3.5 text config (top-level or nested text_config).
|
// TextConfig holds the Qwen 3.5 text-model architecture fields.
|
||||||
type Config struct {
|
// In VLM configs these live under the "text_config" key; in text-only
|
||||||
|
// configs they appear at the top level.
|
||||||
|
type TextConfig struct {
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"`
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
@@ -67,6 +78,19 @@ type Config struct {
|
|||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
RopeScaling map[string]any `json:"rope_scaling"`
|
RopeScaling map[string]any `json:"rope_scaling"`
|
||||||
RopeParameters *RopeParameters `json:"rope_parameters"`
|
RopeParameters *RopeParameters `json:"rope_parameters"`
|
||||||
|
MRoPESections []int32 `json:"mrope_sections"`
|
||||||
|
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the full model config. It embeds TextConfig for the text-model
|
||||||
|
// fields and adds top-level-only fields (vision, token IDs, quantization).
|
||||||
|
type Config struct {
|
||||||
|
TextConfig
|
||||||
|
|
||||||
|
Vision *VisionConfig `json:"vision_config"`
|
||||||
|
ImageTokenID int32 `json:"image_token_id"`
|
||||||
|
VisionStartToken int32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndToken int32 `json:"vision_end_token_id"`
|
||||||
|
|
||||||
// Quantization metadata.
|
// Quantization metadata.
|
||||||
QuantGroupSize int `json:"-"`
|
QuantGroupSize int `json:"-"`
|
||||||
@@ -90,6 +114,9 @@ type Model struct {
|
|||||||
*Config
|
*Config
|
||||||
|
|
||||||
weightPrefix string
|
weightPrefix string
|
||||||
|
|
||||||
|
Vision *VisionModel
|
||||||
|
ImageProcessor *VisionImageProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Layer is a transformer decoder layer.
|
// Layer is a transformer decoder layer.
|
||||||
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
|
|
||||||
var cfg Config
|
var cfg Config
|
||||||
activeRaw := rawTop
|
activeRaw := rawTop
|
||||||
|
|
||||||
|
// First pass: unmarshal the full config to pick up top-level fields
|
||||||
|
// (vision_config, image_token_id, etc.) and text fields for text-only models.
|
||||||
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: if text_config exists, unmarshal it into TextConfig so
|
||||||
|
// text-model fields from text_config take priority over any top-level
|
||||||
|
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
|
||||||
|
// because they live on Config, not TextConfig.
|
||||||
if textRaw, ok := rawTop["text_config"]; ok {
|
if textRaw, ok := rawTop["text_config"]; ok {
|
||||||
if err := json.Unmarshal(textRaw, &cfg); err != nil {
|
if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
|
||||||
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
||||||
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.HiddenSize <= 0 {
|
if cfg.HiddenSize <= 0 {
|
||||||
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.RMSNormEps == 0 {
|
cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
|
||||||
cfg.RMSNormEps = 1e-6
|
cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
|
||||||
}
|
|
||||||
if cfg.LinearConvKernelDim <= 0 {
|
|
||||||
cfg.LinearConvKernelDim = 4
|
|
||||||
}
|
|
||||||
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
||||||
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
||||||
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
||||||
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
||||||
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
||||||
}
|
}
|
||||||
|
if len(cfg.MRoPESections) == 0 {
|
||||||
|
switch {
|
||||||
|
case len(cfg.RopeParameters.MRoPESection) > 0:
|
||||||
|
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
|
||||||
|
case len(cfg.RopeParameters.DimensionSections) > 0:
|
||||||
|
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
|
||||||
}
|
}
|
||||||
if cfg.RopeTheta == 0 {
|
|
||||||
cfg.RopeTheta = 100000.0
|
|
||||||
}
|
}
|
||||||
if cfg.PartialRotaryFactor == 0 {
|
cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
|
||||||
cfg.PartialRotaryFactor = 0.25
|
|
||||||
}
|
}
|
||||||
if cfg.PartialRotaryFactor < 0 {
|
if len(cfg.MRoPESections) > 4 {
|
||||||
|
cfg.MRoPESections = cfg.MRoPESections[:4]
|
||||||
|
}
|
||||||
|
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
|
||||||
|
if cfg.PartialRotaryFactor <= 0 {
|
||||||
cfg.PartialRotaryFactor = 0.25
|
cfg.PartialRotaryFactor = 0.25
|
||||||
}
|
}
|
||||||
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
||||||
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cfg.NumExperts > 0 {
|
if cfg.NumExperts > 0 {
|
||||||
if cfg.NumExpertsPerTok <= 0 {
|
cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
|
||||||
cfg.NumExpertsPerTok = 1
|
cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
|
||||||
}
|
cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
|
||||||
if cfg.MoeIntermediateSize <= 0 {
|
|
||||||
cfg.MoeIntermediateSize = cfg.IntermediateSize
|
|
||||||
}
|
|
||||||
if cfg.SharedExpertIntermediateSize <= 0 {
|
|
||||||
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
|
|
||||||
}
|
|
||||||
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
||||||
cfg.NormTopKProb = true
|
cfg.NormTopKProb = true
|
||||||
}
|
}
|
||||||
if cfg.DecoderSparseStep <= 0 {
|
cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
|
||||||
cfg.DecoderSparseStep = 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
|
||||||
|
if cfg.Vision != nil {
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
}
|
||||||
|
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
|
||||||
|
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
|
||||||
|
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if cfg.Vision != nil {
|
||||||
|
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
|
||||||
|
cfg.Vision.applyPreprocessorConfig(preprocessorData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if qt := root.QuantType(); qt != "" {
|
if qt := root.QuantType(); qt != "" {
|
||||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
m.Layers[i] = layer
|
m.Layers[i] = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
|
||||||
|
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Vision = vision
|
||||||
|
m.ImageProcessor = processor
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
|||||||
return q, k, v, z, b, a
|
return q, k, v, z, b, a
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
func rotateHalf(x *mlx.Array) *mlx.Array {
|
||||||
|
shape := x.Dims()
|
||||||
|
last := int32(shape[len(shape)-1])
|
||||||
|
half := last / 2
|
||||||
|
if half <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
|
||||||
|
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||||
|
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
|
||||||
|
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 4 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
last := int32(shape[len(shape)-1])
|
||||||
|
if ropeDim > last {
|
||||||
|
ropeDim = last
|
||||||
|
}
|
||||||
|
if ropeDim%2 != 0 {
|
||||||
|
ropeDim--
|
||||||
|
}
|
||||||
|
if ropeDim <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
|
||||||
|
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
|
||||||
|
if ropeDim == last {
|
||||||
|
return rot
|
||||||
|
}
|
||||||
|
|
||||||
|
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||||
|
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||||
qg := a.QProj.Forward(x)
|
qg := a.QProj.Forward(x)
|
||||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||||
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
|||||||
if c != nil {
|
if c != nil {
|
||||||
offset = c.Offset()
|
offset = c.Offset()
|
||||||
}
|
}
|
||||||
|
if ropeCos != nil && ropeSin != nil {
|
||||||
|
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
|
||||||
|
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
|
||||||
|
} else {
|
||||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
}
|
||||||
|
|
||||||
if c != nil {
|
if c != nil {
|
||||||
k, v = c.Update(k, v)
|
k, v = c.Update(k, v)
|
||||||
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|||||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||||
var r *mlx.Array
|
var r *mlx.Array
|
||||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||||
if l.IsLinear {
|
if l.IsLinear {
|
||||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||||
} else {
|
} else {
|
||||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
|
||||||
}
|
}
|
||||||
h := mlx.Add(x, r)
|
h := mlx.Add(x, r)
|
||||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||||
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
return m.ForwardWithState(tokens, caches, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
|
||||||
dims := tokens.Dims()
|
dims := tokens.Dims()
|
||||||
B, L := int32(dims[0]), int32(dims[1])
|
B, L := int32(dims[0]), int32(dims[1])
|
||||||
|
|
||||||
|
startPos := promptStartPosFromCaches(caches)
|
||||||
|
promptState := promptVisionStateFromState(state)
|
||||||
h := m.EmbedTokens.Forward(tokens)
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
|
||||||
|
var ropeCos, ropeSin *mlx.Array
|
||||||
|
if len(m.MRoPESections) > 0 {
|
||||||
|
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
|
||||||
|
}
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
var c cache.Cache
|
var c cache.Cache
|
||||||
if caches != nil && i < len(caches) {
|
if caches != nil && i < len(caches) {
|
||||||
c = caches[i]
|
c = caches[i]
|
||||||
}
|
}
|
||||||
h = layer.Forward(h, c, B, L, m.Config)
|
h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
|
||||||
}
|
}
|
||||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package qwen3_5
|
package qwen3_5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func skipIfNoMLX(t *testing.T) {
|
func skipIfNoMLX(t *testing.T) {
|
||||||
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLayerSelectionHelpers(t *testing.T) {
|
func TestLayerSelectionHelpers(t *testing.T) {
|
||||||
cfg := &Config{
|
cfg := &Config{TextConfig: TextConfig{
|
||||||
NumHiddenLayers: 6,
|
NumHiddenLayers: 6,
|
||||||
FullAttentionInterval: 3,
|
FullAttentionInterval: 3,
|
||||||
NumExperts: 8,
|
NumExperts: 8,
|
||||||
DecoderSparseStep: 2,
|
DecoderSparseStep: 2,
|
||||||
MLPOnlyLayers: []int32{1},
|
MLPOnlyLayers: []int32{1},
|
||||||
}
|
}}
|
||||||
|
|
||||||
if !layerIsLinear(cfg, 0) {
|
if !layerIsLinear(cfg, 0) {
|
||||||
t.Fatalf("layer 0 should be linear")
|
t.Fatalf("layer 0 should be linear")
|
||||||
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewCachesLayout(t *testing.T) {
|
func TestNewCachesLayout(t *testing.T) {
|
||||||
m := &Model{
|
m := &Model{
|
||||||
Config: &Config{
|
Config: &Config{TextConfig: TextConfig{
|
||||||
LinearConvKernelDim: 4,
|
LinearConvKernelDim: 4,
|
||||||
LinearNumKeyHeads: 2,
|
LinearNumKeyHeads: 2,
|
||||||
LinearKeyHeadDim: 8,
|
LinearKeyHeadDim: 8,
|
||||||
LinearNumValueHeads: 4,
|
LinearNumValueHeads: 4,
|
||||||
LinearValueHeadDim: 16,
|
LinearValueHeadDim: 16,
|
||||||
},
|
}},
|
||||||
Layers: []*Layer{
|
Layers: []*Layer{
|
||||||
{IsLinear: true},
|
{IsLinear: true},
|
||||||
{IsLinear: false},
|
{IsLinear: false},
|
||||||
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
|
|||||||
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||||
skipIfNoMLX(t)
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
cfg := &Config{
|
cfg := &Config{TextConfig: TextConfig{
|
||||||
HiddenSize: 4,
|
HiddenSize: 4,
|
||||||
IntermediateSize: 8,
|
IntermediateSize: 8,
|
||||||
NumHiddenLayers: 2,
|
NumHiddenLayers: 2,
|
||||||
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
|||||||
LinearValueHeadDim: 2,
|
LinearValueHeadDim: 2,
|
||||||
LinearConvKernelDim: 4,
|
LinearConvKernelDim: 4,
|
||||||
FullAttentionInterval: 2,
|
FullAttentionInterval: 2,
|
||||||
}
|
}}
|
||||||
|
|
||||||
m := &Model{
|
m := &Model{
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
|||||||
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseConfigVisionFields(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vision_config": {
|
||||||
|
"depth": 2,
|
||||||
|
"hidden_size": 256,
|
||||||
|
"num_heads": 8,
|
||||||
|
"in_channels": 3,
|
||||||
|
"patch_size": 14,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"layer_norm_epsilon": 0.000001,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"num_position_embeddings": 2304
|
||||||
|
},
|
||||||
|
"image_token_id": 111,
|
||||||
|
"vision_start_token_id": 112,
|
||||||
|
"vision_end_token_id": 113
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Vision == nil {
|
||||||
|
t.Fatal("vision config should be parsed")
|
||||||
|
}
|
||||||
|
if cfg.Vision.Depth != 2 {
|
||||||
|
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
|
||||||
|
}
|
||||||
|
if cfg.Vision.GridPerSide != 48 {
|
||||||
|
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
|
||||||
|
}
|
||||||
|
if cfg.Vision.RopeTheta != 10000 {
|
||||||
|
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.RopeTheta != 10000000 {
|
||||||
|
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
|
||||||
|
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 2048,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"head_dim": 256,
|
||||||
|
"linear_num_value_heads": 32,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000,
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"mrope_interleaved": true,
|
||||||
|
"mrope_section": [11, 11, 10]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.MRoPEInterleaved {
|
||||||
|
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
|
||||||
|
}
|
||||||
|
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
|
||||||
|
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
|
||||||
|
}
|
||||||
|
if cfg.RopeDim != 64 {
|
||||||
|
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigVisionTokenDefaults(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ImageTokenID != 151655 {
|
||||||
|
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
|
||||||
|
}
|
||||||
|
if cfg.VisionStartToken != 151652 {
|
||||||
|
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
|
||||||
|
}
|
||||||
|
if cfg.VisionEndToken != 151653 {
|
||||||
|
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveVisionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tensors map[string]*mlx.Array
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "legacy visual prefix",
|
||||||
|
tensors: map[string]*mlx.Array{
|
||||||
|
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
|
||||||
|
},
|
||||||
|
want: "model.visual",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "imported vision tower prefix",
|
||||||
|
tensors: map[string]*mlx.Array{
|
||||||
|
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
|
||||||
|
},
|
||||||
|
want: "vision_tower",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
|
||||||
|
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
|
||||||
|
v := &VisionConfig{}
|
||||||
|
v.applyDefaults()
|
||||||
|
|
||||||
|
v.applyPreprocessorConfig([]byte(`{
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 3,
|
||||||
|
"merge_size": 4,
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 1024,
|
||||||
|
"longest_edge": 8192
|
||||||
|
},
|
||||||
|
"image_mean": [0.1, 0.2, 0.3],
|
||||||
|
"image_std": [0.9, 0.8, 0.7]
|
||||||
|
}`))
|
||||||
|
|
||||||
|
if v.PatchSize != 16 {
|
||||||
|
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
|
||||||
|
}
|
||||||
|
if v.TemporalPatchSize != 3 {
|
||||||
|
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
|
||||||
|
}
|
||||||
|
if v.SpatialMergeSize != 4 {
|
||||||
|
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
|
||||||
|
}
|
||||||
|
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
|
||||||
|
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
|
||||||
|
}
|
||||||
|
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
|
||||||
|
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
|
||||||
|
v := &VisionConfig{}
|
||||||
|
v.applyDefaults()
|
||||||
|
|
||||||
|
v.applyPreprocessorConfig([]byte(`{
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 65536,
|
||||||
|
"longest_edge": 16777216
|
||||||
|
},
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"merge_size": 2,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5]
|
||||||
|
}`))
|
||||||
|
|
||||||
|
p := newVisionImageProcessor(v)
|
||||||
|
if p == nil {
|
||||||
|
t.Fatal("newVisionImageProcessor returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
|
||||||
|
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tok, err := tokenizer.LoadFromBytes([]byte(`{
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": {"a": 0},
|
||||||
|
"merges": []
|
||||||
|
}
|
||||||
|
}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load test tokenizer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tok
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
tok: testTokenizer(t),
|
||||||
|
Config: &Config{
|
||||||
|
ImageTokenID: 101,
|
||||||
|
VisionStartToken: 102,
|
||||||
|
VisionEndToken: 103,
|
||||||
|
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||||
|
},
|
||||||
|
Vision: &VisionModel{},
|
||||||
|
ImageProcessor: &VisionImageProcessor{},
|
||||||
|
}
|
||||||
|
|
||||||
|
main := mlx.FromValues([]float32{
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
}, 1, 2, 2)
|
||||||
|
|
||||||
|
resolveCalls := 0
|
||||||
|
got, state, err := m.tokenizePromptWithResolvedImages(
|
||||||
|
"a[img-7][img-7]a",
|
||||||
|
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
|
||||||
|
func(data []byte) (*VisionEmbeddings, error) {
|
||||||
|
if string(data) != "img7" {
|
||||||
|
return nil, fmt.Errorf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
resolveCalls++
|
||||||
|
return &VisionEmbeddings{
|
||||||
|
Main: main,
|
||||||
|
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
|
||||||
|
}
|
||||||
|
if resolveCalls != 1 {
|
||||||
|
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []int32{
|
||||||
|
0,
|
||||||
|
102, 101, 101, 103,
|
||||||
|
102, 101, 101, 103,
|
||||||
|
0,
|
||||||
|
}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("expected prompt vision state")
|
||||||
|
}
|
||||||
|
if len(state.Spans) != 2 {
|
||||||
|
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
|
||||||
|
}
|
||||||
|
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
|
||||||
|
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
|
||||||
|
}
|
||||||
|
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
|
||||||
|
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
|
||||||
|
}
|
||||||
|
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
|
||||||
|
if !slices.Equal(state.PositionCache, wantPos) {
|
||||||
|
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPromptMRoPEPositions(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Config: &Config{
|
||||||
|
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
state := &promptVisionState{
|
||||||
|
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
|
||||||
|
Spans: []promptVisionSpan{
|
||||||
|
{
|
||||||
|
Start: 2,
|
||||||
|
End: 8,
|
||||||
|
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := m.buildPromptMRoPEPositions(state, 0, 10)
|
||||||
|
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("time positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("height positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("width positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
|
||||||
|
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
|
||||||
|
|
||||||
|
if got := mapPromptPosition(state, 3); got != 2 {
|
||||||
|
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
|
||||||
|
}
|
||||||
|
if got := mapPromptPosition(state, 5); got != 4 {
|
||||||
|
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPromptVisionEmbeddings(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
m := &Model{}
|
||||||
|
state := &promptVisionState{
|
||||||
|
Spans: []promptVisionSpan{
|
||||||
|
{
|
||||||
|
Start: 1,
|
||||||
|
End: 3,
|
||||||
|
Main: mlx.FromValues([]float32{
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
}, 1, 2, 2),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h := mlx.FromValues([]float32{
|
||||||
|
0, 1,
|
||||||
|
2, 3,
|
||||||
|
4, 5,
|
||||||
|
6, 7,
|
||||||
|
}, 1, 4, 2)
|
||||||
|
|
||||||
|
got := m.applyPromptVisionEmbeddings(h, 0, state)
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
want := []float32{
|
||||||
|
0, 1,
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
6, 7,
|
||||||
|
}
|
||||||
|
if !slices.Equal(got.Floats(), want) {
|
||||||
|
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
854
x/models/qwen3_5/vision.go
Normal file
854
x/models/qwen3_5/vision.go
Normal file
@@ -0,0 +1,854 @@
|
|||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model/imageproc"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoVisionModel = errors.New("qwen3_5: no vision model")
|
||||||
|
|
||||||
|
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
|
||||||
|
type VisionConfig struct {
|
||||||
|
Depth int32 `json:"depth"`
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHeads int32 `json:"num_heads"`
|
||||||
|
InChannels int32 `json:"in_channels"`
|
||||||
|
PatchSize int32 `json:"patch_size"`
|
||||||
|
SpatialMergeSize int32 `json:"spatial_merge_size"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||||
|
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
|
||||||
|
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge int32 `json:"shortest_edge"`
|
||||||
|
LongestEdge int32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
|
||||||
|
GridPerSide int32 `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VisionConfig) applyDefaults() {
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.HiddenSize <= 0 {
|
||||||
|
v.HiddenSize = 1280
|
||||||
|
}
|
||||||
|
if v.NumHeads <= 0 {
|
||||||
|
v.NumHeads = 16
|
||||||
|
}
|
||||||
|
if v.InChannels <= 0 {
|
||||||
|
v.InChannels = 3
|
||||||
|
}
|
||||||
|
if v.PatchSize <= 0 {
|
||||||
|
v.PatchSize = 14
|
||||||
|
}
|
||||||
|
if v.SpatialMergeSize <= 0 {
|
||||||
|
v.SpatialMergeSize = 2
|
||||||
|
}
|
||||||
|
if v.LayerNormEpsilon == 0 {
|
||||||
|
v.LayerNormEpsilon = 1e-6
|
||||||
|
}
|
||||||
|
if v.RopeTheta == 0 {
|
||||||
|
v.RopeTheta = 10000
|
||||||
|
}
|
||||||
|
if v.TemporalPatchSize <= 0 {
|
||||||
|
v.TemporalPatchSize = 2
|
||||||
|
}
|
||||||
|
if v.NumPositionEmbeddings <= 0 {
|
||||||
|
v.NumPositionEmbeddings = 2304
|
||||||
|
}
|
||||||
|
if len(v.ImageMean) < 3 {
|
||||||
|
v.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
}
|
||||||
|
if len(v.ImageStd) < 3 {
|
||||||
|
v.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||||
|
}
|
||||||
|
if v.Size.ShortestEdge <= 0 {
|
||||||
|
v.Size.ShortestEdge = 64 << 10
|
||||||
|
}
|
||||||
|
if v.Size.LongestEdge <= 0 {
|
||||||
|
v.Size.LongestEdge = 2 << 20
|
||||||
|
}
|
||||||
|
|
||||||
|
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
|
||||||
|
if grid <= 0 {
|
||||||
|
grid = 48
|
||||||
|
}
|
||||||
|
v.GridPerSide = grid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
|
||||||
|
if v == nil || len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var pre struct {
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge int32 `json:"shortest_edge"`
|
||||||
|
LongestEdge int32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
PatchSize int32 `json:"patch_size"`
|
||||||
|
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||||
|
MergeSize int32 `json:"merge_size"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &pre); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if pre.PatchSize > 0 {
|
||||||
|
v.PatchSize = pre.PatchSize
|
||||||
|
}
|
||||||
|
if pre.TemporalPatchSize > 0 {
|
||||||
|
v.TemporalPatchSize = pre.TemporalPatchSize
|
||||||
|
}
|
||||||
|
if pre.MergeSize > 0 {
|
||||||
|
v.SpatialMergeSize = pre.MergeSize
|
||||||
|
}
|
||||||
|
if pre.Size.ShortestEdge > 0 {
|
||||||
|
v.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if pre.Size.LongestEdge > 0 {
|
||||||
|
v.Size.LongestEdge = pre.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(pre.ImageMean) >= 3 {
|
||||||
|
v.ImageMean = pre.ImageMean
|
||||||
|
}
|
||||||
|
if len(pre.ImageStd) >= 3 {
|
||||||
|
v.ImageStd = pre.ImageStd
|
||||||
|
}
|
||||||
|
v.applyDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionGrid tracks patch-grid dimensions for an image.
|
||||||
|
type VisionGrid struct {
|
||||||
|
Height int32
|
||||||
|
Width int32
|
||||||
|
Temporal int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionImageProcessor reproduces qwen3vl image preprocessing.
|
||||||
|
type VisionImageProcessor struct {
|
||||||
|
numChannels int32
|
||||||
|
patchSize int32
|
||||||
|
temporalPatchSize int32
|
||||||
|
mergeSize int32
|
||||||
|
shortestEdge int32
|
||||||
|
longestEdge int32
|
||||||
|
factor int32
|
||||||
|
imageMean [3]float32
|
||||||
|
imageStd [3]float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &VisionImageProcessor{
|
||||||
|
numChannels: cfg.InChannels,
|
||||||
|
patchSize: cfg.PatchSize,
|
||||||
|
temporalPatchSize: cfg.TemporalPatchSize,
|
||||||
|
mergeSize: cfg.SpatialMergeSize,
|
||||||
|
shortestEdge: cfg.Size.ShortestEdge,
|
||||||
|
longestEdge: cfg.Size.LongestEdge,
|
||||||
|
factor: cfg.PatchSize * cfg.SpatialMergeSize,
|
||||||
|
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
|
||||||
|
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
|
||||||
|
factor := int(p.factor)
|
||||||
|
if factor <= 0 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if height < factor || width < factor {
|
||||||
|
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
|
||||||
|
}
|
||||||
|
if min(height, width) == 0 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
|
||||||
|
}
|
||||||
|
if max(height, width)/min(height, width) > 200 {
|
||||||
|
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
|
||||||
|
}
|
||||||
|
|
||||||
|
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||||
|
|
||||||
|
hBar := roundEven(float64(height)/float64(factor)) * factor
|
||||||
|
wBar := roundEven(float64(width)/float64(factor)) * factor
|
||||||
|
|
||||||
|
if hBar*wBar > int(p.longestEdge) {
|
||||||
|
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
|
||||||
|
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||||
|
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||||
|
} else if hBar*wBar < int(p.shortestEdge) {
|
||||||
|
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
|
||||||
|
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||||
|
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||||
|
}
|
||||||
|
|
||||||
|
return hBar, wBar, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
|
||||||
|
if p == nil {
|
||||||
|
return nil, nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
img = imageproc.Composite(img)
|
||||||
|
origW := img.Bounds().Dx()
|
||||||
|
origH := img.Bounds().Dy()
|
||||||
|
|
||||||
|
resizedH, resizedW, err := p.smartResize(origH, origW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resized := imageproc.Resize(
|
||||||
|
img,
|
||||||
|
image.Point{X: resizedW, Y: resizedH},
|
||||||
|
imageproc.ResizeBilinear,
|
||||||
|
)
|
||||||
|
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
|
||||||
|
|
||||||
|
grid := &VisionGrid{
|
||||||
|
Height: int32(resizedH / int(p.patchSize)),
|
||||||
|
Width: int32(resizedW / int(p.patchSize)),
|
||||||
|
Temporal: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
patches := p.createPatches(pixels, resizedH, resizedW, grid)
|
||||||
|
|
||||||
|
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
|
||||||
|
numPatches := int(grid.Height * grid.Width)
|
||||||
|
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
|
||||||
|
return pixelValues, grid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
|
||||||
|
channels := int(p.numChannels)
|
||||||
|
patchSize := int(p.patchSize)
|
||||||
|
mergeSize := int(p.mergeSize)
|
||||||
|
temporalPatchSize := int(p.temporalPatchSize)
|
||||||
|
|
||||||
|
// Temporal is always 1 for static images; only spatial patches are created.
|
||||||
|
numPatches := int(grid.Height * grid.Width)
|
||||||
|
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||||
|
result := make([]float32, numPatches*patchDim)
|
||||||
|
|
||||||
|
patchIndex := 0
|
||||||
|
for h := 0; h < int(grid.Height); h += mergeSize {
|
||||||
|
for w := 0; w < int(grid.Width); w += mergeSize {
|
||||||
|
for mh := range mergeSize {
|
||||||
|
for mw := range mergeSize {
|
||||||
|
baseOffset := patchIndex * patchDim
|
||||||
|
|
||||||
|
for c := range channels {
|
||||||
|
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||||
|
for py := range patchSize {
|
||||||
|
for px := range patchSize {
|
||||||
|
y := (h+mh)*patchSize + py
|
||||||
|
x := (w+mw)*patchSize + px
|
||||||
|
srcIdx := c*height*width + y*width + x
|
||||||
|
dstIdx := channelOffset + py*patchSize + px
|
||||||
|
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||||
|
result[dstIdx] = pixels[srcIdx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if temporalPatchSize > 1 {
|
||||||
|
for c := range channels {
|
||||||
|
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||||
|
frameSize := patchSize * patchSize
|
||||||
|
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||||
|
cur := channelOffset + tp*frameSize
|
||||||
|
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
patchIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionAttention runs one self-attention block inside the vision encoder.
|
||||||
|
type VisionAttention struct {
|
||||||
|
QKV nn.LinearLayer
|
||||||
|
Query nn.LinearLayer
|
||||||
|
Key nn.LinearLayer
|
||||||
|
Value nn.LinearLayer
|
||||||
|
Output nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||||
|
headDim := cfg.HiddenSize / cfg.NumHeads
|
||||||
|
if headDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
var q, k, v *mlx.Array
|
||||||
|
if a.QKV != nil {
|
||||||
|
qkv := a.QKV.Forward(x)
|
||||||
|
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
|
||||||
|
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
|
||||||
|
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
|
||||||
|
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
|
||||||
|
} else {
|
||||||
|
if a.Query == nil || a.Key == nil || a.Value == nil {
|
||||||
|
return nil, errors.New("vision attention is missing q/k/v projections")
|
||||||
|
}
|
||||||
|
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
q = applyVisionRoPE(q, cos, sin)
|
||||||
|
k = applyVisionRoPE(k, cos, sin)
|
||||||
|
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||||
|
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
|
||||||
|
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
|
||||||
|
if a.Output == nil {
|
||||||
|
return nil, errors.New("vision attention is missing output projection")
|
||||||
|
}
|
||||||
|
return a.Output.Forward(attn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionMLP is the vision feed-forward block.
|
||||||
|
type VisionMLP struct {
|
||||||
|
FC1 nn.LinearLayer
|
||||||
|
FC2 nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
|
||||||
|
if m.FC1 == nil || m.FC2 == nil {
|
||||||
|
return nil, errors.New("vision mlp is missing fc1/fc2")
|
||||||
|
}
|
||||||
|
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionEncoderLayer is one transformer block in the vision encoder.
|
||||||
|
type VisionEncoderLayer struct {
|
||||||
|
Norm1 *nn.LayerNorm
|
||||||
|
Attn *VisionAttention
|
||||||
|
Norm2 *nn.LayerNorm
|
||||||
|
MLP *VisionMLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
|
||||||
|
return nil, errors.New("vision layer is incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := x
|
||||||
|
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x = mlx.Add(r, a)
|
||||||
|
|
||||||
|
r = x
|
||||||
|
m, err := l.MLP.Forward(l.Norm2.Forward(x))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mlx.Add(r, m), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionPatchMerger projects merged spatial groups into language embedding space.
|
||||||
|
type VisionPatchMerger struct {
|
||||||
|
Norm *nn.LayerNorm
|
||||||
|
FC1 nn.LinearLayer
|
||||||
|
FC2 nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||||
|
group := merge * merge
|
||||||
|
if group <= 0 || L%group != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
|
||||||
|
}
|
||||||
|
|
||||||
|
x = mlx.Reshape(x, B, L/group, group, D)
|
||||||
|
x = mlx.Reshape(x, B, L/group, group*D)
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
|
||||||
|
return nil, errors.New("vision patch merger is incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
x = m.Norm.Forward(x)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionModel contains the full Qwen vision tower.
|
||||||
|
type VisionModel struct {
|
||||||
|
PatchProjection nn.LinearLayer
|
||||||
|
PositionEmbed *nn.Embedding
|
||||||
|
Layers []*VisionEncoderLayer
|
||||||
|
PatchMerger *VisionPatchMerger
|
||||||
|
|
||||||
|
cfg *VisionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
// Temporal is always 1 for static images; only spatial coordinates are generated.
|
||||||
|
coords := make([][2]int32, 0, grid.Height*grid.Width)
|
||||||
|
for h := int32(0); h < grid.Height; h += merge {
|
||||||
|
for w := int32(0); w < grid.Width; w += merge {
|
||||||
|
for mh := range merge {
|
||||||
|
for mw := range merge {
|
||||||
|
coords = append(coords, [2]int32{h + mh, w + mw})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return coords
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||||
|
if m.PositionEmbed == nil {
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
B, D := int32(shape[0]), int32(shape[2])
|
||||||
|
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||||
|
L := int32(len(coords))
|
||||||
|
if L != int32(shape[1]) {
|
||||||
|
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
|
||||||
|
}
|
||||||
|
|
||||||
|
stepH := float32(0)
|
||||||
|
if grid.Height > 1 {
|
||||||
|
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
|
||||||
|
}
|
||||||
|
stepW := float32(0)
|
||||||
|
if grid.Width > 1 {
|
||||||
|
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
indices := make([]int32, 0, L*4)
|
||||||
|
weights := make([]float32, 0, L*4)
|
||||||
|
for _, c := range coords {
|
||||||
|
y := float32(c[0]) * stepH
|
||||||
|
x0 := float32(c[1]) * stepW
|
||||||
|
|
||||||
|
fy := int32(y)
|
||||||
|
fx := int32(x0)
|
||||||
|
cy := min(fy+1, m.cfg.GridPerSide-1)
|
||||||
|
cx := min(fx+1, m.cfg.GridPerSide-1)
|
||||||
|
|
||||||
|
indices = append(indices,
|
||||||
|
fy*m.cfg.GridPerSide+fx,
|
||||||
|
fy*m.cfg.GridPerSide+cx,
|
||||||
|
cy*m.cfg.GridPerSide+fx,
|
||||||
|
cy*m.cfg.GridPerSide+cx,
|
||||||
|
)
|
||||||
|
|
||||||
|
dy := y - float32(fy)
|
||||||
|
dx := x0 - float32(fx)
|
||||||
|
weights = append(weights,
|
||||||
|
(1-dy)*(1-dx),
|
||||||
|
(1-dy)*dx,
|
||||||
|
dy*(1-dx),
|
||||||
|
dy*dx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
idxArr := mlx.FromValues(indices, int(L), 4)
|
||||||
|
wArr := mlx.FromValues(weights, int(L), 4, 1)
|
||||||
|
|
||||||
|
pos := m.PositionEmbed.Forward(idxArr)
|
||||||
|
wArr = wArr.AsType(pos.DType())
|
||||||
|
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
|
||||||
|
if D != int32(pos.Dim(1)) {
|
||||||
|
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
pos = mlx.ExpandDims(pos, 0)
|
||||||
|
if B > 1 {
|
||||||
|
pos = mlx.Tile(pos, []int32{B, 1, 1})
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.Add(x, pos), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
|
||||||
|
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
|
||||||
|
if headDim <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||||
|
L := int32(len(coords))
|
||||||
|
half := headDim / 2
|
||||||
|
quarter := half / 2
|
||||||
|
if quarter <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
angles := make([]float32, L*headDim)
|
||||||
|
for i, c := range coords {
|
||||||
|
base := int32(i) * headDim
|
||||||
|
for j := range quarter {
|
||||||
|
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
|
||||||
|
angles[base+j] = float32(float64(c[0]) * freq)
|
||||||
|
angles[base+quarter+j] = float32(float64(c[1]) * freq)
|
||||||
|
}
|
||||||
|
for j := range half {
|
||||||
|
angles[base+half+j] = angles[base+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := mlx.FromValues(angles, int(L), int(headDim))
|
||||||
|
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
|
||||||
|
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
|
||||||
|
return cos, sin, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||||
|
if m == nil || pixelValues == nil || grid == nil {
|
||||||
|
return nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
if m.PatchProjection == nil || m.PatchMerger == nil {
|
||||||
|
return nil, errors.New("vision model is missing required projections")
|
||||||
|
}
|
||||||
|
|
||||||
|
x := m.PatchProjection.Forward(pixelValues)
|
||||||
|
var err error
|
||||||
|
x, err = m.addPositionEmbedding(x, grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cos, sin, err := m.rotaryEmbeddings(grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
x, err = layer.Forward(x, cos, sin, m.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("vision layer %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
main, err := m.PatchMerger.Forward(x, m.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("vision patch merger: %w", err)
|
||||||
|
}
|
||||||
|
return main, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionEmbeddings struct {
|
||||||
|
Main *mlx.Array
|
||||||
|
Grid *VisionGrid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
|
||||||
|
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
|
||||||
|
return nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
main, err := m.Vision.Forward(pixelValues, grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &VisionEmbeddings{Main: main, Grid: grid}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
|
||||||
|
candidates := []string{
|
||||||
|
"vision_tower",
|
||||||
|
weightPrefix + "vision_tower",
|
||||||
|
"model.visual",
|
||||||
|
"visual",
|
||||||
|
weightPrefix + "model.visual",
|
||||||
|
weightPrefix + "visual",
|
||||||
|
}
|
||||||
|
|
||||||
|
hasTensor := func(prefix string) bool {
|
||||||
|
for _, suffix := range []string{
|
||||||
|
".patch_embed.proj.weight",
|
||||||
|
".patch_embed.weight",
|
||||||
|
".pos_embed.weight",
|
||||||
|
".blocks.0.attn.qkv.weight",
|
||||||
|
".merger.linear_fc1.weight",
|
||||||
|
".merger.mlp.0.weight",
|
||||||
|
} {
|
||||||
|
if tensors[prefix+suffix] != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range candidates {
|
||||||
|
if hasTensor(prefix) {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
|
||||||
|
for _, p := range paths {
|
||||||
|
if l := linears.Make(p); l != nil {
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
|
||||||
|
for _, base := range bases {
|
||||||
|
if w := tensors[base+".weight"]; w != nil {
|
||||||
|
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
|
||||||
|
}
|
||||||
|
if w := tensors[base]; w != nil {
|
||||||
|
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadVisionPatchMerger(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
linears mlxmodel.LinearFactory,
|
||||||
|
eps float32,
|
||||||
|
bases ...string,
|
||||||
|
) *VisionPatchMerger {
|
||||||
|
for _, base := range bases {
|
||||||
|
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
|
||||||
|
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
|
||||||
|
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
|
||||||
|
if norm != nil && fc1 != nil && fc2 != nil {
|
||||||
|
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
|
||||||
|
if w == nil || !w.Valid() {
|
||||||
|
return nil, errors.New("missing patch embedding weight")
|
||||||
|
}
|
||||||
|
if w.NumDims() < 2 {
|
||||||
|
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
|
||||||
|
}
|
||||||
|
if w.NumDims() == 2 {
|
||||||
|
return w, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := int32(w.Dim(0))
|
||||||
|
in := int32(w.Size() / w.Dim(0))
|
||||||
|
return mlx.Reshape(w, out, in), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadVisionComponents(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
linears mlxmodel.LinearFactory,
|
||||||
|
cfg *Config,
|
||||||
|
weightPrefix string,
|
||||||
|
) (*VisionModel, *VisionImageProcessor, error) {
|
||||||
|
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
|
||||||
|
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
|
||||||
|
if visionPrefix == "" {
|
||||||
|
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
patchW, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".patch_embed.proj.weight",
|
||||||
|
visionPrefix+".patch_embed.weight",
|
||||||
|
)
|
||||||
|
if patchW == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
patchW, err := flattenPatchEmbeddingWeight(patchW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
patchB, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".patch_embed.proj.bias",
|
||||||
|
visionPrefix+".patch_embed.bias",
|
||||||
|
)
|
||||||
|
|
||||||
|
patchProj := nn.NewLinear(patchW, patchB)
|
||||||
|
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
|
||||||
|
return nil, nil, fmt.Errorf(
|
||||||
|
"vision patch embedding input dim mismatch: got %d expected %d",
|
||||||
|
got,
|
||||||
|
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
posW, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".pos_embed.weight",
|
||||||
|
visionPrefix+".position_embedding.weight",
|
||||||
|
)
|
||||||
|
if posW == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
|
||||||
|
vm := &VisionModel{
|
||||||
|
PatchProjection: patchProj,
|
||||||
|
PositionEmbed: nn.NewEmbedding(posW),
|
||||||
|
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
|
||||||
|
cfg: cfg.Vision,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.Vision.Depth {
|
||||||
|
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
|
||||||
|
layer := &VisionEncoderLayer{
|
||||||
|
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
|
||||||
|
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
|
||||||
|
Attn: &VisionAttention{
|
||||||
|
QKV: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.qkv",
|
||||||
|
layerPrefix+".attn_qkv",
|
||||||
|
),
|
||||||
|
Query: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.q_proj",
|
||||||
|
layerPrefix+".attn_q",
|
||||||
|
),
|
||||||
|
Key: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.k_proj",
|
||||||
|
layerPrefix+".attn_k",
|
||||||
|
),
|
||||||
|
Value: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.v_proj",
|
||||||
|
layerPrefix+".attn_v",
|
||||||
|
),
|
||||||
|
Output: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.proj",
|
||||||
|
layerPrefix+".attn_out",
|
||||||
|
layerPrefix+".attn.o_proj",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
MLP: &VisionMLP{
|
||||||
|
FC1: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".mlp.fc1",
|
||||||
|
layerPrefix+".mlp.linear_fc1",
|
||||||
|
),
|
||||||
|
FC2: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".mlp.fc2",
|
||||||
|
layerPrefix+".mlp.linear_fc2",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if layer.Norm1 == nil || layer.Norm2 == nil {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
|
||||||
|
}
|
||||||
|
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
|
||||||
|
}
|
||||||
|
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
vm.Layers[i] = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
vm.PatchMerger = loadVisionPatchMerger(
|
||||||
|
tensors,
|
||||||
|
linears,
|
||||||
|
cfg.Vision.LayerNormEpsilon,
|
||||||
|
visionPrefix+".merger",
|
||||||
|
)
|
||||||
|
if vm.PatchMerger == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return vm, newVisionImageProcessor(cfg.Vision), nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user