diff --git a/api/client.go b/api/client.go index c70516899..ef638da5e 100644 --- a/api/client.go +++ b/api/client.go @@ -377,6 +377,15 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { return &lr, nil } +// Usage returns usage statistics and system info. +func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) { + var ur UsageResponse + if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil { + return nil, err + } + return &ur, nil +} + // Copy copies a model - creating a model with another name from an existing // model. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { diff --git a/api/types.go b/api/types.go index 2434fe478..61dff2e9d 100644 --- a/api/types.go +++ b/api/types.go @@ -792,6 +792,33 @@ type ProcessResponse struct { Models []ProcessModelResponse `json:"models"` } +// UsageResponse is the response from [Client.Usage]. +type UsageResponse struct { + GPUs []GPUUsage `json:"gpus,omitempty"` +} + +// GPUUsage contains GPU/device memory usage breakdown. +type GPUUsage struct { + Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090") + Backend string `json:"backend"` // CUDA, ROCm, Metal, etc. + Total uint64 `json:"total"` + Free uint64 `json:"free"` + Used uint64 `json:"used"` // Memory used by Ollama + Other uint64 `json:"other"` // Memory used by other processes +} + +// UsageStats contains usage statistics. +type UsageStats struct { + Requests int64 `json:"requests"` + TokensInput int64 `json:"tokens_input"` + TokensOutput int64 `json:"tokens_output"` + TotalTokens int64 `json:"total_tokens"` + Models map[string]int64 `json:"models,omitempty"` + Sources map[string]int64 `json:"sources,omitempty"` + ToolCalls int64 `json:"tool_calls,omitempty"` + StructuredOutput int64 `json:"structured_output,omitempty"` +} + // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { Name string `json:"name"` diff --git a/cmd/cmd.go b/cmd/cmd.go index f56a1d4b7..bae4f5bfb 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1833,6 +1833,7 @@ func NewCLI() *cobra.Command { PreRunE: checkServerHeartbeat, RunE: ListRunningHandler, } + copyCmd := &cobra.Command{ Use: "cp SOURCE DESTINATION", Short: "Copy a model", diff --git a/envconfig/config.go b/envconfig/config.go index 238e5e6e1..4e8c22aa6 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -206,6 +206,8 @@ var ( UseAuth = Bool("OLLAMA_AUTH") // Enable Vulkan backend EnableVulkan = Bool("OLLAMA_VULKAN") + // Usage enables usage statistics reporting + Usage = Bool("OLLAMA_USAGE") ) func String(s string) func() string { diff --git a/server/routes.go b/server/routes.go index 977a13ff2..9fc9f1b5f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "os/signal" + "runtime" "slices" "strings" "sync/atomic" @@ -44,6 +45,7 @@ import ( "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" + "github.com/ollama/ollama/server/usage" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" "github.com/ollama/ollama/tools" @@ -82,6 +84,7 @@ type Server struct { addr net.Addr sched *Scheduler lowVRAM bool + stats *usage.Stats } func init() { @@ -104,6 +107,30 @@ var ( errBadTemplate = errors.New("template error") ) +// usage records a request to usage stats if enabled. +func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) { + if s.stats == nil { + return + } + s.stats.Record(&usage.Request{ + Endpoint: endpoint, + Model: model, + Architecture: architecture, + APIType: usage.ClassifyAPIType(c.Request.URL.Path), + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + UsedTools: usedTools, + }) +} + +// usageError records a failed request to usage stats if enabled. +func (s *Server) usageError() { + if s.stats == nil { + return + } + s.stats.RecordError() +} + func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -374,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return } else if err != nil { - handleScheduleError(c, req.Model, err) + s.handleScheduleError(c, req.Model, err) return } @@ -561,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false) if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) @@ -680,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { - handleScheduleError(c, req.Model, err) + s.handleScheduleError(c, req.Model, err) return } @@ -790,6 +818,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { LoadDuration: checkpointLoaded.Sub(checkpointStart), PromptEvalCount: int(totalTokens), } + s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false) c.JSON(http.StatusOK, resp) } @@ -827,7 +856,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { - handleScheduleError(c, req.Model, err) + s.handleScheduleError(c, req.Model, err) return } @@ -1531,6 +1560,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { // Inference r.GET("/api/ps", s.PsHandler) + r.GET("/api/usage", s.UsageHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) @@ -1593,6 +1623,13 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} + // Initialize usage stats if enabled + if envconfig.Usage() { + s.stats = usage.New() + s.stats.Start() + slog.Info("usage stats enabled") + } + var rc *ollama.Registry if useClient2 { var err error @@ -1632,6 +1669,9 @@ func Serve(ln net.Listener) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals + if s.stats != nil { + s.stats.Stop() + } srvr.Close() schedDone() sched.unloadAllRunners() @@ -1649,6 +1689,24 @@ func Serve(ln net.Listener) error { gpus := discover.GPUDevices(ctx, nil) discover.LogDetails(gpus) + // Set GPU info for usage reporting + if s.stats != nil { + usage.GPUInfoFunc = func() []usage.GPU { + var result []usage.GPU + for _, gpu := range gpus { + result = append(result, usage.GPU{ + Name: gpu.Name, + VRAMBytes: gpu.TotalMemory, + ComputeMajor: gpu.ComputeMajor, + ComputeMinor: gpu.ComputeMinor, + DriverMajor: gpu.DriverMajor, + DriverMinor: gpu.DriverMinor, + }) + } + return result + } + } + var totalVRAM uint64 for _, gpu := range gpus { totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead() @@ -1852,6 +1910,63 @@ func (s *Server) PsHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } +func (s *Server) UsageHandler(c *gin.Context) { + // Get total VRAM used by Ollama + s.sched.loadedMu.Lock() + var totalOllamaVRAM uint64 + for _, runner := range s.sched.loaded { + totalOllamaVRAM += runner.vramSize + } + s.sched.loadedMu.Unlock() + + var resp api.UsageResponse + + // Get GPU/device info + gpus := discover.GPUDevices(c.Request.Context(), nil) + + // On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize + // because unified memory means GPU and CPU share the same physical RAM pool + var sysTotal, sysFree uint64 + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + sysInfo := discover.GetSystemInfo() + sysTotal = sysInfo.TotalMemory + sysFree = sysInfo.FreeMemory + } + + for _, gpu := range gpus { + total := gpu.TotalMemory + free := gpu.FreeMemory + + // On Apple Silicon, override with system memory values + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 { + total = sysTotal + free = sysFree + } + + used := total - free + ollamaUsed := min(totalOllamaVRAM, used) + otherUsed := used - ollamaUsed + + // Use Description for Name (actual device name like "Apple M2 Max") + // Fall back to backend name if Description is empty + name := gpu.Description + if name == "" { + name = gpu.Name + } + + resp.GPUs = append(resp.GPUs, api.GPUUsage{ + Name: name, + Backend: gpu.Library, + Total: total, + Free: free, + Used: ollamaUsed, + Other: otherUsed, + }) + } + + c.JSON(http.StatusOK, resp) +} + func toolCallId() string { const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, 8) @@ -2032,7 +2147,7 @@ func (s *Server) ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return } else if err != nil { - handleScheduleError(c, req.Model, err) + s.handleScheduleError(c, req.Model, err) return } @@ -2180,6 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) { res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0) } if builtinParser != nil { @@ -2355,6 +2471,7 @@ func (s *Server) ChatHandler(c *gin.Context) { resp.Message.ToolCalls = toolCalls } + s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0) c.JSON(http.StatusOK, resp) return } @@ -2362,7 +2479,8 @@ func (s *Server) ChatHandler(c *gin.Context) { streamResponse(c, ch) } -func handleScheduleError(c *gin.Context, name string, err error) { +func (s *Server) handleScheduleError(c *gin.Context, name string, err error) { + s.usageError() switch { case errors.Is(err, errCapabilities), errors.Is(err, errRequired): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/server/routes_usage_test.go b/server/routes_usage_test.go new file mode 100644 index 000000000..0e45243fb --- /dev/null +++ b/server/routes_usage_test.go @@ -0,0 +1,60 @@ +package server + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" +) + +func TestUsageHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("empty server", func(t *testing.T) { + s := Server{ + sched: &Scheduler{ + loaded: make(map[string]*runnerRef), + }, + } + + w := createRequest(t, s.UsageHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + var resp api.UsageResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + // GPUs may or may not be present depending on system + // Just verify we can decode the response + }) + + t.Run("response structure", func(t *testing.T) { + s := Server{ + sched: &Scheduler{ + loaded: make(map[string]*runnerRef), + }, + } + + w := createRequest(t, s.UsageHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + // Verify we can decode the response as valid JSON + var resp map[string]any + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + // The response should be a valid object (not null) + if resp == nil { + t.Error("expected non-nil response") + } + }) +} diff --git a/server/usage/reporter.go b/server/usage/reporter.go new file mode 100644 index 000000000..ff8d1817c --- /dev/null +++ b/server/usage/reporter.go @@ -0,0 +1,65 @@ +package usage + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/ollama/ollama/version" +) + +const ( + reportTimeout = 10 * time.Second + usageURL = "https://ollama.com/api/usage" +) + +// HeartbeatResponse is the response from the heartbeat endpoint. +type HeartbeatResponse struct { + UpdateVersion string `json:"update_version,omitempty"` +} + +// UpdateAvailable returns the available update version, if any. +func (t *Stats) UpdateAvailable() string { + if v := t.updateAvailable.Load(); v != nil { + return v.(string) + } + return "" +} + +// sendHeartbeat sends usage stats and checks for updates. +func (t *Stats) sendHeartbeat(payload *Payload) { + data, err := json.Marshal(payload) + if err != nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), reportTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data)) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + var heartbeat HeartbeatResponse + if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil { + return + } + + t.updateAvailable.Store(heartbeat.UpdateVersion) +} diff --git a/server/usage/source.go b/server/usage/source.go new file mode 100644 index 000000000..591f1e242 --- /dev/null +++ b/server/usage/source.go @@ -0,0 +1,23 @@ +package usage + +import ( + "strings" +) + +// API type constants +const ( + APITypeOllama = "ollama" + APITypeOpenAI = "openai" + APITypeAnthropic = "anthropic" +) + +// ClassifyAPIType determines the API type from the request path. +func ClassifyAPIType(path string) string { + if strings.HasPrefix(path, "/v1/messages") { + return APITypeAnthropic + } + if strings.HasPrefix(path, "/v1/") { + return APITypeOpenAI + } + return APITypeOllama +} diff --git a/server/usage/usage.go b/server/usage/usage.go new file mode 100644 index 000000000..3b3304d61 --- /dev/null +++ b/server/usage/usage.go @@ -0,0 +1,324 @@ +// Package usage provides in-memory usage statistics collection and reporting. +package usage + +import ( + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/ollama/ollama/discover" + "github.com/ollama/ollama/version" +) + +// Stats collects usage statistics in memory and reports them periodically. +type Stats struct { + mu sync.RWMutex + + // Atomic counters for hot path + requestsTotal atomic.Int64 + tokensPrompt atomic.Int64 + tokensCompletion atomic.Int64 + errorsTotal atomic.Int64 + + // Map-based counters (require lock) + endpoints map[string]int64 + architectures map[string]int64 + apis map[string]int64 + models map[string]*ModelStats // per-model stats + + // Feature usage + toolCalls atomic.Int64 + structuredOutput atomic.Int64 + + // Update info (set by reporter after pinging update endpoint) + updateAvailable atomic.Value // string + + // Reporter + stopCh chan struct{} + doneCh chan struct{} + interval time.Duration + endpoint string +} + +// ModelStats tracks per-model usage statistics. +type ModelStats struct { + Requests int64 + TokensInput int64 + TokensOutput int64 +} + +// Request contains the data to record for a single request. +type Request struct { + Endpoint string // "chat", "generate", "embed" + Model string // model name (e.g., "llama3.2:3b") + Architecture string // model architecture (e.g., "llama", "qwen2") + APIType string // "native" or "openai_compat" + PromptTokens int + CompletionTokens int + UsedTools bool + StructuredOutput bool +} + +// SystemInfo contains hardware information to report. +type SystemInfo struct { + OS string `json:"os"` + Arch string `json:"arch"` + CPUCores int `json:"cpu_cores"` + RAMBytes uint64 `json:"ram_bytes"` + GPUs []GPU `json:"gpus,omitempty"` +} + +// GPU contains information about a GPU. +type GPU struct { + Name string `json:"name"` + VRAMBytes uint64 `json:"vram_bytes"` + ComputeMajor int `json:"compute_major,omitempty"` + ComputeMinor int `json:"compute_minor,omitempty"` + DriverMajor int `json:"driver_major,omitempty"` + DriverMinor int `json:"driver_minor,omitempty"` +} + +// Payload is the data sent to the heartbeat endpoint. +type Payload struct { + Version string `json:"version"` + Time time.Time `json:"time"` + System SystemInfo `json:"system"` + + Totals struct { + Requests int64 `json:"requests"` + Errors int64 `json:"errors"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + } `json:"totals"` + + Endpoints map[string]int64 `json:"endpoints"` + Architectures map[string]int64 `json:"architectures"` + APIs map[string]int64 `json:"apis"` + + Features struct { + ToolCalls int64 `json:"tool_calls"` + StructuredOutput int64 `json:"structured_output"` + } `json:"features"` +} + +const ( + defaultInterval = 1 * time.Hour +) + +// New creates a new Stats instance. +func New(opts ...Option) *Stats { + t := &Stats{ + endpoints: make(map[string]int64), + architectures: make(map[string]int64), + apis: make(map[string]int64), + models: make(map[string]*ModelStats), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + interval: defaultInterval, + } + + for _, opt := range opts { + opt(t) + } + + return t +} + +// Option configures the Stats instance. +type Option func(*Stats) + +// WithInterval sets the reporting interval. +func WithInterval(d time.Duration) Option { + return func(t *Stats) { + t.interval = d + } +} + +// Record records a request. This is the hot path and should be fast. +func (t *Stats) Record(r *Request) { + t.requestsTotal.Add(1) + t.tokensPrompt.Add(int64(r.PromptTokens)) + t.tokensCompletion.Add(int64(r.CompletionTokens)) + + if r.UsedTools { + t.toolCalls.Add(1) + } + if r.StructuredOutput { + t.structuredOutput.Add(1) + } + + t.mu.Lock() + t.endpoints[r.Endpoint]++ + t.architectures[r.Architecture]++ + t.apis[r.APIType]++ + + // Track per-model stats + if r.Model != "" { + if t.models[r.Model] == nil { + t.models[r.Model] = &ModelStats{} + } + t.models[r.Model].Requests++ + t.models[r.Model].TokensInput += int64(r.PromptTokens) + t.models[r.Model].TokensOutput += int64(r.CompletionTokens) + } + t.mu.Unlock() +} + +// RecordError records a failed request. +func (t *Stats) RecordError() { + t.errorsTotal.Add(1) +} + +// GetModelStats returns a copy of per-model statistics. +func (t *Stats) GetModelStats() map[string]*ModelStats { + t.mu.RLock() + defer t.mu.RUnlock() + + result := make(map[string]*ModelStats, len(t.models)) + for k, v := range t.models { + result[k] = &ModelStats{ + Requests: v.Requests, + TokensInput: v.TokensInput, + TokensOutput: v.TokensOutput, + } + } + return result +} + +// View returns current stats without resetting counters. +func (t *Stats) View() *Payload { + t.mu.RLock() + defer t.mu.RUnlock() + + now := time.Now() + + // Copy maps + endpoints := make(map[string]int64, len(t.endpoints)) + for k, v := range t.endpoints { + endpoints[k] = v + } + architectures := make(map[string]int64, len(t.architectures)) + for k, v := range t.architectures { + architectures[k] = v + } + apis := make(map[string]int64, len(t.apis)) + for k, v := range t.apis { + apis[k] = v + } + + p := &Payload{ + Version: version.Version, + Time: now, + System: getSystemInfo(), + Endpoints: endpoints, + Architectures: architectures, + APIs: apis, + } + + p.Totals.Requests = t.requestsTotal.Load() + p.Totals.Errors = t.errorsTotal.Load() + p.Totals.InputTokens = t.tokensPrompt.Load() + p.Totals.OutputTokens = t.tokensCompletion.Load() + p.Features.ToolCalls = t.toolCalls.Load() + p.Features.StructuredOutput = t.structuredOutput.Load() + + return p +} + +// Snapshot returns current stats and resets counters. +func (t *Stats) Snapshot() *Payload { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + p := &Payload{ + Version: version.Version, + Time: now, + System: getSystemInfo(), + Endpoints: t.endpoints, + Architectures: t.architectures, + APIs: t.apis, + } + + p.Totals.Requests = t.requestsTotal.Swap(0) + p.Totals.Errors = t.errorsTotal.Swap(0) + p.Totals.InputTokens = t.tokensPrompt.Swap(0) + p.Totals.OutputTokens = t.tokensCompletion.Swap(0) + p.Features.ToolCalls = t.toolCalls.Swap(0) + p.Features.StructuredOutput = t.structuredOutput.Swap(0) + + // Reset maps + t.endpoints = make(map[string]int64) + t.architectures = make(map[string]int64) + t.apis = make(map[string]int64) + + return p +} + +// getSystemInfo collects hardware information. +func getSystemInfo() SystemInfo { + info := SystemInfo{ + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + + // Get CPU and memory info + sysInfo := discover.GetSystemInfo() + info.CPUCores = sysInfo.ThreadCount + info.RAMBytes = sysInfo.TotalMemory + + // Get GPU info + gpus := getGPUInfo() + info.GPUs = gpus + + return info +} + +// GPUInfoFunc is a function that returns GPU information. +// It's set by the server package after GPU discovery. +var GPUInfoFunc func() []GPU + +// getGPUInfo collects GPU information. +func getGPUInfo() []GPU { + if GPUInfoFunc != nil { + return GPUInfoFunc() + } + return nil +} + +// Start begins the periodic reporting goroutine. +func (t *Stats) Start() { + go t.reportLoop() +} + +// Stop stops reporting and waits for the final report. +func (t *Stats) Stop() { + close(t.stopCh) + <-t.doneCh +} + +// reportLoop runs the periodic reporting. +func (t *Stats) reportLoop() { + defer close(t.doneCh) + + ticker := time.NewTicker(t.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.report() + case <-t.stopCh: + // Send final report before stopping + t.report() + return + } + } +} + +// report sends usage stats and checks for updates. +func (t *Stats) report() { + payload := t.Snapshot() + t.sendHeartbeat(payload) +} diff --git a/server/usage/usage_test.go b/server/usage/usage_test.go new file mode 100644 index 000000000..96898143f --- /dev/null +++ b/server/usage/usage_test.go @@ -0,0 +1,194 @@ +package usage + +import ( + "testing" +) + +func TestNew(t *testing.T) { + stats := New() + if stats == nil { + t.Fatal("New() returned nil") + } +} + +func TestRecord(t *testing.T) { + stats := New() + + stats.Record(&Request{ + Model: "llama3:8b", + Endpoint: "chat", + Architecture: "llama", + APIType: "native", + PromptTokens: 100, + CompletionTokens: 50, + UsedTools: true, + StructuredOutput: false, + }) + + // Check totals + payload := stats.View() + if payload.Totals.Requests != 1 { + t.Errorf("expected 1 request, got %d", payload.Totals.Requests) + } + if payload.Totals.InputTokens != 100 { + t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens) + } + if payload.Totals.OutputTokens != 50 { + t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens) + } + if payload.Features.ToolCalls != 1 { + t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls) + } + if payload.Features.StructuredOutput != 0 { + t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput) + } +} + +func TestGetModelStats(t *testing.T) { + stats := New() + + // Record requests for multiple models + stats.Record(&Request{ + Model: "llama3:8b", + PromptTokens: 100, + CompletionTokens: 50, + }) + stats.Record(&Request{ + Model: "llama3:8b", + PromptTokens: 200, + CompletionTokens: 100, + }) + stats.Record(&Request{ + Model: "mistral:7b", + PromptTokens: 50, + CompletionTokens: 25, + }) + + modelStats := stats.GetModelStats() + + // Check llama3:8b stats + llama := modelStats["llama3:8b"] + if llama == nil { + t.Fatal("expected llama3:8b stats") + } + if llama.Requests != 2 { + t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests) + } + if llama.TokensInput != 300 { + t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput) + } + if llama.TokensOutput != 150 { + t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput) + } + + // Check mistral:7b stats + mistral := modelStats["mistral:7b"] + if mistral == nil { + t.Fatal("expected mistral:7b stats") + } + if mistral.Requests != 1 { + t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests) + } + if mistral.TokensInput != 50 { + t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput) + } + if mistral.TokensOutput != 25 { + t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput) + } +} + +func TestRecordError(t *testing.T) { + stats := New() + + stats.RecordError() + stats.RecordError() + + payload := stats.View() + if payload.Totals.Errors != 2 { + t.Errorf("expected 2 errors, got %d", payload.Totals.Errors) + } +} + +func TestView(t *testing.T) { + stats := New() + + stats.Record(&Request{ + Model: "llama3:8b", + Endpoint: "chat", + Architecture: "llama", + APIType: "native", + }) + + // First view + _ = stats.View() + + // View should not reset counters + payload := stats.View() + if payload.Totals.Requests != 1 { + t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests) + } +} + +func TestSnapshot(t *testing.T) { + stats := New() + + stats.Record(&Request{ + Model: "llama3:8b", + Endpoint: "chat", + PromptTokens: 100, + CompletionTokens: 50, + }) + + // Snapshot should return data and reset counters + snapshot := stats.Snapshot() + if snapshot.Totals.Requests != 1 { + t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests) + } + + // After snapshot, counters should be reset + payload2 := stats.View() + if payload2.Totals.Requests != 0 { + t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests) + } +} + +func TestConcurrentAccess(t *testing.T) { + stats := New() + + done := make(chan bool) + + // Concurrent writes + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + stats.Record(&Request{ + Model: "llama3:8b", + PromptTokens: 10, + CompletionTokens: 5, + }) + } + done <- true + }() + } + + // Concurrent reads + for i := 0; i < 5; i++ { + go func() { + for j := 0; j < 100; j++ { + _ = stats.View() + _ = stats.GetModelStats() + } + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 15; i++ { + <-done + } + + payload := stats.View() + if payload.Totals.Requests != 1000 { + t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests) + } +}