Files
ollama-ollama/x/imagegen/image.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

294 lines
6.3 KiB
Go

package imagegen
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SaveImage saves an MLX array as a PNG image file.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func SaveImage(arr *mlx.Array, path string) error {
img, err := ArrayToImage(arr)
if err != nil {
return err
}
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func EncodeImageBase64(arr *mlx.Array) (string, error) {
img, err := ArrayToImage(arr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// ArrayToImage converts an MLX array to a Go image.RGBA.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}
// DecodeImage decodes image bytes with EXIF orientation applied.
// Transparent images are composited onto a white background.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, err
}
img = flattenAlpha(img)
return applyOrientation(img, orientation), nil
}
// flattenAlpha composites an image onto a white background,
// removing any transparency. This is needed because image
// generation models don't handle alpha channels well.
func flattenAlpha(img image.Image) image.Image {
if _, ok := img.(*image.RGBA); !ok {
if _, ok := img.(*image.NRGBA); !ok {
// No alpha channel, return as-is
return img
}
}
bounds := img.Bounds()
dst := image.NewRGBA(bounds)
// Fill with white background
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
// Composite the image on top
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
return dst
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
return 1 // Not JPEG
}
r := bytes.NewReader(data[2:])
for {
var marker [2]byte
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
return 1
}
if marker[1] == 0xE1 { // APP1 (EXIF)
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen < 14 {
r.Seek(int64(segLen), 1)
continue
}
seg := make([]byte, segLen)
if _, err := r.Read(seg); err != nil {
return 1
}
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
return parseTIFFOrientation(seg[6:])
}
continue
}
if marker[1] == 0xD9 || marker[1] == 0xDA {
return 1 // EOI or SOS
}
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
continue // RST markers
}
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen > 0 {
r.Seek(int64(segLen), 1)
}
}
}
func parseTIFFOrientation(tiff []byte) int {
if len(tiff) < 8 {
return 1
}
var big bool
switch string(tiff[:2]) {
case "MM":
big = true
case "II":
big = false
default:
return 1
}
u16 := func(b []byte) uint16 {
if big {
return uint16(b[0])<<8 | uint16(b[1])
}
return uint16(b[1])<<8 | uint16(b[0])
}
u32 := func(b []byte) uint32 {
if big {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
}
if u16(tiff[2:4]) != 42 {
return 1
}
ifdOffset := u32(tiff[4:8])
if int(ifdOffset)+2 > len(tiff) {
return 1
}
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
for i := range int(numEntries) {
offset := ifdOffset + 2 + uint32(i)*12
if int(offset)+12 > len(tiff) {
break
}
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
o := int(u16(tiff[offset+8 : offset+10]))
if o >= 1 && o <= 8 {
return o
}
return 1
}
}
return 1
}
func applyOrientation(img image.Image, orientation int) image.Image {
if orientation <= 1 || orientation > 8 {
return img
}
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
outW, outH := w, h
if orientation >= 5 {
outW, outH = h, w
}
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
for y := range h {
for x := range w {
var dx, dy int
switch orientation {
case 2:
dx, dy = w-1-x, y
case 3:
dx, dy = w-1-x, h-1-y
case 4:
dx, dy = x, h-1-y
case 5:
dx, dy = y, x
case 6:
dx, dy = h-1-y, x
case 7:
dx, dy = h-1-y, w-1-x
case 8:
dx, dy = y, w-1-x
}
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
}
}
return out
}