sched: Model eviction for MLX

MLX runners (image generation and LLM) previously bypassed the
scheduler's standard load path via a separate loadMLX method. This meant
they skipped VRAM fitting checks and couldn't participate in model
eviction.

Now all model types flow through the same load function. Model eviction
for MLX is based on weights as KV cache and compute graph are dynamic.
This means that eviction does not take into account the worst case
memory and models can still compete for memory but it is a significant
improvement.
This commit is contained in:
Jesse Gross
2026-03-02 15:27:34 -08:00
parent bcf6d55b54
commit bbbad97686
8 changed files with 291 additions and 290 deletions

View File

@@ -40,7 +40,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading // add small delay to simulate loading
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
@@ -234,7 +234,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading // add small delay to simulate loading
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{

View File

@@ -45,7 +45,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
@@ -230,7 +230,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,

View File

@@ -187,7 +187,7 @@ func TestGenerateChat(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading // add small delay to simulate loading
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
@@ -904,7 +904,7 @@ func TestGenerate(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading // add small delay to simulate loading
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
@@ -1388,7 +1388,7 @@ func TestGenerateLogprobs(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock} req.successCh <- &runnerRef{llama: mock}
return false return false
}, },
@@ -1568,7 +1568,7 @@ func TestChatLogprobs(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock} req.successCh <- &runnerRef{llama: mock}
return false return false
}, },
@@ -1678,7 +1678,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock} req.successCh <- &runnerRef{llama: mock}
return false return false
@@ -2123,7 +2123,7 @@ func TestGenerateUnload(t *testing.T) {
newServerFn: newMockServer(&mockRunner{}), newServerFn: newMockServer(&mockRunner{}),
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFnCalled = true loadFnCalled = true
req.successCh <- &runnerRef{llama: &mockRunner{}} req.successCh <- &runnerRef{llama: &mockRunner{}}
return false return false
@@ -2225,7 +2225,7 @@ func TestGenerateWithImages(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,

View File

@@ -265,7 +265,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond, waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }
@@ -416,7 +416,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond, waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }
@@ -598,7 +598,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
getGpuFn: getGpuFn, getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn, getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond, waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }

View File

