mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
223 lines
7.8 KiB
Go
223 lines
7.8 KiB
Go
package flux2
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// RoPEConfig holds 4D RoPE configuration for Flux2
|
|
type RoPEConfig struct {
|
|
Theta int32 // 2000 for Klein
|
|
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
|
|
}
|
|
|
|
// RoPECache holds precomputed RoPE cos/sin values
|
|
type RoPECache struct {
|
|
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
|
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
|
TextLen int32 // Length of text sequence
|
|
ImageLen int32 // Length of image sequence
|
|
}
|
|
|
|
// PrepareTextIDs creates position IDs for text tokens.
|
|
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
|
|
// Returns: [seqLen, 4]
|
|
func PrepareTextIDs(seqLen int32) *mlx.Array {
|
|
ids := make([]float32, seqLen*4)
|
|
for i := int32(0); i < seqLen; i++ {
|
|
idx := i * 4
|
|
ids[idx+0] = 0 // T = 0
|
|
ids[idx+1] = 0 // H = 0
|
|
ids[idx+2] = 0 // W = 0
|
|
ids[idx+3] = float32(i) // L = sequence position
|
|
}
|
|
return mlx.NewArray(ids, []int32{seqLen, 4})
|
|
}
|
|
|
|
// PrepareLatentIDs creates position IDs for image latent tokens.
|
|
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
|
|
// The latents are in row-major order (H then W).
|
|
// Returns: [height*width, 4]
|
|
func PrepareLatentIDs(height, width int32) *mlx.Array {
|
|
seqLen := height * width
|
|
ids := make([]float32, seqLen*4)
|
|
idx := 0
|
|
for h := int32(0); h < height; h++ {
|
|
for w := int32(0); w < width; w++ {
|
|
ids[idx*4+0] = 0 // T = 0
|
|
ids[idx*4+1] = float32(h) // H = row
|
|
ids[idx*4+2] = float32(w) // W = column
|
|
ids[idx*4+3] = 0 // L = 0
|
|
idx++
|
|
}
|
|
}
|
|
return mlx.NewArray(ids, []int32{seqLen, 4})
|
|
}
|
|
|
|
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
|
|
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
|
|
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
|
|
// Returns: [total_tokens, 4]
|
|
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
|
|
// Calculate total tokens
|
|
totalTokens := int32(0)
|
|
for i := range imageHeights {
|
|
totalTokens += imageHeights[i] * imageWidths[i]
|
|
}
|
|
|
|
ids := make([]float32, totalTokens*4)
|
|
idx := int32(0)
|
|
for imgIdx, h := range imageHeights {
|
|
w := imageWidths[imgIdx]
|
|
tValue := float32(scale * int32(imgIdx+1))
|
|
for hi := int32(0); hi < h; hi++ {
|
|
for wi := int32(0); wi < w; wi++ {
|
|
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
|
|
ids[idx*4+1] = float32(hi) // H = row
|
|
ids[idx*4+2] = float32(wi) // W = column
|
|
ids[idx*4+3] = 0 // L = 0
|
|
idx++
|
|
}
|
|
}
|
|
}
|
|
return mlx.NewArray(ids, []int32{totalTokens, 4})
|
|
}
|
|
|
|
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
|
|
// ids: [L, 4] with (T, H, W, L) coordinates
|
|
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
|
|
// theta: base frequency (2000 for Klein)
|
|
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
|
|
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
|
|
shape := ids.Shape()
|
|
seqLen := shape[0]
|
|
|
|
// Compute total head dim (sum of all axes dims)
|
|
headDim := int32(0)
|
|
for _, d := range axesDims {
|
|
headDim += d
|
|
}
|
|
|
|
// Extract each coordinate dimension
|
|
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
|
|
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
|
|
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
|
|
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
|
|
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
|
|
|
|
// Compute frequencies for each axis
|
|
logTheta := float32(math.Log(float64(theta)))
|
|
cosArrs := make([]*mlx.Array, 4)
|
|
sinArrs := make([]*mlx.Array, 4)
|
|
positions := []*mlx.Array{posT, posH, posW, posL}
|
|
|
|
for i, axisDim := range axesDims {
|
|
half := axisDim / 2
|
|
|
|
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
|
|
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
|
|
freqs := make([]float32, half)
|
|
for j := int32(0); j < half; j++ {
|
|
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
|
|
}
|
|
freqArr := mlx.NewArray(freqs, []int32{1, half})
|
|
|
|
// Compute pos * freq -> [L, half]
|
|
posExpanded := positions[i] // [L, 1]
|
|
args := mlx.Mul(posExpanded, freqArr) // [L, half]
|
|
|
|
// Compute cos and sin for this axis
|
|
cosAxis := mlx.Cos(args) // [L, half]
|
|
sinAxis := mlx.Sin(args) // [L, half]
|
|
|
|
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
|
|
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
|
|
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
|
|
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
|
|
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
|
|
|
|
sinAxis = mlx.ExpandDims(sinAxis, 2)
|
|
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
|
|
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
|
|
|
|
cosArrs[i] = cosAxis
|
|
sinArrs[i] = sinAxis
|
|
}
|
|
|
|
// Concatenate all axes: [L, headDim]
|
|
cos := mlx.Concatenate(cosArrs, 1)
|
|
sin := mlx.Concatenate(sinArrs, 1)
|
|
|
|
// Reshape to [1, L, 1, headDim] for broadcasting with attention
|
|
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
|
|
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
|
|
|
|
return cos, sin
|
|
}
|
|
|
|
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
|
|
// x: [B, L, nheads, head_dim]
|
|
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
|
|
// Returns: x with RoPE applied
|
|
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
|
|
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
nheads := shape[2]
|
|
headDim := shape[3]
|
|
half := headDim / 2
|
|
|
|
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
|
|
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
|
|
|
|
// Extract real (index 0) and imag (index 1) parts
|
|
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
|
|
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
|
|
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
|
|
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
|
|
|
|
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
|
|
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
|
|
negXImag := mlx.Neg(xImag)
|
|
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
|
|
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
|
|
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
|
|
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
|
|
|
|
// out = x * cos + x_rotated * sin
|
|
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
|
|
}
|
|
|
|
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
|
|
// textLen: number of text tokens
|
|
// noiseH, noiseW: dimensions of the noise latent in patch tokens
|
|
// axesDims: [32, 32, 32, 32]
|
|
// theta: 2000
|
|
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
|
|
// scale: time coordinate offset between reference images (e.g., 10)
|
|
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
|
|
textIDs := PrepareTextIDs(textLen)
|
|
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
|
|
|
|
var allIDs *mlx.Array
|
|
imageLen := noiseH * noiseW
|
|
|
|
if len(refHeights) > 0 {
|
|
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
|
|
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
|
|
for i := range refHeights {
|
|
imageLen += refHeights[i] * refWidths[i]
|
|
}
|
|
} else {
|
|
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
|
|
}
|
|
|
|
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
|
|
cos = mlx.ToBFloat16(cos)
|
|
sin = mlx.ToBFloat16(sin)
|
|
|
|
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
|
|
}
|