Compare commits

...

1 Commits

Author SHA1 Message Date
Jesse Gross
4d5ff25724 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-25 15:06:37 -08:00
7 changed files with 56 additions and 63 deletions

View File

@@ -74,8 +74,7 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
VRAMSize() uint64 // Total VRAM across all GPUs MemorySize() (total, vram uint64)
TotalSize() uint64
VRAMByGPU(id ml.DeviceID) uint64 VRAMByGPU(id ml.DeviceID) uint64
Pid() int Pid() int
GetPort() int GetPort() int
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache // For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) || (len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) { (s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil return nil
} }
func (s *llmServer) VRAMSize() uint64 { func (s *llmServer) MemorySize() (total, vram uint64) {
if s.mem == nil { if s.mem == nil {
return 0 return 0, 0
} }
var mem uint64
for _, g := range s.mem.GPUs { for _, g := range s.mem.GPUs {
mem += g.Size() vram += g.Size()
} }
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers // Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading. // on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true noCPULayers := true
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
} }
} }
if noCPULayers { if noCPULayers {
mem += s.mem.InputWeights vram += s.mem.InputWeights
mem += s.mem.CPU.Graph vram += s.mem.CPU.Graph
} }
return mem return total, vram
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
} }
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 { func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -1951,6 +1951,9 @@ func (s *Server) PsHandler(c *gin.Context) {
} }
if v.llama != nil { if v.llama != nil {
mr.ContextLength = v.llama.ContextLength() mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
} }
// The scheduler waits to set expiresAt, so if a model is loading it's // The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just // possible that it will be set to the unix epoch. For those cases, just

View File

@@ -536,6 +536,7 @@ iGPUScan:
} }
} }
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -545,8 +546,8 @@ iGPUScan:
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
gpus: gpuIDs, gpus: gpuIDs,
discreteGPUs: discreteGPUs, discreteGPUs: discreteGPUs,
vramSize: llama.VRAMSize(), totalSize: totalSize,
totalSize: llama.TotalSize(), vramSize: vramSize,
loading: true, loading: true,
pid: llama.Pid(), pid: llama.Pid(),
} }
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
loading: false, loading: false,
isImagegen: isImagegen, isImagegen: isImagegen,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
totalSize: server.TotalSize(), totalSize: totalSize,
vramSize: server.VRAMSize(), vramSize: vramSize,
} }
s.loadedMu.Lock() s.loadedMu.Lock()

View File

@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
s.closeCalled = true s.closeCalled = true
return s.closeResp return s.closeResp
} }
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 } func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 } func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -374,14 +374,9 @@ func (s *Server) Close() error {
return nil return nil
} }
// VRAMSize returns the estimated VRAM usage. // MemorySize returns the total and VRAM memory usage.
func (s *Server) VRAMSize() uint64 { func (s *Server) MemorySize() (total, vram uint64) {
return s.vramSize return s.vramSize, s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
} }
// VRAMByGPU returns VRAM usage for a specific GPU. // VRAMByGPU returns VRAM usage for a specific GPU.

View File

@@ -24,14 +24,13 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
) )
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct { type Client struct {
port int port int
modelName string modelName string
vramSize uint64 memory uint
done chan error done chan error
client *http.Client client *http.Client
lastErr string lastErr string
@@ -98,18 +97,9 @@ func NewClient(modelName string) (*Client, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
} }
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{ c := &Client{
port: port, port: port,
modelName: modelName, modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1), done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute}, client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd, cmd: cmd,
@@ -347,9 +337,15 @@ func (c *Client) Pid() int {
return -1 return -1
} }
type statusResponse struct {
Status int
Progress int
Memory uint
}
// Ping implements llm.LlamaServer. // Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error { func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port) reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil { if err != nil {
return err return err
@@ -362,6 +358,12 @@ func (c *Client) Ping(ctx context.Context) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode) return fmt.Errorf("health check failed: %d", resp.StatusCode)
} }
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.memory = status.Memory
return nil return nil
} }
@@ -388,19 +390,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
return tokens, nil return tokens, nil
} }
// TotalSize implements llm.LlamaServer. func (c *Client) currentMemory() uint64 {
func (c *Client) TotalSize() uint64 { ctx, cancel := context.WithTimeout(context.Background(), time.Second)
return c.vramSize defer cancel()
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return uint64(c.memory)
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
} }
// VRAMByGPU implements llm.LlamaServer. // VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.vramSize return c.currentMemory()
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
return c.vramSize
} }
// WaitUntilRunning implements llm.LlamaServer. // WaitUntilRunning implements llm.LlamaServer.

View File

@@ -50,9 +50,10 @@ func Execute(args []string) error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{ if err := json.NewEncoder(w).Encode(statusResponse{
"status": 0, Status: 0,
"progress": 100, Progress: 100,
Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil { }); err != nil {
slog.Error("Failed to encode response", "error", err) slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)