x/grammar: add experimental GPU accelerated constrained decoding package

This commit is contained in:
jmorganca
2026-01-10 16:42:45 -08:00
parent 7cc2a653f2
commit e23ddd84b8
38 changed files with 5819 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ import (
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -109,7 +110,11 @@ type input struct {
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
WiredLimitGB int // Metal wired memory limit in GB (default 32)
JSONMode bool // Enable JSON grammar constraint
GrammarEBNF string // Raw EBNF grammar string
GrammarStart string // Start rule name for grammar
Vocab []string // Vocabulary for constrained decoding
}
type output struct {
@@ -127,9 +132,11 @@ type Decoder struct {
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
grammar *grammar.Engine // Optional grammar constraint engine
grammarVocab []string // Vocab for grammar debug
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
@@ -145,6 +152,12 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
}
}
// SetGrammar enables constrained decoding with the given grammar engine.
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
d.grammar = g
d.grammarVocab = vocab
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
@@ -222,6 +235,16 @@ func (d *Decoder) prefill(inputIDs []int32) int {
} else {
logits = d.model.Forward(x, d.caches)
}
// Apply grammar constraints if enabled
if d.grammar != nil {
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
@@ -245,6 +268,15 @@ func (d *Decoder) prefill(inputIDs []int32) int {
func (d *Decoder) step() int32 {
prevToken := d.token
// Sync on previous token FIRST to get its value and update grammar state
// This must happen before computing the next mask
val := prevToken.ItemInt32()
// Update grammar state with the token we just synced
if d.grammar != nil {
d.grammar.Accept(int(val))
}
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
@@ -253,6 +285,18 @@ func (d *Decoder) step() int32 {
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
// Apply grammar constraints if enabled
if d.grammar != nil {
// Get last position logits: [1, 1, vocab] -> [vocab]
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
// Reshape back to [1, 1, vocab] for sample()
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
@@ -262,9 +306,6 @@ func (d *Decoder) step() int32 {
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
@@ -289,6 +330,48 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Set up grammar constraint if enabled
var grammarEngine *grammar.Engine
var grammarVocab []string
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
var compiled *grammar.Grammar
var err error
if in.GrammarEBNF != "" {
// Custom EBNF grammar
startRule := in.GrammarStart
if startRule == "" {
startRule = "root"
}
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
if err != nil {
return fmt.Errorf("failed to parse grammar: %w", err)
}
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
} else {
// JSON object grammar (only allows objects at top level)
compiled, err = grammar.JSONObjectGrammar()
if err != nil {
return fmt.Errorf("failed to create JSON grammar: %w", err)
}
fmt.Println("[JSON object mode enabled]")
}
// Pad vocab to match model's vocab size if needed
grammarVocab = in.Vocab
modelVocabSize := int(m.VocabSize())
if len(grammarVocab) < modelVocabSize {
padded := make([]string, modelVocabSize)
copy(padded, grammarVocab)
grammarVocab = padded
}
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
if err != nil {
return fmt.Errorf("failed to create grammar engine: %w", err)
}
defer grammarEngine.Close()
}
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
@@ -304,6 +387,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tokens = tok.Encode(prompt, true)
}
if grammarEngine != nil {
dec.SetGrammar(grammarEngine, grammarVocab)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
@@ -327,6 +414,11 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete after first token
if dec.grammar != nil && dec.grammar.IsComplete() {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
@@ -341,6 +433,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete (valid JSON document finished)
if dec.grammar != nil && dec.grammar.IsComplete() {
break
}
if n%256 == 0 {
mlx.ClearCache()

View File

@@ -44,6 +44,9 @@ func main() {
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
// Image generation params
width := flag.Int("width", 1024, "Image width")
@@ -186,6 +189,20 @@ func main() {
}
}
// Get vocab for constrained decoding if needed
var vocab []string
var grammarEBNF string
if *jsonMode || *grammarFile != "" {
vocab = m.Tokenizer().Vocab()
}
if *grammarFile != "" {
data, err := os.ReadFile(*grammarFile)
if err != nil {
log.Fatalf("failed to read grammar file: %v", err)
}
grammarEBNF = string(data)
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
@@ -194,6 +211,10 @@ func main() {
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
JSONMode: *jsonMode,
GrammarEBNF: grammarEBNF,
GrammarStart: *grammarStart,
Vocab: vocab,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)