sched: Model eviction for MLX

MLX runners (image generation and LLM) previously bypassed the
scheduler's standard load path via a separate loadMLX method. This meant
they skipped VRAM fitting checks and couldn't participate in model
eviction.

Now all model types flow through the same load function. Model eviction
for MLX is based on weights as KV cache and compute graph are dynamic.
This means that eviction does not take into account the worst case
memory and models can still compete for memory but it is a significant
improvement.
This commit is contained in:
Jesse Gross
2026-03-02 15:27:34 -08:00
parent bcf6d55b54
commit bbbad97686
8 changed files with 291 additions and 290 deletions

View File

@@ -40,7 +40,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
@@ -234,7 +234,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{

View File

@@ -45,7 +45,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
@@ -230,7 +230,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,

View File

@@ -187,7 +187,7 @@ func TestGenerateChat(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
@@ -904,7 +904,7 @@ func TestGenerate(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
@@ -1388,7 +1388,7 @@ func TestGenerateLogprobs(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock}
return false
},
@@ -1568,7 +1568,7 @@ func TestChatLogprobs(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock}
return false
},
@@ -1678,7 +1678,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
@@ -2123,7 +2123,7 @@ func TestGenerateUnload(t *testing.T) {
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFnCalled = true
req.successCh <- &runnerRef{llama: &mockRunner{}}
return false
@@ -2225,7 +2225,7 @@ func TestGenerateWithImages(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,

View File

@@ -265,7 +265,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
@@ -416,7 +416,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
@@ -598,7 +598,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}

View File

@@ -51,7 +51,7 @@ type Scheduler struct {
activeLoading llm.LlamaServer
loaded map[string]*runnerRef
loadFn func(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
loadFn func(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo
getSystemInfoFn func() ml.SystemInfo
@@ -220,33 +220,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
if err != nil {
pending.errCh <- err
break
}
// Update free memory from currently loaded models
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
s.updateFreeSpace(gpus)
@@ -254,14 +227,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath)
s.loadFn(pending, ggml, systemInfo, gpus, false)
s.loadFn(pending, systemInfo, gpus, false)
break
}
// More than one loaded model, so we have to see if the
// new one fits
logutil.Trace("loading additional model", "model", pending.model.ModelPath)
needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true)
needEvict := s.loadFn(pending, systemInfo, gpus, true)
if !needEvict {
slog.Debug("new model fits with existing models, loading")
break
@@ -435,7 +408,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
numParallel := max(int(envconfig.NumParallel()), 1)
// Embedding models should always be loaded with parallel=1
@@ -460,15 +433,33 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
if llama == nil {
var err error
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
if !req.model.IsMLX() {
f, loadErr := llm.LoadModel(req.model.ModelPath, 1024)
if loadErr != nil {
slog.Info("failed to load model metadata", "model", req.model.ModelPath, "error", loadErr)
req.errCh <- loadErr
s.loadedMu.Unlock()
return false
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
}
} else {
modelName := req.model.ShortName
if slices.Contains(req.model.Config.Capabilities, "image") {
llama, err = imagegen.NewServer(modelName)
} else {
llama, err = mlxrunner.NewClient(modelName)
}
}
if err != nil {
slog.Info("failed to create server", "model", req.model.ShortName, "error", err)
req.errCh <- err
s.loadedMu.Unlock()
return false
@@ -476,8 +467,12 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
s.activeLoading = llama
} else {
if s.activeLoading.ModelPath() != req.model.ModelPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath))
wantPath := req.model.ModelPath
if wantPath == "" {
wantPath = req.model.ShortName
}
if s.activeLoading.ModelPath() != wantPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), wantPath))
}
}
@@ -544,6 +539,7 @@ iGPUScan:
sessionDuration: sessionDuration,
gpus: gpuIDs,
discreteGPUs: discreteGPUs,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
totalSize: totalSize,
vramSize: vramSize,
loading: true,
@@ -591,59 +587,6 @@ iGPUScan:
return false
}
// loadMLX loads an experimental safetensors model using MLX runners.
// Image models use x/imagegen; LLM models use x/mlxrunner.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
modelName := req.model.ShortName
var server llm.LlamaServer
var err error
if slices.Contains(req.model.Config.Capabilities, "image") {
server, err = imagegen.NewServer(modelName)
} else {
server, err = mlxrunner.NewClient(modelName)
}
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
sessionDuration: sessionDuration,
totalSize: totalSize,
vramSize: vramSize,
}
s.loadedMu.Lock()
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -39,10 +39,25 @@ func TestSchedLoad(t *testing.T) {
defer done()
s := InitScheduler(ctx)
s.waitForRecovery = 10 * time.Millisecond
var f *ggml.GGML // value not used in tests
modelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
model: &Model{ModelPath: modelPath},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
@@ -54,7 +69,7 @@ func TestSchedLoad(t *testing.T) {
}
gpus := []ml.DeviceInfo{}
systemInfo := ml.SystemInfo{}
s.load(req, f, systemInfo, gpus, false)
s.load(req, systemInfo, gpus, false)
require.Empty(t, req.successCh)
require.Len(t, req.errCh, 1)
s.loadedMu.Lock()
@@ -68,7 +83,7 @@ func TestSchedLoad(t *testing.T) {
server.modelPath = model
return server, nil
}
s.load(req, f, systemInfo, gpus, false)
s.load(req, systemInfo, gpus, false)
select {
case err := <-req.errCh:
require.NoError(t, err)
@@ -80,9 +95,24 @@ func TestSchedLoad(t *testing.T) {
s.loadedMu.Unlock()
}
req.model.ModelPath = "dummy_model_path"
modelPath2, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req.model.ModelPath = modelPath2
server.waitResp = errors.New("wait failure")
s.load(req, f, systemInfo, gpus, false)
s.load(req, systemInfo, gpus, false)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
@@ -90,7 +120,7 @@ func TestSchedLoad(t *testing.T) {
t.Fatalf("unexpected success %v", resp)
}
s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"]
runner := s.loaded[modelPath2]
s.loadedMu.Unlock()
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
@@ -103,7 +133,6 @@ type reqBundle struct {
ctxDone func()
srv *mockLlm
req *LlmRequest
f *ggml.GGML
}
func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@@ -132,11 +161,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
})
model := &Model{Name: modelName, ModelPath: p}
f, err := llm.LoadModel(model.ModelPath, 0)
if err != nil {
t.Fatal(err)
}
b.f = f
if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond}
}
@@ -178,7 +202,6 @@ func TestSchedRequestsSameModelSameRequest(t *testing.T) {
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
b.req.model = a.req.model
b.f = a.f
s.newServerFn = a.newServer
slog.Info("a")
@@ -223,7 +246,6 @@ func TestSchedRequestsSimpleReloadSameModel(t *testing.T) {
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
tmpModel := *a.req.model
b.req.model = &tmpModel
b.f = a.f
s.newServerFn = a.newServer
slog.Info("a")
@@ -518,16 +540,31 @@ func TestSchedExpireRunner(t *testing.T) {
defer done()
s := InitScheduler(ctx)
s.waitForRecovery = 10 * time.Millisecond
modelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
model: &Model{ModelPath: modelPath},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: &api.Duration{Duration: 2 * time.Minute},
}
var f *ggml.GGML
gpus := []ml.DeviceInfo{}
systemInfo := ml.SystemInfo{}
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
@@ -535,7 +572,7 @@ func TestSchedExpireRunner(t *testing.T) {
server.modelPath = model
return server, nil
}
s.load(req, f, systemInfo, gpus, false)
s.load(req, systemInfo, gpus, false)
select {
case err := <-req.errCh:
@@ -550,7 +587,7 @@ func TestSchedExpireRunner(t *testing.T) {
s.loadedMu.Unlock()
}
s.expireRunner(&Model{ModelPath: "foo"})
s.expireRunner(&Model{ModelPath: modelPath})
s.finishedReqCh <- req
s.processCompleted(ctx)