mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 00:36:11 +02:00
805 lines
23 KiB
Go
805 lines
23 KiB
Go
//go:build mlx
|
|
|
|
package flux2
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
"github.com/ollama/ollama/x/imagegen/vae"
|
|
)
|
|
|
|
// VAEConfig holds AutoencoderKLFlux2 configuration
|
|
type VAEConfig struct {
|
|
ActFn string `json:"act_fn"` // "silu"
|
|
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
|
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
|
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
|
ForceUpcast bool `json:"force_upcast"` // true
|
|
InChannels int32 `json:"in_channels"` // 3
|
|
LatentChannels int32 `json:"latent_channels"` // 32
|
|
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
|
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
|
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
|
OutChannels int32 `json:"out_channels"` // 3
|
|
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
|
SampleSize int32 `json:"sample_size"` // 1024
|
|
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
|
UseQuantConv bool `json:"use_quant_conv"` // true
|
|
}
|
|
|
|
// BatchNorm2D implements 2D batch normalization with running statistics
|
|
type BatchNorm2D struct {
|
|
RunningMean *mlx.Array // [C]
|
|
RunningVar *mlx.Array // [C]
|
|
Weight *mlx.Array // [C] gamma
|
|
Bias *mlx.Array // [C] beta
|
|
Eps float32
|
|
Momentum float32
|
|
}
|
|
|
|
// Forward applies batch normalization (inference mode - uses running stats)
|
|
// Input and output are in NHWC format [B, H, W, C]
|
|
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
C := shape[3]
|
|
|
|
// Reshape stats for broadcasting [1, 1, 1, C]
|
|
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
|
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
|
|
|
// Normalize: (x - mean) / sqrt(var + eps)
|
|
xNorm := mlx.Sub(x, mean)
|
|
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
|
|
|
// Scale and shift (only if affine=True)
|
|
if bn.Weight != nil {
|
|
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
|
xNorm = mlx.Mul(xNorm, weight)
|
|
}
|
|
if bn.Bias != nil {
|
|
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
|
xNorm = mlx.Add(xNorm, bias)
|
|
}
|
|
|
|
return xNorm
|
|
}
|
|
|
|
// Denormalize inverts the batch normalization
|
|
// Used when decoding latents
|
|
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
C := shape[3]
|
|
|
|
// Reshape stats for broadcasting [1, 1, 1, C]
|
|
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
|
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
|
|
|
// Inverse: first undo affine, then undo normalization
|
|
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
|
|
if bn.Bias != nil {
|
|
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
|
x = mlx.Sub(x, bias)
|
|
}
|
|
if bn.Weight != nil {
|
|
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
|
x = mlx.Div(x, weight)
|
|
}
|
|
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
|
x = mlx.Add(x, mean)
|
|
|
|
return x
|
|
}
|
|
|
|
// GroupNormLayer implements group normalization
|
|
// Reused from zimage package pattern
|
|
type GroupNormLayer struct {
|
|
Weight *mlx.Array `weight:"weight"`
|
|
Bias *mlx.Array `weight:"bias"`
|
|
NumGroups int32
|
|
Eps float32
|
|
}
|
|
|
|
// Forward applies group normalization
|
|
// Input and output are in NHWC format [B, H, W, C]
|
|
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
// Reshape to [B, H, W, groups, C/groups]
|
|
groupSize := C / gn.NumGroups
|
|
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
|
|
|
// Compute mean and variance per group
|
|
mean := mlx.Mean(x, 1, true)
|
|
mean = mlx.Mean(mean, 2, true)
|
|
mean = mlx.Mean(mean, 4, true)
|
|
|
|
xCentered := mlx.Sub(x, mean)
|
|
|
|
sq := mlx.Square(xCentered)
|
|
variance := mlx.Mean(sq, 1, true)
|
|
variance = mlx.Mean(variance, 2, true)
|
|
variance = mlx.Mean(variance, 4, true)
|
|
|
|
// Normalize
|
|
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
|
|
|
// Reshape back to [B, H, W, C]
|
|
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
|
|
|
// Scale and shift
|
|
if gn.Weight != nil {
|
|
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
|
xNorm = mlx.Mul(xNorm, weight)
|
|
}
|
|
if gn.Bias != nil {
|
|
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
|
xNorm = mlx.Add(xNorm, bias)
|
|
}
|
|
|
|
return xNorm
|
|
}
|
|
|
|
// Conv2D represents a 2D convolution layer (reused pattern)
|
|
type Conv2D struct {
|
|
Weight *mlx.Array `weight:"weight"`
|
|
Bias *mlx.Array `weight:"bias,optional"`
|
|
Stride int32
|
|
Padding int32
|
|
}
|
|
|
|
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
|
|
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
|
|
if field == "Weight" {
|
|
return mlx.Transpose(arr, 0, 2, 3, 1)
|
|
}
|
|
return arr
|
|
}
|
|
|
|
// Forward applies convolution (NHWC format)
|
|
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
|
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
|
|
|
if conv.Bias != nil {
|
|
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
|
out = mlx.Add(out, bias)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// ResnetBlock2D implements a ResNet block for VAE
|
|
type ResnetBlock2D struct {
|
|
Norm1 *GroupNormLayer `weight:"norm1"`
|
|
Conv1 *Conv2D `weight:"conv1"`
|
|
Norm2 *GroupNormLayer `weight:"norm2"`
|
|
Conv2 *Conv2D `weight:"conv2"`
|
|
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
|
|
}
|
|
|
|
// Forward applies the ResNet block
|
|
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
h := rb.Norm1.Forward(x)
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv1.Forward(h)
|
|
|
|
h = rb.Norm2.Forward(h)
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv2.Forward(h)
|
|
|
|
if rb.ConvShortcut != nil {
|
|
x = rb.ConvShortcut.Forward(x)
|
|
}
|
|
|
|
return mlx.Add(h, x)
|
|
}
|
|
|
|
// VAEAttentionBlock implements self-attention for VAE
|
|
type VAEAttentionBlock struct {
|
|
GroupNorm *GroupNormLayer `weight:"group_norm"`
|
|
ToQ nn.LinearLayer `weight:"to_q"`
|
|
ToK nn.LinearLayer `weight:"to_k"`
|
|
ToV nn.LinearLayer `weight:"to_v"`
|
|
ToOut nn.LinearLayer `weight:"to_out.0"`
|
|
}
|
|
|
|
// Forward applies attention (NHWC format)
|
|
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
residual := x
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
h := ab.GroupNorm.Forward(x)
|
|
h = mlx.Reshape(h, B, H*W, C)
|
|
|
|
q := ab.ToQ.Forward(h)
|
|
k := ab.ToK.Forward(h)
|
|
v := ab.ToV.Forward(h)
|
|
|
|
q = mlx.ExpandDims(q, 1)
|
|
k = mlx.ExpandDims(k, 1)
|
|
v = mlx.ExpandDims(v, 1)
|
|
|
|
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
out = mlx.Squeeze(out, 1)
|
|
|
|
out = ab.ToOut.Forward(out)
|
|
out = mlx.Reshape(out, B, H, W, C)
|
|
out = mlx.Add(out, residual)
|
|
|
|
return out
|
|
}
|
|
|
|
// UpDecoderBlock2D implements an upsampling decoder block
|
|
type UpDecoderBlock2D struct {
|
|
ResnetBlocks []*ResnetBlock2D
|
|
Upsample *Conv2D
|
|
}
|
|
|
|
// Forward applies the up decoder block
|
|
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, resnet := range ub.ResnetBlocks {
|
|
x = resnet.Forward(x)
|
|
}
|
|
|
|
if ub.Upsample != nil {
|
|
x = upsample2x(x)
|
|
x = ub.Upsample.Forward(x)
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// upsample2x performs 2x nearest neighbor upsampling
|
|
func upsample2x(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
H := shape[1]
|
|
W := shape[2]
|
|
|
|
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
|
hIdx = mlx.Reshape(hIdx, H, 1)
|
|
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
|
hIdx = mlx.Reshape(hIdx, H*2)
|
|
|
|
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
|
wIdx = mlx.Reshape(wIdx, W, 1)
|
|
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
|
wIdx = mlx.Reshape(wIdx, W*2)
|
|
|
|
x = mlx.Take(x, hIdx, 1)
|
|
x = mlx.Take(x, wIdx, 2)
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEMidBlock is the middle block with attention
|
|
type VAEMidBlock struct {
|
|
Resnet1 *ResnetBlock2D
|
|
Attention *VAEAttentionBlock
|
|
Resnet2 *ResnetBlock2D
|
|
}
|
|
|
|
// Forward applies the mid block
|
|
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
x = mb.Resnet1.Forward(x)
|
|
x = mb.Attention.Forward(x)
|
|
x = mb.Resnet2.Forward(x)
|
|
return x
|
|
}
|
|
|
|
// DefaultTilingConfig returns reasonable defaults for tiled decoding
|
|
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
|
|
func DefaultTilingConfig() *vae.TilingConfig {
|
|
return vae.DefaultTilingConfig()
|
|
}
|
|
|
|
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
|
|
type AutoencoderKLFlux2 struct {
|
|
Config *VAEConfig
|
|
|
|
// Encoder components (for image editing)
|
|
EncoderConvIn *Conv2D
|
|
EncoderMid *VAEMidBlock
|
|
EncoderDown []*DownEncoderBlock2D
|
|
EncoderNormOut *GroupNormLayer
|
|
EncoderConvOut *Conv2D
|
|
|
|
// Decoder components
|
|
DecoderConvIn *Conv2D
|
|
DecoderMid *VAEMidBlock
|
|
DecoderUp []*UpDecoderBlock2D
|
|
DecoderNormOut *GroupNormLayer
|
|
DecoderConvOut *Conv2D
|
|
|
|
// Quant conv layers
|
|
QuantConv *Conv2D
|
|
PostQuantConv *Conv2D
|
|
|
|
// BatchNorm for latent normalization
|
|
LatentBN *BatchNorm2D
|
|
|
|
// Tiling configuration (nil = no tiling)
|
|
Tiling *vae.TilingConfig
|
|
}
|
|
|
|
// DownEncoderBlock2D implements a downsampling encoder block
|
|
type DownEncoderBlock2D struct {
|
|
ResnetBlocks []*ResnetBlock2D
|
|
Downsample *Conv2D
|
|
}
|
|
|
|
// Forward applies the down encoder block
|
|
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, resnet := range db.ResnetBlocks {
|
|
x = resnet.Forward(x)
|
|
}
|
|
|
|
if db.Downsample != nil {
|
|
// Pad then conv with stride 2
|
|
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
|
|
x = db.Downsample.Forward(x)
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// Load loads the Flux2 VAE from ollama blob storage.
|
|
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
|
|
fmt.Print(" Loading VAE... ")
|
|
|
|
// Load config from blob
|
|
var cfg VAEConfig
|
|
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
|
return fmt.Errorf("config: %w", err)
|
|
}
|
|
m.Config = &cfg
|
|
|
|
// Load weights from tensor blobs
|
|
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
if err := weights.Load(0); err != nil {
|
|
return fmt.Errorf("load weights: %w", err)
|
|
}
|
|
defer weights.ReleaseAll()
|
|
|
|
return m.loadWeights(weights, &cfg)
|
|
}
|
|
|
|
// loadWeights loads VAE weights from any WeightSource
|
|
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
|
var err error
|
|
|
|
// Load encoder components (for image conditioning)
|
|
if err := m.loadEncoderWeights(weights, cfg); err != nil {
|
|
return fmt.Errorf("encoder: %w", err)
|
|
}
|
|
|
|
// Load decoder conv_in
|
|
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
|
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
|
|
return fmt.Errorf("decoder.conv_in: %w", err)
|
|
}
|
|
|
|
// Load mid block
|
|
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
|
if err != nil {
|
|
return fmt.Errorf("decoder.mid_block: %w", err)
|
|
}
|
|
|
|
// Load up blocks
|
|
numBlocks := len(cfg.BlockOutChannels)
|
|
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
|
|
for i := 0; i < numBlocks; i++ {
|
|
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
hasUpsample := i < numBlocks-1
|
|
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
|
if err != nil {
|
|
return fmt.Errorf("%s: %w", prefix, err)
|
|
}
|
|
}
|
|
|
|
// Load decoder conv_norm_out and conv_out
|
|
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
|
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
|
|
return fmt.Errorf("decoder.conv_norm_out: %w", err)
|
|
}
|
|
|
|
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
|
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
|
|
return fmt.Errorf("decoder.conv_out: %w", err)
|
|
}
|
|
|
|
// Load post_quant_conv
|
|
if cfg.UsePostQuantConv {
|
|
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
|
|
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
|
|
return fmt.Errorf("post_quant_conv: %w", err)
|
|
}
|
|
}
|
|
|
|
// Load latent BatchNorm (affine=False, so no weight/bias)
|
|
bnMean, err := weights.GetTensor("bn.running_mean")
|
|
if err != nil {
|
|
return fmt.Errorf("bn.running_mean: %w", err)
|
|
}
|
|
bnVar, err := weights.GetTensor("bn.running_var")
|
|
if err != nil {
|
|
return fmt.Errorf("bn.running_var: %w", err)
|
|
}
|
|
m.LatentBN = &BatchNorm2D{
|
|
RunningMean: bnMean,
|
|
RunningVar: bnVar,
|
|
Weight: nil, // affine=False
|
|
Bias: nil, // affine=False
|
|
Eps: cfg.BatchNormEps,
|
|
Momentum: cfg.BatchNormMomentum,
|
|
}
|
|
|
|
fmt.Println("✓")
|
|
return nil
|
|
}
|
|
|
|
// loadVAEMidBlock loads the mid block.
|
|
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
|
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VAEMidBlock{
|
|
Resnet1: resnet1,
|
|
Attention: attention,
|
|
Resnet2: resnet2,
|
|
}, nil
|
|
}
|
|
|
|
// loadResnetBlock2D loads a ResNet block.
|
|
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
|
block := &ResnetBlock2D{
|
|
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
|
Conv1: &Conv2D{Stride: 1, Padding: 1},
|
|
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
|
Conv2: &Conv2D{Stride: 1, Padding: 1},
|
|
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
|
|
}
|
|
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
|
return nil, err
|
|
}
|
|
// If ConvShortcut wasn't loaded (no weights found), nil it out
|
|
if block.ConvShortcut.Weight == nil {
|
|
block.ConvShortcut = nil
|
|
}
|
|
return block, nil
|
|
}
|
|
|
|
// loadVAEAttentionBlock loads an attention block using LoadModule.
|
|
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
|
ab := &VAEAttentionBlock{
|
|
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
|
}
|
|
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
|
|
return nil, err
|
|
}
|
|
return ab, nil
|
|
}
|
|
|
|
// loadUpDecoderBlock2D loads an up decoder block.
|
|
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
|
resnets := make([]*ResnetBlock2D, numLayers)
|
|
for i := int32(0); i < numLayers; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resnets[i] = resnet
|
|
}
|
|
|
|
var upsample *Conv2D
|
|
if hasUpsample {
|
|
upsample = &Conv2D{Stride: 1, Padding: 1}
|
|
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &UpDecoderBlock2D{
|
|
ResnetBlocks: resnets,
|
|
Upsample: upsample,
|
|
}, nil
|
|
}
|
|
|
|
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
|
|
// This is the inverse of the VAE's patchify for feeding to transformer
|
|
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
|
|
shape := latents.Shape()
|
|
B := shape[0]
|
|
C := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
|
|
patchH := vae.Config.PatchSize[0]
|
|
patchW := vae.Config.PatchSize[1]
|
|
|
|
pH := H / patchH
|
|
pW := W / patchW
|
|
|
|
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
|
|
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
|
|
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
|
|
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
|
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
|
|
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
|
|
}
|
|
|
|
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
|
|
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
|
|
shape := patches.Shape()
|
|
B := shape[0]
|
|
|
|
patchH := vae.Config.PatchSize[0]
|
|
patchW := vae.Config.PatchSize[1]
|
|
|
|
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
|
|
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
|
|
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
|
|
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
|
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
|
|
H := pH * patchH
|
|
W := pW * patchW
|
|
return mlx.Reshape(x, B, C, H, W)
|
|
}
|
|
|
|
// denormalizePatchified applies inverse batch normalization to patchified latents.
|
|
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
|
// Output: [B, L, 128] denormalized
|
|
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
C := shape[2] // 128
|
|
|
|
// Reshape stats for broadcasting [1, 1, C]
|
|
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
|
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
|
|
|
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
|
|
if vae.LatentBN.Bias != nil {
|
|
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
|
x = mlx.Sub(x, bias)
|
|
}
|
|
if vae.LatentBN.Weight != nil {
|
|
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
|
x = mlx.Div(x, weight)
|
|
}
|
|
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
|
x = mlx.Add(x, mean)
|
|
|
|
return x
|
|
}
|
|
|
|
// Decode decodes latent patches to images.
|
|
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
|
// latents: [B, L, C*4] patchified latents from transformer
|
|
// pH, pW: patch grid dimensions
|
|
// Returns: [B, 3, H, W] image tensor
|
|
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
|
|
// Denormalize patchified latents
|
|
z := v.denormalizePatchified(latents)
|
|
|
|
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
|
|
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
|
|
|
|
// Convert NCHW -> NHWC for processing
|
|
z = mlx.Transpose(z, 0, 2, 3, 1)
|
|
|
|
// Use tiled decoding if enabled
|
|
if v.Tiling != nil {
|
|
mlx.Eval(z)
|
|
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
|
}
|
|
|
|
// Direct decode (no tiling)
|
|
h := v.decodeTile(z)
|
|
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
|
h = mlx.Transpose(h, 0, 3, 1, 2)
|
|
return h
|
|
}
|
|
|
|
// decodeTile decodes a single latent tile to pixels (internal helper)
|
|
// z: [B, H, W, C] latent tile in NHWC format
|
|
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
|
|
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
|
|
// Post-quant conv
|
|
if vae.PostQuantConv != nil {
|
|
z = vae.PostQuantConv.Forward(z)
|
|
}
|
|
|
|
// Decoder
|
|
h := vae.DecoderConvIn.Forward(z)
|
|
h = vae.DecoderMid.Forward(h)
|
|
|
|
for _, upBlock := range vae.DecoderUp {
|
|
h = upBlock.Forward(h)
|
|
}
|
|
|
|
h = vae.DecoderNormOut.Forward(h)
|
|
h = mlx.SiLU(h)
|
|
h = vae.DecoderConvOut.Forward(h)
|
|
|
|
// VAE outputs [-1, 1], convert to [0, 1]
|
|
h = mlx.MulScalar(h, 0.5)
|
|
h = mlx.AddScalar(h, 0.5)
|
|
|
|
return h
|
|
}
|
|
|
|
// loadEncoderWeights loads the encoder components for image conditioning
|
|
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
|
var err error
|
|
|
|
// Load encoder conv_in
|
|
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
|
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
|
|
return fmt.Errorf("encoder.conv_in: %w", err)
|
|
}
|
|
|
|
// Load encoder down blocks
|
|
numBlocks := len(cfg.BlockOutChannels)
|
|
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
|
|
for i := 0; i < numBlocks; i++ {
|
|
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
|
|
hasDownsample := i < numBlocks-1
|
|
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
|
|
if err != nil {
|
|
return fmt.Errorf("%s: %w", prefix, err)
|
|
}
|
|
}
|
|
|
|
// Load encoder mid block
|
|
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
|
|
if err != nil {
|
|
return fmt.Errorf("encoder.mid_block: %w", err)
|
|
}
|
|
|
|
// Load encoder conv_norm_out and conv_out
|
|
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
|
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
|
|
return fmt.Errorf("encoder.conv_norm_out: %w", err)
|
|
}
|
|
|
|
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
|
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
|
|
return fmt.Errorf("encoder.conv_out: %w", err)
|
|
}
|
|
|
|
// Load quant_conv (for encoding)
|
|
if cfg.UseQuantConv {
|
|
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
|
|
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
|
|
return fmt.Errorf("quant_conv: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadDownEncoderBlock2D loads a down encoder block.
|
|
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
|
|
resnets := make([]*ResnetBlock2D, numLayers)
|
|
for i := int32(0); i < numLayers; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resnets[i] = resnet
|
|
}
|
|
|
|
var downsample *Conv2D
|
|
if hasDownsample {
|
|
downsample = &Conv2D{Stride: 2, Padding: 0}
|
|
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &DownEncoderBlock2D{
|
|
ResnetBlocks: resnets,
|
|
Downsample: downsample,
|
|
}, nil
|
|
}
|
|
|
|
// EncodeImage encodes an image to normalized latents.
|
|
// image: [B, 3, H, W] image tensor in [-1, 1]
|
|
// Returns: [B, L, C*4] patchified normalized latents
|
|
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
|
|
// Convert NCHW -> NHWC
|
|
x := mlx.Transpose(image, 0, 2, 3, 1)
|
|
|
|
// Encoder
|
|
h := vae.EncoderConvIn.Forward(x)
|
|
|
|
for _, downBlock := range vae.EncoderDown {
|
|
h = downBlock.Forward(h)
|
|
}
|
|
|
|
h = vae.EncoderMid.Forward(h)
|
|
h = vae.EncoderNormOut.Forward(h)
|
|
h = mlx.SiLU(h)
|
|
h = vae.EncoderConvOut.Forward(h)
|
|
|
|
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
|
|
if vae.QuantConv != nil {
|
|
h = vae.QuantConv.Forward(h)
|
|
}
|
|
|
|
// Take only the mean (first latent_channels) - deterministic encoding
|
|
// h is [B, H, W, 64] -> take first 32 channels for mean
|
|
shape := h.Shape()
|
|
latentChannels := vae.Config.LatentChannels // 32
|
|
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
|
|
|
|
// Convert NHWC -> NCHW for patchifying
|
|
h = mlx.Transpose(h, 0, 3, 1, 2)
|
|
|
|
// Patchify: [B, C, H, W] -> [B, L, C*4]
|
|
h = vae.Patchify(h)
|
|
|
|
// Apply BatchNorm on patchified latents [B, L, 128]
|
|
// The BatchNorm has 128 channels matching the patchified dimension
|
|
h = vae.normalizePatchified(h)
|
|
|
|
return h
|
|
}
|
|
|
|
// normalizePatchified applies batch normalization to patchified latents.
|
|
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
|
// Output: [B, L, 128] normalized
|
|
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
C := shape[2] // 128
|
|
|
|
// Reshape stats for broadcasting [1, 1, C]
|
|
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
|
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
|
|
|
// Normalize: (x - mean) / sqrt(var + eps)
|
|
xNorm := mlx.Sub(x, mean)
|
|
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
|
|
|
// Scale and shift (only if affine=True)
|
|
if vae.LatentBN.Weight != nil {
|
|
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
|
xNorm = mlx.Mul(xNorm, weight)
|
|
}
|
|
if vae.LatentBN.Bias != nil {
|
|
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
|
xNorm = mlx.Add(xNorm, bias)
|
|
}
|
|
|
|
return xNorm
|
|
}
|