diff --git a/cmd/bench/bench.go b/cmd/bench/bench.go index d6ea0ade2..1ecd42cfc 100644 --- a/cmd/bench/bench.go +++ b/cmd/bench/bench.go @@ -32,6 +32,7 @@ type flagOptions struct { verbose *bool warmup *int promptTokens *int + numCtx *int } type Metrics struct { @@ -48,6 +49,7 @@ type ModelInfo struct { Family string SizeBytes int64 VRAMBytes int64 + NumCtx int64 } const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.` @@ -64,9 +66,12 @@ var promptWordList = []string{ "old", "stone", "bridge", "that", "crosses", "winding", "river", } +// tokensPerWord is the calibrated ratio of tokens to words for the current model. +// Initialized with a heuristic, then updated during warmup based on actual tokenization. +var tokensPerWord = 1.3 + func generatePromptForTokenCount(targetTokens int, epoch int) string { - // ~1.3 tokens per word heuristic - targetWords := int(float64(targetTokens) / 1.3) + targetWords := int(float64(targetTokens) / tokensPerWord) if targetWords < 1 { targetWords = 1 } @@ -81,6 +86,17 @@ func generatePromptForTokenCount(targetTokens int, epoch int) string { return strings.Join(words, " ") } +// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run. +func calibratePromptTokens(targetTokens, actualTokens, wordCount int) { + if actualTokens <= 0 || wordCount <= 0 { + return + } + tokensPerWord = float64(actualTokens) / float64(wordCount) + newWords := int(float64(targetTokens) / tokensPerWord) + fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n", + tokensPerWord, targetTokens, actualTokens, wordCount, newWords) +} + func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest { options := make(map[string]interface{}) if *fOpt.maxTokens > 0 { @@ -90,6 +106,9 @@ func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, if fOpt.seed != nil && *fOpt.seed > 0 { options["seed"] = *fOpt.seed } + if fOpt.numCtx != nil && *fOpt.numCtx > 0 { + options["num_ctx"] = *fOpt.numCtx + } var keepAliveDuration *api.Duration if *fOpt.keepAlive > 0 { @@ -146,7 +165,6 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si return m.Size, m.SizeVRAM } } - // Try prefix match (model names may include :latest or tags) for _, m := range resp.Models { if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) { return m.Size, m.SizeVRAM @@ -155,6 +173,19 @@ func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (si return 0, 0 } +func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 { + resp, err := client.ListRunning(ctx) + if err != nil { + return 0 + } + for _, m := range resp.Models { + if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) { + return int64(m.ContextLength) + } + } + return 0 +} + func outputFormatHeader(w io.Writer, format string, verbose bool) { switch format { case "benchstat": @@ -177,8 +208,12 @@ func outputModelInfo(w io.Writer, format string, info ModelInfo) { if info.SizeBytes > 0 { memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes) } - fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n", - info.Name, params, quant, family, memStr) + ctxStr := "" + if info.NumCtx > 0 { + ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx) + } + fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n", + info.Name, params, quant, family, memStr, ctxStr) } func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) { @@ -276,21 +311,38 @@ func BenchmarkModel(fOpt flagOptions) error { req := buildGenerateRequest(model, fOpt, imgData, -(i + 1)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) + var warmupMetrics *api.Metrics err = client.Generate(ctx, req, func(resp api.GenerateResponse) error { + if resp.Done { + warmupMetrics = &resp.Metrics + } return nil }) cancel() if err != nil { fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err) - } else if *fOpt.debug { - fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model) + } else { + if *fOpt.debug { + fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model) + } + // Calibrate prompt token count on last warmup run + if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil { + prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1)) + wordCount := len(strings.Fields(prompt)) + calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount) + } } } - // Fetch memory usage once after warmup (model is loaded and stable) + // Fetch memory/context info once after warmup (model is loaded and stable) memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second) info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model) + if fOpt.numCtx != nil && *fOpt.numCtx > 0 { + info.NumCtx = int64(*fOpt.numCtx) + } else { + info.NumCtx = fetchContextLength(memCtx, client, model) + } memCancel() outputModelInfo(out, *fOpt.format, info) @@ -479,6 +531,7 @@ func main() { debug: flag.Bool("debug", false, "Show debug information"), warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"), promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"), + numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"), } flag.Usage = func() {