@@ -51,7 +51,7 @@ type Scheduler struct {
activeLoading llm.LlamaServer activeLoading llm.LlamaServer
loaded map[string]*runnerRef loaded map[string]*runnerRef
loadFn func(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool loadFn func(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo
getSystemInfoFn func() ml.SystemInfo getSystemInfoFn func() ml.SystemInfo
@@ -220,33 +220,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus)) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
} }
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
if err != nil {
pending.errCh <- err
break
}
// Update free memory from currently loaded models // Update free memory from currently loaded models
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath) logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
s.updateFreeSpace(gpus) s.updateFreeSpace(gpus)
@@ -254,14 +227,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
if loadedCount == 0 { if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit. // No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath) slog.Debug("loading first model", "model", pending.model.ModelPath)
s.loadFn(pending, ggml, systemInfo, gpus, false) s.loadFn(pending, systemInfo, gpus, false)
break break
} }
// More than one loaded model, so we have to see if the // More than one loaded model, so we have to see if the
// new one fits // new one fits
logutil.Trace("loading additional model", "model", pending.model.ModelPath) logutil.Trace("loading additional model", "model", pending.model.ModelPath)
needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true) needEvict := s.loadFn(pending, systemInfo, gpus, true)
if !needEvict { if !needEvict {
slog.Debug("new model fits with existing models, loading") slog.Debug("new model fits with existing models, loading")
break break
@@ -435,7 +408,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit. // (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool { func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
numParallel := max(int(envconfig.NumParallel()), 1) numParallel := max(int(envconfig.NumParallel()), 1)
// Embedding models should always be loaded with parallel=1 // Embedding models should always be loaded with parallel=1
@@ -460,15 +433,33 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
if llama == nil { if llama == nil {
var err error var err error
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) if !req.model.IsMLX() {
if err != nil { f, loadErr := llm.LoadModel(req.model.ModelPath, 1024)
// some older models are not compatible with newer versions of llama.cpp if loadErr != nil {
// show a generalized compatibility error until there is a better way to slog.Info("failed to load model metadata", "model", req.model.ModelPath, "error", loadErr)
// check for model compatibility req.errCh <- loadErr
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { s.loadedMu.Unlock()
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) return false
} }
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
}
} else {
modelName := req.model.ShortName
if slices.Contains(req.model.Config.Capabilities, "image") {
llama, err = imagegen.NewServer(modelName)
} else {
llama, err = mlxrunner.NewClient(modelName)
}
}
if err != nil {
slog.Info("failed to create server", "model", req.model.ShortName, "error", err)
req.errCh <- err req.errCh <- err
s.loadedMu.Unlock() s.loadedMu.Unlock()
return false return false
@@ -476,8 +467,12 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
s.activeLoading = llama s.activeLoading = llama
} else { } else {
if s.activeLoading.ModelPath() != req.model.ModelPath { wantPath := req.model.ModelPath
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath)) if wantPath == "" {
wantPath = req.model.ShortName
}
if s.activeLoading.ModelPath() != wantPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), wantPath))
} }
} }
@@ -544,6 +539,7 @@ iGPUScan:
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
gpus: gpuIDs, gpus: gpuIDs,
discreteGPUs: discreteGPUs, discreteGPUs: discreteGPUs,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
totalSize: totalSize, totalSize: totalSize,
vramSize: vramSize, vramSize: vramSize,
loading: true, loading: true,
@@ -591,59 +587,6 @@ iGPUScan:
return false return false
} }
// loadMLX loads an experimental safetensors model using MLX runners.
// Image models use x/imagegen; LLM models use x/mlxrunner.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
modelName := req.model.ShortName
var server llm.LlamaServer
var err error
if slices.Contains(req.model.Config.Capabilities, "image") {
server, err = imagegen.NewServer(modelName)
} else {
server, err = mlxrunner.NewClient(modelName)
}
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
sessionDuration: sessionDuration,
totalSize: totalSize,
vramSize: vramSize,
}
s.loadedMu.Lock()
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) { func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 { if len(allGpus) == 0 {
return return

View File

@@ -39,10 +39,25 @@ func TestSchedLoad(t *testing.T) {
defer done() defer done()
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.waitForRecovery = 10 * time.Millisecond s.waitForRecovery = 10 * time.Millisecond
var f *ggml.GGML // value not used in tests
modelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req := &LlmRequest{ req := &LlmRequest{
ctx: ctx, ctx: ctx,
model: &Model{ModelPath: "foo"}, model: &Model{ModelPath: modelPath},
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
@@ -54,7 +69,7 @@ func TestSchedLoad(t *testing.T) {
} }
gpus := []ml.DeviceInfo{} gpus := []ml.DeviceInfo{}
systemInfo := ml.SystemInfo{} systemInfo := ml.SystemInfo{}
s.load(req, f, systemInfo, gpus, false) s.load(req, systemInfo, gpus, false)
require.Empty(t, req.successCh) require.Empty(t, req.successCh)
require.Len(t, req.errCh, 1) require.Len(t, req.errCh, 1)
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -68,7 +83,7 @@ func TestSchedLoad(t *testing.T) {
server.modelPath = model server.modelPath = model
return server, nil return server, nil
} }
s.load(req, f, systemInfo, gpus, false) s.load(req, systemInfo, gpus, false)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.NoError(t, err) require.NoError(t, err)
@@ -80,9 +95,24 @@ func TestSchedLoad(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
} }
req.model.ModelPath = "dummy_model_path" modelPath2, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req.model.ModelPath = modelPath2
server.waitResp = errors.New("wait failure") server.waitResp = errors.New("wait failure")
s.load(req, f, systemInfo, gpus, false) s.load(req, systemInfo, gpus, false)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure") require.Contains(t, err.Error(), "wait failure")
@@ -90,7 +120,7 @@ func TestSchedLoad(t *testing.T) {
t.Fatalf("unexpected success %v", resp) t.Fatalf("unexpected success %v", resp)
} }
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"] runner := s.loaded[modelPath2]
s.loadedMu.Unlock() s.loadedMu.Unlock()
require.NotNil(t, runner) require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount) require.Equal(t, uint(0), runner.refCount)
@@ -103,7 +133,6 @@ type reqBundle struct {
ctxDone func() ctxDone func()
srv *mockLlm srv *mockLlm
req *LlmRequest req *LlmRequest
f *ggml.GGML
} }
func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@@ -132,11 +161,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
}) })
model := &Model{Name: modelName, ModelPath: p} model := &Model{Name: modelName, ModelPath: p}
f, err := llm.LoadModel(model.ModelPath, 0)
if err != nil {
t.Fatal(err)
}
b.f = f
if duration == nil { if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond} duration = &api.Duration{Duration: 5 * time.Millisecond}
} }
@@ -178,7 +202,6 @@ func TestSchedRequestsSameModelSameRequest(t *testing.T) {
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil) a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil) b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
b.req.model = a.req.model b.req.model = a.req.model
b.f = a.f
s.newServerFn = a.newServer s.newServerFn = a.newServer
slog.Info("a") slog.Info("a")
@@ -223,7 +246,6 @@ func TestSchedRequestsSimpleReloadSameModel(t *testing.T) {
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil) b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
tmpModel := *a.req.model tmpModel := *a.req.model
b.req.model = &tmpModel b.req.model = &tmpModel
b.f = a.f
s.newServerFn = a.newServer s.newServerFn = a.newServer
slog.Info("a") slog.Info("a")
@@ -518,16 +540,31 @@ func TestSchedExpireRunner(t *testing.T) {
defer done() defer done()
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.waitForRecovery = 10 * time.Millisecond s.waitForRecovery = 10 * time.Millisecond
modelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req := &LlmRequest{ req := &LlmRequest{
ctx: ctx, ctx: ctx,
model: &Model{ModelPath: "foo"}, model: &Model{ModelPath: modelPath},
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
sessionDuration: &api.Duration{Duration: 2 * time.Minute}, sessionDuration: &api.Duration{Duration: 2 * time.Minute},
} }
var f *ggml.GGML
gpus := []ml.DeviceInfo{} gpus := []ml.DeviceInfo{}
systemInfo := ml.SystemInfo{} systemInfo := ml.SystemInfo{}
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
@@ -535,7 +572,7 @@ func TestSchedExpireRunner(t *testing.T) {
server.modelPath = model server.modelPath = model
return server, nil return server, nil
} }
s.load(req, f, systemInfo, gpus, false) s.load(req, systemInfo, gpus, false)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
@@ -550,7 +587,7 @@ func TestSchedExpireRunner(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
} }
s.expireRunner(&Model{ModelPath: "foo"}) s.expireRunner(&Model{ModelPath: modelPath})
s.finishedReqCh <- req s.finishedReqCh <- req
s.processCompleted(ctx) s.processCompleted(ctx)

