mirror of
https://github.com/ollama/ollama.git
synced 2026-04-19 08:54:21 +02:00
236 lines
5.7 KiB
Go
236 lines
5.7 KiB
Go
//go:build mlx
|
|
|
|
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
"github.com/ollama/ollama/x/grammar"
|
|
"github.com/ollama/ollama/x/grammar/schema"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
const jsonGBNF = `
|
|
root ::= value
|
|
value ::= object | array | string | number | "true" | "false" | "null"
|
|
object ::= "{" ws "}" | "{" members "}"
|
|
members ::= member ("," member)*
|
|
member ::= ws string ws ":" element
|
|
array ::= "[" ws "]" | "[" elements "]"
|
|
elements ::= element ("," element)*
|
|
element ::= ws value ws
|
|
string ::= "\"" characters "\""
|
|
characters ::= character*
|
|
character ::= [^"\\] | "\\" escape
|
|
escape ::= ["\\bfnrt]
|
|
number ::= "-"? integer fraction? exponent?
|
|
integer ::= "0" | [1-9] [0-9]*
|
|
fraction ::= "." [0-9]+
|
|
exponent ::= [eE] [+-]? [0-9]+
|
|
ws ::= [ \t\n\r]*
|
|
`
|
|
|
|
type result struct {
|
|
vocabSize int `json:"vocab_size"`
|
|
Iterations int `json:"iterations"`
|
|
Warmup int `json:"warmup"`
|
|
ConstrainedSource string `json:"constrained_source"`
|
|
LlamaSource string `json:"llama_source"`
|
|
LlamaApply string `json:"llama_apply"`
|
|
ConstrainedGraph string `json:"constrained_graph"`
|
|
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
|
EvalOnly string `json:"eval_only,omitempty"`
|
|
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
|
}
|
|
|
|
func main() {
|
|
var (
|
|
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
|
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
|
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
|
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
|
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
|
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
|
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
|
startRule = flag.String("start", "root", "Start rule for EBNF")
|
|
)
|
|
flag.Parse()
|
|
|
|
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
|
fmt.Fprintln(os.Stderr, "invalid flags")
|
|
os.Exit(2)
|
|
}
|
|
|
|
vocab := createVocab(*vocabSize)
|
|
|
|
if *schemaPath != "" && *ebnfPath != "" {
|
|
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
|
os.Exit(2)
|
|
}
|
|
|
|
var constrainedSource string
|
|
var compiled *grammar.Grammar
|
|
var err error
|
|
switch {
|
|
case *schemaPath != "":
|
|
data, readErr := os.ReadFile(*schemaPath)
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
|
os.Exit(1)
|
|
}
|
|
compiled, err = schema.Grammar(string(data))
|
|
constrainedSource = "schema:" + *schemaPath
|
|
case *ebnfPath != "":
|
|
data, readErr := os.ReadFile(*ebnfPath)
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
|
os.Exit(1)
|
|
}
|
|
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
|
constrainedSource = "ebnf:" + *ebnfPath
|
|
default:
|
|
compiled, err = grammar.JSONGrammar()
|
|
constrainedSource = "json"
|
|
}
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
engine, err := grammar.NewEngine(compiled, vocab)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer engine.Close()
|
|
|
|
logits := mlx.Ones(int32(*vocabSize))
|
|
mlx.Keep(logits)
|
|
|
|
for i := 0; i < *warmup; i++ {
|
|
masked := engine.ApplyMask(logits)
|
|
if *withEval {
|
|
mlx.Eval(masked)
|
|
}
|
|
}
|
|
|
|
graphAvg := measure(*iterations, func() {
|
|
_ = engine.ApplyMask(logits)
|
|
})
|
|
|
|
var evalAvg time.Duration
|
|
var evalOnlyAvg time.Duration
|
|
if *withEval {
|
|
evalOnlyAvg = measure(*iterations, func() {
|
|
baseline := mlx.MulScalar(logits, 1)
|
|
mlx.Eval(baseline)
|
|
baseline.Free()
|
|
})
|
|
|
|
evalAvg = measure(*iterations, func() {
|
|
masked := engine.ApplyMask(logits)
|
|
mlx.Eval(masked)
|
|
})
|
|
}
|
|
|
|
vocabIDs := make([]uint32, *vocabSize)
|
|
for i := range vocabIDs {
|
|
vocabIDs[i] = uint32(i)
|
|
}
|
|
eogTokens := []int32{0}
|
|
|
|
gbnf := jsonGBNF
|
|
llamaSource := "json"
|
|
if *gbnfPath != "" {
|
|
data, readErr := os.ReadFile(*gbnfPath)
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
|
os.Exit(1)
|
|
}
|
|
gbnf = string(data)
|
|
llamaSource = *gbnfPath
|
|
}
|
|
|
|
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
|
if llamaGrammar == nil {
|
|
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
|
os.Exit(1)
|
|
}
|
|
defer llamaGrammar.Free()
|
|
|
|
llamaTokens := make([]llama.TokenData, *vocabSize)
|
|
|
|
for i := 0; i < *warmup; i++ {
|
|
for j := range llamaTokens {
|
|
llamaTokens[j].Logit = 1.0
|
|
}
|
|
llamaGrammar.Apply(llamaTokens)
|
|
}
|
|
|
|
llamaAvg := measure(*iterations, func() {
|
|
for j := range llamaTokens {
|
|
llamaTokens[j].Logit = 1.0
|
|
}
|
|
llamaGrammar.Apply(llamaTokens)
|
|
})
|
|
|
|
out := result{
|
|
vocabSize: *vocabSize,
|
|
Iterations: *iterations,
|
|
Warmup: *warmup,
|
|
LlamaApply: llamaAvg.String(),
|
|
ConstrainedGraph: graphAvg.String(),
|
|
ConstrainedSource: constrainedSource,
|
|
LlamaSource: llamaSource,
|
|
}
|
|
if *withEval {
|
|
out.ConstrainedWithEval = evalAvg.String()
|
|
out.EvalOnly = evalOnlyAvg.String()
|
|
if evalAvg > evalOnlyAvg {
|
|
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
|
} else {
|
|
out.ConstrainedEvalNet = "0s"
|
|
}
|
|
}
|
|
|
|
enc := json.NewEncoder(os.Stdout)
|
|
if err := enc.Encode(out); err != nil {
|
|
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func measure(iterations int, fn func()) time.Duration {
|
|
start := time.Now()
|
|
for i := 0; i < iterations; i++ {
|
|
fn()
|
|
}
|
|
return time.Since(start) / time.Duration(iterations)
|
|
}
|
|
|
|
func createVocab(size int) []string {
|
|
vocab := make([]string, size)
|
|
|
|
jsonTokens := []string{
|
|
"{", "}", "[", "]", ":", ",",
|
|
"true", "false", "null",
|
|
" ", "\n", "\t", "\r",
|
|
"\"",
|
|
}
|
|
for i, t := range jsonTokens {
|
|
if i < size {
|
|
vocab[i] = t
|
|
}
|
|
}
|
|
|
|
for i := len(jsonTokens); i < size; i++ {
|
|
vocab[i] = fmt.Sprintf("tok%d", i)
|
|
}
|
|
|
|
return vocab
|
|
}
|