mirror of
https://github.com/ollama/ollama.git
synced 2026-04-27 19:25:55 +02:00
tests: basic benchmarking test framework (#12964)
This change adds a basic benchmarking test framework for Ollama which can be used to determine the prefill, eval, load duration, and total duration for running a given model or models.
This commit is contained in:
309
cmd/bench/bench.go
Normal file
309
cmd/bench/bench.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type flagOptions struct {
|
||||
models *string
|
||||
epochs *int
|
||||
maxTokens *int
|
||||
temperature *float64
|
||||
seed *int
|
||||
timeout *int
|
||||
prompt *string
|
||||
imageFile *string
|
||||
keepAlive *float64
|
||||
format *string
|
||||
outputFile *string
|
||||
debug *bool
|
||||
verbose *bool
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
Model string
|
||||
Step string
|
||||
Count int
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
printHeader := func() {
|
||||
fmt.Printf("sysname: %s\n", runtime.GOOS)
|
||||
fmt.Printf("machine: %s\n", runtime.GOARCH)
|
||||
}
|
||||
once.Do(printHeader)
|
||||
}
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
if m.Count > 0 {
|
||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
||||
m.Model, m.Step, m.Count)
|
||||
}
|
||||
} else {
|
||||
var suffix string
|
||||
if m.Step == "load" {
|
||||
suffix = "/step=load"
|
||||
}
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
||||
m.Model, suffix, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "csv":
|
||||
printHeader := func() {
|
||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
var nsPerToken float64
|
||||
var tokensPerSec float64
|
||||
if m.Count > 0 {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
}
|
||||
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "markdown":
|
||||
printHeader := func() {
|
||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
var nsPerToken, tokensPerSec float64
|
||||
var nsPerTokenStr, tokensPerSecStr string
|
||||
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
||||
} else {
|
||||
nsPerTokenStr = "-"
|
||||
tokensPerSecStr = "-"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
||||
}
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkChat(fOpt flagOptions) error {
|
||||
models := strings.Split(*fOpt.models, ",")
|
||||
|
||||
// todo - add multi-image support
|
||||
var imgData api.ImageData
|
||||
var err error
|
||||
if *fOpt.imageFile != "" {
|
||||
imgData, err = readImage(*fOpt.imageFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if *fOpt.debug && imgData != nil {
|
||||
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
for range *fOpt.epochs {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||
keepAliveDuration = &duration
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: *fOpt.prompt,
|
||||
},
|
||||
},
|
||||
Options: options,
|
||||
KeepAlive: keepAliveDuration,
|
||||
}
|
||||
|
||||
if imgData != nil {
|
||||
req.Messages[0].Images = []api.ImageData{imgData}
|
||||
}
|
||||
|
||||
var responseMetrics *api.Metrics
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
||||
}
|
||||
|
||||
if resp.Done {
|
||||
responseMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if responseMetrics == nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||
continue
|
||||
}
|
||||
|
||||
metrics := []Metrics{
|
||||
{
|
||||
Model: model,
|
||||
Step: "prefill",
|
||||
Count: responseMetrics.PromptEvalCount,
|
||||
Duration: responseMetrics.PromptEvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "generate",
|
||||
Count: responseMetrics.EvalCount,
|
||||
Duration: responseMetrics.EvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "load",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.LoadDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "total",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.TotalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
|
||||
|
||||
if *fOpt.keepAlive > 0 {
|
||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readImage(filePath string) (api.ImageData, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return api.ImageData(data), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
fOpt := flagOptions{
|
||||
models: flag.String("model", "", "Model to benchmark"),
|
||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||
seed: flag.Int("seed", 0, "Random seed"),
|
||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||
verbose: flag.Bool("v", false, "Show system information"),
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Description:\n")
|
||||
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(*fOpt.models) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
BenchmarkChat(fOpt)
|
||||
}
|
||||
Reference in New Issue
Block a user