View File

@@ -22,6 +22,7 @@ import (
"time" "time"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen/manifest" "github.com/ollama/ollama/x/imagegen/manifest"
@@ -43,13 +44,52 @@ type Server struct {
lastErrLock sync.Mutex lastErrLock sync.Mutex
} }
// NewServer spawns a new MLX runner subprocess and waits until it's ready. // NewServer prepares a new MLX runner server for image generation models.
// The subprocess is not started until Load() is called.
func NewServer(modelName string) (*Server, error) { func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start // Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil { if err := CheckPlatformSupport(); err != nil {
return nil, err return nil, err
} }
return &Server{
modelName: modelName,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
// Load checks whether the model fits in GPU memory and starts the subprocess.
func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
// Estimate VRAM based on tensor size from manifest
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
s.vramSize = uint64(modelManifest.TotalTensorSize())
} else {
s.vramSize = 8 * 1024 * 1024 * 1024
}
if len(gpus) > 0 {
available := gpus[0].FreeMemory
overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead()
if available > overhead {
available -= overhead
} else {
available = 0
}
if s.vramSize > available {
if requireFull {
return nil, llm.ErrLoadRequiredFull
}
return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(s.vramSize), format.HumanBytes2(available), format.HumanBytes2(overhead))
}
}
// Find a free port // Find a free port
port := 0 port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -61,6 +101,7 @@ func NewServer(modelName string) (*Server, error) {
if port == 0 { if port == 0 {
port = rand.Intn(65535-49152) + 49152 port = rand.Intn(65535-49152) + 49152
} }
s.port = port
// Get the current executable path (we use the same binary with runner subcommand) // Get the current executable path (we use the same binary with runner subcommand)
exe, err := os.Executable() exe, err := os.Executable()
@@ -72,7 +113,7 @@ func NewServer(modelName string) (*Server, error) {
} }
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port> // Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ() cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories // On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -105,23 +146,7 @@ func NewServer(modelName string) (*Server, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
} }
// Estimate VRAM based on tensor size from manifest s.cmd = cmd
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
// Forward subprocess stdout/stderr to server logs // Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe() stdout, _ := cmd.StdoutPipe()
@@ -143,7 +168,7 @@ func NewServer(modelName string) (*Server, error) {
} }
}() }()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port) slog.Info("starting mlx runner subprocess", "model", s.modelName, "port", s.port)
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err) return nil, fmt.Errorf("failed to start mlx runner: %w", err)
} }
@@ -154,22 +179,6 @@ func NewServer(modelName string) (*Server, error) {
s.done <- err s.done <- err
}() }()
// Wait for subprocess to be ready
if err := s.waitUntilRunning(); err != nil {
s.Close()
return nil, err
}
return s, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil return nil, nil
} }
@@ -191,9 +200,15 @@ func (s *Server) Ping(ctx context.Context) error {
return nil return nil
} }
// waitUntilRunning waits for the subprocess to be ready. // getLastErr returns the last stderr line.
func (s *Server) waitUntilRunning() error { func (s *Server) getLastErr() string {
ctx := context.Background() s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
// WaitUntilRunning waits for the subprocess to be ready.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
timeout := time.After(envconfig.LoadTimeout()) timeout := time.After(envconfig.LoadTimeout())
ticker := time.NewTicker(100 * time.Millisecond) ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@@ -201,7 +216,6 @@ func (s *Server) waitUntilRunning() error {
for { for {
select { select {
case err := <-s.done: case err := <-s.done:
// Include recent stderr lines for better error context
errMsg := s.getLastErr() errMsg := s.getLastErr()
if errMsg != "" { if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err) return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
@@ -222,18 +236,6 @@ func (s *Server) waitUntilRunning() error {
} }
} }
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
// WaitUntilRunning satisfies the LlamaServer interface.
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion handles both text and image generation requests. // Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed seed := req.Seed

View File

@@ -22,9 +22,12 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
) )
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
@@ -41,118 +44,24 @@ type Client struct {
cmd *exec.Cmd cmd *exec.Cmd
} }
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready. // NewClient prepares a new MLX runner client for LLM models.
// The subprocess is not started until Load() is called.
func NewClient(modelName string) (*Client, error) { func NewClient(modelName string) (*Client, error) {
if err := imagegen.CheckPlatformSupport(); err != nil { if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err return nil, err
} }
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the current executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// Set library path environment variable for MLX libraries
// Linux: LD_LIBRARY_PATH, Windows: PATH
var libPathEnvVar string
switch runtime.GOOS {
case "linux":
libPathEnvVar = "LD_LIBRARY_PATH"
case "windows":
libPathEnvVar = "PATH"
}
if libPathEnvVar != "" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
found := false
for i := range cmd.Env {
envName := cmd.Env[i]
if runtime.GOOS == "windows" {
envName = strings.ToUpper(envName)
}
if strings.HasPrefix(envName, libPathEnvVar+"=") {
cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
}
c := &Client{ c := &Client{
port: port,
modelName: modelName, modelName: modelName,
done: make(chan error, 1), done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute}, client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd,
} }
// Forward subprocess stdout/stderr to server logs modelManifest, err := manifest.LoadManifest(modelName)
stdout, _ := cmd.StdoutPipe() if err != nil {
stderr, _ := cmd.StderrPipe()
go func() {
io.Copy(os.Stderr, stdout) //nolint:errcheck
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
fmt.Fprintln(os.Stderr, line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
c.done <- err
}()
// Wait for subprocess to be ready
if err := c.waitUntilRunning(); err != nil {
c.Close()
return nil, err return nil, err
} }
c.memory.Store(uint64(modelManifest.TotalTensorSize()))
return c, nil return c, nil
} }
@@ -163,14 +72,16 @@ func (c *Client) getLastErr() string {
return c.lastErr return c.lastErr
} }
func (c *Client) waitUntilRunning() error { // WaitUntilRunning waits for the subprocess to be ready.
ctx := context.Background() func (c *Client) WaitUntilRunning(ctx context.Context) error {
timeout := time.After(2 * time.Minute) timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond) ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done():
return ctx.Err()
case err := <-c.done: case err := <-c.done:
errMsg := c.getLastErr() errMsg := c.getLastErr()
if errMsg != "" { if errMsg != "" {
@@ -345,8 +256,123 @@ func (c *Client) HasExited() bool {
} }
} }
// Load implements llm.LlamaServer. // Load checks whether the model fits in GPU memory and starts the subprocess.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) { func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
if len(gpus) > 0 {
modelSize := c.memory.Load()
// We currently only use the first GPU with MLX
available := gpus[0].FreeMemory
overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead()
if available > overhead {
available -= overhead
} else {
available = 0
}
if modelSize > available {
if requireFull {
return nil, llm.ErrLoadRequiredFull
}
return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(modelSize), format.HumanBytes2(available), format.HumanBytes2(overhead))
}
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
c.port = port
// Get the current executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", c.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// Set library path environment variable for MLX libraries
// Linux: LD_LIBRARY_PATH, Windows: PATH
var libPathEnvVar string
switch runtime.GOOS {
case "linux":
libPathEnvVar = "LD_LIBRARY_PATH"
case "windows":
libPathEnvVar = "PATH"
}
if libPathEnvVar != "" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
found := false
for i := range cmd.Env {
envName := cmd.Env[i]
if runtime.GOOS == "windows" {
envName = strings.ToUpper(envName)
}
if strings.HasPrefix(envName, libPathEnvVar+"=") {
cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
}
c.cmd = cmd
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
io.Copy(os.Stderr, stdout) //nolint:errcheck
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
fmt.Fprintln(os.Stderr, line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
c.done <- err
}()
return nil, nil return nil, nil
} }
@@ -425,9 +451,7 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
func (c *Client) currentMemory() uint64 { func (c *Client) currentMemory() uint64 {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
if err := c.Ping(ctx); err != nil { c.Ping(ctx) //nolint:errcheck
slog.Warn("failed to get current memory", "error", err)
}
return c.memory.Load() return c.memory.Load()
} }
@@ -442,9 +466,4 @@ func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.currentMemory() return c.currentMemory()
} }
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
return nil
}
var _ llm.LlamaServer = (*Client)(nil) var _ llm.LlamaServer = (*Client)(nil)