mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
561 lines
19 KiB
Go
561 lines
19 KiB
Go
package flux2
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// TransformerConfig holds Flux2 transformer configuration
|
|
type TransformerConfig struct {
|
|
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
|
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
|
Eps float32 `json:"eps"` // 1e-6
|
|
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
|
InChannels int32 `json:"in_channels"` // 128
|
|
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
|
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
|
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
|
NumLayers int32 `json:"num_layers"` // 5
|
|
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
|
PatchSize int32 `json:"patch_size"` // 1
|
|
RopeTheta int32 `json:"rope_theta"` // 2000
|
|
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
|
}
|
|
|
|
// Computed dimensions
|
|
func (c *TransformerConfig) InnerDim() int32 {
|
|
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
|
|
}
|
|
|
|
func (c *TransformerConfig) MLPHiddenDim() int32 {
|
|
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
|
|
}
|
|
|
|
// TimestepEmbedder creates timestep embeddings
|
|
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
|
|
type TimestepEmbedder struct {
|
|
Linear1 nn.LinearLayer `weight:"linear_1"`
|
|
Linear2 nn.LinearLayer `weight:"linear_2"`
|
|
EmbedDim int32 // 256
|
|
}
|
|
|
|
// Forward creates sinusoidal embeddings and projects them
|
|
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
|
|
half := t.EmbedDim / 2
|
|
freqs := make([]float32, half)
|
|
for i := int32(0); i < half; i++ {
|
|
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
|
}
|
|
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
|
|
|
// timesteps: [B] -> [B, 1]
|
|
tExpanded := mlx.ExpandDims(timesteps, 1)
|
|
// args: [B, half]
|
|
args := mlx.Mul(tExpanded, freqsArr)
|
|
|
|
// [cos(args), sin(args)] -> [B, embed_dim]
|
|
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
|
|
|
|
// MLP: linear_1 -> silu -> linear_2
|
|
h := t.Linear1.Forward(sinEmbed)
|
|
h = mlx.SiLU(h)
|
|
return t.Linear2.Forward(h)
|
|
}
|
|
|
|
// TimeGuidanceEmbed wraps the timestep embedder
|
|
// Weight names: time_guidance_embed.timestep_embedder.*
|
|
type TimeGuidanceEmbed struct {
|
|
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
|
|
}
|
|
|
|
// Forward computes timestep embeddings
|
|
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
|
|
return t.TimestepEmbedder.Forward(timesteps)
|
|
}
|
|
|
|
// Modulation computes adaptive modulation parameters
|
|
// Weight names: double_stream_modulation_img.linear.weight, etc.
|
|
type Modulation struct {
|
|
Linear nn.LinearLayer `weight:"linear"`
|
|
}
|
|
|
|
// Forward computes modulation parameters
|
|
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
|
|
h := mlx.SiLU(temb)
|
|
return m.Linear.Forward(h)
|
|
}
|
|
|
|
// TransformerBlockAttn implements dual-stream attention
|
|
// Weight names: transformer_blocks.N.attn.*
|
|
type TransformerBlockAttn struct {
|
|
// Image stream (separate Q, K, V projections)
|
|
ToQ nn.LinearLayer `weight:"to_q"`
|
|
ToK nn.LinearLayer `weight:"to_k"`
|
|
ToV nn.LinearLayer `weight:"to_v"`
|
|
// Note: to_out has .0 suffix in weights, handled specially
|
|
ToOut0 nn.LinearLayer `weight:"to_out.0"`
|
|
|
|
// Text stream (add_ projections)
|
|
AddQProj nn.LinearLayer `weight:"add_q_proj"`
|
|
AddKProj nn.LinearLayer `weight:"add_k_proj"`
|
|
AddVProj nn.LinearLayer `weight:"add_v_proj"`
|
|
ToAddOut nn.LinearLayer `weight:"to_add_out"`
|
|
|
|
// QK norms for image stream
|
|
NormQ *mlx.Array `weight:"norm_q.weight"`
|
|
NormK *mlx.Array `weight:"norm_k.weight"`
|
|
|
|
// QK norms for text stream (added)
|
|
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
|
|
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
|
|
}
|
|
|
|
// FeedForward implements SwiGLU MLP
|
|
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
|
|
type FeedForward struct {
|
|
LinearIn nn.LinearLayer `weight:"linear_in"`
|
|
LinearOut nn.LinearLayer `weight:"linear_out"`
|
|
}
|
|
|
|
// Forward applies SwiGLU MLP
|
|
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
|
// LinearIn outputs 2x hidden dim for SwiGLU
|
|
h := ff.LinearIn.Forward(x)
|
|
shape := h.Shape()
|
|
half := shape[len(shape)-1] / 2
|
|
|
|
// Split into gate and up
|
|
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
|
|
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
|
|
|
|
// SwiGLU: silu(gate) * up
|
|
h = mlx.Mul(mlx.SiLU(gate), up)
|
|
return ff.LinearOut.Forward(h)
|
|
}
|
|
|
|
// TransformerBlock implements a dual-stream transformer block
|
|
// Weight names: transformer_blocks.N.*
|
|
type TransformerBlock struct {
|
|
Attn *TransformerBlockAttn `weight:"attn"`
|
|
FF *FeedForward `weight:"ff"`
|
|
FFContext *FeedForward `weight:"ff_context"`
|
|
|
|
// Config (set after loading)
|
|
NHeads int32
|
|
HeadDim int32
|
|
Scale float32
|
|
}
|
|
|
|
// Forward applies the dual-stream block
|
|
// imgHidden: [B, imgLen, dim]
|
|
// txtHidden: [B, txtLen, dim]
|
|
// imgMod, txtMod: modulation params [B, 6*dim] each
|
|
// cos, sin: RoPE values
|
|
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
imgShape := imgHidden.Shape()
|
|
B := imgShape[0]
|
|
imgLen := imgShape[1]
|
|
dim := imgShape[2]
|
|
txtLen := txtHidden.Shape()[1]
|
|
|
|
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
|
|
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
|
|
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
|
|
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
|
|
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
|
|
|
|
// === Attention branch ===
|
|
// Modulate inputs
|
|
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
|
|
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
|
|
|
|
// Compute Q, K, V for image stream (separate projections)
|
|
imgQ := block.Attn.ToQ.Forward(imgNorm)
|
|
imgK := block.Attn.ToK.Forward(imgNorm)
|
|
imgV := block.Attn.ToV.Forward(imgNorm)
|
|
|
|
// Compute Q, K, V for text stream (add_ projections)
|
|
txtQ := block.Attn.AddQProj.Forward(txtNorm)
|
|
txtK := block.Attn.AddKProj.Forward(txtNorm)
|
|
txtV := block.Attn.AddVProj.Forward(txtNorm)
|
|
|
|
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
|
|
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
|
|
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
|
|
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
|
|
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
|
|
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
|
|
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
|
|
|
|
// Apply QK norm (RMSNorm with learned scale)
|
|
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
|
|
imgK = applyQKNorm(imgK, block.Attn.NormK)
|
|
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
|
|
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
|
|
|
|
// Concatenate for joint attention: text first, then image
|
|
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
|
|
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
|
|
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
|
|
|
|
// Apply RoPE
|
|
q = ApplyRoPE4D(q, cos, sin)
|
|
k = ApplyRoPE4D(k, cos, sin)
|
|
|
|
// Transpose for SDPA: [B, nheads, L, headDim]
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
// Scaled dot-product attention
|
|
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
|
|
|
// Transpose back: [B, L, nheads, headDim]
|
|
out = mlx.Transpose(out, 0, 2, 1, 3)
|
|
|
|
// Split back into txt and img
|
|
totalLen := txtLen + imgLen
|
|
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
|
|
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
|
|
|
|
// Reshape and project
|
|
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
|
|
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
|
|
txtOut = block.Attn.ToAddOut.Forward(txtOut)
|
|
imgOut = block.Attn.ToOut0.Forward(imgOut)
|
|
|
|
// Apply gates and residual
|
|
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
|
|
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
|
|
|
|
// === MLP branch ===
|
|
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
|
|
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
|
|
|
|
imgFFOut := block.FF.Forward(imgNorm)
|
|
txtFFOut := block.FFContext.Forward(txtNorm)
|
|
|
|
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
|
|
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
|
|
|
|
return imgHidden, txtHidden
|
|
}
|
|
|
|
// SingleTransformerBlockAttn implements attention for single-stream blocks
|
|
// Weight names: single_transformer_blocks.N.attn.*
|
|
type SingleTransformerBlockAttn struct {
|
|
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
|
|
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
|
|
NormQ *mlx.Array `weight:"norm_q.weight"`
|
|
NormK *mlx.Array `weight:"norm_k.weight"`
|
|
}
|
|
|
|
// SingleTransformerBlock implements a single-stream transformer block
|
|
// Weight names: single_transformer_blocks.N.*
|
|
type SingleTransformerBlock struct {
|
|
Attn *SingleTransformerBlockAttn `weight:"attn"`
|
|
|
|
// Config
|
|
NHeads int32
|
|
HeadDim int32
|
|
InnerDim int32
|
|
MLPHidDim int32
|
|
Scale float32
|
|
}
|
|
|
|
// Forward applies the single-stream block
|
|
// x: [B, L, dim] concatenated text+image
|
|
// mod: modulation [B, 3*dim]
|
|
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
dim := shape[2]
|
|
|
|
// Parse modulation: (shift, scale, gate)
|
|
shift, scale, gate := parseModulation3(mod, dim, 0)
|
|
|
|
// Modulate input
|
|
h := modulateLayerNorm(x, shift, scale)
|
|
|
|
// Fused projection: QKV + MLP gate/up
|
|
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
|
|
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
|
|
|
|
// Split: first 3*dim is QKV, rest is MLP
|
|
qkvDim := 3 * block.InnerDim
|
|
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
|
|
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
|
|
|
|
// Split QKV
|
|
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
|
|
|
|
// Reshape for attention
|
|
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
|
|
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
|
|
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
|
|
|
|
// QK norm
|
|
q = applyQKNorm(q, block.Attn.NormQ)
|
|
k = applyQKNorm(k, block.Attn.NormK)
|
|
|
|
// Apply RoPE
|
|
q = ApplyRoPE4D(q, cos, sin)
|
|
k = ApplyRoPE4D(k, cos, sin)
|
|
|
|
// Transpose for SDPA
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
// SDPA
|
|
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
|
|
|
// Transpose back and reshape
|
|
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
|
|
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
|
|
|
|
// MLP: SwiGLU
|
|
mlpShape := mlpIn.Shape()
|
|
half := mlpShape[2] / 2
|
|
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
|
|
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
|
|
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
|
|
|
|
// Concatenate attention and MLP for fused output
|
|
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
|
|
|
|
// Output projection
|
|
out := block.Attn.ToOut.Forward(combined)
|
|
|
|
// Apply gate and residual
|
|
return mlx.Add(x, mlx.Mul(gate, out))
|
|
}
|
|
|
|
// NormOut implements the output normalization with modulation
|
|
// Weight names: norm_out.linear.weight
|
|
type NormOut struct {
|
|
Linear nn.LinearLayer `weight:"linear"`
|
|
}
|
|
|
|
// Forward computes final modulated output
|
|
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
dim := shape[2]
|
|
|
|
// Modulation: temb -> silu -> linear -> [shift, scale]
|
|
mod := mlx.SiLU(temb)
|
|
mod = n.Linear.Forward(mod)
|
|
|
|
// Split into scale and shift (diffusers order: scale first, shift second)
|
|
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
|
|
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
|
|
shift = mlx.ExpandDims(shift, 1)
|
|
scale = mlx.ExpandDims(scale, 1)
|
|
|
|
// Modulate with RMSNorm
|
|
return modulateLayerNorm(x, shift, scale)
|
|
}
|
|
|
|
// Flux2Transformer2DModel is the main Flux2 transformer
|
|
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
|
|
type Flux2Transformer2DModel struct {
|
|
// Timestep embedding
|
|
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
|
|
|
|
// Shared modulation
|
|
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
|
|
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
|
|
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
|
|
|
|
// Embedders
|
|
XEmbedder nn.LinearLayer `weight:"x_embedder"`
|
|
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
|
|
|
|
// Transformer blocks
|
|
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
|
|
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
|
|
|
|
// Output
|
|
NormOut *NormOut `weight:"norm_out"`
|
|
ProjOut nn.LinearLayer `weight:"proj_out"`
|
|
|
|
*TransformerConfig
|
|
}
|
|
|
|
// Load loads the Flux2 transformer from ollama blob storage.
|
|
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
|
|
fmt.Print(" Loading transformer... ")
|
|
|
|
// Load config from blob
|
|
var cfg TransformerConfig
|
|
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
|
return fmt.Errorf("config: %w", err)
|
|
}
|
|
m.TransformerConfig = &cfg
|
|
|
|
// Initialize slices
|
|
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
|
|
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
|
|
|
|
// Initialize TimeGuidanceEmbed with embed dim
|
|
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
|
|
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
|
|
}
|
|
|
|
// Load weights from tensor blobs
|
|
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
if err := weights.Load(0); err != nil {
|
|
return fmt.Errorf("load weights: %w", err)
|
|
}
|
|
defer weights.ReleaseAll()
|
|
|
|
return m.loadWeights(weights)
|
|
}
|
|
|
|
// loadWeights loads weights from any WeightSource into the model
|
|
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
|
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
return fmt.Errorf("load module: %w", err)
|
|
}
|
|
m.initComputedFields()
|
|
fmt.Println("✓")
|
|
return nil
|
|
}
|
|
|
|
// initComputedFields initializes computed fields after loading weights
|
|
func (m *Flux2Transformer2DModel) initComputedFields() {
|
|
cfg := m.TransformerConfig
|
|
innerDim := cfg.InnerDim()
|
|
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
|
|
|
|
// Initialize transformer blocks
|
|
for _, block := range m.TransformerBlocks {
|
|
block.NHeads = cfg.NumAttentionHeads
|
|
block.HeadDim = cfg.AttentionHeadDim
|
|
block.Scale = scale
|
|
}
|
|
|
|
// Initialize single transformer blocks
|
|
for _, block := range m.SingleTransformerBlocks {
|
|
block.NHeads = cfg.NumAttentionHeads
|
|
block.HeadDim = cfg.AttentionHeadDim
|
|
block.InnerDim = innerDim
|
|
block.MLPHidDim = cfg.MLPHiddenDim()
|
|
block.Scale = scale
|
|
}
|
|
}
|
|
|
|
// Forward runs the Flux2 transformer
|
|
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
|
|
patchShape := patches.Shape()
|
|
B := patchShape[0]
|
|
imgLen := patchShape[1]
|
|
txtLen := txtEmbeds.Shape()[1]
|
|
|
|
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
|
|
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
|
|
|
|
// Compute timestep embedding
|
|
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
|
|
|
|
// Embed patches and text
|
|
imgHidden := m.XEmbedder.Forward(patches)
|
|
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
|
|
|
|
// Compute shared modulation
|
|
imgMod := m.DoubleStreamModulationImg.Forward(temb)
|
|
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
|
|
singleMod := m.SingleStreamModulation.Forward(temb)
|
|
|
|
// Double (dual-stream) blocks
|
|
for _, block := range m.TransformerBlocks {
|
|
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
|
|
}
|
|
|
|
// Concatenate for single-stream: text first, then image
|
|
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
|
|
|
|
// Single-stream blocks
|
|
for _, block := range m.SingleTransformerBlocks {
|
|
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
|
|
}
|
|
|
|
// Extract image portion
|
|
totalLen := txtLen + imgLen
|
|
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
|
|
|
|
// Final norm and projection
|
|
imgOut = m.NormOut.Forward(imgOut, temb)
|
|
return m.ProjOut.Forward(imgOut)
|
|
}
|
|
|
|
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
|
|
// See applyQKNorm function below
|
|
|
|
// compiledSwiGLU fuses: silu(gate) * up
|
|
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
|
|
var compiledSwiGLU *mlx.CompiledFunc
|
|
|
|
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
|
if compiledSwiGLU == nil {
|
|
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
|
gate, up := inputs[0], inputs[1]
|
|
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
|
|
}, true)
|
|
}
|
|
return compiledSwiGLU
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
|
|
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
|
B := mod.Shape()[0]
|
|
start := offset * dim
|
|
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
|
|
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
|
|
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
|
|
|
|
// Expand for broadcasting [B, dim] -> [B, 1, dim]
|
|
shift = mlx.ExpandDims(shift, 1)
|
|
scale = mlx.ExpandDims(scale, 1)
|
|
gate = mlx.ExpandDims(gate, 1)
|
|
|
|
return shift, scale, gate
|
|
}
|
|
|
|
// modulateLayerNorm applies LayerNorm then shift/scale modulation
|
|
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
|
|
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
|
|
// Fast LayerNorm without learnable params
|
|
x = mlx.LayerNorm(x, 1e-6)
|
|
|
|
// Modulate: x * (1 + scale) + shift
|
|
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
|
return mlx.Add(x, shift)
|
|
}
|
|
|
|
// splitQKV splits a fused QKV tensor into Q, K, V
|
|
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
|
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
|
|
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
|
|
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
|
|
return q, k, v
|
|
}
|
|
|
|
// applyQKNorm applies RMSNorm with learned scale (no bias)
|
|
// Uses the optimized mlx_fast_rms_norm
|
|
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
|
|
return mlx.RMSNorm(x, scale, 1e-6)
|
|
}
|