diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index 69ac8471d..b2d6a7350 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -11,9 +11,11 @@ import ( "os" "path/filepath" "runtime/pprof" + "strings" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/models/gemma3" + "github.com/ollama/ollama/x/imagegen/models/glm_image" "github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/llama" "github.com/ollama/ollama/x/imagegen/models/qwen_image" @@ -61,6 +63,7 @@ func main() { // Legacy mode flags zimageFlag := flag.Bool("zimage", false, "Z-Image generation") + glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation") qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation") qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing") var inputImages stringSlice @@ -117,6 +120,33 @@ func main() { if err == nil { err = saveImageArray(img, *out) } + case *glmImageFlag: + m := &glm_image.Model{} + // Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest) + var loadErr error + if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") { + loadErr = m.LoadFromPath(*modelPath) + } else { + loadErr = m.Load(*modelPath) + } + if loadErr != nil { + log.Fatal(loadErr) + } + var img *mlx.Array + img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{ + Prompt: *prompt, + Width: int32(*width), + Height: int32(*height), + Steps: *steps, + Seed: *seed, + Temperature: float32(*temperature), + TopP: float32(*topP), + GuidanceScale: float32(*cfgScale), + MaxVisualTokens: int32(*maxTokens), + }) + if err == nil { + err = saveImageArray(img, *out) + } case *qwenImage: m, loadErr := qwen_image.LoadPersistent(*modelPath) if loadErr != nil { diff --git a/x/imagegen/create.go b/x/imagegen/create.go index c2e22d3df..a90cfe700 100644 --- a/x/imagegen/create.go +++ b/x/imagegen/create.go @@ -48,7 +48,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, var totalParams int64 // Count parameters from original tensor shapes // Components to process - extract individual tensors from each - components := []string{"text_encoder", "transformer", "vae"} + components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"} for _, component := range components { componentDir := filepath.Join(modelDir, component) @@ -126,10 +126,13 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, "text_encoder/generation_config.json", "transformer/config.json", "vae/config.json", + "vision_language_encoder/config.json", "scheduler/scheduler_config.json", "tokenizer/tokenizer.json", "tokenizer/tokenizer_config.json", "tokenizer/vocab.json", + "processor/tokenizer.json", // GLM-Image main tokenizer + "processor/tokenizer_config.json", // GLM-Image tokenizer config } for _, cfgPath := range configFiles { diff --git a/x/imagegen/imagegen.md b/x/imagegen/imagegen.md new file mode 100644 index 000000000..be001de9f --- /dev/null +++ b/x/imagegen/imagegen.md @@ -0,0 +1,19 @@ +# Image generation models (experimental) + +Experimental image generation models are available for **macOS** in Ollama: + +## Available models + +- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo) + +``` +ollama run x/z-image-turbo +``` + +> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models + +More models coming soon: + +1. Qwen-Image-2512 +2. Qwen-Image-Edit-2511 +3. GLM-Image diff --git a/x/imagegen/memory.go b/x/imagegen/memory.go index 1d3728432..e3db51547 100644 --- a/x/imagegen/memory.go +++ b/x/imagegen/memory.go @@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{ "ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE) "FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture) "QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now + "GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model } // CheckPlatformSupport validates that image generation is supported on the current platform. diff --git a/x/imagegen/models/glm_image/glm_image.go b/x/imagegen/models/glm_image/glm_image.go new file mode 100644 index 000000000..38214b4eb --- /dev/null +++ b/x/imagegen/models/glm_image/glm_image.go @@ -0,0 +1,693 @@ +//go:build mlx + +// Package glm_image implements the GLM-Image hybrid AR + diffusion model. +package glm_image + +import ( + "context" + "fmt" + "math" + "path/filepath" + "time" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// ByT5Tokenizer is a simple byte-level tokenizer for ByT5 +// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258) +// Special tokens: 0=pad, 1=eos, 2=unk +type ByT5Tokenizer struct { + PadTokenID int32 + EOSTokenID int32 + UNKTokenID int32 +} + +// NewByT5Tokenizer creates a new ByT5 tokenizer +func NewByT5Tokenizer() *ByT5Tokenizer { + return &ByT5Tokenizer{ + PadTokenID: 0, + EOSTokenID: 1, + UNKTokenID: 2, + } +} + +// Encode converts a string to token IDs +func (t *ByT5Tokenizer) Encode(text string) []int32 { + bytes := []byte(text) + tokens := make([]int32, len(bytes)) + for i, b := range bytes { + // Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258 + // (tokens 0, 1, 2 are PAD, EOS, UNK) + tokens[i] = int32(b) + 3 + } + return tokens +} + +// Decode converts token IDs back to a string +func (t *ByT5Tokenizer) Decode(tokens []int32) string { + bytes := make([]byte, 0, len(tokens)) + for _, tok := range tokens { + if tok >= 3 && tok < 259 { + bytes = append(bytes, byte(tok-3)) + } + } + return string(bytes) +} + +// GenerateConfig holds all options for image generation. +type GenerateConfig struct { + Prompt string + NegativePrompt string // For CFG (optional, not typically used with GLM-Image) + GuidanceScale float32 // Guidance scale (default: 1.5) + Width int32 // Image width (default: 1024, must be divisible by 32) + Height int32 // Image height (default: 1024, must be divisible by 32) + Steps int // Diffusion denoising steps (default: 50) + Seed int64 // Random seed + Progress ProgressFunc // Optional progress callback + + // AR generation options + MaxVisualTokens int32 // Max visual tokens to generate (default: 256) + Temperature float32 // AR sampling temperature (default: 0.9) + TopP float32 // Nucleus sampling (default: 0.75) +} + +// ProgressFunc is called during generation with stage and step progress. +type ProgressFunc func(stage string, step, totalSteps int) + +// Model represents a GLM-Image hybrid model. +type Model struct { + ModelName string + Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings) + GLMTokenizer *GLMTokenizer // For AR model (visual token generation) + TextEncoder *T5TextEncoder + VisionLanguageEncoder *VisionLanguageEncoder + Transformer *DiffusionTransformer + VAEDecoder *VAEDecoder +} + +// Load loads the GLM-Image model from ollama blob storage. +func (m *Model) Load(modelName string) error { + fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName) + start := time.Now() + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + mlx.EnableCompile() + } + + m.ModelName = modelName + + // Load manifest + manifest, err := imagegen.LoadManifest(modelName) + if err != nil { + return fmt.Errorf("load manifest: %w", err) + } + + // Create ByT5 tokenizer (byte-level, no vocabulary file needed) + // Used for T5 text encoder (glyph embeddings) + fmt.Print(" Creating ByT5 tokenizer... ") + m.Tokenizer = NewByT5Tokenizer() + fmt.Println("✓") + + // Load GLM tokenizer for AR model (visual token generation) + fmt.Print(" Loading GLM tokenizer... ") + glmTok, err := NewGLMTokenizer(manifest) + if err != nil { + return fmt.Errorf("glm tokenizer: %w", err) + } + m.GLMTokenizer = glmTok + fmt.Println("✓") + + // Load T5 text encoder (~830MB) + m.TextEncoder = &T5TextEncoder{} + if err := m.TextEncoder.Load(manifest); err != nil { + return fmt.Errorf("text encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.TextEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load vision-language encoder (~19GB, 9B params) + m.VisionLanguageEncoder = &VisionLanguageEncoder{} + if err := m.VisionLanguageEncoder.Load(manifest); err != nil { + return fmt.Errorf("vision language encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load diffusion transformer (~13GB, 7B params) + m.Transformer = &DiffusionTransformer{} + if err := m.Transformer.Load(manifest); err != nil { + return fmt.Errorf("transformer: %w", err) + } + mlx.Eval(mlx.Collect(m.Transformer)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load VAE decoder (~775MB) + m.VAEDecoder = &VAEDecoder{} + if err := m.VAEDecoder.Load(manifest); err != nil { + return fmt.Errorf("VAE decoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VAEDecoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + mem := mlx.MetalGetActiveMemory() + fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024)) + + return nil +} + +// LoadFromPath loads the model from a directory path (not ollama manifest) +func (m *Model) LoadFromPath(modelPath string) error { + fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath) + start := time.Now() + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + mlx.EnableCompile() + } + + m.ModelName = modelPath + + // Create ByT5 tokenizer (byte-level, no vocabulary file needed) + fmt.Print(" Creating ByT5 tokenizer... ") + m.Tokenizer = NewByT5Tokenizer() + fmt.Println("✓") + + // Load GLM tokenizer for AR model (visual token generation) + fmt.Print(" Loading GLM tokenizer... ") + glmTok, err := NewGLMTokenizerFromPath(modelPath) + if err != nil { + return fmt.Errorf("glm tokenizer: %w", err) + } + m.GLMTokenizer = glmTok + fmt.Println("✓") + + // Load T5 text encoder + m.TextEncoder = &T5TextEncoder{} + if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil { + return fmt.Errorf("text encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.TextEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load vision-language encoder + m.VisionLanguageEncoder = &VisionLanguageEncoder{} + if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil { + return fmt.Errorf("vision language encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load diffusion transformer + m.Transformer = &DiffusionTransformer{} + if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil { + return fmt.Errorf("transformer: %w", err) + } + mlx.Eval(mlx.Collect(m.Transformer)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load VAE decoder + m.VAEDecoder = &VAEDecoder{} + if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil { + return fmt.Errorf("VAE decoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VAEDecoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + mem := mlx.MetalGetActiveMemory() + fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024)) + + return nil +} + +// Generate creates an image from a prompt. +func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.GenerateFromConfig(context.Background(), &GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + }) +} + +// GenerateWithProgress creates an image with progress callback. +func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) { + return m.GenerateFromConfig(context.Background(), &GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + Progress: progress, + }) +} + +// GenerateFromConfig generates an image using the unified config struct. +func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) { + start := time.Now() + result, err := m.generate(ctx, cfg) + if err != nil { + return nil, err + } + fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps) + return result, nil +} + +// GenerateImage implements model.ImageModel interface. +func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.Generate(prompt, width, height, steps, seed) +} + +// generate is the internal generation pipeline. +func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) { + // Apply defaults + if cfg.Width <= 0 { + cfg.Width = 1024 + } + if cfg.Height <= 0 { + cfg.Height = 1024 + } + if cfg.Steps <= 0 { + cfg.Steps = 50 + } + if cfg.GuidanceScale <= 0 { + cfg.GuidanceScale = 1.5 + } + // Calculate MaxVisualTokens based on image dimensions + // GLM-Image generates TWO grids of visual tokens: + // 1. First: prev (small) grid - prevTokenH × prevTokenW tokens + // 2. Then: target (large) grid - tokenH × tokenW tokens + // After generation, we extract only the TARGET grid tokens for diffusion. + factor := int32(32) + tokenH := cfg.Height / factor + tokenW := cfg.Width / factor + targetGridTokens := tokenH * tokenW + + // Compute prev grid dimensions using diffusers formula: + // ratio = token_h / token_w + // prev_token_h = int(sqrt(ratio) * 16) + // prev_token_w = int(sqrt(1/ratio) * 16) + ratio := float64(tokenH) / float64(tokenW) + prevTokenH := int32(math.Sqrt(ratio) * 16) + prevTokenW := int32(math.Sqrt(1/ratio) * 16) + prevGridTokens := prevTokenH * prevTokenW + + // Total tokens to generate = prev grid + target grid + // (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway) + cfg.MaxVisualTokens = prevGridTokens + targetGridTokens + if cfg.Temperature <= 0 { + cfg.Temperature = 0.9 + } + if cfg.TopP <= 0 { + cfg.TopP = 0.75 + } + + // Ensure dimensions are divisible by 32 + cfg.Width = (cfg.Width / 32) * 32 + cfg.Height = (cfg.Height / 32) * 32 + + tcfg := m.Transformer.Config + latentH := cfg.Height / 8 + latentW := cfg.Width / 8 + + // Progress callback helper + progress := func(stage string, step, total int) { + if cfg.Progress != nil { + cfg.Progress(stage, step, total) + } + } + + // === PHASE 1: T5 Text Encoding === + fmt.Println("[T5] Encoding glyph text...") + progress("text_encoding", 0, 1) + textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt) + mlx.Keep(textEmbed) + mlx.Eval(textEmbed) + fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape()) + progress("text_encoding", 1, 1) + + // === PHASE 2: AR Visual Token Generation === + fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens) + progress("ar_generation", 0, int(cfg.MaxVisualTokens)) + visualTokens := m.VisionLanguageEncoder.Generate( + cfg.Prompt, + m.GLMTokenizer, + cfg.MaxVisualTokens, + cfg.Temperature, + cfg.TopP, + cfg.Seed, + cfg.Height, + cfg.Width, + func(step int) { + if step%100 == 0 || step < 10 { + fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens) + } + progress("ar_generation", step, int(cfg.MaxVisualTokens)) + }, + ) + mlx.Keep(visualTokens) + mlx.Eval(visualTokens) + fmt.Printf("[AR] Done generating visual tokens\n") + progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens)) + + vtShape := visualTokens.Shape() + totalGenerated := vtShape[1] + fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated) + + // Extract only the TARGET grid tokens (skip the prev grid tokens) + // diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...] + // large_image_start_offset = prev_grid_size + var targetGridVisualTokens *mlx.Array + if totalGenerated >= prevGridTokens+targetGridTokens { + // Full generation completed - extract target grid + targetGridVisualTokens = mlx.Slice(visualTokens, + []int32{0, prevGridTokens}, + []int32{1, prevGridTokens + targetGridTokens}) + mlx.Keep(targetGridVisualTokens) + mlx.Eval(targetGridVisualTokens) + } else if totalGenerated > prevGridTokens { + // Partial target grid - take what we have + actualTargetTokens := totalGenerated - prevGridTokens + targetGridVisualTokens = mlx.Slice(visualTokens, + []int32{0, prevGridTokens}, + []int32{1, totalGenerated}) + mlx.Keep(targetGridVisualTokens) + mlx.Eval(targetGridVisualTokens) + fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n", + actualTargetTokens, targetGridTokens) + } else { + // Not enough tokens - EOS came too early + return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1", + totalGenerated, prevGridTokens) + } + + // === PHASE 3: Diffusion Decoding === + // Setup scheduler with dynamic shift based on image size + scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig()) + imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize) + scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen) + + // Initialize noise latents [B, C, H, W] + latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed) + mlx.Eval(latents) + + // Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers) + // target_grid tokens -> 2x upsample -> patch_count + // e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024 + visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2) + + // Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection) + priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled) + mlx.Keep(priorEmbed) + mlx.Eval(priorEmbed) + + // Prepare text conditioning (project T5 embeddings) + textCond := m.Transformer.ProjectTextEmbeddings(textEmbed) + mlx.Keep(textCond) + mlx.Eval(textCond) + + // === CFG Setup === + // For classifier-free guidance, we need unconditional (negative) text embeddings + // GLM-Image uses empty string "" for negative prompt + doCFG := cfg.GuidanceScale > 1.0 + var negativeTextCond *mlx.Array + if doCFG { + // Encode empty string for negative prompt + negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "") + mlx.Keep(negativeTextEmbed) + mlx.Eval(negativeTextEmbed) + negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed) + mlx.Keep(negativeTextCond) + mlx.Eval(negativeTextCond) + negativeTextEmbed.Free() + } + + // Prepare conditioning inputs + targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2}) + cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset + targetSize = mlx.ToBFloat16(targetSize) + cropCoords = mlx.ToBFloat16(cropCoords) + mlx.Keep(targetSize) + mlx.Keep(cropCoords) + mlx.Eval(targetSize, cropCoords) + + pH := latentH / tcfg.PatchSize + pW := latentW / tcfg.PatchSize + + // Denoising loop + fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps) + progress("diffusion", 0, cfg.Steps) + for i := 0; i < cfg.Steps; i++ { + fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1) + // Check for cancellation + if ctx != nil { + select { + case <-ctx.Done(): + textEmbed.Free() + visualTokens.Free() + // visualTokensUpsampled points to visualTokens, don't double-free + priorEmbed.Free() + textCond.Free() + latents.Free() + return nil, ctx.Err() + default: + } + } + + // Get timestep value for the transformer + // scheduler.Timesteps contains raw timestep values (1000 down to ~20) + // Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1 + timestepVal := scheduler.Timesteps[i] - 1 + timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1})) + + // Patchify latents [B, C, H, W] -> [B, L, C*p*p] + patches := PatchifyLatents(latents, tcfg.PatchSize) + + // Transformer forward with MMDiT architecture + // Conditional pass (with text + prior embeddings) + outputCond := m.Transformer.ForwardWithPriorDrop( + patches, + priorEmbed, + textCond, + timestep, + targetSize, + cropCoords, + pH, + pW, + false, // priorTokenDrop = false for conditional + ) + + // Unpatchify [B, L, C*p*p] -> [B, C, H, W] + noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels) + + var noisePred *mlx.Array + if doCFG { + // Unconditional pass (empty text, dropped prior embeddings) + outputUncond := m.Transformer.ForwardWithPriorDrop( + patches, + priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true + negativeTextCond, + timestep, + targetSize, + cropCoords, + pH, + pW, + true, // priorTokenDrop = true for unconditional + ) + noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels) + + // CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond) + diff := mlx.Sub(noisePredCond, noisePredUncond) + scaled := mlx.MulScalar(diff, cfg.GuidanceScale) + noisePred = mlx.Add(noisePredUncond, scaled) + } else { + noisePred = noisePredCond + } + + // Scheduler step + oldLatents := latents + latents = scheduler.Step(noisePred, latents, i) + mlx.Eval(latents) + oldLatents.Free() + + progress("diffusion", i+1, cfg.Steps) + } + + // Cleanup intermediate arrays + textEmbed.Free() + visualTokens.Free() + // visualTokensUpsampled points to visualTokens, don't double-free + priorEmbed.Free() + textCond.Free() + if negativeTextCond != nil { + negativeTextCond.Free() + } + targetSize.Free() + cropCoords.Free() + + // === PHASE 4: VAE Decode === + progress("vae_decode", 0, 1) + decoded := m.VAEDecoder.Decode(latents) + mlx.Eval(decoded) + latents.Free() + progress("vae_decode", 1, 1) + + return decoded, nil +} + +// upsampleTokens performs nearest-neighbor upsampling of visual tokens +// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x) +// scale must be 2 or 4 +// +// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements, +// missing tokens are padded with 0 (visual token padding value). +func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array { + // tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS) + // Each token at (i, j) becomes scale*scale tokens in the output + + mlx.Eval(tokens) + tokenData := tokens.DataInt32() + numTokens := int32(len(tokenData)) + expectedTokens := prevH * prevW + + // Warn if we got fewer tokens than expected (early EOS) + if numTokens < expectedTokens { + fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n", + numTokens, expectedTokens) + } + + targetH := prevH * scale + targetW := prevW * scale + upsampled := make([]int32, targetH*targetW) + + for i := int32(0); i < prevH; i++ { + for j := int32(0); j < prevW; j++ { + srcIdx := i*prevW + j + + // Handle early EOS: use 0 (padding) for missing tokens + var val int32 + if srcIdx < numTokens { + val = tokenData[srcIdx] + } else { + val = 0 // Padding token + } + + // Place in scale*scale positions + dstI := i * scale + dstJ := j * scale + for di := int32(0); di < scale; di++ { + for dj := int32(0); dj < scale; dj++ { + upsampled[(dstI+di)*targetW+(dstJ+dj)] = val + } + } + } + } + + return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW}) +} + +// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p] +func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array { + shape := latents.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + pH := H / patchSize + pW := W / patchSize + + // Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p] + x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize) + // Transpose: -> [B, pH, pW, C, p, p] + x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5) + // Flatten: -> [B, pH*pW, C*p*p] + return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize) +} + +// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W] +func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array { + shape := patches.Shape() + B := shape[0] + + pH := H / patchSize + pW := W / patchSize + + // Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p] + x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize) + // Transpose: -> [B, C, pH, p, pW, p] + x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5) + // Reshape: -> [B, C, H, W] + return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize) +} + +// CalculateShift computes the dynamic shift for flow matching based on image sequence length. +func CalculateShift(imgSeqLen int32) float32 { + cfg := DefaultSchedulerConfig() + if !cfg.UseDynamicShifting { + return 0 + } + + // Sqrt-based shift calculation (matches diffusers) + m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen))) + return m*cfg.MaxShift + cfg.BaseShift +} + +// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation +// tokens: [B, H*W] -> [B, (H*2)*(W*2)] +// This matches diffusers' _upsample_token_ids function +func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array { + shape := tokens.Shape() + B := shape[0] + + // Reshape to [B, 1, H, W] for interpolation + tokens = mlx.Reshape(tokens, B, 1, gridH, gridW) + + // Convert to float for interpolation + tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32) + + // 2x nearest neighbor upsample + // [B, 1, H, W] -> [B, 1, H*2, W*2] + upsampled := nearestUpsample2x(tokensFloat) + + // Convert back to int and reshape to [B, H*2*W*2] + upsampled = mlx.AsType(upsampled, mlx.DtypeInt32) + return mlx.Reshape(upsampled, B, gridH*2*gridW*2) +} + +// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor +func nearestUpsample2x(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + // Repeat each element 2x2 + // [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2] + x = mlx.Reshape(x, B, C, H, 1, W, 1) + + // Tile to repeat each pixel 2x2 + x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2}) + + // Reshape to final size + return mlx.Reshape(x, B, C, H*2, W*2) +} diff --git a/x/imagegen/models/glm_image/glm_tokenizer.go b/x/imagegen/models/glm_image/glm_tokenizer.go new file mode 100644 index 000000000..f8b31c422 --- /dev/null +++ b/x/imagegen/models/glm_image/glm_tokenizer.go @@ -0,0 +1,358 @@ +//go:build mlx + +package glm_image + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/ollama/ollama/x/imagegen" +) + +// GLMTokenizer implements the GLM tokenizer for the AR model +// This is a BPE-style tokenizer with ignore_merges=true, meaning it does +// greedy longest-match tokenization from the vocab without runtime merging. +type GLMTokenizer struct { + Vocab map[string]int32 // token string -> token ID + VocabReverse map[int32]string // token ID -> token string + SpecialTokens map[string]int32 // special token strings -> IDs + + // Special token IDs + SopTokenID int32 // = grid_bos_token (167845) + EopTokenID int32 // = grid_eos_token (167846) + BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384) + EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385) + PadTokenID int32 + + // Sorted vocab keys by length (longest first) for greedy matching + sortedTokens []string +} + +// tokenizerJSON represents the structure of tokenizer.json +type tokenizerJSON struct { + Model struct { + Vocab map[string]int32 `json:"vocab"` + } `json:"model"` + AddedTokens []struct { + ID int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` +} + +// NewGLMTokenizer creates a GLM tokenizer from the model manifest +func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) { + // Read tokenizer.json from processor directory in manifest + data, err := manifest.ReadConfig("processor/tokenizer.json") + if err != nil { + return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err) + } + + var tj tokenizerJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err) + } + + tok := &GLMTokenizer{ + Vocab: make(map[string]int32), + VocabReverse: make(map[int32]string), + SpecialTokens: make(map[string]int32), + } + + // Load vocab from model section + for token, id := range tj.Model.Vocab { + tok.Vocab[token] = id + tok.VocabReverse[id] = token + } + + // Load added tokens (special tokens including dit_tokens) + for _, at := range tj.AddedTokens { + tok.Vocab[at.Content] = at.ID + tok.VocabReverse[at.ID] = at.Content + if at.Special { + tok.SpecialTokens[at.Content] = at.ID + } + } + + // Set special token IDs + tok.SopTokenID = 167845 // + tok.EopTokenID = 167846 // + tok.BosTokenID = 16384 // <|dit_token_16384|> + tok.EosTokenID = 16385 // <|dit_token_16385|> + tok.PadTokenID = 16385 // Same as EOS + + // Build sorted token list for greedy matching (longest first) + tok.sortedTokens = make([]string, 0, len(tok.Vocab)) + for token := range tok.Vocab { + tok.sortedTokens = append(tok.sortedTokens, token) + } + sort.Slice(tok.sortedTokens, func(i, j int) bool { + return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j]) + }) + + fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab)) + + return tok, nil +} + +// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path +func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) { + // Read tokenizer.json from processor directory + tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json") + data, err := os.ReadFile(tokenizerPath) + if err != nil { + return nil, fmt.Errorf("failed to read tokenizer.json: %w", err) + } + + var tj tokenizerJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err) + } + + tok := &GLMTokenizer{ + Vocab: make(map[string]int32), + VocabReverse: make(map[int32]string), + SpecialTokens: make(map[string]int32), + } + + // Load vocab from model section + for token, id := range tj.Model.Vocab { + tok.Vocab[token] = id + tok.VocabReverse[id] = token + } + + // Load added tokens (special tokens including dit_tokens) + for _, at := range tj.AddedTokens { + tok.Vocab[at.Content] = at.ID + tok.VocabReverse[at.ID] = at.Content + if at.Special { + tok.SpecialTokens[at.Content] = at.ID + } + } + + // Set special token IDs + tok.SopTokenID = 167845 // + tok.EopTokenID = 167846 // + tok.BosTokenID = 16384 // <|dit_token_16384|> + tok.EosTokenID = 16385 // <|dit_token_16385|> + tok.PadTokenID = 16385 // Same as EOS + + // Build sorted token list for greedy matching (longest first) + tok.sortedTokens = make([]string, 0, len(tok.Vocab)) + for token := range tok.Vocab { + tok.sortedTokens = append(tok.sortedTokens, token) + } + sort.Slice(tok.sortedTokens, func(i, j int) bool { + return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j]) + }) + + fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab)) + + return tok, nil +} + +// Encode tokenizes a string into token IDs +// This uses greedy longest-match tokenization with GPT-2 style space handling +func (t *GLMTokenizer) Encode(text string) []int32 { + if text == "" { + return []int32{} + } + + var tokens []int32 + + // First, check for and handle special tokens + // Replace special tokens with placeholders, encode, then restore + specialReplacements := make(map[string]int32) + for special, id := range t.SpecialTokens { + if strings.Contains(text, special) { + specialReplacements[special] = id + } + } + + // Process text character by character with special token handling + i := 0 + isFirstToken := true + + for i < len(text) { + // Check for special tokens first + foundSpecial := false + for special, id := range specialReplacements { + if strings.HasPrefix(text[i:], special) { + tokens = append(tokens, id) + i += len(special) + isFirstToken = false + foundSpecial = true + break + } + } + if foundSpecial { + continue + } + + // Handle regular text with GPT-2 style space prefix + // "Ġ" (U+0120) represents a space before a token + remaining := text[i:] + + // Try to find the longest matching token + matched := false + for _, token := range t.sortedTokens { + // Skip special tokens in regular matching + if _, isSpecial := t.SpecialTokens[token]; isSpecial { + continue + } + + // Check if this token matches + tokenText := token + + // Handle the Ġ prefix (represents space) + if strings.HasPrefix(token, "Ġ") { + // This token expects a leading space + if i > 0 || !isFirstToken { + // Check if remaining starts with space + token content + tokenContent := token[len("Ġ"):] + if strings.HasPrefix(remaining, " "+tokenContent) { + tokens = append(tokens, t.Vocab[token]) + i += 1 + len(tokenContent) // space + content + isFirstToken = false + matched = true + break + } + } + } else { + // Regular token without space prefix + if strings.HasPrefix(remaining, tokenText) { + tokens = append(tokens, t.Vocab[token]) + i += len(tokenText) + isFirstToken = false + matched = true + break + } + } + } + + if !matched { + // No token found - skip this character (or use UNK) + // For now, just skip unknown characters + i++ + } + } + + return tokens +} + +// EncodeForGeneration encodes a prompt with grid tokens for image generation +// Format: {prompt}{token_h} {token_w}{prev_h} {prev_w}<|dit_token_16384|> +// +// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with +// space prefix), matching the HuggingFace tokenizer behavior. +func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 { + // Calculate grid dimensions + factor := int32(32) + height := (targetHeight / factor) * factor + width := (targetWidth / factor) * factor + tokenH := height / factor + tokenW := width / factor + + // Calculate previous grid dimensions + ratio := float64(tokenH) / float64(tokenW) + prevTokenH := int32(sqrt(ratio) * 16) + prevTokenW := int32(sqrt(1.0/ratio) * 16) + + // Encode the prompt text + promptTokens := t.Encode(prompt) + + // Build the full sequence: + // [prompt tokens] [tokenH] [Ġ+tokenW] [prevH] [Ġ+prevW] + // Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32" + var tokens []int32 + tokens = append(tokens, promptTokens...) + + // First grid: H W + // First number has no space prefix, second number has space prefix (Ġ) + tokens = append(tokens, t.SopTokenID) + tokens = append(tokens, t.encodeNumber(tokenH)...) + tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W + tokens = append(tokens, t.EopTokenID) + + // Second grid: prevH prevW + tokens = append(tokens, t.SopTokenID) + tokens = append(tokens, t.encodeNumber(prevTokenH)...) + tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW + tokens = append(tokens, t.EopTokenID) + + // BOS token (start of image generation) + tokens = append(tokens, t.BosTokenID) + + return tokens +} + +// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit +func (t *GLMTokenizer) encodeNumber(n int32) []int32 { + s := fmt.Sprintf("%d", n) + // First try: look up the whole number as a single token + if id, ok := t.Vocab[s]; ok { + return []int32{id} + } + // Fallback: encode digit by digit + var tokens []int32 + for _, c := range s { + if id, ok := t.Vocab[string(c)]; ok { + tokens = append(tokens, id) + } + } + return tokens +} + +// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer +// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32" +func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 { + s := fmt.Sprintf("%d", n) + + // First try: look up "Ġ{number}" as a single token (e.g., "Ġ32") + spaceToken := "Ġ" + s + if id, ok := t.Vocab[spaceToken]; ok { + return []int32{id} + } + + // Fallback: bare space Ġ + number tokens + var tokens []int32 + if spaceID, ok := t.Vocab["Ġ"]; ok { + tokens = append(tokens, spaceID) + } + tokens = append(tokens, t.encodeNumber(n)...) + return tokens +} + +// sqrt is a helper for float64 sqrt +func sqrt(x float64) float64 { + if x <= 0 { + return 0 + } + // Newton's method + z := x + for i := 0; i < 10; i++ { + z = z - (z*z-x)/(2*z) + } + return z +} + +// Decode converts token IDs back to a string +func (t *GLMTokenizer) Decode(tokens []int32) string { + var sb strings.Builder + for _, id := range tokens { + if token, ok := t.VocabReverse[id]; ok { + // Handle Ġ prefix (convert back to space) + if strings.HasPrefix(token, "Ġ") { + sb.WriteString(" ") + sb.WriteString(token[len("Ġ"):]) + } else { + sb.WriteString(token) + } + } + } + return sb.String() +} diff --git a/x/imagegen/models/glm_image/scheduler.go b/x/imagegen/models/glm_image/scheduler.go new file mode 100644 index 000000000..231bab6ca --- /dev/null +++ b/x/imagegen/models/glm_image/scheduler.go @@ -0,0 +1,159 @@ +//go:build mlx + +package glm_image + +import ( + "math" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// FlowMatchSchedulerConfig holds scheduler configuration +type FlowMatchSchedulerConfig struct { + NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000 + BaseShift float32 `json:"base_shift"` // 0.25 + MaxShift float32 `json:"max_shift"` // 0.75 + BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256 + MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096 + UseDynamicShifting bool `json:"use_dynamic_shifting"` // true + TimeShiftType string `json:"time_shift_type"` // "linear" +} + +// DefaultSchedulerConfig returns the default config for GLM-Image +func DefaultSchedulerConfig() *FlowMatchSchedulerConfig { + return &FlowMatchSchedulerConfig{ + NumTrainTimesteps: 1000, + BaseShift: 0.25, + MaxShift: 0.75, + BaseImageSeqLen: 256, + MaxImageSeqLen: 4096, + UseDynamicShifting: true, + TimeShiftType: "linear", + } +} + +// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler +type FlowMatchScheduler struct { + Config *FlowMatchSchedulerConfig + Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted) + Sigmas []float32 // Shifted sigmas for Euler step calculation + NumSteps int +} + +// NewFlowMatchScheduler creates a new scheduler +func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler { + return &FlowMatchScheduler{Config: cfg} +} + +// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size +// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation +func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) { + s.NumSteps = numSteps + + // Calculate shift (mu) based on image sequence length + mu := s.calculateShift(imgSeqLen) + + // Create timesteps: linspace from sigma_max_t to sigma_min_t + // sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0) + // Then apply time shift and append terminal sigma=0 + s.Timesteps = make([]float32, numSteps) + s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma + + numTrainTimesteps := float32(s.Config.NumTrainTimesteps) + + // Create base sigmas: linspace from 1.0 to small value (matching diffusers) + for i := 0; i < numSteps; i++ { + // linspace from 1000 to ~20 (sigma_min * num_train_timesteps) + tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1) + s.Timesteps[i] = tRaw + + // Convert to sigma [0, 1] + sigma := tRaw / numTrainTimesteps + + // Apply time shift if enabled + if s.Config.UseDynamicShifting && mu > 0 { + sigma = s.applyShift(mu, sigma) + } + + s.Sigmas[i] = sigma + } + + // Append terminal sigma = 0 (the final clean image) + s.Sigmas[numSteps] = 0 +} + +// calculateShift computes dynamic shift based on image sequence length +// Uses the sqrt-based formula from diffusers: +// m = (image_seq_len / base_seq_len) ** 0.5 +// mu = m * max_shift + base_shift +func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 { + cfg := s.Config + + if !cfg.UseDynamicShifting { + return 0 + } + + // Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py) + m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen))) + mu := m*cfg.MaxShift + cfg.BaseShift + return mu +} + +// applyShift applies time shift transformation +// mu: the computed shift value +// t: sigma value in [0, 1] +func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 { + if t <= 0 { + return 0 + } + if t >= 1 { + return 1 + } + + // sigma=1.0 for both shift types + sigma := float32(1.0) + + if s.Config.TimeShiftType == "linear" { + // Linear: mu / (mu + (1/t - 1)^sigma) + return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma)))) + } + + // Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma) + expMu := float32(math.Exp(float64(mu))) + return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma)))) +} + +// Step performs one denoising step +func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array { + sigma := s.Sigmas[stepIdx] + sigmaNext := s.Sigmas[stepIdx+1] + + // Euler step: x_{t-dt} = x_t + dt * v_t + dt := sigmaNext - sigma // Negative (going from noise to clean) + + scaledOutput := mlx.MulScalar(modelOutput, dt) + return mlx.Add(sample, scaledOutput) +} + +// InitNoise creates initial noise +func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array { + return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16) +} + +// AddNoise adds noise to clean samples for a given timestep (for img2img) +func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array { + // In flow matching: x_t = (1-sigma) * x_0 + sigma * noise + // Use sigmas (shifted) for the interpolation + sigma := s.Sigmas[timestepIdx] + oneMinusSigma := 1.0 - sigma + + scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma) + scaledNoise := mlx.MulScalar(noise, sigma) + + return mlx.Add(scaledClean, scaledNoise) +} + +// GetTimesteps returns all timesteps +func (s *FlowMatchScheduler) GetTimesteps() []float32 { + return s.Timesteps +} diff --git a/x/imagegen/models/glm_image/text_encoder.go b/x/imagegen/models/glm_image/text_encoder.go new file mode 100644 index 000000000..76a5468fa --- /dev/null +++ b/x/imagegen/models/glm_image/text_encoder.go @@ -0,0 +1,497 @@ +//go:build mlx + +package glm_image + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "regexp" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// T5Config holds T5 encoder configuration +type T5Config struct { + DModel int32 `json:"d_model"` // 1472 + DFF int32 `json:"d_ff"` // 3584 + DKV int32 `json:"d_kv"` // 64 + NumHeads int32 `json:"num_heads"` // 6 + NumLayers int32 `json:"num_layers"` // 12 + VocabSize int32 `json:"vocab_size"` // 384 (byte-level) + LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6 + IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu) + + // Relative position bias + RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32 + RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128 +} + +// T5TextEncoder is the T5 encoder for text conditioning +type T5TextEncoder struct { + Config *T5Config + + // Embedding (shared for ByT5) + SharedEmbed *nn.Embedding `weight:"shared"` + + // Encoder layers + Layers []*T5Block `weight:"encoder.block"` + + // Final layer norm + FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"` + + // Relative position bias (from first layer, shared across all) + RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"` +} + +// T5Block is a single T5 encoder block +type T5Block struct { + // Self attention + Layer0 *T5LayerSelfAttention `weight:"layer.0"` + // FFN + Layer1 *T5LayerFF `weight:"layer.1"` +} + +// T5LayerSelfAttention is T5's self-attention layer +type T5LayerSelfAttention struct { + SelfAttention *T5Attention `weight:"SelfAttention"` + LayerNorm *T5LayerNorm `weight:"layer_norm"` +} + +// T5Attention implements T5's relative attention +type T5Attention struct { + Q *mlx.Array `weight:"q.weight"` // No bias in T5 + K *mlx.Array `weight:"k.weight"` + V *mlx.Array `weight:"v.weight"` + O *mlx.Array `weight:"o.weight"` + + NHeads int32 + DKV int32 + Scale float32 +} + +// T5LayerFF is T5's feedforward layer with gated-gelu +type T5LayerFF struct { + DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"` + LayerNorm *T5LayerNorm `weight:"layer_norm"` +} + +// T5DenseGatedGelu is T5's gated-gelu FFN +type T5DenseGatedGelu struct { + Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection + Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection + Wo *mlx.Array `weight:"wo.weight"` // down projection +} + +// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction) +type T5LayerNorm struct { + Weight *mlx.Array `weight:"weight"` + Eps float32 +} + +// Load loads the T5 text encoder from manifest +func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error { + fmt.Print(" Loading T5 text encoder... ") + + // Load config + var cfg T5Config + if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil { + return fmt.Errorf("config: %w", err) + } + m.Config = &cfg + + // Pre-allocate layers + m.Layers = make([]*T5Block, cfg.NumLayers) + + // Load weights + weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder") + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(0); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Println("✓") + return nil +} + +// LoadFromPath loads the T5 text encoder from a directory path +func (m *T5TextEncoder) LoadFromPath(path string) error { + fmt.Print(" Loading T5 text encoder... ") + + // Load config + var cfg T5Config + configPath := filepath.Join(path, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + if err := json.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("parse config: %w", err) + } + m.Config = &cfg + + // Pre-allocate layers + m.Layers = make([]*T5Block, cfg.NumLayers) + + // Load weights from safetensors files + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(0); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Println("✓") + return nil +} + +func (m *T5TextEncoder) initComputedFields() { + cfg := m.Config + m.FinalNorm.Eps = cfg.LayerNormEps + for _, block := range m.Layers { + attn := block.Layer0.SelfAttention + attn.NHeads = cfg.NumHeads + attn.DKV = cfg.DKV + attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV))) + + block.Layer0.LayerNorm.Eps = cfg.LayerNormEps + block.Layer1.LayerNorm.Eps = cfg.LayerNormEps + } +} + +// Forward encodes text tokens +func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array { + cfg := m.Config + + // Get embeddings + h := m.SharedEmbed.Forward(tokens) + + // Compute relative position bias once + seqLen := tokens.Shape()[1] + posBias := m.computeRelativePositionBias(seqLen) + + // Forward through layers + for _, block := range m.Layers { + h = block.Forward(h, posBias, cfg.LayerNormEps) + } + + // Final norm + h = m.FinalNorm.Forward(h) + + return h +} + +// extractGlyphTexts extracts quoted text (glyphs) from the prompt +// This matches diffusers' get_glyph_texts from pipeline_glm_image.py +// Glyph texts are used for text rendering guidance in the generated image +func extractGlyphTexts(prompt string) []string { + var glyphTexts []string + + // Extract text in single quotes: 'text' + re1 := regexp.MustCompile(`'([^']*)'`) + for _, match := range re1.FindAllStringSubmatch(prompt, -1) { + if len(match) > 1 { + glyphTexts = append(glyphTexts, match[1]) + } + } + + // Extract text in Unicode curly double quotes: "text" + re2 := regexp.MustCompile(`"([^""]*)"`) + for _, match := range re2.FindAllStringSubmatch(prompt, -1) { + if len(match) > 1 { + glyphTexts = append(glyphTexts, match[1]) + } + } + + // Extract text in ASCII double quotes: "text" + re3 := regexp.MustCompile(`"([^"]*)"`) + for _, match := range re3.FindAllStringSubmatch(prompt, -1) { + if len(match) > 1 { + glyphTexts = append(glyphTexts, match[1]) + } + } + + // Extract text in Japanese quotes: 「text」 + re4 := regexp.MustCompile(`「([^「」]*)」`) + for _, match := range re4.FindAllStringSubmatch(prompt, -1) { + if len(match) > 1 { + glyphTexts = append(glyphTexts, match[1]) + } + } + + return glyphTexts +} + +// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder +// This provides text conditioning for the diffusion transformer via the glyph projector +// +// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the +// full prompt. Glyph texts are used for text rendering guidance in the generated image. +// Multiple glyph texts are encoded and concatenated to form the conditioning signal. +// This matches diffusers' _get_glyph_embeds() behavior. +func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array { + // Extract glyph texts from prompt (text in quotes) + glyphTexts := extractGlyphTexts(prompt) + + // If no glyph texts found, encode empty string (matches diffusers: [""] fallback) + if len(glyphTexts) == 0 { + glyphTexts = []string{""} + } + + // Encode each glyph text and collect token sequences + // Matching diffusers' _get_glyph_embeds() which batches all glyph texts + var allTokenSeqs [][]int32 + + for _, glyphText := range glyphTexts { + // ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258) + tokens := tok.Encode(glyphText) + + // Add EOS token (1) at the end to match HuggingFace tokenizer behavior + tokens = append(tokens, tok.EOSTokenID) + + allTokenSeqs = append(allTokenSeqs, tokens) + } + + // Process each glyph text through the encoder + var allEmbeddings []*mlx.Array + for _, tokens := range allTokenSeqs { + tokenLen := len(tokens) + if tokenLen == 0 { + continue + } + + // Create token array [1, L] + tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)}) + + // Forward through encoder + output := m.Forward(tokensArr) + mlx.Eval(output) + + allEmbeddings = append(allEmbeddings, output) + } + + // Concatenate all glyph embeddings along sequence dimension + var output *mlx.Array + if len(allEmbeddings) == 0 { + // Fallback: return single zero embedding + output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16) + } else if len(allEmbeddings) == 1 { + output = allEmbeddings[0] + } else { + output = mlx.Concatenate(allEmbeddings, 1) + } + mlx.Eval(output) + + return output +} + +// computeRelativePositionBias computes T5's relative position encoding +func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array { + cfg := m.Config + + // Create relative position matrix + // For each (query_pos, key_pos) pair, compute bucketed relative position + numBuckets := cfg.RelativeAttentionNumBuckets + maxDistance := cfg.RelativeAttentionMaxDistance + + // Create position indices + contextPos := make([]int32, seqLen*seqLen) + memoryPos := make([]int32, seqLen*seqLen) + for i := int32(0); i < seqLen; i++ { + for j := int32(0); j < seqLen; j++ { + contextPos[i*seqLen+j] = i + memoryPos[i*seqLen+j] = j + } + } + + // Compute relative positions and bucket them + buckets := make([]int32, seqLen*seqLen) + for i := int32(0); i < seqLen*seqLen; i++ { + relPos := memoryPos[i] - contextPos[i] + buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false) + } + + // Create bucket indices array + bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen}) + + // Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6] + // Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads] + bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads] + + // Transpose to [numHeads, seqLen, seqLen] + bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen] + bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen] + + return bias +} + +// relativePosistionBucket computes the bucket for a relative position +func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 { + var bucket int32 = 0 + var n int32 = -relativePosition + + if bidirectional { + numBuckets /= 2 + if n < 0 { + bucket += numBuckets + n = -n + } + } else { + if n < 0 { + n = 0 + } + } + + // Half buckets are for exact positions, half are for log-spaced + maxExact := numBuckets / 2 + if n < maxExact { + bucket += n + } else { + // Log-spaced buckets + logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact)) + bucket += maxExact + int32(logVal*float64(numBuckets-maxExact)) + if bucket > numBuckets-1 { + bucket = numBuckets - 1 + } + } + + return bucket +} + +// Forward for T5Block +func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array { + // Self attention with residual + h := b.Layer0.Forward(x, posBias, eps) + + // FFN with residual + h = b.Layer1.Forward(h, eps) + + return h +} + +// Forward for T5LayerSelfAttention +func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array { + // Pre-norm + normed := l.LayerNorm.Forward(x) + + // Attention + attnOut := l.SelfAttention.Forward(normed, posBias) + + // Residual + return mlx.Add(x, attnOut) +} + +// Forward for T5Attention +func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // Q, K, V projections (no bias) + // Weights are [out_features, in_features], so we use matmul with transpose + q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0)) + k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0)) + v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0)) + + // Reshape to [B, L, nheads, d_kv] + q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV) + k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV) + v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV) + + // Transpose to [B, nheads, L, d_kv] + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // Attention scores with relative position bias + // T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias + // (no 1/sqrt(d_k) scale factor like in standard transformers) + scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2)) + scores = mlx.Add(scores, posBias) + + // Softmax + attnWeights := mlx.Softmax(scores, -1) + + // Attend to values + out := mlx.Matmul(attnWeights, v) + + // Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv] + out = mlx.Transpose(out, 0, 2, 1, 3) + // Reshape to [B, L, D] + out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV) + + // Output projection + out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0)) + + _ = D // Silence unused warning + return out +} + +// Forward for T5LayerFF +func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array { + // Pre-norm + normed := l.LayerNorm.Forward(x) + + // FFN + ffOut := l.DenseReluDense.Forward(normed) + + // Residual + return mlx.Add(x, ffOut) +} + +// geluNew implements the GELU activation with tanh approximation (gelu_new) +// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation +// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +func geluNew(x *mlx.Array) *mlx.Array { + sqrt2OverPi := float32(0.7978845608) // sqrt(2/π) + coeff := float32(0.044715) + + x3 := mlx.Mul(mlx.Mul(x, x), x) + inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi) + return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0)) +} + +// Forward for T5DenseGatedGelu (gated-gelu activation) +func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array { + // Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new) + gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0)) + gate = geluNew(gate) + + // Up projection + up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0)) + + // Gated output + h := mlx.Mul(gate, up) + + // Down projection + return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0)) +} + +// Forward for T5LayerNorm (RMSNorm variant) +func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array { + // T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight + variance := mlx.Mean(mlx.Square(x), -1, true) + x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps))) + return mlx.Mul(x, ln.Weight) +} diff --git a/x/imagegen/models/glm_image/transformer.go b/x/imagegen/models/glm_image/transformer.go new file mode 100644 index 000000000..0d7addfa3 --- /dev/null +++ b/x/imagegen/models/glm_image/transformer.go @@ -0,0 +1,1255 @@ +//go:build mlx + +package glm_image + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +var debugOnce = true + +// DiffusionLayerKVCache holds KV cache for a single diffusion layer +type DiffusionLayerKVCache struct { + Keys *mlx.Array + Values *mlx.Array + Mode string // "write", "read", "skip", or "" +} + +// Store adds K,V to the cache +func (c *DiffusionLayerKVCache) Store(k, v *mlx.Array) { + if c.Keys == nil { + c.Keys = k + c.Values = v + } else { + oldK, oldV := c.Keys, c.Values + c.Keys = mlx.Concatenate([]*mlx.Array{oldK, k}, 1) + c.Values = mlx.Concatenate([]*mlx.Array{oldV, v}, 1) + oldK.Free() + oldV.Free() + } +} + +// Get returns cached K,V concatenated with new K,V +func (c *DiffusionLayerKVCache) Get(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { + // Expand cache if batch size differs + kCache, vCache := c.Keys, c.Values + if c.Keys.Shape()[0] != k.Shape()[0] { + kCache = mlx.BroadcastTo(c.Keys, []int32{k.Shape()[0], c.Keys.Shape()[1], c.Keys.Shape()[2], c.Keys.Shape()[3]}) + vCache = mlx.BroadcastTo(c.Values, []int32{v.Shape()[0], c.Values.Shape()[1], c.Values.Shape()[2], c.Values.Shape()[3]}) + } + return mlx.Concatenate([]*mlx.Array{kCache, k}, 1), + mlx.Concatenate([]*mlx.Array{vCache, v}, 1) +} + +// Clear releases cached tensors +func (c *DiffusionLayerKVCache) Clear() { + if c.Keys != nil { + c.Keys.Free() + c.Keys = nil + } + if c.Values != nil { + c.Values.Free() + c.Values = nil + } + c.Mode = "" +} + +// DiffusionKVCache holds KV caches for all diffusion layers +type DiffusionKVCache struct { + Layers []*DiffusionLayerKVCache +} + +// NewDiffusionKVCache creates a cache for the given number of layers +func NewDiffusionKVCache(numLayers int32) *DiffusionKVCache { + layers := make([]*DiffusionLayerKVCache, numLayers) + for i := range layers { + layers[i] = &DiffusionLayerKVCache{} + } + return &DiffusionKVCache{Layers: layers} +} + +// SetMode sets the cache mode for all layers +func (c *DiffusionKVCache) SetMode(mode string) { + for _, layer := range c.Layers { + layer.Mode = mode + } +} + +// Clear releases all cached tensors +func (c *DiffusionKVCache) Clear() { + for _, layer := range c.Layers { + layer.Clear() + } +} + +// DiffusionConfig holds diffusion transformer configuration +type DiffusionConfig struct { + AttentionHeadDim int32 `json:"attention_head_dim"` // 128 + NumAttentionHeads int32 `json:"num_attention_heads"` // 32 + NumLayers int32 `json:"num_layers"` // 30 + InChannels int32 `json:"in_channels"` // 16 + OutChannels int32 `json:"out_channels"` // 16 + PatchSize int32 `json:"patch_size"` // 2 + TextEmbedDim int32 `json:"text_embed_dim"` // 1472 (T5 output) + TimeEmbedDim int32 `json:"time_embed_dim"` // 512 + ConditionDim int32 `json:"condition_dim"` // 256 + PriorVQCodebookSize int32 `json:"prior_vq_quantizer_codebook_size"` // 16384 + RopeTheta float32 `json:"rope_theta"` // 10000.0 + + // Computed + HiddenDim int32 // num_heads * head_dim = 4096 +} + +// DiffusionTransformer is the 7B diffusion decoder +type DiffusionTransformer struct { + Config *DiffusionConfig + + // Prior token embedding (VQ codebook) + PriorTokenEmbedding *nn.Embedding `weight:"prior_token_embedding"` + + // Projectors + PriorProjector *DiTMLPSiLU `weight:"prior_projector"` + ImageProjector *mlx.Array `weight:"image_projector.proj.weight"` + ImageProjectorBias *mlx.Array `weight:"image_projector.proj.bias"` + GlyphProjector *DiTMLP `weight:"glyph_projector"` + + // Time + condition embedding + TimeProj *mlx.Array `weight:"time_condition_embed.timestep_embedder.linear_1.weight"` + TimeProjBias *mlx.Array `weight:"time_condition_embed.timestep_embedder.linear_1.bias"` + TimeProj2 *mlx.Array `weight:"time_condition_embed.timestep_embedder.linear_2.weight"` + TimeProjBias2 *mlx.Array `weight:"time_condition_embed.timestep_embedder.linear_2.bias"` + ConditionProj *mlx.Array `weight:"time_condition_embed.condition_embedder.linear_1.weight"` + ConditionProjBias *mlx.Array `weight:"time_condition_embed.condition_embedder.linear_1.bias"` + ConditionProj2 *mlx.Array `weight:"time_condition_embed.condition_embedder.linear_2.weight"` + ConditionProjBias2 *mlx.Array `weight:"time_condition_embed.condition_embedder.linear_2.bias"` + + // Transformer blocks (single-stream) + Blocks []*DiTBlock `weight:"transformer_blocks"` + + // Output + NormOutLinear *mlx.Array `weight:"norm_out.linear.weight"` + NormOutLinearBias *mlx.Array `weight:"norm_out.linear.bias"` + ProjOut *mlx.Array `weight:"proj_out.weight"` + ProjOutBias *mlx.Array `weight:"proj_out.bias"` +} + +// DiTMLP is a simple MLP with GELU activation (used for glyph_projector) +type DiTMLP struct { + Linear1 *mlx.Array `weight:"net.0.proj.weight"` + Bias1 *mlx.Array `weight:"net.0.proj.bias"` + Linear2 *mlx.Array `weight:"net.2.weight"` + Bias2 *mlx.Array `weight:"net.2.bias"` +} + +// DiTMLPSiLU is an MLP with SiLU activation (used for prior_projector) +type DiTMLPSiLU struct { + Linear1 *mlx.Array `weight:"net.0.proj.weight"` + Bias1 *mlx.Array `weight:"net.0.proj.bias"` + Linear2 *mlx.Array `weight:"net.2.weight"` + Bias2 *mlx.Array `weight:"net.2.bias"` +} + +// DiTBlock is a single-stream transformer block +type DiTBlock struct { + // Single attention (no cross-attention) + Attn1 *DiTAttention `weight:"attn1"` + + // FFN + FF *DiTFeedForward `weight:"ff"` + + // AdaLN modulation + Norm1Linear *mlx.Array `weight:"norm1.linear.weight"` + Norm1LinearBias *mlx.Array `weight:"norm1.linear.bias"` +} + +// DiTAttention implements self-attention for DiT +type DiTAttention struct { + ToQ *mlx.Array `weight:"to_q.weight"` + ToK *mlx.Array `weight:"to_k.weight"` + ToV *mlx.Array `weight:"to_v.weight"` + ToOut *mlx.Array `weight:"to_out.0.weight"` + ToOutBias *mlx.Array `weight:"to_out.0.bias"` + + // Have biases + QBias *mlx.Array `weight:"to_q.bias"` + KBias *mlx.Array `weight:"to_k.bias"` + VBias *mlx.Array `weight:"to_v.bias"` + + NHeads int32 + HeadDim int32 + Scale float32 +} + +// DiTFeedForward is the FFN in DiT blocks +type DiTFeedForward struct { + Linear1 *mlx.Array `weight:"net.0.proj.weight"` + Bias1 *mlx.Array `weight:"net.0.proj.bias"` + Linear2 *mlx.Array `weight:"net.2.weight"` + Bias2 *mlx.Array `weight:"net.2.bias"` +} + +// Load loads the diffusion transformer +func (m *DiffusionTransformer) Load(manifest *imagegen.ModelManifest) error { + fmt.Print(" Loading diffusion transformer... ") + + // Load config + var cfg DiffusionConfig + if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil { + return fmt.Errorf("config: %w", err) + } + cfg.HiddenDim = cfg.NumAttentionHeads * cfg.AttentionHeadDim + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 10000.0 // Default value matching diffusers + } + m.Config = &cfg + + // Pre-allocate blocks + m.Blocks = make([]*DiTBlock, cfg.NumLayers) + + // Load weights + weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer") + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Printf("✓ [%d layers]\n", cfg.NumLayers) + return nil +} + +// LoadFromPath loads the diffusion transformer from a directory path +func (m *DiffusionTransformer) LoadFromPath(path string) error { + fmt.Print(" Loading diffusion transformer... ") + + // Load config + var cfg DiffusionConfig + configPath := filepath.Join(path, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + if err := json.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("parse config: %w", err) + } + cfg.HiddenDim = cfg.NumAttentionHeads * cfg.AttentionHeadDim + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 10000.0 // Default value matching diffusers + } + m.Config = &cfg + + // Pre-allocate blocks + m.Blocks = make([]*DiTBlock, cfg.NumLayers) + + // Load weights + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Printf("✓ [%d layers]\n", cfg.NumLayers) + return nil +} + +func (m *DiffusionTransformer) initComputedFields() { + cfg := m.Config + for _, block := range m.Blocks { + block.Attn1.NHeads = cfg.NumAttentionHeads + block.Attn1.HeadDim = cfg.AttentionHeadDim + block.Attn1.Scale = float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim))) + } +} + +// EmbedPriorTokens converts visual token IDs to embeddings +func (m *DiffusionTransformer) EmbedPriorTokens(tokens *mlx.Array) *mlx.Array { + // Visual tokens from AR are already relative indices (0 to VisionVocabSize-1) + // stored in VisionLanguageEncoder.Generate() as: nextToken - ImageStartTokenID + // VQ codebook has PriorVQCodebookSize entries (16384), indices 0 to 16383 + // Note: VisionVocabSize (16512) > PriorVQCodebookSize (16384), so we clamp + tokensInt := mlx.AsType(tokens, mlx.DtypeInt32) + + // Clamp to valid range [0, codebook_size-1] to avoid out-of-bounds + // (VisionVocabSize may be larger than PriorVQCodebookSize) + codebookIndices := mlx.ClipScalar(tokensInt, 0, float32(m.Config.PriorVQCodebookSize-1), true, true) + codebookIndices = mlx.AsType(codebookIndices, mlx.DtypeInt32) + + // Lookup in VQ codebook + embedded := m.PriorTokenEmbedding.Forward(codebookIndices) + // Project to hidden dim via MLP (uses SiLU activation) + return m.PriorProjector.Forward(embedded) +} + +// ProjectTextEmbeddings projects T5 embeddings for conditioning +func (m *DiffusionTransformer) ProjectTextEmbeddings(textEmbed *mlx.Array) *mlx.Array { + return m.GlyphProjector.Forward(textEmbed) +} + +// Forward runs the diffusion transformer +// imgPatches: [B, L_img, C*p*p] patchified latents +// priorEmbed: [B, L_img, hidden_dim] visual token embeddings (same length as image patches!) +// textCond: [B, L_text, hidden_dim] text condition embeddings +// timestep: [B] timestep values (0-indexed, 0 to num_train_timesteps-1) +// targetSize: [B, 2] target height and width +// cropCoords: [B, 2] crop coordinates (top, left) +// pH, pW: patch grid dimensions +func (m *DiffusionTransformer) Forward( + imgPatches, priorEmbed, textCond *mlx.Array, + timestep *mlx.Array, + targetSize, cropCoords *mlx.Array, + pH, pW int32, +) *mlx.Array { + return m.ForwardWithPriorDrop(imgPatches, priorEmbed, textCond, timestep, targetSize, cropCoords, pH, pW, false) +} + +// ForwardWithPriorDrop runs the diffusion transformer with optional prior token dropping for CFG. +// priorTokenDrop: when true, zeros out prior embeddings (for unconditional CFG pass) +func (m *DiffusionTransformer) ForwardWithPriorDrop( + imgPatches, priorEmbed, textCond *mlx.Array, + timestep *mlx.Array, + targetSize, cropCoords *mlx.Array, + pH, pW int32, + priorTokenDrop bool, +) *mlx.Array { + cfg := m.Config + + // Project image patches to hidden dim + imgH := mlx.Matmul(imgPatches, mlx.Transpose(m.ImageProjector, 1, 0)) + if m.ImageProjectorBias != nil { + imgH = mlx.Add(imgH, m.ImageProjectorBias) + } + + // Add prior embeddings to image patches (element-wise, NOT concatenation!) + // This is the key architectural difference from a standard DiT + // For CFG unconditional pass, zero out prior embeddings (matches diffusers prior_token_drop) + if priorTokenDrop { + // Don't add prior embeddings - effectively zeros them out + } else { + imgH = mlx.Add(imgH, priorEmbed) + } + + // Compute timestep + condition embedding + temb := m.computeTimestepEmbedding(timestep, targetSize, cropCoords) + + // Sequence for attention: [text, image] + // Text = encoder_hidden_states (glyph embeddings) + // Image = hidden_states (image patches + prior embeddings) + textLen := textCond.Shape()[1] + imgLen := imgH.Shape()[1] + + // Compute 2D RoPE for IMAGE tokens ONLY + // Text tokens do NOT get RoPE + rope := ComputeRoPE2D(pH, pW, cfg.AttentionHeadDim, cfg.RopeTheta) + + // Forward through transformer blocks + // Note: textCond is encoder_hidden_states, imgH is hidden_states + for _, block := range m.Blocks { + imgH, textCond = block.ForwardMMDiT(imgH, textCond, temb, cfg.HiddenDim, rope.Cos, rope.Sin) + } + + // Final norm with modulation (only on image hidden states) + imgOut := m.applyOutputNorm(imgH, temb) + + // Project to output channels + output := mlx.Matmul(imgOut, mlx.Transpose(m.ProjOut, 1, 0)) + if m.ProjOutBias != nil { + output = mlx.Add(output, m.ProjOutBias) + } + + _ = textLen + _ = imgLen + + return output +} + +// computeTimestepEmbedding computes the timestep + condition embedding +// targetSize: [B, 2] - target height and width +// cropCoords: [B, 2] - crop top and left coordinates +func (m *DiffusionTransformer) computeTimestepEmbedding(timestep, targetSize, cropCoords *mlx.Array) *mlx.Array { + cfg := m.Config + + // Sinusoidal timestep embedding (flip_sin_to_cos=True, downscale_freq_shift=0) + halfDim := cfg.TimeEmbedDim / 2 + freqs := make([]float32, halfDim) + for i := int32(0); i < halfDim; i++ { + freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(halfDim))) + } + freqsArr := mlx.NewArray(freqs, []int32{halfDim}) + + // timestep: [B] -> [B, 1] * [1, halfDim] -> [B, halfDim] + t := mlx.Reshape(timestep, -1, 1) + freqsArr = mlx.Reshape(freqsArr, 1, halfDim) + args := mlx.Mul(t, freqsArr) + + // flip_sin_to_cos: concatenate cos first, then sin + cosEmb := mlx.Cos(args) + sinEmb := mlx.Sin(args) + temb := mlx.Concatenate([]*mlx.Array{cosEmb, sinEmb}, -1) // [B, TimeEmbedDim] + + // Project through TimestepEmbedding MLP: linear1 -> SiLU -> linear2 + temb = mlx.Matmul(temb, mlx.Transpose(m.TimeProj, 1, 0)) + if m.TimeProjBias != nil { + temb = mlx.Add(temb, m.TimeProjBias) + } + temb = mlx.SiLU(temb) + temb = mlx.Matmul(temb, mlx.Transpose(m.TimeProj2, 1, 0)) + if m.TimeProjBias2 != nil { + temb = mlx.Add(temb, m.TimeProjBias2) + } + + // Compute condition embedding from crop_coords and target_size + // Each is [B, 2] -> sinusoidal embed each value -> [B, 2*condition_dim] + condEmb := m.computeConditionEmbedding(cropCoords, targetSize) + + // Add condition embedding to timestep embedding + temb = mlx.Add(temb, condEmb) + + // Apply final SiLU + temb = mlx.SiLU(temb) + + return temb +} + +// computeConditionEmbedding computes sinusoidal embeddings for condition values +func (m *DiffusionTransformer) computeConditionEmbedding(cropCoords, targetSize *mlx.Array) *mlx.Array { + cfg := m.Config + + // Sinusoidal embedding for each condition value + halfDim := cfg.ConditionDim / 2 + freqs := make([]float32, halfDim) + for i := int32(0); i < halfDim; i++ { + freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(halfDim))) + } + freqsArr := mlx.NewArray(freqs, []int32{halfDim}) + + // Flatten crop_coords: [B, 2] -> [B*2] + cropFlat := mlx.Reshape(cropCoords, -1) + cropEmb := sinusoidalEmbed(cropFlat, freqsArr, halfDim) + // Reshape back: [B*2, condDim] -> [B, 2*condDim] + B := cropCoords.Shape()[0] + cropEmb = mlx.Reshape(cropEmb, B, 2*cfg.ConditionDim) + + // Same for target_size + targetFlat := mlx.Reshape(targetSize, -1) + targetEmb := sinusoidalEmbed(targetFlat, freqsArr, halfDim) + targetEmb = mlx.Reshape(targetEmb, B, 2*cfg.ConditionDim) + + // Concatenate: [B, 4*condDim] = pooled_projection_dim + condProj := mlx.Concatenate([]*mlx.Array{cropEmb, targetEmb}, -1) + + // Project through condition embedder MLP: linear1 -> SiLU -> linear2 + condEmb := mlx.Matmul(condProj, mlx.Transpose(m.ConditionProj, 1, 0)) + if m.ConditionProjBias != nil { + condEmb = mlx.Add(condEmb, m.ConditionProjBias) + } + condEmb = mlx.SiLU(condEmb) + condEmb = mlx.Matmul(condEmb, mlx.Transpose(m.ConditionProj2, 1, 0)) + if m.ConditionProjBias2 != nil { + condEmb = mlx.Add(condEmb, m.ConditionProjBias2) + } + + return condEmb +} + +// sinusoidalEmbed computes sinusoidal embeddings for a 1D array of values +func sinusoidalEmbed(x *mlx.Array, freqs *mlx.Array, halfDim int32) *mlx.Array { + // x: [N] -> [N, 1] + x = mlx.Reshape(x, -1, 1) + // freqs: [halfDim] -> [1, halfDim] + freqs = mlx.Reshape(freqs, 1, halfDim) + // args: [N, halfDim] + args := mlx.Mul(x, freqs) + + // flip_sin_to_cos: cos first, then sin + cosEmb := mlx.Cos(args) + sinEmb := mlx.Sin(args) + return mlx.Concatenate([]*mlx.Array{cosEmb, sinEmb}, -1) // [N, 2*halfDim] +} + +// applyOutputNorm applies the final norm with AdaLN modulation +func (m *DiffusionTransformer) applyOutputNorm(x *mlx.Array, temb *mlx.Array) *mlx.Array { + // Compute modulation parameters from temb + modParams := mlx.Matmul(temb, mlx.Transpose(m.NormOutLinear, 1, 0)) + if m.NormOutLinearBias != nil { + modParams = mlx.Add(modParams, m.NormOutLinearBias) + } + + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // Split into scale and shift (each is D-dimensional) + // IMPORTANT: diffusers does "scale, shift = chunk(2)" so scale comes FIRST + halfDim := D + modParams = mlx.Reshape(modParams, B, 1, -1) + // Assuming modParams has 2*D dimensions for scale and shift + modDim := modParams.Shape()[2] + if modDim >= 2*halfDim { + scale := mlx.Slice(modParams, []int32{0, 0, 0}, []int32{B, 1, halfDim}) + shift := mlx.Slice(modParams, []int32{0, 0, halfDim}, []int32{B, 1, 2 * halfDim}) + + // Apply LayerNorm then modulate + // LN(x) * (1 + scale) + shift + xNorm := layerNorm(x) + xNorm = mlx.Mul(xNorm, mlx.AddScalar(scale, 1.0)) + xNorm = mlx.Add(xNorm, shift) + return xNorm + } + + // Fallback: just apply layer norm + _ = L + return layerNorm(x) +} + +// layerNorm applies layer normalization +// Uses eps=1e-5 to match diffusers GlmImageAdaLayerNormZero +func layerNorm(x *mlx.Array) *mlx.Array { + eps := float32(1e-5) + mean := mlx.Mean(x, -1, true) + x = mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(x), -1, true) + return mlx.Div(x, mlx.Sqrt(mlx.AddScalar(variance, eps))) +} + +// ForwardMMDiT implements the MMDiT-style transformer block for GLM-Image +// hiddenStates: image tokens [B, L_img, D] +// encoderHiddenStates: text tokens [B, L_text, D] +// RoPE is applied only to image tokens +// Returns updated (hiddenStates, encoderHiddenStates) +func (b *DiTBlock) ForwardMMDiT( + hiddenStates, encoderHiddenStates *mlx.Array, + temb *mlx.Array, + hiddenDim int32, + cos, sin *mlx.Array, +) (*mlx.Array, *mlx.Array) { + shape := hiddenStates.Shape() + B := shape[0] + imgSeqLen := shape[1] + textSeqLen := encoderHiddenStates.Shape()[1] + + // === 1. Timestep conditioning (AdaLN) === + // norm1 produces 12 modulation parameters + modParams := mlx.Matmul(temb, mlx.Transpose(b.Norm1Linear, 1, 0)) + if b.Norm1LinearBias != nil { + modParams = mlx.Add(modParams, b.Norm1LinearBias) + } + + // Extract 12 modulation parameters (NO tanh on gates, per diffusers reference) + // Order: shift_msa, c_shift_msa, scale_msa, c_scale_msa, gate_msa, c_gate_msa, + // shift_mlp, c_shift_mlp, scale_mlp, c_scale_mlp, gate_mlp, c_gate_mlp + modParams = mlx.Reshape(modParams, B, -1) + + shiftMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, 0}, []int32{B, hiddenDim}), B, 1, hiddenDim) + cShiftMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, hiddenDim}, []int32{B, 2 * hiddenDim}), B, 1, hiddenDim) + scaleMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, 2 * hiddenDim}, []int32{B, 3 * hiddenDim}), B, 1, hiddenDim) + cScaleMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, 3 * hiddenDim}, []int32{B, 4 * hiddenDim}), B, 1, hiddenDim) + gateMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, 4 * hiddenDim}, []int32{B, 5 * hiddenDim}), B, 1, hiddenDim) + cGateMsa := mlx.Reshape(mlx.Slice(modParams, []int32{0, 5 * hiddenDim}, []int32{B, 6 * hiddenDim}), B, 1, hiddenDim) + + shiftMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 6 * hiddenDim}, []int32{B, 7 * hiddenDim}), B, 1, hiddenDim) + cShiftMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 7 * hiddenDim}, []int32{B, 8 * hiddenDim}), B, 1, hiddenDim) + scaleMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 8 * hiddenDim}, []int32{B, 9 * hiddenDim}), B, 1, hiddenDim) + cScaleMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 9 * hiddenDim}, []int32{B, 10 * hiddenDim}), B, 1, hiddenDim) + gateMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 10 * hiddenDim}, []int32{B, 11 * hiddenDim}), B, 1, hiddenDim) + cGateMlp := mlx.Reshape(mlx.Slice(modParams, []int32{0, 11 * hiddenDim}, []int32{B, 12 * hiddenDim}), B, 1, hiddenDim) + + // === 2. Apply LayerNorm and modulation === + // Image tokens: LN(x) * (1 + scale) + shift + normHiddenStates := layerNorm(hiddenStates) + normHiddenStates = mlx.Mul(normHiddenStates, mlx.AddScalar(scaleMsa, 1.0)) + normHiddenStates = mlx.Add(normHiddenStates, shiftMsa) + + // Text tokens (encoder_hidden_states): LN(x) * (1 + c_scale) + c_shift + normEncoderStates := layerNorm(encoderHiddenStates) + normEncoderStates = mlx.Mul(normEncoderStates, mlx.AddScalar(cScaleMsa, 1.0)) + normEncoderStates = mlx.Add(normEncoderStates, cShiftMsa) + + // === 3. Self-attention (joint over text + image) === + // Concatenate for joint attention: [text, image] + attnHiddenStates, attnEncoderStates := b.Attn1.ForwardMMDiT( + normHiddenStates, normEncoderStates, + cos, sin, + ) + + // Apply gated residual connection + hiddenStates = mlx.Add(hiddenStates, mlx.Mul(attnHiddenStates, gateMsa)) + encoderHiddenStates = mlx.Add(encoderHiddenStates, mlx.Mul(attnEncoderStates, cGateMsa)) + + // === 4. Feedforward === + // Apply norm and modulation + normHiddenStates = layerNorm(hiddenStates) + normHiddenStates = mlx.Mul(normHiddenStates, mlx.AddScalar(scaleMlp, 1.0)) + normHiddenStates = mlx.Add(normHiddenStates, shiftMlp) + + normEncoderStates = layerNorm(encoderHiddenStates) + normEncoderStates = mlx.Mul(normEncoderStates, mlx.AddScalar(cScaleMlp, 1.0)) + normEncoderStates = mlx.Add(normEncoderStates, cShiftMlp) + + // FFN (same network for both) + ffHiddenStates := b.FF.Forward(normHiddenStates) + ffEncoderStates := b.FF.Forward(normEncoderStates) + + // Apply gated residual connection + hiddenStates = mlx.Add(hiddenStates, mlx.Mul(ffHiddenStates, gateMlp)) + encoderHiddenStates = mlx.Add(encoderHiddenStates, mlx.Mul(ffEncoderStates, cGateMlp)) + + _ = imgSeqLen + _ = textSeqLen + + return hiddenStates, encoderHiddenStates +} + +// ForwardMMDiT implements joint attention for MMDiT +// hiddenStates: image tokens [B, L_img, D] - gets RoPE +// encoderHiddenStates: text tokens [B, L_text, D] - no RoPE +func (attn *DiTAttention) ForwardMMDiT( + hiddenStates, encoderHiddenStates *mlx.Array, + cos, sin *mlx.Array, +) (*mlx.Array, *mlx.Array) { + imgShape := hiddenStates.Shape() + textShape := encoderHiddenStates.Shape() + B := imgShape[0] + imgSeqLen := imgShape[1] + textSeqLen := textShape[1] + + // Concatenate: [text, image] + combined := mlx.Concatenate([]*mlx.Array{encoderHiddenStates, hiddenStates}, 1) + totalLen := textSeqLen + imgSeqLen + + // Q, K, V projections + q := mlx.Matmul(combined, mlx.Transpose(attn.ToQ, 1, 0)) + if attn.QBias != nil { + q = mlx.Add(q, attn.QBias) + } + k := mlx.Matmul(combined, mlx.Transpose(attn.ToK, 1, 0)) + if attn.KBias != nil { + k = mlx.Add(k, attn.KBias) + } + v := mlx.Matmul(combined, mlx.Transpose(attn.ToV, 1, 0)) + if attn.VBias != nil { + v = mlx.Add(v, attn.VBias) + } + + // Reshape to [B, L, nheads, head_dim] + q = mlx.Reshape(q, B, totalLen, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, totalLen, attn.NHeads, attn.HeadDim) + v = mlx.Reshape(v, B, totalLen, attn.NHeads, attn.HeadDim) + + // Apply QK normalization if present (attn.norm_q, attn.norm_k) + // GLM-Image uses LayerNorm on Q and K + q = layerNormLastDim(q) + k = layerNormLastDim(k) + + // Apply RoPE to image tokens ONLY (after text tokens) + if cos != nil && sin != nil { + // Extract image Q and K + qImg := mlx.Slice(q, []int32{0, textSeqLen, 0, 0}, []int32{B, totalLen, attn.NHeads, attn.HeadDim}) + kImg := mlx.Slice(k, []int32{0, textSeqLen, 0, 0}, []int32{B, totalLen, attn.NHeads, attn.HeadDim}) + + // Apply RoPE + qImg = applyRoPE2D(qImg, cos, sin) + kImg = applyRoPE2D(kImg, cos, sin) + + // Reconstruct full Q and K + qText := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, textSeqLen, attn.NHeads, attn.HeadDim}) + kText := mlx.Slice(k, []int32{0, 0, 0, 0}, []int32{B, textSeqLen, attn.NHeads, attn.HeadDim}) + + q = mlx.Concatenate([]*mlx.Array{qText, qImg}, 1) + k = mlx.Concatenate([]*mlx.Array{kText, kImg}, 1) + } + + // Transpose to [B, nheads, L, head_dim] + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // SDPA (no causal mask for diffusion - all tokens attend to all) + out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) + + // Transpose back and reshape + out = mlx.Transpose(out, 0, 2, 1, 3) + out = mlx.Reshape(out, B, totalLen, attn.NHeads*attn.HeadDim) + + // Output projection + out = mlx.Matmul(out, mlx.Transpose(attn.ToOut, 1, 0)) + if attn.ToOutBias != nil { + out = mlx.Add(out, attn.ToOutBias) + } + + // Split back into text and image + encoderOut := mlx.Slice(out, []int32{0, 0, 0}, []int32{B, textSeqLen, attn.NHeads * attn.HeadDim}) + hiddenOut := mlx.Slice(out, []int32{0, textSeqLen, 0}, []int32{B, totalLen, attn.NHeads * attn.HeadDim}) + + return hiddenOut, encoderOut +} + +// layerNormLastDim applies layer normalization on the last dimension +func layerNormLastDim(x *mlx.Array) *mlx.Array { + eps := float32(1e-5) + mean := mlx.Mean(x, -1, true) + x = mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(x), -1, true) + return mlx.Div(x, mlx.Sqrt(mlx.AddScalar(variance, eps))) +} + +// Forward for DiTBlock (no RoPE, for compatibility) +func (b *DiTBlock) Forward(x *mlx.Array, temb *mlx.Array, hiddenDim int32, contextLen int32) *mlx.Array { + return b.ForwardWithRoPE(x, temb, hiddenDim, nil, nil, contextLen) +} + +// ForwardWithRoPE applies the block with optional RoPE (legacy interface) +// contextLen is the number of context tokens (prior + text) at the start of the sequence +func (b *DiTBlock) ForwardWithRoPE(x *mlx.Array, temb *mlx.Array, hiddenDim int32, cos, sin *mlx.Array, contextLen int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + + // AdaLN modulation: norm1 produces 12 parameters per hidden dim + modParams := mlx.Matmul(temb, mlx.Transpose(b.Norm1Linear, 1, 0)) + if b.Norm1LinearBias != nil { + modParams = mlx.Add(modParams, b.Norm1LinearBias) + } + modParams = mlx.Reshape(modParams, B, 1, -1) + + modDim := modParams.Shape()[2] + + // Debug: print modDim vs expected once + if debugOnce { + fmt.Printf(" [DEBUG] modDim=%d, 12*hiddenDim=%d\n", modDim, 12*hiddenDim) + debugOnce = false + } + + if modDim >= 12*hiddenDim { + // Extract 12 modulation parameters (NO tanh on gates) + shiftMsa := mlx.Slice(modParams, []int32{0, 0, 0}, []int32{B, 1, hiddenDim}) + cShiftMsa := mlx.Slice(modParams, []int32{0, 0, hiddenDim}, []int32{B, 1, 2 * hiddenDim}) + scaleMsa := mlx.Slice(modParams, []int32{0, 0, 2 * hiddenDim}, []int32{B, 1, 3 * hiddenDim}) + cScaleMsa := mlx.Slice(modParams, []int32{0, 0, 3 * hiddenDim}, []int32{B, 1, 4 * hiddenDim}) + gateMsa := mlx.Slice(modParams, []int32{0, 0, 4 * hiddenDim}, []int32{B, 1, 5 * hiddenDim}) + cGateMsa := mlx.Slice(modParams, []int32{0, 0, 5 * hiddenDim}, []int32{B, 1, 6 * hiddenDim}) + + shiftMlp := mlx.Slice(modParams, []int32{0, 0, 6 * hiddenDim}, []int32{B, 1, 7 * hiddenDim}) + cShiftMlp := mlx.Slice(modParams, []int32{0, 0, 7 * hiddenDim}, []int32{B, 1, 8 * hiddenDim}) + scaleMlp := mlx.Slice(modParams, []int32{0, 0, 8 * hiddenDim}, []int32{B, 1, 9 * hiddenDim}) + cScaleMlp := mlx.Slice(modParams, []int32{0, 0, 9 * hiddenDim}, []int32{B, 1, 10 * hiddenDim}) + gateMlp := mlx.Slice(modParams, []int32{0, 0, 10 * hiddenDim}, []int32{B, 1, 11 * hiddenDim}) + cGateMlp := mlx.Slice(modParams, []int32{0, 0, 11 * hiddenDim}, []int32{B, 1, 12 * hiddenDim}) + + // Apply LayerNorm + xNorm := layerNorm(x) + + // Split context (prior + text) and image tokens + imgLen := L - contextLen + + // Apply different modulation to context vs image tokens + // Image tokens: use regular parameters + var xMod *mlx.Array + if contextLen > 0 && imgLen > 0 { + contextNorm := mlx.Slice(xNorm, []int32{0, 0, 0}, []int32{B, contextLen, hiddenDim}) + imgNorm := mlx.Slice(xNorm, []int32{0, contextLen, 0}, []int32{B, L, hiddenDim}) + + // Modulate context: (1 + c_scale) * x + c_shift + contextMod := mlx.Mul(contextNorm, mlx.AddScalar(cScaleMsa, 1.0)) + contextMod = mlx.Add(contextMod, cShiftMsa) + + // Modulate image: (1 + scale) * x + shift + imgMod := mlx.Mul(imgNorm, mlx.AddScalar(scaleMsa, 1.0)) + imgMod = mlx.Add(imgMod, shiftMsa) + + xMod = mlx.Concatenate([]*mlx.Array{contextMod, imgMod}, 1) + } else { + // All tokens treated the same (shouldn't happen normally) + xMod = mlx.Mul(xNorm, mlx.AddScalar(scaleMsa, 1.0)) + xMod = mlx.Add(xMod, shiftMsa) + } + + // Self-attention with RoPE + attnOut := b.Attn1.ForwardWithRoPE(xMod, cos, sin) + + // Apply different gates to context vs image + if contextLen > 0 && imgLen > 0 { + contextAttn := mlx.Slice(attnOut, []int32{0, 0, 0}, []int32{B, contextLen, hiddenDim}) + imgAttn := mlx.Slice(attnOut, []int32{0, contextLen, 0}, []int32{B, L, hiddenDim}) + + contextAttn = mlx.Mul(contextAttn, cGateMsa) + imgAttn = mlx.Mul(imgAttn, gateMsa) + + attnOut = mlx.Concatenate([]*mlx.Array{contextAttn, imgAttn}, 1) + } else { + attnOut = mlx.Mul(attnOut, gateMsa) + } + + x = mlx.Add(x, attnOut) + + // FFN with modulation + xNorm = layerNorm(x) + + if contextLen > 0 && imgLen > 0 { + contextNorm := mlx.Slice(xNorm, []int32{0, 0, 0}, []int32{B, contextLen, hiddenDim}) + imgNorm := mlx.Slice(xNorm, []int32{0, contextLen, 0}, []int32{B, L, hiddenDim}) + + // Modulate context + contextMod := mlx.Mul(contextNorm, mlx.AddScalar(cScaleMlp, 1.0)) + contextMod = mlx.Add(contextMod, cShiftMlp) + + // Modulate image + imgMod := mlx.Mul(imgNorm, mlx.AddScalar(scaleMlp, 1.0)) + imgMod = mlx.Add(imgMod, shiftMlp) + + xMod = mlx.Concatenate([]*mlx.Array{contextMod, imgMod}, 1) + } else { + xMod = mlx.Mul(xNorm, mlx.AddScalar(scaleMlp, 1.0)) + xMod = mlx.Add(xMod, shiftMlp) + } + + ffOut := b.FF.Forward(xMod) + + // Apply gates + if contextLen > 0 && imgLen > 0 { + contextFF := mlx.Slice(ffOut, []int32{0, 0, 0}, []int32{B, contextLen, hiddenDim}) + imgFF := mlx.Slice(ffOut, []int32{0, contextLen, 0}, []int32{B, L, hiddenDim}) + + contextFF = mlx.Mul(contextFF, cGateMlp) + imgFF = mlx.Mul(imgFF, gateMlp) + + ffOut = mlx.Concatenate([]*mlx.Array{contextFF, imgFF}, 1) + } else { + ffOut = mlx.Mul(ffOut, gateMlp) + } + + x = mlx.Add(x, ffOut) + } else { + // Fallback path without full modulation (shouldn't happen for GLM-Image) + xNorm := layerNorm(x) + attnOut := b.Attn1.ForwardWithRoPE(xNorm, cos, sin) + x = mlx.Add(x, attnOut) + + xNorm = layerNorm(x) + ffOut := b.FF.Forward(xNorm) + x = mlx.Add(x, ffOut) + } + + return x +} + +// Forward for DiTAttention with optional RoPE +func (attn *DiTAttention) Forward(x *mlx.Array) *mlx.Array { + return attn.ForwardWithRoPE(x, nil, nil) +} + +// ForwardWithRoPE applies attention with rotary position embeddings +// RoPE is applied ONLY to image tokens (after contextLen positions) +// cos, sin have shape [1, imgLen, 1, headDim] for image tokens only +func (attn *DiTAttention) ForwardWithRoPE(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // Q, K, V projections + q := mlx.Matmul(x, mlx.Transpose(attn.ToQ, 1, 0)) + if attn.QBias != nil { + q = mlx.Add(q, attn.QBias) + } + k := mlx.Matmul(x, mlx.Transpose(attn.ToK, 1, 0)) + if attn.KBias != nil { + k = mlx.Add(k, attn.KBias) + } + v := mlx.Matmul(x, mlx.Transpose(attn.ToV, 1, 0)) + if attn.VBias != nil { + v = mlx.Add(v, attn.VBias) + } + + // Reshape to [B, L, nheads, head_dim] + q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim) + v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim) + + // Transpose to [B, nheads, L, head_dim] (before RoPE for easier slicing) + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // Apply RoPE to image tokens only (like CogView4) + // cos, sin are for image tokens only [1, imgLen, 1, headDim] + if cos != nil && sin != nil { + imgLen := cos.Shape()[1] + contextLen := L - imgLen + + if contextLen >= 0 && imgLen > 0 { + // Split Q and K into context and image parts + qContext := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, attn.NHeads, contextLen, attn.HeadDim}) + qImg := mlx.Slice(q, []int32{0, 0, contextLen, 0}, []int32{B, attn.NHeads, L, attn.HeadDim}) + + kContext := mlx.Slice(k, []int32{0, 0, 0, 0}, []int32{B, attn.NHeads, contextLen, attn.HeadDim}) + kImg := mlx.Slice(k, []int32{0, 0, contextLen, 0}, []int32{B, attn.NHeads, L, attn.HeadDim}) + + // Apply RoPE only to image tokens + // cos, sin: [1, imgLen, 1, headDim] -> need to transpose for [B, nheads, imgLen, headDim] + cosT := mlx.Transpose(cos, 0, 2, 1, 3) // [1, 1, imgLen, headDim] + sinT := mlx.Transpose(sin, 0, 2, 1, 3) + + qImgRope := applyRoPETransposed(qImg, cosT, sinT) + kImgRope := applyRoPETransposed(kImg, cosT, sinT) + + // Reconstruct full Q and K + q = mlx.Concatenate([]*mlx.Array{qContext, qImgRope}, 2) + k = mlx.Concatenate([]*mlx.Array{kContext, kImgRope}, 2) + } + } + + // SDPA (no causal mask for diffusion - all tokens attend to all) + out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) + + // Transpose back and reshape + out = mlx.Transpose(out, 0, 2, 1, 3) + out = mlx.Reshape(out, B, L, D) + + // Output projection + out = mlx.Matmul(out, mlx.Transpose(attn.ToOut, 1, 0)) + if attn.ToOutBias != nil { + out = mlx.Add(out, attn.ToOutBias) + } + + return out +} + +// applyRoPETransposed applies RoPE when Q/K are in [B, nheads, L, headDim] format +// Uses split-half approach (use_real_unbind_dim=-2) to match diffusers GLM-Image +func applyRoPETransposed(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { + // x: [B, nheads, L, headDim] + // cos, sin: [1, 1, L, headDim] (first half == second half, duplicated) + shape := x.Shape() + B := shape[0] + nHeads := shape[1] + L := shape[2] + headDim := shape[3] + halfDim := headDim / 2 + + // Split x into first and second half + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, nHeads, L, halfDim}) + x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{B, nHeads, L, headDim}) + + // Get first half of cos/sin (they're duplicated, so first half == second half) + cosHalf := mlx.Slice(cos, []int32{0, 0, 0, 0}, []int32{1, 1, L, halfDim}) + sinHalf := mlx.Slice(sin, []int32{0, 0, 0, 0}, []int32{1, 1, L, halfDim}) + + // Apply rotation: out1 = x1*cos - x2*sin, out2 = x2*cos + x1*sin + out1 := mlx.Sub(mlx.Mul(x1, cosHalf), mlx.Mul(x2, sinHalf)) + out2 := mlx.Add(mlx.Mul(x2, cosHalf), mlx.Mul(x1, sinHalf)) + + // Concatenate back to full dimension + return mlx.Concatenate([]*mlx.Array{out1, out2}, 3) +} + +// geluApproximate implements GELU with tanh approximation (matches diffusers "gelu-approximate") +// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +func geluApproximate(x *mlx.Array) *mlx.Array { + // Constants + sqrt2OverPi := float32(0.7978845608) // sqrt(2/π) + coeff := float32(0.044715) + + // x³ + x3 := mlx.Mul(mlx.Mul(x, x), x) + + // inner = sqrt(2/π) * (x + 0.044715 * x³) + inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi) + + // 0.5 * x * (1 + tanh(inner)) + return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0)) +} + +// Forward for DiTFeedForward +func (ff *DiTFeedForward) Forward(x *mlx.Array) *mlx.Array { + // GELU approximate MLP (matches diffusers activation_fn="gelu-approximate") + h := mlx.Matmul(x, mlx.Transpose(ff.Linear1, 1, 0)) + if ff.Bias1 != nil { + h = mlx.Add(h, ff.Bias1) + } + h = geluApproximate(h) + + h = mlx.Matmul(h, mlx.Transpose(ff.Linear2, 1, 0)) + if ff.Bias2 != nil { + h = mlx.Add(h, ff.Bias2) + } + + return h +} + +// Forward for DiTMLP (projector MLPs with GELU - used for glyph_projector) +func (m *DiTMLP) Forward(x *mlx.Array) *mlx.Array { + h := mlx.Matmul(x, mlx.Transpose(m.Linear1, 1, 0)) + if m.Bias1 != nil { + h = mlx.Add(h, m.Bias1) + } + h = mlx.GELU(h) + + h = mlx.Matmul(h, mlx.Transpose(m.Linear2, 1, 0)) + if m.Bias2 != nil { + h = mlx.Add(h, m.Bias2) + } + + return h +} + +// Forward for DiTMLPSiLU (projector MLPs with SiLU - used for prior_projector) +func (m *DiTMLPSiLU) Forward(x *mlx.Array) *mlx.Array { + h := mlx.Matmul(x, mlx.Transpose(m.Linear1, 1, 0)) + if m.Bias1 != nil { + h = mlx.Add(h, m.Bias1) + } + h = mlx.SiLU(h) // SiLU activation for prior_projector (matches diffusers "linear-silu") + + h = mlx.Matmul(h, mlx.Transpose(m.Linear2, 1, 0)) + if m.Bias2 != nil { + h = mlx.Add(h, m.Bias2) + } + + return h +} + +// RoPE2DCache holds precomputed RoPE values for 2D image positions +type RoPE2DCache struct { + Cos *mlx.Array // [1, L, 1, head_dim] + Sin *mlx.Array // [1, L, 1, head_dim] +} + +// ComputeUnifiedRoPE computes RoPE for the full unified sequence (prior + text + image) +// Prior and text tokens get sequential 1D positions (h=0, w=index) +// Image tokens get 2D grid positions (h, w) from patch grid +func ComputeUnifiedRoPE(priorLen, textLen, pH, pW, headDim int32, theta float32) *RoPE2DCache { + imgLen := pH * pW + totalLen := priorLen + textLen + imgLen + + // Split head_dim between h and w dimensions + dimH := headDim / 2 + dimW := headDim / 2 + + // Compute inverse frequencies + hFreqs := make([]float32, dimH/2) + for i := int32(0); i < dimH/2; i++ { + hFreqs[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(dimH))) + } + + wFreqs := make([]float32, dimW/2) + for i := int32(0); i < dimW/2; i++ { + wFreqs[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(dimW))) + } + + cosVals := make([]float32, totalLen*headDim) + sinVals := make([]float32, totalLen*headDim) + + // Prior tokens: h=0, w=idx (sequential positions on w axis) + for idx := int32(0); idx < priorLen; idx++ { + offset := idx * headDim + h := float32(0) + w := float32(idx) + + for i := int32(0); i < dimH/2; i++ { + angle := h * hFreqs[i] + cosVals[offset+2*i] = float32(math.Cos(float64(angle))) + cosVals[offset+2*i+1] = float32(math.Cos(float64(angle))) + sinVals[offset+2*i] = float32(math.Sin(float64(angle))) + sinVals[offset+2*i+1] = float32(math.Sin(float64(angle))) + } + for i := int32(0); i < dimW/2; i++ { + angle := w * wFreqs[i] + idx2 := dimH + 2*i + cosVals[offset+idx2] = float32(math.Cos(float64(angle))) + cosVals[offset+idx2+1] = float32(math.Cos(float64(angle))) + sinVals[offset+idx2] = float32(math.Sin(float64(angle))) + sinVals[offset+idx2+1] = float32(math.Sin(float64(angle))) + } + } + + // Text tokens: h=0, w=priorLen+idx (continue sequential positions) + for idx := int32(0); idx < textLen; idx++ { + offset := (priorLen + idx) * headDim + h := float32(0) + w := float32(priorLen + idx) + + for i := int32(0); i < dimH/2; i++ { + angle := h * hFreqs[i] + cosVals[offset+2*i] = float32(math.Cos(float64(angle))) + cosVals[offset+2*i+1] = float32(math.Cos(float64(angle))) + sinVals[offset+2*i] = float32(math.Sin(float64(angle))) + sinVals[offset+2*i+1] = float32(math.Sin(float64(angle))) + } + for i := int32(0); i < dimW/2; i++ { + angle := w * wFreqs[i] + idx2 := dimH + 2*i + cosVals[offset+idx2] = float32(math.Cos(float64(angle))) + cosVals[offset+idx2+1] = float32(math.Cos(float64(angle))) + sinVals[offset+idx2] = float32(math.Sin(float64(angle))) + sinVals[offset+idx2+1] = float32(math.Sin(float64(angle))) + } + } + + // Image tokens: 2D grid positions (h, w) + for hPos := int32(0); hPos < pH; hPos++ { + for wPos := int32(0); wPos < pW; wPos++ { + patchIdx := hPos*pW + wPos + offset := (priorLen + textLen + patchIdx) * headDim + h := float32(hPos) + w := float32(wPos) + + for i := int32(0); i < dimH/2; i++ { + angle := h * hFreqs[i] + cosVals[offset+2*i] = float32(math.Cos(float64(angle))) + cosVals[offset+2*i+1] = float32(math.Cos(float64(angle))) + sinVals[offset+2*i] = float32(math.Sin(float64(angle))) + sinVals[offset+2*i+1] = float32(math.Sin(float64(angle))) + } + for i := int32(0); i < dimW/2; i++ { + angle := w * wFreqs[i] + idx2 := dimH + 2*i + cosVals[offset+idx2] = float32(math.Cos(float64(angle))) + cosVals[offset+idx2+1] = float32(math.Cos(float64(angle))) + sinVals[offset+idx2] = float32(math.Sin(float64(angle))) + sinVals[offset+idx2+1] = float32(math.Sin(float64(angle))) + } + } + } + + cos := mlx.NewArray(cosVals, []int32{1, totalLen, 1, headDim}) + sin := mlx.NewArray(sinVals, []int32{1, totalLen, 1, headDim}) + + cos = mlx.ToBFloat16(cos) + sin = mlx.ToBFloat16(sin) + + return &RoPE2DCache{Cos: cos, Sin: sin} +} + +// ComputeRoPE2D computes 2D rotary position embeddings for image patches +// Matches the diffusers GlmImageRotaryPosEmbed implementation exactly. +// pH, pW: patch grid dimensions (height, width in patches) +// headDim: attention head dimension (40 for GLM-Image) +// theta: RoPE theta (10000.0 for GLM-Image) +func ComputeRoPE2D(pH, pW, headDim int32, theta float32) *RoPE2DCache { + // Split head_dim between h and w dimensions + // For headDim=40: dimH = dimW = 20 + dimH := headDim / 2 + dimW := headDim / 2 + + // Compute inverse frequencies matching diffusers GlmImageRotaryPosEmbed: + // h_inv_freq = 1.0 / (theta ** (arange(0, dim_h, 2)[:dim_h // 2] / dim_h)) + // For dim_h=20: arange(0, 20, 2) = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + // [:10] keeps all 10 values, so numHFreqs = dim_h / 2 = 10 + numHFreqs := dimH / 2 + hFreqs := make([]float32, numHFreqs) + for i := int32(0); i < numHFreqs; i++ { + // exponent = (2*i) / dim_h + hFreqs[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(dimH))) + } + + numWFreqs := dimW / 2 + wFreqs := make([]float32, numWFreqs) + for i := int32(0); i < numWFreqs; i++ { + wFreqs[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(dimW))) + } + + // Build the full frequency tensor + numPatches := pH * pW + halfDim := headDim / 2 // dim_h/2 + dim_w/2 = headDim/2 + cosVals := make([]float32, numPatches*headDim) + sinVals := make([]float32, numPatches*headDim) + + for h := int32(0); h < pH; h++ { + for w := int32(0); w < pW; w++ { + patchIdx := h*pW + w + offset := patchIdx * headDim + + // Compute freqs for this position + // freqs = [freqs_h, freqs_w, freqs_h, freqs_w] (duplicated) + // First half: [freqs_h, freqs_w] + // Second half: same as first half + + // Height frequencies (first dim_h/2 values) + for i := int32(0); i < numHFreqs; i++ { + angle := float32(h) * hFreqs[i] + cosVals[offset+i] = float32(math.Cos(float64(angle))) + sinVals[offset+i] = float32(math.Sin(float64(angle))) + } + + // Width frequencies (next dim_w/2 values) + for i := int32(0); i < numWFreqs; i++ { + angle := float32(w) * wFreqs[i] + idx := numHFreqs + i + cosVals[offset+idx] = float32(math.Cos(float64(angle))) + sinVals[offset+idx] = float32(math.Sin(float64(angle))) + } + + // Duplicate for second half (freqs = cat([freqs, freqs], -1)) + for i := int32(0); i < halfDim; i++ { + cosVals[offset+halfDim+i] = cosVals[offset+i] + sinVals[offset+halfDim+i] = sinVals[offset+i] + } + } + } + + cos := mlx.NewArray(cosVals, []int32{1, numPatches, 1, headDim}) + sin := mlx.NewArray(sinVals, []int32{1, numPatches, 1, headDim}) + + cos = mlx.ToBFloat16(cos) + sin = mlx.ToBFloat16(sin) + + return &RoPE2DCache{Cos: cos, Sin: sin} +} + +// applyRoPE2D applies 2D rotary position embeddings to Q or K +// Uses split-half approach (use_real_unbind_dim=-2) to match diffusers GLM-Image +// x: [B, L, nheads, head_dim] +// cos, sin: [1, L, 1, head_dim] (first half == second half, duplicated) +func applyRoPE2D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { + // Split-half RoPE (use_real_unbind_dim=-2): + // x1 = x[..., :head_dim/2], x2 = x[..., head_dim/2:] + // output = cat([x1*cos - x2*sin, x2*cos + x1*sin], dim=-1) + // Since cos/sin are duplicated (first half == second half), we use half values + + shape := x.Shape() + B := shape[0] + L := shape[1] + nHeads := shape[2] + headDim := shape[3] + halfDim := headDim / 2 + + // Split x into first and second half + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, nHeads, halfDim}) + x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{B, L, nHeads, headDim}) + + // Get first half of cos/sin (they're duplicated, so first half == second half) + cosHalf := mlx.Slice(cos, []int32{0, 0, 0, 0}, []int32{1, L, 1, halfDim}) + sinHalf := mlx.Slice(sin, []int32{0, 0, 0, 0}, []int32{1, L, 1, halfDim}) + + // Apply rotation: out1 = x1*cos - x2*sin, out2 = x2*cos + x1*sin + out1 := mlx.Sub(mlx.Mul(x1, cosHalf), mlx.Mul(x2, sinHalf)) + out2 := mlx.Add(mlx.Mul(x2, cosHalf), mlx.Mul(x1, sinHalf)) + + // Concatenate back to full dimension + return mlx.Concatenate([]*mlx.Array{out1, out2}, 3) +} diff --git a/x/imagegen/models/glm_image/vae.go b/x/imagegen/models/glm_image/vae.go new file mode 100644 index 000000000..1ab5fb608 --- /dev/null +++ b/x/imagegen/models/glm_image/vae.go @@ -0,0 +1,477 @@ +//go:build mlx + +package glm_image + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// VAEConfig holds VAE decoder configuration +type VAEConfig struct { + InChannels int32 `json:"in_channels"` // 3 + OutChannels int32 `json:"out_channels"` // 3 + LatentChannels int32 `json:"latent_channels"` // 16 + BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024] + LayersPerBlock int32 `json:"layers_per_block"` // 3 + NormNumGroups int32 `json:"norm_num_groups"` // 32 + ScalingFactor float32 `json:"scaling_factor"` // 0.18215 + ShiftFactor *float32 `json:"shift_factor"` // null + LatentsMean []float32 `json:"latents_mean"` // [16 values] + LatentsStd []float32 `json:"latents_std"` // [16 values] +} + +// VAEDecoder is the VAE latent decoder +type VAEDecoder struct { + Config *VAEConfig + + // Decoder components + ConvIn *VAEConv2d `weight:"decoder.conv_in"` + MidBlock *VAEMidBlock `weight:"decoder.mid_block"` + UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"` + ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"` + ConvOut *VAEConv2d `weight:"decoder.conv_out"` +} + +// VAEConv2d is a 2D convolution layer +type VAEConv2d struct { + Weight *mlx.Array `weight:"weight"` + Bias *mlx.Array `weight:"bias"` + Stride int32 + Padding int32 +} + +// GroupNorm is group normalization +type GroupNorm struct { + Weight *mlx.Array `weight:"weight"` + Bias *mlx.Array `weight:"bias"` + NumGroups int32 + Eps float32 +} + +// VAEMidBlock is the middle block of the VAE +type VAEMidBlock struct { + Resnets []*VAEResnetBlock `weight:"resnets"` +} + +// VAEUpBlock is an upsampling block +type VAEUpBlock struct { + Resnets []*VAEResnetBlock `weight:"resnets"` + Upsamplers []*VAEUpsampler `weight:"upsamplers"` +} + +// VAEResnetBlock is a residual block +type VAEResnetBlock struct { + Norm1 *GroupNorm `weight:"norm1"` + Conv1 *VAEConv2d `weight:"conv1"` + Norm2 *GroupNorm `weight:"norm2"` + Conv2 *VAEConv2d `weight:"conv2"` + ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch +} + +// VAEUpsampler is an upsampling layer +type VAEUpsampler struct { + Conv *VAEConv2d `weight:"conv"` +} + +// Load loads the VAE decoder from manifest +func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error { + fmt.Print(" Loading VAE decoder... ") + + // Load config + var cfg VAEConfig + if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil { + return fmt.Errorf("config: %w", err) + } + m.Config = &cfg + + // Initialize structure based on config + numBlocks := len(cfg.BlockOutChannels) + m.UpBlocks = make([]*VAEUpBlock, numBlocks) + + // Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets) + m.MidBlock = &VAEMidBlock{ + Resnets: make([]*VAEResnetBlock, 2), + } + + // Pre-allocate UpBlocks with their resnets and upsamplers + // VAE decoder has layers_per_block+1 resnets per up_block (to match encoder) + // And all but the last up_block has an upsampler + for i := 0; i < numBlocks; i++ { + numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets + m.UpBlocks[i] = &VAEUpBlock{ + Resnets: make([]*VAEResnetBlock, numResnets), + } + // All but the last block has upsamplers + if i < numBlocks-1 { + m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1) + } + } + + // Load weights + weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae") + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + // Initialize GroupNorm parameters + m.initGroupNorms() + + fmt.Println("✓") + return nil +} + +// LoadFromPath loads the VAE decoder from a directory path +func (m *VAEDecoder) LoadFromPath(path string) error { + fmt.Print(" Loading VAE decoder... ") + + // Load config + var cfg VAEConfig + configPath := filepath.Join(path, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + if err := json.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("parse config: %w", err) + } + m.Config = &cfg + + // Initialize structure based on config + numBlocks := len(cfg.BlockOutChannels) + m.UpBlocks = make([]*VAEUpBlock, numBlocks) + + // Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets) + m.MidBlock = &VAEMidBlock{ + Resnets: make([]*VAEResnetBlock, 2), + } + + // Pre-allocate UpBlocks with their resnets and upsamplers + for i := 0; i < numBlocks; i++ { + numResnets := cfg.LayersPerBlock + 1 + m.UpBlocks[i] = &VAEUpBlock{ + Resnets: make([]*VAEResnetBlock, numResnets), + } + if i < numBlocks-1 { + m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1) + } + } + + // Load weights from safetensors files + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + // Initialize GroupNorm parameters + m.initGroupNorms() + + fmt.Println("✓") + return nil +} + +func (m *VAEDecoder) initGroupNorms() { + cfg := m.Config + numGroups := cfg.NormNumGroups + eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5) + + if m.ConvNormOut != nil { + m.ConvNormOut.NumGroups = numGroups + m.ConvNormOut.Eps = eps + } + + if m.MidBlock != nil { + for _, resnet := range m.MidBlock.Resnets { + if resnet.Norm1 != nil { + resnet.Norm1.NumGroups = numGroups + resnet.Norm1.Eps = eps + } + if resnet.Norm2 != nil { + resnet.Norm2.NumGroups = numGroups + resnet.Norm2.Eps = eps + } + } + } + + for _, upBlock := range m.UpBlocks { + if upBlock == nil { + continue + } + for _, resnet := range upBlock.Resnets { + if resnet == nil { + continue + } + if resnet.Norm1 != nil { + resnet.Norm1.NumGroups = numGroups + resnet.Norm1.Eps = eps + } + if resnet.Norm2 != nil { + resnet.Norm2.NumGroups = numGroups + resnet.Norm2.Eps = eps + } + } + } +} + +// Decode decodes latents to an image +func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array { + cfg := m.Config + + // Apply latent denormalization if mean/std are provided + // This matches diffusers GLM-Image: latents = latents * std + mean + // Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs) + if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 { + latents = m.denormalizeLatents(latents) + } + + // Convert from NCHW to NHWC for processing + // [B, C, H, W] -> [B, H, W, C] + x := mlx.Transpose(latents, 0, 2, 3, 1) + + // Initial convolution + x = m.ConvIn.Forward(x) + + // Mid block + x = m.MidBlock.Forward(x) + + // Up blocks (forward order - index 0 is at lowest resolution/highest channels) + for i := 0; i < len(m.UpBlocks); i++ { + if m.UpBlocks[i] != nil { + x = m.UpBlocks[i].Forward(x) + } + } + + // Final normalization and convolution + x = m.ConvNormOut.Forward(x) + x = mlx.SiLU(x) + x = m.ConvOut.Forward(x) + + // Convert back to NCHW + // [B, H, W, C] -> [B, C, H, W] + x = mlx.Transpose(x, 0, 3, 1, 2) + + // Clamp to valid range and convert to [0, 1] + x = mlx.ClipScalar(x, -1.0, 1.0, true, true) + x = mlx.AddScalar(x, 1.0) + x = mlx.DivScalar(x, 2.0) + + return x +} + +// denormalizeLatents applies the latent mean/std denormalization +func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array { + cfg := m.Config + + // Create mean and std arrays [1, C, 1, 1] for broadcasting + mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1}) + std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1}) + + // Denormalize: latents * std + mean + latents = mlx.Mul(latents, std) + latents = mlx.Add(latents, mean) + + return latents +} + +// Forward for VAEConv2d +func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array { + // x: [B, H, W, C_in] (NHWC) + // PyTorch weight: [C_out, C_in, kH, kW] (OIHW) + // MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI) + // So we need to transpose from OIHW to OHWI + + stride := c.Stride + if stride == 0 { + stride = 1 + } + padding := c.Padding + if padding == 0 { + // Default to same padding for 3x3 kernels + wShape := c.Weight.Shape() + if len(wShape) >= 3 && wShape[2] == 3 { + padding = 1 + } + } + + // Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in] + weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) + + out := mlx.Conv2d(x, weight, stride, padding) + if c.Bias != nil { + // Bias: [C_out] -> [1, 1, 1, C_out] + bias := mlx.Reshape(c.Bias, 1, 1, 1, -1) + out = mlx.Add(out, bias) + } + return out +} + +// Forward for GroupNorm +func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array { + // x: [B, H, W, C] (NHWC) + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + C := shape[3] + + numGroups := gn.NumGroups + if numGroups == 0 { + numGroups = 32 + } + groupSize := C / numGroups + + // Reshape to [B, H, W, groups, groupSize] + x = mlx.Reshape(x, B, H, W, numGroups, groupSize) + + // Compute mean and variance per group + mean := mlx.Mean(x, 1, true) + mean = mlx.Mean(mean, 2, true) + mean = mlx.Mean(mean, 4, true) + + xCentered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(xCentered), 1, true) + variance = mlx.Mean(variance, 2, true) + variance = mlx.Mean(variance, 4, true) + + // Normalize + xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps))) + + // Reshape back + xNorm = mlx.Reshape(xNorm, B, H, W, C) + + // Scale and shift + if gn.Weight != nil { + weight := mlx.Reshape(gn.Weight, 1, 1, 1, C) + xNorm = mlx.Mul(xNorm, weight) + } + if gn.Bias != nil { + bias := mlx.Reshape(gn.Bias, 1, 1, 1, C) + xNorm = mlx.Add(xNorm, bias) + } + + return xNorm +} + +// Forward for VAEMidBlock +func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array { + for _, resnet := range mb.Resnets { + x = resnet.Forward(x) + } + return x +} + +// Forward for VAEUpBlock +func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array { + // Apply resnets + for _, resnet := range ub.Resnets { + if resnet != nil { + x = resnet.Forward(x) + } + } + + // Apply upsamplers + for _, upsampler := range ub.Upsamplers { + if upsampler != nil { + x = upsampler.Forward(x) + } + } + + return x +} + +// Forward for VAEResnetBlock +func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array { + residual := x + + // First norm + activation + conv + h := rb.Norm1.Forward(x) + h = mlx.SiLU(h) + h = rb.Conv1.Forward(h) + + // Second norm + activation + conv + h = rb.Norm2.Forward(h) + h = mlx.SiLU(h) + h = rb.Conv2.Forward(h) + + // Shortcut for channel mismatch + if rb.ConvShortcut != nil { + residual = rb.ConvShortcut.Forward(residual) + } + + return mlx.Add(h, residual) +} + +// Forward for VAEUpsampler (2x nearest neighbor upsample + conv) +func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array { + // x: [B, H, W, C] + // 2x nearest neighbor upsample + x = upsample2x(x) + + // Conv + if us.Conv != nil { + x = us.Conv.Forward(x) + } + + return x +} + +// upsample2x performs 2x nearest neighbor upsampling. +// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C] +func upsample2x(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + C := shape[3] + + // Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor + hIndices := make([]int32, H*2) + for i := int32(0); i < H; i++ { + hIndices[i*2] = i + hIndices[i*2+1] = i + } + wIndices := make([]int32, W*2) + for i := int32(0); i < W; i++ { + wIndices[i*2] = i + wIndices[i*2+1] = i + } + + hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2}) + wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2}) + + // Take along height axis + x = mlx.Reshape(x, B*H, W, C) + x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C] + x = mlx.Reshape(x, B, H, W*2, C) + + // Take along width axis - transpose to [B, W*2, H, C], take, transpose back + x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C] + x = mlx.Reshape(x, B*(W*2), H, C) + x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C] + x = mlx.Reshape(x, B, W*2, H*2, C) + x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C] + + return x +} diff --git a/x/imagegen/models/glm_image/vision_language_encoder.go b/x/imagegen/models/glm_image/vision_language_encoder.go new file mode 100644 index 000000000..825272fd1 --- /dev/null +++ b/x/imagegen/models/glm_image/vision_language_encoder.go @@ -0,0 +1,982 @@ +//go:build mlx + +package glm_image + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// VisionLanguageConfig holds GLM-Image AR generator configuration +type VisionLanguageConfig struct { + // Text model config + HiddenSize int32 `json:"hidden_size"` // 4096 + NumHiddenLayers int32 `json:"num_hidden_layers"` // 40 + IntermediateSize int32 `json:"intermediate_size"` // 13696 + NumAttentionHeads int32 `json:"num_attention_heads"` // 32 + NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2 + VocabSize int32 `json:"vocab_size"` // 168064 + RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5 + + // RoPE config + RopeTheta float32 `json:"rope_theta"` // 10000 + PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5 + MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12] + + // Visual token config + VisionVocabSize int32 `json:"vision_vocab_size"` // 16512 + ImageStartTokenID int32 `json:"image_start_token_id"` // 16384 + ImageEndTokenID int32 `json:"image_end_token_id"` // 16385 + ImageTokenID int32 `json:"image_token_id"` // 167855 + + // Computed + HeadDim int32 +} + +// VisionLanguageEncoder is the 9B AR generator +type VisionLanguageEncoder struct { + Config *VisionLanguageConfig + + // Embedding + EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"` + + // Transformer layers + Layers []*GLMBlock `weight:"model.language_model.layers"` + + // Final norm + FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"` + + // LM Head + LMHead *mlx.Array `weight:"lm_head.weight"` +} + +// GLMBlock is a single transformer block in GLM-4 style +type GLMBlock struct { + // Pre-attention norm (GLM uses post-LN variant) + InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` + PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"` + PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` + PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"` + + // Attention + SelfAttn *GLMAttention `weight:"self_attn"` + + // MLP (fused gate_up) + MLP *GLMMLP `weight:"mlp"` +} + +// GLMAttention implements GQA with partial rotary and MRoPE +type GLMAttention struct { + QProj *mlx.Array `weight:"q_proj.weight"` + KProj *mlx.Array `weight:"k_proj.weight"` + VProj *mlx.Array `weight:"v_proj.weight"` + OProj *mlx.Array `weight:"o_proj.weight"` + + // QKV have biases in GLM + QBias *mlx.Array `weight:"q_proj.bias"` + KBias *mlx.Array `weight:"k_proj.bias"` + VBias *mlx.Array `weight:"v_proj.bias"` + + // Computed + NHeads int32 + NKVHeads int32 + HeadDim int32 + Scale float32 + PartialRotary float32 // Only rotate this fraction of head_dim + RopeTheta float32 + MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width) +} + +// ARCache holds KV caches for all layers using the shared cache implementation +type ARCache struct { + Layers []cache.Cache +} + +// NewARCache creates a new cache for the given number of layers +func NewARCache(numLayers int32) *ARCache { + layers := make([]cache.Cache, numLayers) + for i := range layers { + layers[i] = cache.NewKVCache() + } + return &ARCache{Layers: layers} +} + +// Free releases all cached tensors +func (c *ARCache) Free() { + for _, layer := range c.Layers { + for _, arr := range layer.State() { + if arr != nil { + arr.Free() + } + } + } +} + +// GLMMLP implements fused gate_up SwiGLU MLP +type GLMMLP struct { + // GLM uses fused gate_up_proj: [hidden, 2*intermediate] + GateUpProj *mlx.Array `weight:"gate_up_proj.weight"` + DownProj *mlx.Array `weight:"down_proj.weight"` +} + +// Load loads the vision-language encoder from manifest +func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error { + fmt.Print(" Loading vision-language encoder... ") + + // Load config + var rawCfg struct { + TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + VisionVocabSize int32 `json:"vision_vocab_size"` + RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + MRoPESection []int32 `json:"mrope_section"` + } `json:"rope_parameters"` + } `json:"text_config"` + ImageStartTokenID int32 `json:"image_start_token_id"` + ImageEndTokenID int32 `json:"image_end_token_id"` + ImageTokenID int32 `json:"image_token_id"` + } + + if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil { + return fmt.Errorf("config: %w", err) + } + + cfg := &VisionLanguageConfig{ + HiddenSize: rawCfg.TextConfig.HiddenSize, + NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers, + IntermediateSize: rawCfg.TextConfig.IntermediateSize, + NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads, + NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads, + VocabSize: rawCfg.TextConfig.VocabSize, + RMSNormEps: rawCfg.TextConfig.RMSNormEps, + VisionVocabSize: rawCfg.TextConfig.VisionVocabSize, + RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta, + PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor, + MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection, + ImageStartTokenID: rawCfg.ImageStartTokenID, + ImageEndTokenID: rawCfg.ImageEndTokenID, + ImageTokenID: rawCfg.ImageTokenID, + } + + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + m.Config = cfg + + // Pre-allocate layers + m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers) + + // Load weights + weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder") + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers) + return nil +} + +// LoadFromPath loads the vision-language encoder from a directory path +func (m *VisionLanguageEncoder) LoadFromPath(path string) error { + fmt.Print(" Loading vision-language encoder... ") + + // Load config + var rawCfg struct { + TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + VisionVocabSize int32 `json:"vision_vocab_size"` + RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + MRoPESection []int32 `json:"mrope_section"` + } `json:"rope_parameters"` + } `json:"text_config"` + ImageStartTokenID int32 `json:"image_start_token_id"` + ImageEndTokenID int32 `json:"image_end_token_id"` + ImageTokenID int32 `json:"image_token_id"` + } + + configPath := filepath.Join(path, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + if err := json.Unmarshal(data, &rawCfg); err != nil { + return fmt.Errorf("parse config: %w", err) + } + + cfg := &VisionLanguageConfig{ + HiddenSize: rawCfg.TextConfig.HiddenSize, + NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers, + IntermediateSize: rawCfg.TextConfig.IntermediateSize, + NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads, + NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads, + VocabSize: rawCfg.TextConfig.VocabSize, + RMSNormEps: rawCfg.TextConfig.RMSNormEps, + VisionVocabSize: rawCfg.TextConfig.VisionVocabSize, + RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta, + PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor, + MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection, + ImageStartTokenID: rawCfg.ImageStartTokenID, + ImageEndTokenID: rawCfg.ImageEndTokenID, + ImageTokenID: rawCfg.ImageTokenID, + } + + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + m.Config = cfg + + // Pre-allocate layers + m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers) + + // Load weights + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + defer weights.ReleaseAll() + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + + m.initComputedFields() + fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers) + return nil +} + +func (m *VisionLanguageEncoder) initComputedFields() { + cfg := m.Config + for _, block := range m.Layers { + block.SelfAttn.NHeads = cfg.NumAttentionHeads + block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads + block.SelfAttn.HeadDim = cfg.HeadDim + block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor + block.SelfAttn.RopeTheta = cfg.RopeTheta + block.SelfAttn.MRoPESection = cfg.MRoPESection + + // Set norm eps + block.InputLayerNorm.Eps = cfg.RMSNormEps + block.PostSelfAttnNorm.Eps = cfg.RMSNormEps + block.PostAttnLayerNorm.Eps = cfg.RMSNormEps + block.PostMLPLayerNorm.Eps = cfg.RMSNormEps + } + m.FinalNorm.Eps = cfg.RMSNormEps +} + +// Generate autoregressively generates visual tokens with KV caching +func (m *VisionLanguageEncoder) Generate( + prompt string, + tok *GLMTokenizer, + maxTokens int32, + temperature float32, + topP float32, + seed int64, + targetHeight, targetWidth int32, + progressFn func(int), +) *mlx.Array { + cfg := m.Config + + // Encode prompt with grid tokens using GLM tokenizer + // Format: {prompt}{h} {w}{prev_h} {prev_w}<|dit_token_16384|> + tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth) + + // Calculate grid dimensions for MRoPE position IDs + factor := int32(32) + tokenH := targetHeight / factor + tokenW := targetWidth / factor + ratio := float64(tokenH) / float64(tokenW) + prevTokenH := int32(math.Sqrt(ratio) * 16) + prevTokenW := int32(math.Sqrt(1.0/ratio) * 16) + prevGridSize := prevTokenH * prevTokenW + + // Create KV cache for all layers + cache := NewARCache(cfg.NumHiddenLayers) + defer cache.Free() + + // ===== PREFILL PHASE ===== + // Process entire prompt at once, populate cache + promptLen := int32(len(tokens)) + tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen}) + h := m.EmbedTokens.Forward(tokenArr) + tokenArr.Free() + + mlx.Eval(h) + + // Compute position IDs for prefill (text tokens use same position for all dims) + prefillPositions := make([][]int32, 3) + for dim := 0; dim < 3; dim++ { + prefillPositions[dim] = make([]int32, promptLen) + for i := int32(0); i < promptLen; i++ { + prefillPositions[dim][i] = i + } + } + + // Forward through layers (prefill) + for i, layer := range m.Layers { + oldH := h + h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions) + if i > 0 { + oldH.Free() + } + } + // Eval h and cache arrays together so cache is materialized + evalArgs := []*mlx.Array{h} + for _, lc := range cache.Layers { + evalArgs = append(evalArgs, lc.State()...) + } + mlx.Eval(evalArgs...) + + // Final norm and get logits for last position + preNormH := h + h = m.FinalNorm.Forward(h, cfg.RMSNormEps) + preNormH.Free() + + lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize}) + h.Free() + lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize) + logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0)) + lastH.Free() + + // Sample first token + var sampleCounter int64 = 0 + nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter) + logits.Free() + + // AR generation loop with caching + // Visual tokens are stored as VQ codebook indices [0, 16383] + // The LM head outputs indices [0, 16511] where: + // - [0, 16383] are VQ codes + // - 16384 is BOS + // - 16385 is EOS + visualTokens := make([]int32, 0, maxTokens) + posOffset := promptLen + visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation + + // Preallocate slice for old cache state to reuse + oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2) + + for i := int32(0); i < maxTokens; i++ { + if progressFn != nil { + progressFn(int(i)) + } + + // Check for end token (EOS = 16385) + if nextToken == cfg.ImageEndTokenID { + break + } + + // Skip BOS token (16384), only store actual VQ codes [0, 16383] + if nextToken == cfg.ImageStartTokenID { + // BOS token - skip storing but continue generation + } else if nextToken < cfg.ImageStartTokenID { + // This is an actual VQ code [0, 16383] - store it + visualTokens = append(visualTokens, nextToken) + } + // Tokens >= 16386 are other special tokens, skip them + + // ===== DECODE PHASE ===== + // Save old cache state before forward (to free after eval) + oldCacheState = oldCacheState[:0] + for _, lc := range cache.Layers { + oldCacheState = append(oldCacheState, lc.State()...) + } + + // Only process the new token, use cached K,V + tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1}) + h := m.EmbedTokens.Forward(tokenArr) + tokenArr.Free() + + // Compute MRoPE position IDs for this visual token + // Visual tokens are arranged in two grids: prev grid then target grid + // Position dimensions: [temporal, height, width] + decodePositions := computeVisualTokenPositions( + visualTokenIdx, posOffset, promptLen, + prevTokenH, prevTokenW, prevGridSize, + tokenH, tokenW, + ) + + // Forward through layers (decode with cache) + for j, layer := range m.Layers { + oldH := h + h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions) + if j > 0 { // Don't free the embedding on first layer + oldH.Free() + } + } + + // Eval h and new cache state + newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2) + for _, lc := range cache.Layers { + newCacheState = append(newCacheState, lc.State()...) + } + mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...) + + // Free old cache state (now that new state is evaluated) + for _, arr := range oldCacheState { + if arr != nil { + arr.Free() + } + } + + // Final norm + preNormH := h + h = m.FinalNorm.Forward(h, cfg.RMSNormEps) + preNormH.Free() + + // Get logits (h is already [1, 1, hidden_size]) + h = mlx.Reshape(h, 1, cfg.HiddenSize) + logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0)) + h.Free() + + // Sample next token + nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter) + logits.Free() + + posOffset++ + visualTokenIdx++ + + // Periodically clear cache to release intermediate memory + if i%256 == 0 { + mlx.ClearCache() + } + } + + if len(visualTokens) == 0 { + // Return at least one token to avoid empty tensor issues + visualTokens = append(visualTokens, 0) + } + + return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))}) +} + +// computeVisualTokenPositions computes MRoPE position IDs for a visual token +// Returns [3][1] position IDs for temporal, height, and width dimensions +// +// MRoPE position encoding for GLM-Image visual tokens: +// - temporal: CONSTANT within each grid (= decode_pos at grid start) +// - height: decode_pos + row index within grid +// - width: decode_pos + column index within grid +// +// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure +// sufficient positional separation. +func computeVisualTokenPositions( + visualIdx int32, absPos int32, promptLen int32, + prevH, prevW, prevSize int32, + targetH, targetW int32, +) [][]int32 { + positions := make([][]int32, 3) + for dim := 0; dim < 3; dim++ { + positions[dim] = make([]int32, 1) + } + + // First grid (prev grid) starts at decode_pos = promptLen + prevGridDecodePos := promptLen + + // Second grid (target grid) starts after first grid + // next_pos = prev_decode_pos + max(prevH, prevW) + maxPrev := prevH + if prevW > maxPrev { + maxPrev = prevW + } + targetGridDecodePos := prevGridDecodePos + maxPrev + + // Compute position IDs based on which grid the token is in + if visualIdx < prevSize { + // Token is in the prev grid (prev_token_h × prev_token_w) + row := visualIdx / prevW + col := visualIdx % prevW + + // temporal is CONSTANT for all tokens in this grid + positions[0][0] = prevGridDecodePos + // height and width are relative to grid's decode_pos + positions[1][0] = prevGridDecodePos + row + positions[2][0] = prevGridDecodePos + col + } else { + // Token is in the target grid (token_h × token_w) + targetIdx := visualIdx - prevSize + row := targetIdx / targetW + col := targetIdx % targetW + + // temporal is CONSTANT for all tokens in this grid + positions[0][0] = targetGridDecodePos + // height and width are relative to grid's decode_pos + positions[1][0] = targetGridDecodePos + row + positions[2][0] = targetGridDecodePos + col + } + + _ = targetH // Used for documentation clarity + _ = absPos // No longer used - kept for API compatibility + return positions +} + +// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling +// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs +// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table +// sampleCounter is incremented for each call to ensure different random values +func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 { + // The LMHead outputs logits for visual tokens only (shape [1, 16512]) + // Output index directly corresponds to vocab ID [0, 16511] + // No offset needed - the visual tokens are at vocab IDs [0, 16511] + visualLogits := logits + + // Apply temperature + if temperature != 1.0 && temperature > 0 { + visualLogits = mlx.DivScalar(visualLogits, temperature) + } + + // Apply softmax to get probabilities + probs := mlx.Softmax(visualLogits, -1) + mlx.Eval(probs) + + // Get the sampled index using top-p sampling + // This directly gives us the vocab ID in [0, 16511] + // Special tokens: 16384 = BOS, 16385 = EOS + // Use seed + counter for reproducible but different random values + effectiveSeed := seed + *sampleCounter + *sampleCounter++ + return sampleTopP(probs, topP, effectiveSeed) +} + +// sampleTopP implements nucleus (top-p) sampling +// probs: [1, vocab_size] probability distribution +// topP: cumulative probability threshold (e.g., 0.75) +// seed: random seed for reproducible sampling +func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 { + // Negate probs for descending sort (Argsort only does ascending) + negProbs := mlx.MulScalar(probs, -1) + sortedIndices := mlx.Argsort(negProbs, -1) + sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1) + cumProbs := mlx.Cumsum(sortedProbs, -1) + mlx.Eval(sortedIndices, sortedProbs, cumProbs) + + // Find cutoff index where cumulative probability exceeds topP + probsData := sortedProbs.Data() + cumProbsData := cumProbs.Data() + indicesData := sortedIndices.DataInt32() + + // Calculate cutoff and renormalize + var cutoffIdx int + var totalProb float32 + for i, cp := range cumProbsData { + totalProb += probsData[i] + if cp >= topP { + cutoffIdx = i + 1 // Include this token + break + } + } + if cutoffIdx == 0 { + cutoffIdx = len(probsData) // Use all tokens if topP is very high + } + + // Sample from the truncated distribution + // Renormalize the truncated probabilities + truncatedProbs := make([]float32, cutoffIdx) + for i := 0; i < cutoffIdx; i++ { + truncatedProbs[i] = probsData[i] / totalProb + } + + // Sample using random number with provided seed for reproducibility + r := mlx.RandomUniform([]int32{1}, uint64(seed)) + mlx.Eval(r) + randVal := r.Data()[0] + + // Find the sampled token + var cumulative float32 + for i := 0; i < cutoffIdx; i++ { + cumulative += truncatedProbs[i] + if randVal < cumulative { + return indicesData[i] + } + } + + // Fallback to the last token in truncated set + return indicesData[cutoffIdx-1] +} + +// Forward for GLMBlock +func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array { + return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil) +} + +// ForwardWithCache performs block forward with optional KV caching and MRoPE +// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions) +func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array { + // Pre-attention norm + normed := b.InputLayerNorm.Forward(x, eps) + + // Self-attention with RoPE/MRoPE and cache + attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs) + + // Post-attention norm (GLM-4 style) + attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps) + + // Residual connection + x = mlx.Add(x, attnOut) + + // Post-attention layer norm + normed = b.PostAttnLayerNorm.Forward(x, eps) + + // MLP + mlpOut := b.MLP.Forward(normed) + + // Post-MLP norm + mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps) + + // Residual connection + x = mlx.Add(x, mlpOut) + + return x +} + +// Forward for GLMAttention (without cache - used for prefill) +func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array { + return attn.ForwardWithCache(x, seqLen, 0, nil, nil) +} + +// ForwardWithCache performs attention with optional KV caching and MRoPE +// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode) +// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode) +// kvcache is updated in-place if provided +func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + + // Q, K, V projections + q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0)) + k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0)) + v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0)) + + // Add biases + if attn.QBias != nil { + q = mlx.Add(q, attn.QBias) + } + if attn.KBias != nil { + k = mlx.Add(k, attn.KBias) + } + if attn.VBias != nil { + v = mlx.Add(v, attn.VBias) + } + + // Reshape to [B, L, nheads, head_dim] + q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim) + v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim) + + // Apply partial RoPE or MRoPE + rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary) + if len(attn.MRoPESection) == 3 && positionIDs != nil { + // Use MRoPE with explicit position IDs + q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection) + k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection) + } else if len(attn.MRoPESection) == 3 { + // Use MRoPE with sequential positions (same for all dims - text mode) + seqPositions := make([][]int32, 3) + for dim := 0; dim < 3; dim++ { + seqPositions[dim] = make([]int32, L) + for i := int32(0); i < L; i++ { + seqPositions[dim][i] = i + posOffset + } + } + q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection) + k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection) + } else { + // Fallback to standard RoPE + q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta) + k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta) + } + + // Transpose to [B, nheads, L, head_dim] + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // Update cache and get full K, V for attention + if kvcache != nil { + k, v = kvcache.Update(k, v, int(L)) + } + + // Repeat KV for GQA + kExpanded := k + vExpanded := v + if attn.NKVHeads < attn.NHeads { + repeats := attn.NHeads / attn.NKVHeads + kExpanded = repeatKV(k, repeats) + vExpanded = repeatKV(v, repeats) + } + + // Scaled dot-product attention with causal mask + out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true) + + // Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim] + out = mlx.Transpose(out, 0, 2, 1, 3) + // Reshape to [B, L, hidden_size] + out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim) + + // Output projection + out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0)) + + return out +} + +// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions +func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array { + return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta) +} + +// applyPartialRoPEWithOffset applies RoPE with a position offset +func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + H := shape[2] + D := shape[3] + + if rotaryDim <= 0 || rotaryDim > D { + rotaryDim = D + } + + // Split into rotary and pass-through parts + xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim}) + xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D}) + + // Apply RoPE to rotary part with position offset + xRot = applyRoPEWithOffset(xRot, L, posOffset, theta) + + // Concatenate back + return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3) +} + +// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions +// positionIDs: [3, L] - position indices for each dimension (temporal, height, width) +// mrope_section: [8, 12, 12] - frequency pairs per dimension +// For text tokens: all 3 dimensions have the same sequential position +// For image tokens: temporal=seq_idx, height=row, width=col +func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + H := shape[2] + D := shape[3] + + if rotaryDim <= 0 || rotaryDim > D { + rotaryDim = D + } + + // Split into rotary and pass-through parts + xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim}) + xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D}) + + // Apply MRoPE to rotary part + xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection) + + // Concatenate back + return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3) +} + +// applyMRoPE applies multi-dimensional rotary position embedding +// x: [B, L, H, D] where D is the rotary dimension +// positionIDs: [3][L] - positions for temporal, height, width dimensions +// mropeSection: [8, 12, 12] - frequency pairs per dimension +func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + H := shape[2] + D := shape[3] + half := D / 2 + + // Validate mrope_section sums to half (number of frequency pairs) + var totalPairs int32 + for _, s := range mropeSection { + totalPairs += s + } + if totalPairs != half { + // Fallback to standard RoPE if section doesn't match + return applyRoPEWithOffset(x, L, 0, theta) + } + + // Build angles for each position dimension (matching Python's MRoPE approach) + // Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate + // Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12] + angleVals := make([]*mlx.Array, 3) + + freqOffset := int32(0) + for dim := 0; dim < 3; dim++ { + numPairs := mropeSection[dim] + if numPairs == 0 { + continue + } + + // Compute inverse frequencies for this section + // Each dimension uses DIFFERENT frequency ranges: + // - Temporal: frequencies 0 to section[0]-1 + // - Height: frequencies section[0] to section[0]+section[1]-1 + // - Width: frequencies section[0]+section[1] to sum(section)-1 + freqsArr := make([]float32, numPairs) + for i := int32(0); i < numPairs; i++ { + globalIdx := freqOffset + i + freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D))) + } + freqs := mlx.NewArray(freqsArr, []int32{numPairs}) + + // Position indices for this dimension + posArr := make([]float32, L) + for i := int32(0); i < L; i++ { + posArr[i] = float32(positionIDs[dim][i]) + } + pos := mlx.NewArray(posArr, []int32{L}) + + // Compute angles: [L, numPairs] = outer(pos, freqs) + posExpanded := mlx.Reshape(pos, L, 1) + freqsExpanded := mlx.Reshape(freqs, 1, numPairs) + angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded) + + freqOffset += numPairs + } + + // Concatenate all sections: [L, half] = [L, 32] + allAngles := mlx.Concatenate(angleVals, 1) + + // Duplicate AFTER concatenation: [L, D] = [L, 64] + // This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12] + allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1) + + // Compute cos/sin + allCos := mlx.Cos(allAngles) + allSin := mlx.Sin(allAngles) + + // Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D] + allCos = mlx.Reshape(allCos, 1, L, 1, D) + allSin = mlx.Reshape(allSin, 1, L, 1, D) + + // x_rotated = cat([-x_imag, x_real], dim=-1) + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real + x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag + x2Neg := mlx.MulScalar(x2, -1) // -x_imag + xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real] + + // out = x * cos + x_rotated * sin + return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin)) +} + +// applyRoPE applies rotary position embedding +func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array { + return applyRoPEWithOffset(x, seqLen, 0, theta) +} + +// applyRoPEWithOffset applies rotary position embedding with position offset +// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2) +func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + H := shape[2] + D := shape[3] + half := D / 2 + + // Compute inverse frequencies: 1 / (theta^(2i/d)) + freqsArr := make([]float32, half) + for i := int32(0); i < half; i++ { + freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D))) + } + freqs := mlx.NewArray(freqsArr, []int32{half}) + + // Position indices with offset + posArr := make([]float32, L) + for i := int32(0); i < L; i++ { + posArr[i] = float32(i + posOffset) + } + pos := mlx.NewArray(posArr, []int32{L}) + + // Compute angles: [L, half] = outer(pos, freqs) + posExpanded := mlx.Reshape(pos, L, 1) + freqsExpanded := mlx.Reshape(freqs, 1, half) + angles := mlx.Mul(posExpanded, freqsExpanded) + + // Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D] + anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1) + + // Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D] + cosVals := mlx.Cos(anglesDup) + sinVals := mlx.Sin(anglesDup) + cosVals = mlx.Reshape(cosVals, L, 1, D) + sinVals = mlx.Reshape(sinVals, L, 1, D) + + // x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:] + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real + x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag + x2Neg := mlx.MulScalar(x2, -1) // -x_imag + xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real] + + // out = x * cos + x_rotated * sin + return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals)) +} + +// repeatKV repeats key/value heads for GQA +func repeatKV(x *mlx.Array, repeats int32) *mlx.Array { + if repeats == 1 { + return x + } + shape := x.Shape() + // x: [B, nkvheads, L, head_dim] + x = mlx.ExpandDims(x, 2) + // x: [B, nkvheads, 1, L, head_dim] + x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1}) + // x: [B, nkvheads, repeats, L, head_dim] + return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3]) +} + +// Forward for GLMMLP (fused gate_up SwiGLU) +func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array { + // gate_up_proj outputs [gate, up] concatenated + gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0)) + + shape := gateUp.Shape() + halfDim := shape[len(shape)-1] / 2 + + // Split into gate and up + gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim}) + up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]}) + + // SwiGLU: silu(gate) * up + gate = mlx.SiLU(gate) + h := mlx.Mul(gate, up) + + // Down projection + return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0)) +} diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go index d00748188..7399b0323 100644 --- a/x/imagegen/runner/runner.go +++ b/x/imagegen/runner/runner.go @@ -19,9 +19,15 @@ import ( "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/glm_image" "github.com/ollama/ollama/x/imagegen/models/zimage" ) +// ImageModel is the interface for image generation models +type ImageModel interface { + GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) +} + // Request is the image generation request format type Request struct { Prompt string `json:"prompt"` @@ -41,8 +47,9 @@ type Response struct { // Server holds the model and handles requests type Server struct { mu sync.Mutex - model *zimage.Model + model ImageModel modelName string + modelType string // "zimage" or "glm_image" } // Execute is the entry point for the image runner subprocess @@ -72,15 +79,35 @@ func Execute(args []string) error { requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024)) } - // Load model - model := &zimage.Model{} - if err := model.Load(*modelName); err != nil { - return fmt.Errorf("failed to load model: %w", err) + // Detect model type and load appropriate model + modelType, err := detectModelType(*modelName) + if err != nil { + return fmt.Errorf("failed to detect model type: %w", err) + } + + var model ImageModel + switch modelType { + case "GlmImagePipeline": + slog.Info("loading GLM-Image model") + m := &glm_image.Model{} + if err := m.Load(*modelName); err != nil { + return fmt.Errorf("failed to load GLM-Image model: %w", err) + } + model = m + default: + // Default to zimage for ZImagePipeline, FluxPipeline, and unknown types + slog.Info("loading Z-Image model") + m := &zimage.Model{} + if err := m.Load(*modelName); err != nil { + return fmt.Errorf("failed to load Z-Image model: %w", err) + } + model = m } server := &Server{ model: model, modelName: *modelName, + modelType: modelType, } // Set up HTTP handlers @@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { req.Height = 1024 } if req.Steps <= 0 { - req.Steps = 9 + // Default steps depend on model type + switch s.modelType { + case "GlmImagePipeline": + req.Steps = 50 // GLM-Image default + default: + req.Steps = 9 // Z-Image turbo default + } } if req.Seed <= 0 { req.Seed = time.Now().UnixNano() @@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { return } - // Generate image + // Generate image using interface method ctx := r.Context() - img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{ - Prompt: req.Prompt, - Width: req.Width, - Height: req.Height, - Steps: req.Steps, - Seed: req.Seed, - Progress: func(step, total int) { - resp := Response{ - Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total), - Done: false, - } - data, _ := json.Marshal(resp) - w.Write(data) - w.Write([]byte("\n")) - flusher.Flush() - }, - }) + img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed) if err != nil { // Don't send error for cancellation @@ -216,3 +233,35 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("\n")) flusher.Flush() } + +// detectModelType reads the model manifest and returns the pipeline class name +func detectModelType(modelName string) (string, error) { + manifest, err := imagegen.LoadManifest(modelName) + if err != nil { + return "", err + } + + data, err := manifest.ReadConfig("model_index.json") + if err != nil { + return "ZImagePipeline", nil // Default to Z-Image + } + + // Try both _class_name (diffusers format) and architecture (ollama format) + var index struct { + ClassName string `json:"_class_name"` + Architecture string `json:"architecture"` + } + if err := json.Unmarshal(data, &index); err != nil { + return "ZImagePipeline", nil + } + + // Prefer _class_name, fall back to architecture + className := index.ClassName + if className == "" { + className = index.Architecture + } + if className == "" { + return "ZImagePipeline", nil + } + return className, nil +}