mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 09:15:44 +02:00
Remove the vendored GGML and llama.cpp backend, CGO runner, Go model implementations, and sample. llama-server (built from upstream llama.cpp via FetchContent) is now the sole inference engine for GGUF-based models. (Safetensor based models continue to run on the new MLX engine.) This allows us to more rapidly pick up new capabilities and fixes from llama.cpp as they come out. On windows this now requires recent AMD driver versions to support ROCm v7 as llama.cpp currently does not support building against v6.
1056 lines
30 KiB
Go
1056 lines
30 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os/exec"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
func TestLlamaServerHealthParsing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
statusCode int
|
|
wantStatus ServerStatus
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "ready",
|
|
body: `{"status":"ok"}`,
|
|
statusCode: 200,
|
|
wantStatus: ServerStatusReady,
|
|
},
|
|
{
|
|
name: "loading",
|
|
body: `{"status":"loading model"}`,
|
|
statusCode: 503,
|
|
wantStatus: ServerStatusLoadingModel,
|
|
},
|
|
{
|
|
name: "no slots",
|
|
body: `{"status":"no slot available"}`,
|
|
statusCode: 503,
|
|
wantStatus: ServerStatusNoSlotsAvailable,
|
|
},
|
|
{
|
|
name: "error status",
|
|
body: `{"status":"error","message":"out of memory"}`,
|
|
statusCode: 500,
|
|
wantStatus: ServerStatusError,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/health" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
}
|
|
w.WriteHeader(tt.statusCode)
|
|
fmt.Fprint(w, tt.body)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
// Parse the port from the test server
|
|
parts := strings.Split(srv.URL, ":")
|
|
port := parts[len(parts)-1]
|
|
var portInt int
|
|
fmt.Sscanf(port, "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
}
|
|
|
|
status, err := runner.getServerStatus(t.Context())
|
|
if tt.wantErr && err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
if !tt.wantErr && err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if status != tt.wantStatus {
|
|
t.Errorf("status = %v, want %v", status, tt.wantStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerCompletionSSEParsing(t *testing.T) {
|
|
// Simulate llama-server SSE streaming response
|
|
sseLines := []string{
|
|
`data: {"content":"Hello","stop":false}`,
|
|
``,
|
|
`data: {"content":" world","stop":false}`,
|
|
``,
|
|
`data: {"content":"","stop":true,"stop_type":"eos","timings":{"prompt_n":5,"prompt_ms":10.5,"predicted_n":2,"predicted_ms":20.3}}`,
|
|
``,
|
|
}
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
if r.URL.Path != "/completion" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
return
|
|
}
|
|
|
|
// Verify request body is valid
|
|
var reqBody llamaServerCompletionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
|
t.Errorf("invalid request body: %v", err)
|
|
return
|
|
}
|
|
if reqBody.Prompt != "test prompt" {
|
|
t.Errorf("prompt = %q, want %q", reqBody.Prompt, "test prompt")
|
|
}
|
|
if !reqBody.Stream {
|
|
t.Error("stream should be true")
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
for _, line := range sseLines {
|
|
fmt.Fprintln(w, line)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
options: api.Options{Runner: api.Runner{NumCtx: 2048}},
|
|
}
|
|
|
|
var responses []CompletionResponse
|
|
opts := api.DefaultOptions()
|
|
err := runner.Completion(t.Context(), CompletionRequest{
|
|
Prompt: "test prompt",
|
|
Options: &opts,
|
|
}, func(cr CompletionResponse) {
|
|
responses = append(responses, cr)
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Completion error: %v", err)
|
|
}
|
|
|
|
if len(responses) != 3 {
|
|
t.Fatalf("got %d responses, want 3", len(responses))
|
|
}
|
|
|
|
// First token
|
|
if responses[0].Content != "Hello" {
|
|
t.Errorf("response[0].Content = %q, want %q", responses[0].Content, "Hello")
|
|
}
|
|
if responses[0].Done {
|
|
t.Error("response[0] should not be done")
|
|
}
|
|
|
|
// Second token
|
|
if responses[1].Content != " world" {
|
|
t.Errorf("response[1].Content = %q, want %q", responses[1].Content, " world")
|
|
}
|
|
|
|
// Final response
|
|
if !responses[2].Done {
|
|
t.Error("response[2] should be done")
|
|
}
|
|
if responses[2].DoneReason != DoneReasonStop {
|
|
t.Errorf("DoneReason = %v, want %v", responses[2].DoneReason, DoneReasonStop)
|
|
}
|
|
if responses[2].PromptEvalCount != 5 {
|
|
t.Errorf("PromptEvalCount = %d, want 5", responses[2].PromptEvalCount)
|
|
}
|
|
if responses[2].EvalCount != 2 {
|
|
t.Errorf("EvalCount = %d, want 2", responses[2].EvalCount)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerCompletionLengthStop(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
fmt.Fprintln(w, `data: {"content":"tok","stop":false}`)
|
|
fmt.Fprintln(w, ``)
|
|
fmt.Fprintln(w, `data: {"content":"","stop":true,"stop_type":"limit","timings":{"prompt_n":1,"prompt_ms":1,"predicted_n":1,"predicted_ms":1}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
options: api.Options{Runner: api.Runner{NumCtx: 2048}},
|
|
}
|
|
|
|
var lastResp CompletionResponse
|
|
opts := api.DefaultOptions()
|
|
err := runner.Completion(t.Context(), CompletionRequest{
|
|
Prompt: "test",
|
|
Options: &opts,
|
|
}, func(cr CompletionResponse) {
|
|
lastResp = cr
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Completion error: %v", err)
|
|
}
|
|
if lastResp.DoneReason != DoneReasonLength {
|
|
t.Errorf("DoneReason = %v, want %v", lastResp.DoneReason, DoneReasonLength)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerCompletionRequestFormat(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
format string
|
|
grammar string
|
|
wantGrammar bool
|
|
wantJsonSchema bool
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "no format",
|
|
},
|
|
{
|
|
name: "null format",
|
|
format: `null`,
|
|
},
|
|
{
|
|
name: "empty string format",
|
|
format: `""`,
|
|
},
|
|
{
|
|
name: "json format",
|
|
format: `"json"`,
|
|
wantGrammar: true,
|
|
},
|
|
{
|
|
name: "json schema",
|
|
format: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
|
wantJsonSchema: true,
|
|
},
|
|
{
|
|
name: "raw grammar",
|
|
grammar: `root ::= "hello"`,
|
|
wantGrammar: true,
|
|
},
|
|
{
|
|
name: "invalid format",
|
|
format: `"xml"`,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var capturedReq llamaServerCompletionRequest
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
json.NewDecoder(r.Body).Decode(&capturedReq)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
fmt.Fprintln(w, `data: {"content":"ok","stop":true,"timings":{"prompt_n":1,"prompt_ms":1,"predicted_n":1,"predicted_ms":1}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
options: api.Options{Runner: api.Runner{NumCtx: 2048}},
|
|
}
|
|
|
|
opts := api.DefaultOptions()
|
|
req := CompletionRequest{
|
|
Prompt: "test",
|
|
Options: &opts,
|
|
Grammar: tt.grammar,
|
|
}
|
|
if tt.format != "" {
|
|
req.Format = json.RawMessage(tt.format)
|
|
}
|
|
|
|
err := runner.Completion(t.Context(), req, func(cr CompletionResponse) {})
|
|
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if tt.wantGrammar && capturedReq.Grammar == "" {
|
|
t.Error("expected grammar to be set")
|
|
}
|
|
if tt.wantJsonSchema && capturedReq.JsonSchema == nil {
|
|
t.Error("expected json_schema to be set")
|
|
}
|
|
if !tt.wantGrammar && !tt.wantJsonSchema && capturedReq.Grammar != "" {
|
|
t.Errorf("unexpected grammar: %s", capturedReq.Grammar)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerTokenize(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/tokenize" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
return
|
|
}
|
|
var req map[string]string
|
|
json.NewDecoder(r.Body).Decode(&req)
|
|
if req["content"] != "hello world" {
|
|
t.Errorf("content = %q, want %q", req["content"], "hello world")
|
|
}
|
|
fmt.Fprint(w, `{"tokens":[1,2,3]}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{port: portInt, cmd: fakeRunningCmd()}
|
|
tokens, err := runner.Tokenize(t.Context(), "hello world")
|
|
if err != nil {
|
|
t.Fatalf("Tokenize error: %v", err)
|
|
}
|
|
if len(tokens) != 3 || tokens[0] != 1 || tokens[1] != 2 || tokens[2] != 3 {
|
|
t.Errorf("tokens = %v, want [1,2,3]", tokens)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerDetokenize(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/detokenize" {
|
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
return
|
|
}
|
|
fmt.Fprint(w, `{"content":"hello world"}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{port: portInt, cmd: fakeRunningCmd()}
|
|
content, err := runner.Detokenize(t.Context(), []int{1, 2, 3})
|
|
if err != nil {
|
|
t.Fatalf("Detokenize error: %v", err)
|
|
}
|
|
if content != "hello world" {
|
|
t.Errorf("content = %q, want %q", content, "hello world")
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerEmbedding(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
if r.URL.Path != "/v1/embeddings" {
|
|
t.Errorf("unexpected path: %s, want /v1/embeddings", r.URL.Path)
|
|
return
|
|
}
|
|
// OAI-compatible format (used when sending "input" field)
|
|
fmt.Fprint(w, `{"data":[{"embedding":[0.1,0.2,0.3],"tokens_evaluated":2}],"usage":{"prompt_tokens":2}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
}
|
|
|
|
embedding, count, err := runner.Embedding(t.Context(), "hello")
|
|
if err != nil {
|
|
t.Fatalf("Embedding error: %v", err)
|
|
}
|
|
if len(embedding) != 3 {
|
|
t.Errorf("embedding length = %d, want 3", len(embedding))
|
|
}
|
|
if count != 2 {
|
|
t.Errorf("prompt_eval_count = %d, want 2", count)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerEmbeddingFallbackFormat(t *testing.T) {
|
|
// Fallback: non-OAI array format (from "content" field)
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
fmt.Fprint(w, `[{"index":0,"embedding":[[0.4,0.5,0.6]]}]`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
}
|
|
|
|
embedding, _, err := runner.Embedding(t.Context(), "hello")
|
|
if err != nil {
|
|
t.Fatalf("Embedding error: %v", err)
|
|
}
|
|
if len(embedding) != 3 {
|
|
t.Errorf("embedding length = %d, want 3", len(embedding))
|
|
}
|
|
if embedding[0] != 0.4 {
|
|
t.Errorf("embedding[0] = %v, want 0.4", embedding[0])
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerEmbeddingFlatArrayFallback(t *testing.T) {
|
|
// Non-OAI format with flat (non-nested) embedding array
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
fmt.Fprint(w, `[{"index":0,"embedding":[0.7,0.8,0.9]}]`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
}
|
|
|
|
embedding, _, err := runner.Embedding(t.Context(), "hello")
|
|
if err != nil {
|
|
t.Fatalf("Embedding error: %v", err)
|
|
}
|
|
if len(embedding) != 3 || embedding[0] != 0.7 {
|
|
t.Errorf("embedding = %v, want [0.7, 0.8, 0.9]", embedding)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerEmbeddingTooLargeError(t *testing.T) {
|
|
// llama-server returns 500 for oversized input; adapter should normalize to 400
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
w.WriteHeader(500)
|
|
fmt.Fprint(w, `{"error":{"code":500,"message":"input is too large to process"}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
}
|
|
|
|
_, _, err := runner.Embedding(t.Context(), "very long input")
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
// Should be normalized to 400 for the embed handler's truncation retry
|
|
var statusErr api.StatusError
|
|
if !errors.As(err, &statusErr) {
|
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
|
}
|
|
if statusErr.StatusCode != 400 {
|
|
t.Errorf("status code = %d, want 400", statusErr.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeEmbeddingError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
body string
|
|
wantStatus int
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "physical batch size",
|
|
statusCode: http.StatusInternalServerError,
|
|
body: `{"error":{"code":500,"message":"input (103 tokens) is too large to process. increase the physical batch size (current batch size: 30)"}}`,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantMsg: "the input length exceeds the context length",
|
|
},
|
|
{
|
|
name: "context length string error",
|
|
statusCode: http.StatusInternalServerError,
|
|
body: `{"error":"input length exceeds the context length"}`,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantMsg: "the input length exceeds the context length",
|
|
},
|
|
{
|
|
name: "available context",
|
|
statusCode: http.StatusBadRequest,
|
|
body: `{"error":{"message":"request (302 tokens) exceeds the available context size (256 tokens), try increasing it"}}`,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantMsg: "the input length exceeds the context length",
|
|
},
|
|
{
|
|
name: "unrelated error",
|
|
statusCode: http.StatusInternalServerError,
|
|
body: `{"error":{"message":"backend failed"}}`,
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantMsg: "backend failed",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
status, msg := normalizeEmbeddingError(tt.statusCode, []byte(tt.body))
|
|
if status != tt.wantStatus {
|
|
t.Fatalf("status = %d, want %d", status, tt.wantStatus)
|
|
}
|
|
if msg != tt.wantMsg {
|
|
t.Fatalf("message = %q, want %q", msg, tt.wantMsg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerCompletionWithLogprobs(t *testing.T) {
|
|
// Verify logprobs are parsed from SSE streaming responses
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
fmt.Fprintln(w, `data: {"content":"Hi","stop":false,"completion_probabilities":[{"token":"Hi","logprob":-0.5,"top_logprobs":[{"token":"Hi","logprob":-0.5},{"token":"Hello","logprob":-1.2}]}]}`)
|
|
fmt.Fprintln(w, ``)
|
|
fmt.Fprintln(w, `data: {"content":"","stop":true,"stop_type":"eos","timings":{"prompt_n":1,"prompt_ms":1,"predicted_n":1,"predicted_ms":1}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
options: api.Options{Runner: api.Runner{NumCtx: 2048}},
|
|
}
|
|
|
|
var responses []CompletionResponse
|
|
opts := api.DefaultOptions()
|
|
err := runner.Completion(t.Context(), CompletionRequest{
|
|
Prompt: "test",
|
|
Options: &opts,
|
|
Logprobs: true,
|
|
TopLogprobs: 2,
|
|
}, func(cr CompletionResponse) {
|
|
responses = append(responses, cr)
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Completion error: %v", err)
|
|
}
|
|
|
|
// First response should have logprobs
|
|
if len(responses) < 1 {
|
|
t.Fatal("expected at least 1 response")
|
|
}
|
|
if len(responses[0].Logprobs) == 0 {
|
|
t.Fatal("expected logprobs in first response")
|
|
}
|
|
if responses[0].Logprobs[0].Token != "Hi" {
|
|
t.Errorf("token = %q, want %q", responses[0].Logprobs[0].Token, "Hi")
|
|
}
|
|
if responses[0].Logprobs[0].Logprob != -0.5 {
|
|
t.Errorf("logprob = %v, want -0.5", responses[0].Logprobs[0].Logprob)
|
|
}
|
|
if len(responses[0].Logprobs[0].TopLogprobs) != 2 {
|
|
t.Errorf("top_logprobs len = %d, want 2", len(responses[0].Logprobs[0].TopLogprobs))
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerCompletionSamplingParams(t *testing.T) {
|
|
var capturedReq llamaServerCompletionRequest
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
return
|
|
}
|
|
json.NewDecoder(r.Body).Decode(&capturedReq)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
fmt.Fprintln(w, `data: {"content":"ok","stop":true,"timings":{"prompt_n":1,"prompt_ms":1,"predicted_n":1,"predicted_ms":1}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
sem: semaphore.NewWeighted(1),
|
|
options: api.Options{Runner: api.Runner{NumCtx: 2048}},
|
|
}
|
|
|
|
opts := api.Options{
|
|
Runner: api.Runner{NumCtx: 2048},
|
|
Temperature: 0.7,
|
|
TopK: 40,
|
|
TopP: 0.9,
|
|
MinP: 0.05,
|
|
NumPredict: 100,
|
|
Stop: []string{"</s>"},
|
|
RepeatPenalty: 1.1,
|
|
FrequencyPenalty: 0.5,
|
|
PresencePenalty: 0.3,
|
|
Seed: 42,
|
|
}
|
|
|
|
err := runner.Completion(t.Context(), CompletionRequest{
|
|
Prompt: "test",
|
|
Options: &opts,
|
|
}, func(cr CompletionResponse) {})
|
|
if err != nil {
|
|
t.Fatalf("Completion error: %v", err)
|
|
}
|
|
|
|
if capturedReq.Temperature != 0.7 {
|
|
t.Errorf("temperature = %v, want 0.7", capturedReq.Temperature)
|
|
}
|
|
if capturedReq.TopK != 40 {
|
|
t.Errorf("top_k = %v, want 40", capturedReq.TopK)
|
|
}
|
|
if capturedReq.TopP != 0.9 {
|
|
t.Errorf("top_p = %v, want 0.9", capturedReq.TopP)
|
|
}
|
|
if capturedReq.NPredict != 100 {
|
|
t.Errorf("n_predict = %v, want 100", capturedReq.NPredict)
|
|
}
|
|
if capturedReq.Seed != 42 {
|
|
t.Errorf("seed = %v, want 42", capturedReq.Seed)
|
|
}
|
|
if capturedReq.RepeatPenalty != 1.1 {
|
|
t.Errorf("repeat_penalty = %v, want 1.1", capturedReq.RepeatPenalty)
|
|
}
|
|
if len(capturedReq.Stop) != 1 || capturedReq.Stop[0] != "</s>" {
|
|
t.Errorf("stop = %v, want [</s>]", capturedReq.Stop)
|
|
}
|
|
}
|
|
|
|
func TestLlamaServerWaitUntilRunning(t *testing.T) {
|
|
callCount := 0
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
callCount++
|
|
if callCount < 3 {
|
|
w.WriteHeader(503)
|
|
fmt.Fprint(w, `{"status":"loading model"}`)
|
|
return
|
|
}
|
|
fmt.Fprint(w, `{"status":"ok"}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
parts := strings.Split(srv.URL, ":")
|
|
var portInt int
|
|
fmt.Sscanf(parts[len(parts)-1], "%d", &portInt)
|
|
|
|
runner := &llamaServerRunner{
|
|
port: portInt,
|
|
cmd: fakeRunningCmd(),
|
|
done: make(chan struct{}),
|
|
loadStart: time.Now(),
|
|
}
|
|
|
|
err := runner.WaitUntilRunning(t.Context())
|
|
if err != nil {
|
|
t.Fatalf("WaitUntilRunning error: %v", err)
|
|
}
|
|
if callCount < 3 {
|
|
t.Errorf("expected at least 3 health checks, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
func TestMemoryParsingWriter(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
lines []string
|
|
wantGPU float64 // MiB
|
|
wantTotal float64 // MiB
|
|
}{
|
|
{
|
|
name: "Metal + CPU",
|
|
lines: []string{
|
|
"llama_model_load_from_file_impl: Metal model buffer size = 1234.56 MiB\n",
|
|
"llama_model_load_from_file_impl: CPU model buffer size = 56.78 MiB\n",
|
|
},
|
|
wantGPU: 1234.56,
|
|
wantTotal: 1234.56 + 56.78,
|
|
},
|
|
{
|
|
name: "CUDA multi-GPU + host",
|
|
lines: []string{
|
|
"llama_model_load_from_file_impl: CUDA0 model buffer size = 800.00 MiB\n",
|
|
"llama_model_load_from_file_impl: CUDA1 model buffer size = 400.00 MiB\n",
|
|
"llama_model_load_from_file_impl: CUDA_Host model buffer size = 100.00 MiB\n",
|
|
},
|
|
wantGPU: 1200.00,
|
|
wantTotal: 1300.00,
|
|
},
|
|
{
|
|
name: "ROCm + host",
|
|
lines: []string{
|
|
"llama_model_load_from_file_impl: ROCm0 model buffer size = 2000.00 MiB\n",
|
|
"llama_model_load_from_file_impl: ROCm_Host model buffer size = 150.00 MiB\n",
|
|
},
|
|
wantGPU: 2000.00,
|
|
wantTotal: 2150.00,
|
|
},
|
|
{
|
|
name: "Vulkan + host",
|
|
lines: []string{
|
|
"llama_model_load_from_file_impl: Vulkan0 model buffer size = 500.00 MiB\n",
|
|
"llama_model_load_from_file_impl: Vulkan_Host model buffer size = 50.00 MiB\n",
|
|
},
|
|
wantGPU: 500.00,
|
|
wantTotal: 550.00,
|
|
},
|
|
{
|
|
name: "Metal Private + Mapped (both GPU memory)",
|
|
lines: []string{
|
|
"llama_model_load_from_file_impl: Metal_Private model buffer size = 300.00 MiB\n",
|
|
"llama_model_load_from_file_impl: Metal_Mapped model buffer size = 20.00 MiB\n",
|
|
},
|
|
wantGPU: 320.00, // both Private and Mapped are device memory
|
|
wantTotal: 320.00,
|
|
},
|
|
{
|
|
name: "no buffer lines",
|
|
lines: []string{"some random log line\n"},
|
|
wantGPU: 0,
|
|
wantTotal: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
runner := &llamaServerRunner{vramByDevice: make(map[string]uint64)}
|
|
w := &memoryParsingWriter{
|
|
inner: io.Discard,
|
|
runner: runner,
|
|
}
|
|
|
|
for _, line := range tt.lines {
|
|
w.Write([]byte(line))
|
|
}
|
|
|
|
expectedGPU := uint64(tt.wantGPU * 1024 * 1024)
|
|
expectedTotal := uint64(tt.wantTotal * 1024 * 1024)
|
|
|
|
if runner.memGPU != expectedGPU {
|
|
t.Errorf("memGPU = %d, want %d", runner.memGPU, expectedGPU)
|
|
}
|
|
if runner.memTotal != expectedTotal {
|
|
t.Errorf("memTotal = %d, want %d", runner.memTotal, expectedTotal)
|
|
}
|
|
|
|
total, vram := runner.MemorySize()
|
|
if total != expectedTotal {
|
|
t.Errorf("MemorySize total = %d, want %d", total, expectedTotal)
|
|
}
|
|
if vram != expectedGPU {
|
|
t.Errorf("MemorySize vram = %d, want %d", vram, expectedGPU)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMemoryParsingPerDevice(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
lines []string
|
|
wantDevice map[string]uint64 // device name → expected MiB
|
|
}{
|
|
{
|
|
name: "CUDA multi-GPU all buffer types",
|
|
lines: []string{
|
|
"load_tensors: CUDA0 model buffer size = 852.89 MiB\n",
|
|
"load_tensors: CUDA1 model buffer size = 1065.46 MiB\n",
|
|
"load_tensors: CPU_Mapped model buffer size = 308.23 MiB\n",
|
|
"llama_kv_cache: CUDA0 KV buffer size = 1920.00 MiB\n",
|
|
"llama_kv_cache: CUDA1 KV buffer size = 1664.00 MiB\n",
|
|
"sched_reserve: CUDA0 compute buffer size = 378.04 MiB\n",
|
|
"sched_reserve: CUDA1 compute buffer size = 408.55 MiB\n",
|
|
"sched_reserve: CUDA_Host compute buffer size = 268.05 MiB\n",
|
|
},
|
|
wantDevice: map[string]uint64{
|
|
"CUDA0": 852 + 1920 + 378, // model + KV + compute (approx MiB)
|
|
"CUDA1": 1065 + 1664 + 408,
|
|
},
|
|
},
|
|
{
|
|
name: "Metal with mapped buffers",
|
|
lines: []string{
|
|
"load_tensors: MTL0_Mapped model buffer size = 1918.35 MiB\n",
|
|
"llama_kv_cache: MTL0 KV buffer size = 448.00 MiB\n",
|
|
"sched_reserve: MTL0 compute buffer size = 256.50 MiB\n",
|
|
"sched_reserve: CPU compute buffer size = 20.01 MiB\n",
|
|
},
|
|
wantDevice: map[string]uint64{
|
|
"MTL0": 1918 + 448 + 256, // Mapped model weights + KV + compute (all GPU)
|
|
},
|
|
},
|
|
{
|
|
name: "ROCm single GPU",
|
|
lines: []string{
|
|
"load_tensors: ROCm0 model buffer size = 1918.35 MiB\n",
|
|
"llama_kv_cache: ROCm0 KV buffer size = 448.00 MiB\n",
|
|
"sched_reserve: ROCm0 compute buffer size = 256.50 MiB\n",
|
|
"sched_reserve: ROCm_Host compute buffer size = 20.01 MiB\n",
|
|
},
|
|
wantDevice: map[string]uint64{
|
|
"ROCm0": 1918 + 448 + 256,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
runner := &llamaServerRunner{vramByDevice: make(map[string]uint64)}
|
|
w := &memoryParsingWriter{inner: io.Discard, runner: runner}
|
|
|
|
for _, line := range tt.lines {
|
|
w.Write([]byte(line))
|
|
}
|
|
|
|
for dev, wantMiB := range tt.wantDevice {
|
|
got := runner.vramByDevice[dev] / (1024 * 1024) // convert to MiB
|
|
// Allow ~1 MiB tolerance for floating point
|
|
if got < wantMiB-2 || got > wantMiB+2 {
|
|
t.Errorf("vramByDevice[%q] = %d MiB, want ~%d MiB", dev, got, wantMiB)
|
|
}
|
|
}
|
|
|
|
// Verify host/mapped buffers are NOT in per-device tracking
|
|
for dev := range runner.vramByDevice {
|
|
if !isGPUBuffer(dev) {
|
|
t.Errorf("non-GPU buffer %q found in vramByDevice", dev)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVRAMByGPU(t *testing.T) {
|
|
runner := &llamaServerRunner{
|
|
vramByDevice: map[string]uint64{
|
|
"CUDA0": 1000 * 1024 * 1024,
|
|
"CUDA1": 2000 * 1024 * 1024,
|
|
},
|
|
gpus: []ml.DeviceInfo{
|
|
{DeviceID: ml.DeviceID{ID: "0", Library: "CUDA"}, Name: "CUDA0"},
|
|
{DeviceID: ml.DeviceID{ID: "1", Library: "CUDA"}, Name: "CUDA1"},
|
|
},
|
|
}
|
|
|
|
got0 := runner.VRAMByGPU(ml.DeviceID{ID: "0", Library: "CUDA"})
|
|
if got0 != 1000*1024*1024 {
|
|
t.Errorf("VRAMByGPU(CUDA:0) = %d, want %d", got0, 1000*1024*1024)
|
|
}
|
|
|
|
got1 := runner.VRAMByGPU(ml.DeviceID{ID: "1", Library: "CUDA"})
|
|
if got1 != 2000*1024*1024 {
|
|
t.Errorf("VRAMByGPU(CUDA:1) = %d, want %d", got1, 2000*1024*1024)
|
|
}
|
|
|
|
// Unknown device returns 0
|
|
gotUnknown := runner.VRAMByGPU(ml.DeviceID{ID: "9", Library: "CUDA"})
|
|
if gotUnknown != 0 {
|
|
t.Errorf("VRAMByGPU(unknown) = %d, want 0", gotUnknown)
|
|
}
|
|
}
|
|
|
|
func TestGetDeviceInfos(t *testing.T) {
|
|
runner := &llamaServerRunner{
|
|
vramByDevice: map[string]uint64{
|
|
"CUDA0": 3000 * 1024 * 1024,
|
|
},
|
|
gpus: []ml.DeviceInfo{
|
|
{
|
|
DeviceID: ml.DeviceID{ID: "0", Library: "CUDA"},
|
|
Name: "CUDA0",
|
|
TotalMemory: 16000 * 1024 * 1024,
|
|
FreeMemory: 15000 * 1024 * 1024, // stale value from discovery
|
|
},
|
|
},
|
|
}
|
|
|
|
infos := runner.GetDeviceInfos(context.Background())
|
|
if len(infos) != 1 {
|
|
t.Fatalf("expected 1 device, got %d", len(infos))
|
|
}
|
|
// Free should be Total - Used, not the stale discovery value
|
|
expectedFree := uint64((16000 - 3000) * 1024 * 1024)
|
|
if infos[0].FreeMemory != expectedFree {
|
|
t.Errorf("FreeMemory = %d, want %d", infos[0].FreeMemory, expectedFree)
|
|
}
|
|
}
|
|
|
|
func TestGetDeviceInfosMinOfTwo(t *testing.T) {
|
|
// External consumer scenario: system reports less free than our accounting expects
|
|
runner := &llamaServerRunner{
|
|
vramByDevice: map[string]uint64{
|
|
"CUDA0": 3000 * 1024 * 1024, // we used 3GB
|
|
},
|
|
systemFreeAtLoad: map[string]uint64{
|
|
"CUDA0": 12000 * 1024 * 1024, // system said 12GB free at load time (external app using 4GB)
|
|
},
|
|
gpus: []ml.DeviceInfo{
|
|
{
|
|
DeviceID: ml.DeviceID{ID: "0", Library: "CUDA"},
|
|
Name: "CUDA0",
|
|
TotalMemory: 16000 * 1024 * 1024, // 16GB total
|
|
},
|
|
},
|
|
}
|
|
|
|
infos := runner.GetDeviceInfos(context.Background())
|
|
// Our accounting: 16000 - 3000 = 13000 MB free
|
|
// System-based: 12000 - 3000 = 9000 MB free (external consumer detected)
|
|
// Min = 9000 MB
|
|
expectedFree := uint64(9000 * 1024 * 1024)
|
|
if infos[0].FreeMemory != expectedFree {
|
|
t.Errorf("FreeMemory = %d MiB, want %d MiB (min-of-two should detect external consumer)",
|
|
infos[0].FreeMemory/(1024*1024), expectedFree/(1024*1024))
|
|
}
|
|
}
|
|
|
|
func TestGetDeviceInfosSystemOptimistic(t *testing.T) {
|
|
// Platform where system over-reports free (e.g., Metal shared memory)
|
|
runner := &llamaServerRunner{
|
|
vramByDevice: map[string]uint64{
|
|
"MTL0": 5000 * 1024 * 1024, // we used 5GB
|
|
},
|
|
systemFreeAtLoad: map[string]uint64{
|
|
"MTL0": 100000 * 1024 * 1024, // system says 100GB free (unified memory, unreliable)
|
|
},
|
|
gpus: []ml.DeviceInfo{
|
|
{
|
|
DeviceID: ml.DeviceID{ID: "0", Library: "Metal"},
|
|
Name: "MTL0",
|
|
TotalMemory: 100000 * 1024 * 1024,
|
|
},
|
|
},
|
|
}
|
|
|
|
infos := runner.GetDeviceInfos(context.Background())
|
|
// Our accounting: 100000 - 5000 = 95000 MB
|
|
// System-based: 100000 - 5000 = 95000 MB
|
|
// Min = 95000 MB (both agree, system isn't lying here)
|
|
expectedFree := uint64(95000 * 1024 * 1024)
|
|
if infos[0].FreeMemory != expectedFree {
|
|
t.Errorf("FreeMemory = %d MiB, want %d MiB",
|
|
infos[0].FreeMemory/(1024*1024), expectedFree/(1024*1024))
|
|
}
|
|
}
|
|
|
|
func TestIsGPUBuffer(t *testing.T) {
|
|
gpu := []string{
|
|
"Metal", "Metal_Private", "CUDA0", "CUDA1", "ROCm0", "Vulkan0", "MUSA0",
|
|
"MTL0_Mapped", "MTL0_REPACK", "CUDA0_Mapped",
|
|
}
|
|
for _, name := range gpu {
|
|
if !isGPUBuffer(name) {
|
|
t.Errorf("isGPUBuffer(%q) = false, want true", name)
|
|
}
|
|
}
|
|
notGPU := []string{
|
|
"CPU", "BLAS", "CUDA_Host", "ROCm_Host", "Vulkan_Host",
|
|
"CPU_Mapped", "CPU_REPACK",
|
|
}
|
|
for _, name := range notGPU {
|
|
if isGPUBuffer(name) {
|
|
t.Errorf("isGPUBuffer(%q) = true, want false", name)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFindLlamaServer(t *testing.T) {
|
|
// This just tests that the function doesn't panic and returns a reasonable error
|
|
// when the binary doesn't exist in the expected locations
|
|
_, err := FindLlamaServer()
|
|
// In the test environment, it may or may not exist depending on whether
|
|
// cmake was run. Just verify it doesn't panic.
|
|
_ = err
|
|
}
|
|
|
|
// fakeRunningCmd returns an exec.Cmd that looks like it's still running
|
|
// (ProcessState is nil, which is the case before Wait() completes).
|
|
// Registers cleanup via t.Cleanup to prevent zombie processes.
|
|
func fakeRunningCmd() *exec.Cmd {
|
|
cmd := exec.Command("sleep", "3600")
|
|
cmd.Start()
|
|
// Note: cleanup happens when the test binary exits since we can't
|
|
// pass *testing.T here without changing all call sites. The OS will
|
|
// SIGKILL children when the test process exits.
|
|
return cmd
|
|
}
|