mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
389 lines
12 KiB
Go
389 lines
12 KiB
Go
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
|
|
package qwen3
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
)
|
|
|
|
// Config holds Qwen3 text encoder configuration
|
|
type Config struct {
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
VocabSize int32 `json:"vocab_size"`
|
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
HeadDim int32 `json:"head_dim"`
|
|
}
|
|
|
|
// Attention implements Qwen3 attention with QK norms
|
|
type Attention struct {
|
|
QProj nn.LinearLayer `weight:"q_proj"`
|
|
KProj nn.LinearLayer `weight:"k_proj"`
|
|
VProj nn.LinearLayer `weight:"v_proj"`
|
|
OProj nn.LinearLayer `weight:"o_proj"`
|
|
QNorm *nn.RMSNorm `weight:"q_norm"`
|
|
KNorm *nn.RMSNorm `weight:"k_norm"`
|
|
// Computed fields
|
|
NHeads int32
|
|
NKVHeads int32
|
|
HeadDim int32
|
|
Scale float32
|
|
RopeTheta float32
|
|
}
|
|
|
|
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
|
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
H := shape[2]
|
|
D := shape[3]
|
|
half := D / 2
|
|
|
|
freqsArr := make([]float32, half)
|
|
logTheta := float32(math.Log(float64(theta)))
|
|
for i := int32(0); i < half; i++ {
|
|
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
|
}
|
|
freqs := mlx.NewArray(freqsArr, []int32{half})
|
|
|
|
posArr := make([]float32, seqLen)
|
|
for i := int32(0); i < seqLen; i++ {
|
|
posArr[i] = float32(i)
|
|
}
|
|
pos := mlx.NewArray(posArr, []int32{seqLen})
|
|
|
|
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
|
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
|
args := mlx.Mul(posExpanded, freqsExpanded)
|
|
|
|
cosVals := mlx.Cos(args)
|
|
sinVals := mlx.Sin(args)
|
|
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
|
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
|
|
|
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
|
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
|
|
|
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
|
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
|
|
|
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
|
}
|
|
|
|
// Forward computes attention with causal masking and optional padding mask
|
|
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
|
|
q := attn.QProj.Forward(x)
|
|
k := attn.KProj.Forward(x)
|
|
v := attn.VProj.Forward(x)
|
|
|
|
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
|
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
|
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
|
|
|
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
|
q = attn.QNorm.Forward(q, 1e-6)
|
|
k = attn.KNorm.Forward(k, 1e-6)
|
|
|
|
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
|
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
|
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
if attn.NKVHeads < attn.NHeads {
|
|
repeats := attn.NHeads / attn.NKVHeads
|
|
k = repeatKV(k, repeats)
|
|
v = repeatKV(v, repeats)
|
|
}
|
|
|
|
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
|
|
|
|
out = mlx.Transpose(out, 0, 2, 1, 3)
|
|
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
|
|
|
out = attn.OProj.Forward(out)
|
|
|
|
return out
|
|
}
|
|
|
|
// repeatKV repeats key/value heads for GQA
|
|
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
|
if repeats == 1 {
|
|
return x
|
|
}
|
|
shape := x.Shape()
|
|
x = mlx.ExpandDims(x, 2)
|
|
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
|
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
|
}
|
|
|
|
// MLP implements Qwen3 SwiGLU MLP
|
|
type MLP struct {
|
|
GateProj nn.LinearLayer `weight:"gate_proj"`
|
|
UpProj nn.LinearLayer `weight:"up_proj"`
|
|
DownProj nn.LinearLayer `weight:"down_proj"`
|
|
}
|
|
|
|
// Forward applies the MLP
|
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := m.GateProj.Forward(x)
|
|
gate = mlx.SiLU(gate)
|
|
up := m.UpProj.Forward(x)
|
|
h := mlx.Mul(gate, up)
|
|
return m.DownProj.Forward(h)
|
|
}
|
|
|
|
// Block represents a single Qwen3 transformer block
|
|
type Block struct {
|
|
Attention *Attention `weight:"self_attn"`
|
|
MLP *MLP `weight:"mlp"`
|
|
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
}
|
|
|
|
// Forward applies the Qwen3 block
|
|
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
|
|
h := qb.InputLayerNorm.Forward(x, eps)
|
|
attnOut := qb.Attention.Forward(h, mask, maskMode)
|
|
x = mlx.Add(x, attnOut)
|
|
|
|
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
|
mlpOut := qb.MLP.Forward(h)
|
|
x = mlx.Add(x, mlpOut)
|
|
|
|
return x
|
|
}
|
|
|
|
// TextEncoder is the full Qwen3 encoder
|
|
type TextEncoder struct {
|
|
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
Layers []*Block `weight:"model.layers"`
|
|
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
|
*Config
|
|
}
|
|
|
|
// Load loads the Qwen3 text encoder from ollama blob storage.
|
|
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
|
|
fmt.Print(" Loading text encoder... ")
|
|
|
|
// Load config from blob
|
|
var cfg Config
|
|
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
|
return fmt.Errorf("config: %w", err)
|
|
}
|
|
m.Config = &cfg
|
|
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
|
|
|
// Load weights from tensor blobs
|
|
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
if err := weights.Load(0); err != nil {
|
|
return fmt.Errorf("load weights: %w", err)
|
|
}
|
|
defer weights.ReleaseAll()
|
|
|
|
return m.loadWeights(weights)
|
|
}
|
|
|
|
// loadWeights loads weights from any WeightSource into the model
|
|
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
return fmt.Errorf("load module: %w", err)
|
|
}
|
|
m.initComputedFields()
|
|
fmt.Println("✓")
|
|
return nil
|
|
}
|
|
|
|
// initComputedFields initializes computed fields after loading weights
|
|
func (m *TextEncoder) initComputedFields() {
|
|
cfg := m.Config
|
|
m.FinalNorm.Eps = cfg.RMSNormEps
|
|
for _, block := range m.Layers {
|
|
// Attention
|
|
block.Attention.NHeads = cfg.NumAttentionHeads
|
|
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
|
block.Attention.HeadDim = cfg.HeadDim
|
|
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
block.Attention.RopeTheta = cfg.RopeTheta
|
|
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
|
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
|
// Block norms
|
|
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
|
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
|
}
|
|
}
|
|
|
|
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
|
|
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
|
|
h := te.EmbedTokens.Forward(tokens)
|
|
eps := te.RMSNormEps
|
|
|
|
for _, layer := range te.Layers {
|
|
h = layer.Forward(h, eps, attnMask, maskMode)
|
|
}
|
|
|
|
// Apply final RMS norm
|
|
h = te.FinalNorm.Forward(h, eps)
|
|
|
|
return h
|
|
}
|
|
|
|
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
|
|
// This is used by Flux2 which needs embeddings from specific intermediate layers.
|
|
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
|
|
h := te.EmbedTokens.Forward(tokens)
|
|
eps := te.RMSNormEps
|
|
|
|
outputs := make([]*mlx.Array, len(layerIndices))
|
|
layerSet := make(map[int]int)
|
|
for i, idx := range layerIndices {
|
|
layerSet[idx] = i
|
|
}
|
|
|
|
for i, layer := range te.Layers {
|
|
h = layer.Forward(h, eps, attnMask, maskMode)
|
|
if outIdx, ok := layerSet[i]; ok {
|
|
outputs[outIdx] = h
|
|
}
|
|
}
|
|
|
|
return outputs
|
|
}
|
|
|
|
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
|
|
// If think is true, adds the <think></think> block after the assistant tag
|
|
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
|
|
func ApplyChatTemplate(prompt string, think bool) string {
|
|
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
|
if think {
|
|
return base + "<think>\n\n</think>\n\n"
|
|
}
|
|
return base
|
|
}
|
|
|
|
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
|
|
// If think is true, includes the <think></think> block in the chat template.
|
|
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
|
|
formattedPrompt := ApplyChatTemplate(prompt, think)
|
|
|
|
tokens := tok.Encode(formattedPrompt, false)
|
|
|
|
if len(tokens) > maxLen {
|
|
tokens = tokens[:maxLen]
|
|
}
|
|
|
|
maskData := make([]float32, maxLen)
|
|
for i := 0; i < len(tokens); i++ {
|
|
maskData[i] = 1.0
|
|
}
|
|
|
|
// Get PAD token (different from EOS for Qwen3)
|
|
padToken := tok.PAD()
|
|
if padToken < 0 {
|
|
padToken = tok.EOS() // fallback
|
|
}
|
|
|
|
paddedTokens := make([]int32, maxLen)
|
|
copy(paddedTokens, tokens)
|
|
for i := len(tokens); i < maxLen; i++ {
|
|
paddedTokens[i] = padToken
|
|
}
|
|
|
|
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
|
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
|
|
|
// Build combined causal + PAD mask [L, L]
|
|
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
|
L := int32(maxLen)
|
|
validLen := int32(len(tokens))
|
|
combinedMaskData := make([]float32, L*L)
|
|
negInf := float32(-1e9)
|
|
for i := int32(0); i < L; i++ {
|
|
for j := int32(0); j < L; j++ {
|
|
idx := i*L + j
|
|
if j <= i && j < validLen {
|
|
combinedMaskData[idx] = 0
|
|
} else {
|
|
combinedMaskData[idx] = negInf
|
|
}
|
|
}
|
|
}
|
|
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
|
|
|
|
embeddings := te.Forward(tokensArr, maskMat, "")
|
|
|
|
return embeddings, maskArr
|
|
}
|
|
|
|
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
|
|
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
|
|
// If think is true, includes the <think></think> block in the chat template.
|
|
// Returns embeddings and padded sequence length.
|
|
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
|
|
formattedPrompt := ApplyChatTemplate(prompt, think)
|
|
tokens := tok.Encode(formattedPrompt, false)
|
|
|
|
if len(tokens) > maxLen {
|
|
tokens = tokens[:maxLen]
|
|
}
|
|
|
|
// Pad to maxLen
|
|
padToken := tok.PAD()
|
|
if padToken < 0 {
|
|
padToken = tok.EOS() // fallback
|
|
}
|
|
padded := make([]int32, maxLen)
|
|
copy(padded, tokens)
|
|
for i := len(tokens); i < maxLen; i++ {
|
|
padded[i] = padToken
|
|
}
|
|
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
|
|
|
|
// Build combined causal + PAD mask [L, L]
|
|
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
|
// This combines causal masking with PAD token masking
|
|
L := int32(maxLen)
|
|
validLen := int32(len(tokens))
|
|
maskData := make([]float32, L*L)
|
|
negInf := float32(-1e9)
|
|
for i := int32(0); i < L; i++ {
|
|
for j := int32(0); j < L; j++ {
|
|
idx := i*L + j
|
|
if j <= i && j < validLen {
|
|
maskData[idx] = 0 // allowed: causal OK and not PAD
|
|
} else {
|
|
maskData[idx] = negInf // blocked: future or PAD
|
|
}
|
|
}
|
|
}
|
|
maskMat := mlx.NewArray(maskData, []int32{L, L})
|
|
|
|
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
|
|
|
|
// Concatenate layer outputs along the hidden dimension
|
|
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
|
|
embeddings := mlx.Concatenate(layerOutputs, 2)
|
|
|
|
// Return embeddings and padded length
|
|
return embeddings, int32(maxLen)
|
|
}
|