diff --git a/api/types.go b/api/types.go index 82caf17dc..8ae120492 100644 --- a/api/types.go +++ b/api/types.go @@ -922,6 +922,19 @@ type UserResponse struct { Plan string `json:"plan,omitempty"` } +type UsageResponse struct { + // Start is the time the server started tracking usage (UTC, RFC 3339). + Start time.Time `json:"start"` + Usage []ModelUsageData `json:"usage"` +} + +type ModelUsageData struct { + Model string `json:"model"` + Requests int64 `json:"requests"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` +} + // Tensor describes the metadata for a given tensor. type Tensor struct { Name string `json:"name"` diff --git a/docs/api.md b/docs/api.md index 150479e6a..8ae21af1a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -15,6 +15,7 @@ - [Push a Model](#push-a-model) - [Generate Embeddings](#generate-embeddings) - [List Running Models](#list-running-models) +- [Usage](#usage) - [Version](#version) - [Experimental: Image Generation](#image-generation-experimental) @@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{ } ``` +## Usage + +``` +GET /api/usage +``` + +Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format. + +### Examples + +#### Request + +```shell +curl http://localhost:11434/api/usage +``` + +#### Response + +```json +{ + "start": "2025-01-27T20:00:00Z", + "usage": [ + { + "model": "llama3.2", + "requests": 5, + "prompt_tokens": 130, + "completion_tokens": 890 + }, + { + "model": "deepseek-r1", + "requests": 2, + "prompt_tokens": 48, + "completion_tokens": 312 + } + ] +} +``` + +#### Response fields + +- `start`: when the server started tracking usage (UTC, RFC 3339) +- `usage`: list of per-model usage statistics + - `model`: model name + - `requests`: total number of completed requests + - `prompt_tokens`: total prompt tokens evaluated + - `completion_tokens`: total completion tokens generated + ## Version ``` diff --git a/server/routes.go b/server/routes.go index cbe771d9f..ef5356978 100644 --- a/server/routes.go +++ b/server/routes.go @@ -91,6 +91,8 @@ type Server struct { aliasesOnce sync.Once aliases *store aliasesErr error + lowVRAM bool + usage *UsageTracker } func init() { @@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { c.Header("Content-Type", contentType) fn := func(resp api.GenerateResponse) error { + if resp.Done { + s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount) + } + resp.Model = origModel resp.RemoteModel = m.Config.RemoteModel resp.RemoteHost = m.Config.RemoteHost @@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { } res.Context = tokens } + + s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount) } if builtinParser != nil { @@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/experimental/aliases", s.CreateAliasHandler) r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler) + r.GET("/api/usage", s.UsageHandler) + // Inference r.GET("/api/ps", s.PsHandler) r.POST("/api/generate", s.GenerateHandler) @@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error { } } - s := &Server{addr: ln.Addr()} + s := &Server{addr: ln.Addr(), usage: NewUsageTracker()} var rc *ollama.Registry if useClient2 { @@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) { c.JSON(http.StatusOK, nil) } +func (s *Server) UsageHandler(c *gin.Context) { + c.JSON(http.StatusOK, s.usage.Stats()) +} + func (s *Server) PsHandler(c *gin.Context) { models := []api.ProcessModelResponse{} @@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) { c.Header("Content-Type", contentType) fn := func(resp api.ChatResponse) error { + if resp.Done { + s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount) + } + resp.Model = origModel resp.RemoteModel = m.Config.RemoteModel resp.RemoteHost = m.Config.RemoteHost @@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) { res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + + s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount) } if builtinParser != nil { diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 0679b4262..d8c288a62 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) { if r.Method != http.MethodPost { t.Errorf("Expected POST request, got %s", r.Method) } - if r.URL.Path != "/api/chat" { - t.Errorf("Expected path '/api/chat', got %s", r.URL.Path) - } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - resp := api.ChatResponse{ - Model: "test", - Done: true, - DoneReason: "load", - } - if err := json.NewEncoder(w).Encode(&resp); err != nil { - t.Fatal(err) + + switch r.URL.Path { + case "/api/chat": + resp := api.ChatResponse{ + Model: "test", + Done: true, + DoneReason: "load", + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 20, + }, + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + t.Fatal(err) + } + case "/api/generate": + resp := api.GenerateResponse{ + Model: "test", + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{ + PromptEvalCount: 5, + EvalCount: 15, + }, + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + t.Fatal(err) + } + default: + t.Errorf("unexpected path %s", r.URL.Path) } })) defer rs.Close() @@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) { } t.Setenv("OLLAMA_REMOTES", p.Hostname()) - s := Server{} + s := Server{usage: NewUsageTracker()} w := createRequest(t, s.CreateHandler, api.CreateRequest{ Model: "test-cloud", RemoteHost: rs.URL, @@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) { t.Errorf("expected done reason load, got %s", actual.DoneReason) } }) + + t.Run("remote chat usage tracking", func(t *testing.T) { + stats := s.usage.Stats() + found := false + for _, m := range stats.Usage { + if m.Model == "test-cloud" { + found = true + if m.Requests != 1 { + t.Errorf("expected 1 request, got %d", m.Requests) + } + if m.PromptTokens != 10 { + t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens) + } + if m.CompletionTokens != 20 { + t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens) + } + } + } + if !found { + t.Error("expected usage entry for test-cloud") + } + }) + + t.Run("remote generate usage tracking", func(t *testing.T) { + // Reset the tracker for a clean test + s.usage = NewUsageTracker() + + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-cloud", + Prompt: "hello", + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + stats := s.usage.Stats() + found := false + for _, m := range stats.Usage { + if m.Model == "test-cloud" { + found = true + if m.Requests != 1 { + t.Errorf("expected 1 request, got %d", m.Requests) + } + if m.PromptTokens != 5 { + t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens) + } + if m.CompletionTokens != 15 { + t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens) + } + } + } + if !found { + t.Error("expected usage entry for test-cloud") + } + }) } func TestGenerateChat(t *testing.T) { diff --git a/server/usage.go b/server/usage.go new file mode 100644 index 000000000..38be92e49 --- /dev/null +++ b/server/usage.go @@ -0,0 +1,62 @@ +package server + +import ( + "sync" + "time" + + "github.com/ollama/ollama/api" +) + +type ModelUsage struct { + Requests int64 + PromptTokens int64 + CompletionTokens int64 +} + +type UsageTracker struct { + mu sync.Mutex + start time.Time + models map[string]*ModelUsage +} + +func NewUsageTracker() *UsageTracker { + return &UsageTracker{ + start: time.Now().UTC(), + models: make(map[string]*ModelUsage), + } +} + +func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) { + u.mu.Lock() + defer u.mu.Unlock() + + m, ok := u.models[model] + if !ok { + m = &ModelUsage{} + u.models[model] = m + } + + m.Requests++ + m.PromptTokens += int64(promptTokens) + m.CompletionTokens += int64(completionTokens) +} + +func (u *UsageTracker) Stats() api.UsageResponse { + u.mu.Lock() + defer u.mu.Unlock() + + byModel := make([]api.ModelUsageData, 0, len(u.models)) + for model, usage := range u.models { + byModel = append(byModel, api.ModelUsageData{ + Model: model, + Requests: usage.Requests, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + }) + } + + return api.UsageResponse{ + Start: u.start, + Usage: byModel, + } +} diff --git a/server/usage_test.go b/server/usage_test.go new file mode 100644 index 000000000..7934395c2 --- /dev/null +++ b/server/usage_test.go @@ -0,0 +1,136 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" +) + +func TestUsageTrackerRecord(t *testing.T) { + tracker := NewUsageTracker() + + tracker.Record("model-a", 10, 20) + tracker.Record("model-a", 5, 15) + tracker.Record("model-b", 100, 200) + + stats := tracker.Stats() + + if len(stats.Usage) != 2 { + t.Fatalf("expected 2 models, got %d", len(stats.Usage)) + } + + lookup := make(map[string]api.ModelUsageData) + for _, m := range stats.Usage { + lookup[m.Model] = m + } + + a := lookup["model-a"] + if a.Requests != 2 { + t.Errorf("model-a requests: expected 2, got %d", a.Requests) + } + if a.PromptTokens != 15 { + t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens) + } + if a.CompletionTokens != 35 { + t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens) + } + + b := lookup["model-b"] + if b.Requests != 1 { + t.Errorf("model-b requests: expected 1, got %d", b.Requests) + } + if b.PromptTokens != 100 { + t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens) + } + if b.CompletionTokens != 200 { + t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens) + } +} + +func TestUsageTrackerConcurrent(t *testing.T) { + tracker := NewUsageTracker() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tracker.Record("model-a", 1, 2) + }() + } + wg.Wait() + + stats := tracker.Stats() + if len(stats.Usage) != 1 { + t.Fatalf("expected 1 model, got %d", len(stats.Usage)) + } + + m := stats.Usage[0] + if m.Requests != 100 { + t.Errorf("requests: expected 100, got %d", m.Requests) + } + if m.PromptTokens != 100 { + t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens) + } + if m.CompletionTokens != 200 { + t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens) + } +} + +func TestUsageTrackerStart(t *testing.T) { + tracker := NewUsageTracker() + + stats := tracker.Stats() + if stats.Start.IsZero() { + t.Error("expected non-zero start time") + } +} + +func TestUsageHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + s := &Server{ + usage: NewUsageTracker(), + } + + s.usage.Record("llama3", 50, 100) + s.usage.Record("llama3", 25, 50) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil) + + s.UsageHandler(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.UsageResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if len(resp.Usage) != 1 { + t.Fatalf("expected 1 model, got %d", len(resp.Usage)) + } + + m := resp.Usage[0] + if m.Model != "llama3" { + t.Errorf("expected model llama3, got %s", m.Model) + } + if m.Requests != 2 { + t.Errorf("expected 2 requests, got %d", m.Requests) + } + if m.PromptTokens != 75 { + t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens) + } + if m.CompletionTokens != 150 { + t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens) + } +}