mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
294 lines
6.3 KiB
Go
294 lines
6.3 KiB
Go
package imagegen
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"image"
|
|
"image/color"
|
|
"image/draw"
|
|
_ "image/jpeg"
|
|
"image/png"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// SaveImage saves an MLX array as a PNG image file.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func SaveImage(arr *mlx.Array, path string) error {
|
|
img, err := ArrayToImage(arr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if filepath.Ext(path) != ".png" {
|
|
path = path + ".png"
|
|
}
|
|
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
return png.Encode(f, img)
|
|
}
|
|
|
|
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func EncodeImageBase64(arr *mlx.Array) (string, error) {
|
|
img, err := ArrayToImage(arr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, img); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
|
}
|
|
|
|
// ArrayToImage converts an MLX array to a Go image.RGBA.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
|
shape := arr.Shape()
|
|
if len(shape) != 4 {
|
|
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
|
}
|
|
|
|
// Transform to [H, W, C] for image conversion
|
|
// Free intermediate arrays to avoid memory leak
|
|
squeezed := mlx.Squeeze(arr, 0)
|
|
transposed := mlx.Transpose(squeezed, 1, 2, 0)
|
|
squeezed.Free()
|
|
img := mlx.Contiguous(transposed)
|
|
transposed.Free()
|
|
mlx.Eval(img)
|
|
|
|
imgShape := img.Shape()
|
|
H := int(imgShape[0])
|
|
W := int(imgShape[1])
|
|
C := int(imgShape[2])
|
|
|
|
if C != 3 {
|
|
img.Free()
|
|
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
|
}
|
|
|
|
// Copy to CPU and free GPU memory
|
|
data := img.Data()
|
|
img.Free()
|
|
|
|
// Write directly to Pix slice (faster than SetRGBA)
|
|
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
|
pix := goImg.Pix
|
|
for y := 0; y < H; y++ {
|
|
for x := 0; x < W; x++ {
|
|
srcIdx := (y*W + x) * C
|
|
dstIdx := (y*W + x) * 4
|
|
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
|
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
|
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
|
pix[dstIdx+3] = 255
|
|
}
|
|
}
|
|
|
|
return goImg, nil
|
|
}
|
|
|
|
func clampF(v, min, max float32) float32 {
|
|
if v < min {
|
|
return min
|
|
}
|
|
if v > max {
|
|
return max
|
|
}
|
|
return v
|
|
}
|
|
|
|
// DecodeImage decodes image bytes with EXIF orientation applied.
|
|
// Transparent images are composited onto a white background.
|
|
func DecodeImage(data []byte) (image.Image, error) {
|
|
orientation := readJPEGOrientation(data)
|
|
|
|
img, _, err := image.Decode(bytes.NewReader(data))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
img = flattenAlpha(img)
|
|
return applyOrientation(img, orientation), nil
|
|
}
|
|
|
|
// flattenAlpha composites an image onto a white background,
|
|
// removing any transparency. This is needed because image
|
|
// generation models don't handle alpha channels well.
|
|
func flattenAlpha(img image.Image) image.Image {
|
|
if _, ok := img.(*image.RGBA); !ok {
|
|
if _, ok := img.(*image.NRGBA); !ok {
|
|
// No alpha channel, return as-is
|
|
return img
|
|
}
|
|
}
|
|
|
|
bounds := img.Bounds()
|
|
dst := image.NewRGBA(bounds)
|
|
|
|
// Fill with white background
|
|
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
|
|
|
|
// Composite the image on top
|
|
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
|
|
|
|
return dst
|
|
}
|
|
|
|
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
|
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
|
func readJPEGOrientation(data []byte) int {
|
|
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
|
|
return 1 // Not JPEG
|
|
}
|
|
|
|
r := bytes.NewReader(data[2:])
|
|
for {
|
|
var marker [2]byte
|
|
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
|
|
return 1
|
|
}
|
|
|
|
if marker[1] == 0xE1 { // APP1 (EXIF)
|
|
var lenBytes [2]byte
|
|
if _, err := r.Read(lenBytes[:]); err != nil {
|
|
return 1
|
|
}
|
|
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
|
if segLen < 14 {
|
|
r.Seek(int64(segLen), 1)
|
|
continue
|
|
}
|
|
seg := make([]byte, segLen)
|
|
if _, err := r.Read(seg); err != nil {
|
|
return 1
|
|
}
|
|
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
|
|
return parseTIFFOrientation(seg[6:])
|
|
}
|
|
continue
|
|
}
|
|
|
|
if marker[1] == 0xD9 || marker[1] == 0xDA {
|
|
return 1 // EOI or SOS
|
|
}
|
|
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
|
|
continue // RST markers
|
|
}
|
|
|
|
var lenBytes [2]byte
|
|
if _, err := r.Read(lenBytes[:]); err != nil {
|
|
return 1
|
|
}
|
|
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
|
if segLen > 0 {
|
|
r.Seek(int64(segLen), 1)
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseTIFFOrientation(tiff []byte) int {
|
|
if len(tiff) < 8 {
|
|
return 1
|
|
}
|
|
|
|
var big bool
|
|
switch string(tiff[:2]) {
|
|
case "MM":
|
|
big = true
|
|
case "II":
|
|
big = false
|
|
default:
|
|
return 1
|
|
}
|
|
|
|
u16 := func(b []byte) uint16 {
|
|
if big {
|
|
return uint16(b[0])<<8 | uint16(b[1])
|
|
}
|
|
return uint16(b[1])<<8 | uint16(b[0])
|
|
}
|
|
u32 := func(b []byte) uint32 {
|
|
if big {
|
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
|
}
|
|
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
|
|
}
|
|
|
|
if u16(tiff[2:4]) != 42 {
|
|
return 1
|
|
}
|
|
|
|
ifdOffset := u32(tiff[4:8])
|
|
if int(ifdOffset)+2 > len(tiff) {
|
|
return 1
|
|
}
|
|
|
|
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
|
|
for i := range int(numEntries) {
|
|
offset := ifdOffset + 2 + uint32(i)*12
|
|
if int(offset)+12 > len(tiff) {
|
|
break
|
|
}
|
|
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
|
|
o := int(u16(tiff[offset+8 : offset+10]))
|
|
if o >= 1 && o <= 8 {
|
|
return o
|
|
}
|
|
return 1
|
|
}
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func applyOrientation(img image.Image, orientation int) image.Image {
|
|
if orientation <= 1 || orientation > 8 {
|
|
return img
|
|
}
|
|
|
|
bounds := img.Bounds()
|
|
w, h := bounds.Dx(), bounds.Dy()
|
|
|
|
outW, outH := w, h
|
|
if orientation >= 5 {
|
|
outW, outH = h, w
|
|
}
|
|
|
|
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
|
for y := range h {
|
|
for x := range w {
|
|
var dx, dy int
|
|
switch orientation {
|
|
case 2:
|
|
dx, dy = w-1-x, y
|
|
case 3:
|
|
dx, dy = w-1-x, h-1-y
|
|
case 4:
|
|
dx, dy = x, h-1-y
|
|
case 5:
|
|
dx, dy = y, x
|
|
case 6:
|
|
dx, dy = h-1-y, x
|
|
case 7:
|
|
dx, dy = h-1-y, w-1-x
|
|
case 8:
|
|
dx, dy = y, w-1-x
|
|
}
|
|
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
|
|
}
|
|
}
|
|
return out
|
|
}
|