mirror of
https://github.com/ollama/ollama.git
synced 2026-04-19 15:54:21 +02:00
x/grammar: add experimental GPU accelerated constrained decoding package
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -109,7 +110,11 @@ type input struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
JSONMode bool // Enable JSON grammar constraint
|
||||
GrammarEBNF string // Raw EBNF grammar string
|
||||
GrammarStart string // Start rule name for grammar
|
||||
Vocab []string // Vocabulary for constrained decoding
|
||||
}
|
||||
|
||||
type output struct {
|
||||
@@ -127,9 +132,11 @@ type Decoder struct {
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
grammar *grammar.Engine // Optional grammar constraint engine
|
||||
grammarVocab []string // Vocab for grammar debug
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
@@ -145,6 +152,12 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGrammar enables constrained decoding with the given grammar engine.
|
||||
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
|
||||
d.grammar = g
|
||||
d.grammarVocab = vocab
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
@@ -222,6 +235,16 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
@@ -245,6 +268,15 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Sync on previous token FIRST to get its value and update grammar state
|
||||
// This must happen before computing the next mask
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Update grammar state with the token we just synced
|
||||
if d.grammar != nil {
|
||||
d.grammar.Accept(int(val))
|
||||
}
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
@@ -253,6 +285,18 @@ func (d *Decoder) step() int32 {
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
// Get last position logits: [1, 1, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
// Reshape back to [1, 1, vocab] for sample()
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
@@ -262,9 +306,6 @@ func (d *Decoder) step() int32 {
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
@@ -289,6 +330,48 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Set up grammar constraint if enabled
|
||||
var grammarEngine *grammar.Engine
|
||||
var grammarVocab []string
|
||||
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
|
||||
if in.GrammarEBNF != "" {
|
||||
// Custom EBNF grammar
|
||||
startRule := in.GrammarStart
|
||||
if startRule == "" {
|
||||
startRule = "root"
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse grammar: %w", err)
|
||||
}
|
||||
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
|
||||
} else {
|
||||
// JSON object grammar (only allows objects at top level)
|
||||
compiled, err = grammar.JSONObjectGrammar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JSON grammar: %w", err)
|
||||
}
|
||||
fmt.Println("[JSON object mode enabled]")
|
||||
}
|
||||
|
||||
// Pad vocab to match model's vocab size if needed
|
||||
grammarVocab = in.Vocab
|
||||
modelVocabSize := int(m.VocabSize())
|
||||
if len(grammarVocab) < modelVocabSize {
|
||||
padded := make([]string, modelVocabSize)
|
||||
copy(padded, grammarVocab)
|
||||
grammarVocab = padded
|
||||
}
|
||||
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create grammar engine: %w", err)
|
||||
}
|
||||
defer grammarEngine.Close()
|
||||
}
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
@@ -304,6 +387,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
if grammarEngine != nil {
|
||||
dec.SetGrammar(grammarEngine, grammarVocab)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
@@ -327,6 +414,11 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete after first token
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
@@ -341,6 +433,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete (valid JSON document finished)
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
break
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
|
||||
@@ -44,6 +44,9 @@ func main() {
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
|
||||
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
|
||||
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
@@ -186,6 +189,20 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Get vocab for constrained decoding if needed
|
||||
var vocab []string
|
||||
var grammarEBNF string
|
||||
if *jsonMode || *grammarFile != "" {
|
||||
vocab = m.Tokenizer().Vocab()
|
||||
}
|
||||
if *grammarFile != "" {
|
||||
data, err := os.ReadFile(*grammarFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read grammar file: %v", err)
|
||||
}
|
||||
grammarEBNF = string(data)
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
@@ -194,6 +211,10 @@ func main() {
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
JSONMode: *jsonMode,
|
||||
GrammarEBNF: grammarEBNF,
|
||||
GrammarStart: *grammarStart,
|
||||
Vocab: vocab,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
|
||||
Reference in New Issue
Block a user