Files
ollama-ollama/x/imagegen/models/flux2/rope.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

223 lines
7.8 KiB
Go

package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// RoPEConfig holds 4D RoPE configuration for Flux2
type RoPEConfig struct {
Theta int32 // 2000 for Klein
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
}
// RoPECache holds precomputed RoPE cos/sin values
type RoPECache struct {
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
TextLen int32 // Length of text sequence
ImageLen int32 // Length of image sequence
}
// PrepareTextIDs creates position IDs for text tokens.
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
// Returns: [seqLen, 4]
func PrepareTextIDs(seqLen int32) *mlx.Array {
ids := make([]float32, seqLen*4)
for i := int32(0); i < seqLen; i++ {
idx := i * 4
ids[idx+0] = 0 // T = 0
ids[idx+1] = 0 // H = 0
ids[idx+2] = 0 // W = 0
ids[idx+3] = float32(i) // L = sequence position
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareLatentIDs creates position IDs for image latent tokens.
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
// The latents are in row-major order (H then W).
// Returns: [height*width, 4]
func PrepareLatentIDs(height, width int32) *mlx.Array {
seqLen := height * width
ids := make([]float32, seqLen*4)
idx := 0
for h := int32(0); h < height; h++ {
for w := int32(0); w < width; w++ {
ids[idx*4+0] = 0 // T = 0
ids[idx*4+1] = float32(h) // H = row
ids[idx*4+2] = float32(w) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
// Returns: [total_tokens, 4]
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
// Calculate total tokens
totalTokens := int32(0)
for i := range imageHeights {
totalTokens += imageHeights[i] * imageWidths[i]
}
ids := make([]float32, totalTokens*4)
idx := int32(0)
for imgIdx, h := range imageHeights {
w := imageWidths[imgIdx]
tValue := float32(scale * int32(imgIdx+1))
for hi := int32(0); hi < h; hi++ {
for wi := int32(0); wi < w; wi++ {
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
ids[idx*4+1] = float32(hi) // H = row
ids[idx*4+2] = float32(wi) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
}
return mlx.NewArray(ids, []int32{totalTokens, 4})
}
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
// ids: [L, 4] with (T, H, W, L) coordinates
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
// theta: base frequency (2000 for Klein)
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
shape := ids.Shape()
seqLen := shape[0]
// Compute total head dim (sum of all axes dims)
headDim := int32(0)
for _, d := range axesDims {
headDim += d
}
// Extract each coordinate dimension
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
// Compute frequencies for each axis
logTheta := float32(math.Log(float64(theta)))
cosArrs := make([]*mlx.Array, 4)
sinArrs := make([]*mlx.Array, 4)
positions := []*mlx.Array{posT, posH, posW, posL}
for i, axisDim := range axesDims {
half := axisDim / 2
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
freqs := make([]float32, half)
for j := int32(0); j < half; j++ {
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
}
freqArr := mlx.NewArray(freqs, []int32{1, half})
// Compute pos * freq -> [L, half]
posExpanded := positions[i] // [L, 1]
args := mlx.Mul(posExpanded, freqArr) // [L, half]
// Compute cos and sin for this axis
cosAxis := mlx.Cos(args) // [L, half]
sinAxis := mlx.Sin(args) // [L, half]
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
sinAxis = mlx.ExpandDims(sinAxis, 2)
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
cosArrs[i] = cosAxis
sinArrs[i] = sinAxis
}
// Concatenate all axes: [L, headDim]
cos := mlx.Concatenate(cosArrs, 1)
sin := mlx.Concatenate(sinArrs, 1)
// Reshape to [1, L, 1, headDim] for broadcasting with attention
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
return cos, sin
}
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
// x: [B, L, nheads, head_dim]
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
// Returns: x with RoPE applied
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
// Extract real (index 0) and imag (index 1) parts
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
negXImag := mlx.Neg(xImag)
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
}
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
// textLen: number of text tokens
// noiseH, noiseW: dimensions of the noise latent in patch tokens
// axesDims: [32, 32, 32, 32]
// theta: 2000
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
// scale: time coordinate offset between reference images (e.g., 10)
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
textIDs := PrepareTextIDs(textLen)
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
var allIDs *mlx.Array
imageLen := noiseH * noiseW
if len(refHeights) > 0 {
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
for i := range refHeights {
imageLen += refHeights[i] * refWidths[i]
}
} else {
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
}
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
cos = mlx.ToBFloat16(cos)
sin = mlx.ToBFloat16(sin)
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
}