mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 16:54:13 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
214 lines
6.1 KiB
Go
214 lines
6.1 KiB
Go
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
|
|
package vae
|
|
|
|
import (
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// TilingConfig holds configuration for tiled VAE decoding.
|
|
// This is a general technique to reduce memory usage when decoding large latents.
|
|
type TilingConfig struct {
|
|
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
|
|
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
|
|
}
|
|
|
|
// DefaultTilingConfig returns reasonable defaults matching diffusers.
|
|
// tile_latent_min_size=64, tile_overlap_factor=0.25
|
|
func DefaultTilingConfig() *TilingConfig {
|
|
return &TilingConfig{
|
|
TileSize: 64, // 64 latent pixels
|
|
Overlap: 16, // 25% overlap
|
|
}
|
|
}
|
|
|
|
// decodedTile holds a decoded tile's pixel data and dimensions
|
|
type decodedTile struct {
|
|
data []float32
|
|
height int32
|
|
width int32
|
|
}
|
|
|
|
// DecodeTiled decodes latents using tiled processing with overlap blending.
|
|
// This reduces memory usage for large images by processing in overlapping tiles.
|
|
//
|
|
// Parameters:
|
|
// - latents: [1, H, W, C] latent tensor in NHWC format
|
|
// - cfg: tiling configuration (tile size and overlap)
|
|
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
|
|
//
|
|
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
|
|
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
|
|
shape := latents.Shape()
|
|
H := shape[1] // latent height
|
|
W := shape[2] // latent width
|
|
C := shape[3]
|
|
|
|
tileLatentSize := cfg.TileSize
|
|
overlapLatent := cfg.Overlap
|
|
|
|
// If image is small enough, just decode normally
|
|
if H <= tileLatentSize && W <= tileLatentSize {
|
|
decoded := decoder(latents)
|
|
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
|
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
|
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
|
|
return decoded
|
|
}
|
|
|
|
// Calculate tiling parameters (matching diffusers)
|
|
overlapSize := tileLatentSize - overlapLatent // stride in latent space
|
|
|
|
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
|
|
// For other scale factors, this could be made configurable
|
|
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
|
|
blendExtent := overlapLatent * 8 // blend region in pixels
|
|
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
|
|
|
|
// Phase 1: Decode all tiles and store in 2D grid
|
|
var rows [][]decodedTile
|
|
|
|
for i := int32(0); i < H; i += overlapSize {
|
|
var row []decodedTile
|
|
for j := int32(0); j < W; j += overlapSize {
|
|
// Extract tile (may be smaller at edges)
|
|
i2 := min(i+tileLatentSize, H)
|
|
j2 := min(j+tileLatentSize, W)
|
|
|
|
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
|
|
decoded := decoder(tile)
|
|
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
|
mlx.Eval(decoded)
|
|
|
|
decodedShape := decoded.Shape()
|
|
tileH := decodedShape[1]
|
|
tileW := decodedShape[2]
|
|
tileData := decoded.Data()
|
|
decoded.Free()
|
|
|
|
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
|
|
}
|
|
rows = append(rows, row)
|
|
}
|
|
|
|
// Phase 2: Blend adjacent tiles (modifies in place)
|
|
for i := range rows {
|
|
for j := range rows[i] {
|
|
tile := &rows[i][j]
|
|
|
|
// Blend with tile above
|
|
if i > 0 {
|
|
above := &rows[i-1][j]
|
|
blendV(above, tile, blendExtent)
|
|
}
|
|
|
|
// Blend with tile to the left
|
|
if j > 0 {
|
|
left := &rows[i][j-1]
|
|
blendH(left, tile, blendExtent)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Phase 3: Calculate crop dimensions for each tile
|
|
colWidths := make([]int32, len(rows[0]))
|
|
for j := range rows[0] {
|
|
keepW := rowLimit
|
|
if int32(j+1)*overlapSize >= W {
|
|
keepW = rows[0][j].width
|
|
}
|
|
colWidths[j] = keepW
|
|
}
|
|
|
|
rowHeights := make([]int32, len(rows))
|
|
for i := range rows {
|
|
keepH := rowLimit
|
|
if int32(i+1)*overlapSize >= H {
|
|
keepH = rows[i][0].height
|
|
}
|
|
rowHeights[i] = keepH
|
|
}
|
|
|
|
// Calculate total dimensions
|
|
var totalW, totalH int32
|
|
for _, w := range colWidths {
|
|
totalW += w
|
|
}
|
|
for _, h := range rowHeights {
|
|
totalH += h
|
|
}
|
|
|
|
// Phase 4: Assemble final image by interleaving tiles row-by-row
|
|
finalData := make([]float32, totalH*totalW*3)
|
|
|
|
dstY := int32(0)
|
|
for i, row := range rows {
|
|
keepH := rowHeights[i]
|
|
|
|
for y := int32(0); y < keepH; y++ {
|
|
dstX := int32(0)
|
|
for j, tile := range row {
|
|
keepW := colWidths[j]
|
|
|
|
for x := int32(0); x < keepW; x++ {
|
|
for c := int32(0); c < 3; c++ {
|
|
srcIdx := (y*tile.width + x) * 3 + c
|
|
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
|
|
finalData[dstIdx] = tile.data[srcIdx]
|
|
}
|
|
}
|
|
dstX += keepW
|
|
}
|
|
}
|
|
dstY += keepH
|
|
}
|
|
|
|
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
|
|
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
|
|
result = mlx.Transpose(result, 0, 3, 1, 2)
|
|
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
|
|
|
|
return result
|
|
}
|
|
|
|
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
|
|
// Matches diffusers blend_v formula
|
|
func blendV(above, current *decodedTile, blendExtent int32) {
|
|
blend := min(blendExtent, min(above.height, current.height))
|
|
if blend <= 0 {
|
|
return
|
|
}
|
|
|
|
w := min(above.width, current.width)
|
|
for y := int32(0); y < blend; y++ {
|
|
alpha := float32(y) / float32(blend)
|
|
for x := int32(0); x < w; x++ {
|
|
for c := int32(0); c < 3; c++ {
|
|
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
|
|
currIdx := (y * current.width + x) * 3 + c
|
|
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
|
|
// Matches diffusers blend_h formula
|
|
func blendH(left, current *decodedTile, blendExtent int32) {
|
|
blend := min(blendExtent, min(left.width, current.width))
|
|
if blend <= 0 {
|
|
return
|
|
}
|
|
|
|
h := min(left.height, current.height)
|
|
for y := int32(0); y < h; y++ {
|
|
for x := int32(0); x < blend; x++ {
|
|
alpha := float32(x) / float32(blend)
|
|
for c := int32(0); c < 3; c++ {
|
|
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
|
|
currIdx := (y * current.width + x) * 3 + c
|
|
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
|
|
}
|
|
}
|
|
}
|
|
}
|