mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
358 lines
9.2 KiB
Go
358 lines
9.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/cache"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
)
|
|
|
|
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
|
var generationStream *mlx.Stream
|
|
|
|
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
|
// This handles cases where tokenizers output partial multi-byte sequences.
|
|
type utf8Streamer struct {
|
|
buffer []byte
|
|
}
|
|
|
|
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
|
func (s *utf8Streamer) Write(text string) string {
|
|
s.buffer = append(s.buffer, text...)
|
|
|
|
// Find the last position that ends with a complete UTF-8 character
|
|
validLen := 0
|
|
for i := 0; i < len(s.buffer); {
|
|
r, size := utf8.DecodeRune(s.buffer[i:])
|
|
if r == utf8.RuneError && size == 1 {
|
|
// Invalid or incomplete UTF-8 sequence at this position
|
|
// Check if it could be a valid start of a multi-byte sequence
|
|
if len(s.buffer)-i < 4 {
|
|
// Might be incomplete, keep it in buffer
|
|
break
|
|
}
|
|
// Definitely invalid, skip this byte
|
|
i++
|
|
validLen = i
|
|
} else {
|
|
i += size
|
|
validLen = i
|
|
}
|
|
}
|
|
|
|
if validLen == 0 {
|
|
return ""
|
|
}
|
|
|
|
result := string(s.buffer[:validLen])
|
|
s.buffer = s.buffer[validLen:]
|
|
return result
|
|
}
|
|
|
|
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
|
func (s *utf8Streamer) Flush() string {
|
|
if len(s.buffer) == 0 {
|
|
return ""
|
|
}
|
|
result := string(s.buffer)
|
|
s.buffer = nil
|
|
return result
|
|
}
|
|
|
|
// withStream runs fn with the generation stream as default
|
|
func withStream(fn func()) {
|
|
// Lazy initialization of generationStream
|
|
if generationStream == nil {
|
|
generationStream = mlx.NewStream()
|
|
}
|
|
orig := mlx.GetDefaultStream()
|
|
mlx.SetDefaultStream(generationStream)
|
|
fn()
|
|
mlx.SetDefaultStream(orig)
|
|
}
|
|
|
|
type Model interface {
|
|
Tokenizer() *tokenizer.Tokenizer
|
|
VocabSize() int32
|
|
NewCache(maxSeqLen int32) []cache.Cache
|
|
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
}
|
|
|
|
// ChatModel is an optional interface for models that support chat formatting
|
|
type ChatModel interface {
|
|
FormatPrompt(prompt string) string
|
|
}
|
|
|
|
// MultimodalModel is for models that support image input
|
|
type MultimodalModel interface {
|
|
Model
|
|
FormatPromptWithImage(prompt string) string
|
|
ExpandImageTokens(tokens []int32) []int32
|
|
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
ImageSize() int32 // Returns expected image size for preprocessing
|
|
}
|
|
|
|
// ImageLoader loads and preprocesses an image for multimodal models
|
|
// Returns nil if path is empty
|
|
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
|
|
|
type input struct {
|
|
Prompt string
|
|
Image *mlx.Array // Optional preprocessed image for multimodal models
|
|
MaxTokens int
|
|
Temperature float32
|
|
TopP float32
|
|
TopK int
|
|
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
|
}
|
|
|
|
type output struct {
|
|
Text string
|
|
Done bool
|
|
PrefillTokSec float64
|
|
GenTokSec float64
|
|
}
|
|
|
|
// Decoder wraps model + cache for autoregressive generation.
|
|
type Decoder struct {
|
|
model Model
|
|
caches []cache.Cache
|
|
vocabSize int32
|
|
temp float32
|
|
topK int
|
|
topP float32
|
|
token *mlx.Array // Current token (kept across pools)
|
|
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
|
image *mlx.Array // Optional image for multimodal prefill
|
|
}
|
|
|
|
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
|
caches := m.NewCache(0)
|
|
return &Decoder{
|
|
model: m,
|
|
caches: caches,
|
|
vocabSize: m.VocabSize(),
|
|
temp: temp,
|
|
topK: topK,
|
|
topP: topP,
|
|
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
|
}
|
|
}
|
|
|
|
// SetImage sets the image for multimodal prefill (call before prefill)
|
|
func (d *Decoder) SetImage(img *mlx.Array) {
|
|
d.image = img
|
|
}
|
|
|
|
func (d *Decoder) prefill(inputIDs []int32) int {
|
|
processed := 0
|
|
|
|
// Track old cache state to free after each chunk
|
|
var oldCacheState []*mlx.Array
|
|
|
|
// For multimodal models with an image, we need to process all tokens together
|
|
// in the first forward pass so the image embeddings can be inserted properly.
|
|
// Skip chunking for multimodal prefill.
|
|
isMultimodal := d.image != nil
|
|
|
|
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
|
// Skip chunking for multimodal - process everything in the final step
|
|
if !isMultimodal {
|
|
for len(inputIDs) > 1 {
|
|
chunkSize := min(2048, len(inputIDs)-1)
|
|
if chunkSize <= 0 {
|
|
break
|
|
}
|
|
chunk := inputIDs[:chunkSize]
|
|
|
|
// Save old cache state before forward
|
|
oldCacheState = oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
oldCacheState = append(oldCacheState, c.State()...)
|
|
}
|
|
|
|
var cacheState []*mlx.Array
|
|
withStream(func() {
|
|
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
|
d.model.Forward(x, d.caches)
|
|
for _, c := range d.caches {
|
|
cacheState = append(cacheState, c.State()...)
|
|
}
|
|
})
|
|
mlx.Eval(cacheState...)
|
|
|
|
// Free old cache state
|
|
for _, arr := range oldCacheState {
|
|
if arr != nil {
|
|
arr.Free()
|
|
}
|
|
}
|
|
|
|
inputIDs = inputIDs[chunkSize:]
|
|
processed += chunkSize
|
|
}
|
|
}
|
|
|
|
// Save old cache state before final step
|
|
oldCacheState = oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
oldCacheState = append(oldCacheState, c.State()...)
|
|
}
|
|
|
|
// Final token + sampling (or all tokens for multimodal)
|
|
withStream(func() {
|
|
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
|
mlx.Eval(x) // Materialize before any other evals
|
|
|
|
var logits *mlx.Array
|
|
// Use ForwardWithImage if we have an image and model supports it
|
|
if d.image != nil {
|
|
if mm, ok := d.model.(MultimodalModel); ok {
|
|
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
|
d.image = nil // Only use image for first forward
|
|
} else {
|
|
logits = d.model.Forward(x, d.caches)
|
|
}
|
|
} else {
|
|
logits = d.model.Forward(x, d.caches)
|
|
}
|
|
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
|
})
|
|
// Keep cache state (token auto-kept by AsyncEval)
|
|
for _, c := range d.caches {
|
|
mlx.Keep(c.State()...)
|
|
}
|
|
mlx.AsyncEval(d.token)
|
|
|
|
// Free old cache state from before final step
|
|
for _, arr := range oldCacheState {
|
|
if arr != nil {
|
|
arr.Free()
|
|
}
|
|
}
|
|
|
|
mlx.ClearCache()
|
|
|
|
return processed + len(inputIDs)
|
|
}
|
|
|
|
func (d *Decoder) step() int32 {
|
|
prevToken := d.token
|
|
|
|
// Save old cache state (reuse preallocated slice)
|
|
d.oldCacheState = d.oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
|
}
|
|
|
|
withStream(func() {
|
|
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
|
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
|
})
|
|
// Keep token and new cache state so they survive cleanup
|
|
mlx.Keep(d.token)
|
|
for _, c := range d.caches {
|
|
mlx.Keep(c.State()...)
|
|
}
|
|
mlx.AsyncEval(d.token)
|
|
|
|
// Sync on previous token (GPU already working on next step)
|
|
val := prevToken.ItemInt32()
|
|
|
|
// Free old token and old cache state
|
|
prevToken.Free()
|
|
for _, arr := range d.oldCacheState {
|
|
arr.Free()
|
|
}
|
|
return val
|
|
}
|
|
|
|
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
|
mlx.EnableCompile()
|
|
wiredLimit := in.WiredLimitGB
|
|
if wiredLimit <= 0 {
|
|
wiredLimit = 32 // default 32GB
|
|
}
|
|
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
|
|
|
temp := in.Temperature
|
|
if temp < 0 {
|
|
temp = 0.7
|
|
}
|
|
|
|
tok := m.Tokenizer()
|
|
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
|
|
|
// Apply chat template - use image template if we have an image
|
|
prompt := in.Prompt
|
|
var tokens []int32
|
|
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
|
prompt = mm.FormatPromptWithImage(prompt)
|
|
tokens = tok.Encode(prompt, true)
|
|
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
|
dec.SetImage(in.Image)
|
|
} else if cm, ok := m.(ChatModel); ok {
|
|
prompt = cm.FormatPrompt(prompt)
|
|
tokens = tok.Encode(prompt, true)
|
|
} else {
|
|
tokens = tok.Encode(prompt, true)
|
|
}
|
|
|
|
prefillStart := time.Now()
|
|
prefillTokens := dec.prefill(tokens)
|
|
// Prefill measurement should include time to first token (like mlx-lm)
|
|
// Step() waits for prefill to complete and returns first token
|
|
firstToken := dec.step()
|
|
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
|
|
|
genStart := time.Now()
|
|
maxTokens := max(in.MaxTokens, 100)
|
|
var genTokens int
|
|
|
|
// UTF-8 streamer to handle partial multi-byte characters
|
|
streamer := &utf8Streamer{}
|
|
|
|
// Handle first token
|
|
genTokens++
|
|
if tok.IsEOS(firstToken) {
|
|
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
|
return nil
|
|
}
|
|
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
for n := 1; n < maxTokens; n++ {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
token := dec.step()
|
|
genTokens++
|
|
|
|
if tok.IsEOS(token) {
|
|
break
|
|
}
|
|
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
if n%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
// Flush any remaining buffered bytes
|
|
if text := streamer.Flush(); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
|
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
|
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
|
return nil
|
|
}
|