diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 3d7dd5b00..fe721fb39 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -90,13 +90,6 @@ func (c *kvCache) begin(seqID int, m base.Model, inputs []int32) *cacheSession { c.ensureCaches(m) c.ensureRoot() - // Ensure the sequence is registered in all cache layers. - for _, kv := range c.caches { - if kv != nil { - kv.SetSeqs([]int{seqID}) - } - } - matchPath, matched := findBestMatch(c.root, inputs) originalMatched := matched diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index fe85bad2c..3f2414a98 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -2,15 +2,9 @@ package mlxrunner import ( "bytes" - "context" "errors" "fmt" "log/slog" - "time" - - "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/x/mlxrunner/batch" - "github.com/ollama/ollama/x/mlxrunner/mlx" ) func prefillChunkSize() int { @@ -46,186 +40,6 @@ func (r *Runner) Prepare(request *Request) error { return nil } -func (r *Runner) TextGenerationPipeline(request Request) error { - enableCompile := true - if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { - enableCompile = modelCompile.EnableCompile() - } - if enableCompile { - mlx.EnableCompile() - } else { - mlx.DisableCompile() - } - mlx.ResetPeakMemory() - ctx := request.Ctx - var ( - sample, logprobs *mlx.Array - nextSample, nextLogprobs *mlx.Array - ) - - defer func() { - if request.Sampler != nil { - request.Sampler.Free() - } - mlx.Unpin(sample, logprobs) - mlx.Unpin(nextSample, nextLogprobs) - mlx.Sweep() - mlx.ClearCache() - - if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { - mlx.LogArrays() - r.cache.dumpTree() - } - slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) - }() - - inputs := request.Tokens - request.Sampler.ResetHistory(inputs) - - session := r.cache.begin(0, r.Model, inputs) - defer session.close() - - caches := session.caches - tokens := session.remaining - prefillChunk := prefillChunkSize() - - // Request periodic snapshots during prefill and near the end of the - // prompt so that long prompts can be partially restored and - // thinking/generation can be retried without full reprocessing. - const snapshotInterval = 8192 - for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval { - session.requestSnapshot(offset) - } - - const preThinking = 4 - if end := len(inputs) - preThinking; end > 0 { - session.requestSnapshot(end) - } - - materializeCaches := func() { - state := make([]*mlx.Array, 0, 2*len(caches)) - for _, c := range caches { - state = append(state, c.State()...) - } - if len(state) == 0 { - return - } - mlx.Eval(state...) - } - - now := time.Now() - total, processed := len(tokens), 0 - for total-processed > 1 { - if err := ctx.Err(); err != nil { - return err - } - - n := min(prefillChunk, total-processed-1) - - // If there's a pending snapshot, split the batch so we can - // capture it at the exact offset. - if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { - baseOffset := len(session.inputs) - len(tokens) - tokensUntilSnapshot := snapOffset - (baseOffset + processed) - if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n { - n = tokensUntilSnapshot - } - } - - r.Model.Forward(&batch.ForwardBatch{ - InputIDs: mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), - SeqIDs: []int{0}, - SeqLens: []int{n}, - }, caches) - mlx.Sweep() - materializeCaches() - processed += n - slog.Info("Prompt processing progress", "processed", processed, "total", total) - - // Create snapshot if we've reached a pending offset. - if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { - baseOffset := len(session.inputs) - len(tokens) - if baseOffset+processed >= snapOffset { - session.snapshot() - } - } - - mlx.ClearCache() - } - - step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { - fwd := r.Model.Forward(&batch.ForwardBatch{ - InputIDs: token.ExpandDims(0), - SeqIDs: []int{0}, - SeqLens: []int{1}, - }, caches) - logits := r.Model.Unembed(fwd) - logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) - - logprobs := logits.Subtract(logits.Logsumexp(true)) - sample := request.Sampler.Sample(logprobs) - - mlx.Pin(sample, logprobs) - mlx.Sweep() - mlx.AsyncEval(sample, logprobs) - - return sample, logprobs - } - - sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed)) - - var b bytes.Buffer - - final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} - for i := range request.Options.MaxTokens { - if err := ctx.Err(); err != nil { - return err - } - - request.Sampler.AppendToken(sample) - nextSample, nextLogprobs = step(sample) - - if i == 0 { - mlx.Eval(sample) - final.PromptEvalDuration = time.Since(now) - now = time.Now() - } - - output := int32(sample.Int()) - session.outputs = append(session.outputs, output) - - if r.Tokenizer.IsEOS(output) { - final.DoneReason = 0 - final.EvalCount = i - break - } - - select { - case <-ctx.Done(): - return ctx.Err() - case request.Responses <- CompletionResponse{ - Content: r.Decode(output, &b), - }: - } - - mlx.Unpin(sample, logprobs) - sample, logprobs = nextSample, nextLogprobs - nextSample, nextLogprobs = nil, nil - - if i%256 == 0 { - mlx.ClearCache() - } - } - - final.EvalDuration = time.Since(now) - select { - case <-ctx.Done(): - return ctx.Err() - case request.Responses <- final: - return nil - } -} - func (r Runner) Decode(sample int32, b *bytes.Buffer) string { token := r.Tokenizer.Decode([]int32{sample}) diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 8f2f3f6d7..cea8e328f 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -2,7 +2,6 @@ package mlxrunner import ( "context" - "errors" "log/slog" "net" "net/http" @@ -10,7 +9,6 @@ import ( "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" @@ -21,7 +19,6 @@ import ( type Request struct { TextCompletionsRequest Responses chan CompletionResponse - Pipeline func(Request) error Ctx context.Context Tokens []int32 @@ -139,30 +136,9 @@ func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) { func (r *Runner) Run(host, port string, mux http.Handler) error { g, ctx := errgroup.WithContext(context.Background()) + sched := r.newScheduler() g.Go(func() error { - for { - select { - case <-ctx.Done(): - return nil - case request := <-r.Requests: - if err := request.Pipeline(request); err != nil { - slog.Info("Request terminated", "error", err) - var statusErr api.StatusError - if !errors.As(err, &statusErr) { - statusErr = api.StatusError{ - StatusCode: http.StatusInternalServerError, - ErrorMessage: err.Error(), - } - } - select { - case request.Responses <- CompletionResponse{Error: &statusErr}: - case <-request.Ctx.Done(): - } - } - - close(request.Responses) - } - } + return sched.run(ctx) }) g.Go(func() error { diff --git a/x/mlxrunner/scheduler.go b/x/mlxrunner/scheduler.go new file mode 100644 index 000000000..b4e260af3 --- /dev/null +++ b/x/mlxrunner/scheduler.go @@ -0,0 +1,437 @@ +package mlxrunner + +import ( + "bytes" + "context" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// activeSeq tracks a single sequence in the decode batch. +type activeSeq struct { + seqID int + session *cacheSession + request Request + + // Decode state — pinned arrays from the previous step. + sample, logprobs *mlx.Array + + buf bytes.Buffer + generated int + final CompletionResponse + decodeAt time.Time // set after prefill completes +} + +func (s *activeSeq) cleanup() { + if s.request.Sampler != nil { + s.request.Sampler.Free() + } + mlx.Unpin(s.sample, s.logprobs) +} + +const maxParallel = 4 + +// scheduler manages prefill and decode for all active sequences. +type scheduler struct { + runner *Runner + active []*activeSeq + used [maxParallel]bool // seqID slot allocation +} + +func (r *Runner) newScheduler() *scheduler { + return &scheduler{runner: r} +} + +// allocSeqID returns the lowest free seqID slot. +func (s *scheduler) allocSeqID() int { + for i, used := range s.used { + if !used { + s.used[i] = true + return i + } + } + panic("no free sequence slots") +} + +// freeSeqID returns a seqID slot to the pool. +func (s *scheduler) freeSeqID(seqID int) { + s.used[seqID] = false +} + +func (s *scheduler) run(ctx context.Context) error { + r := s.runner + + enableCompile := true + if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { + enableCompile = modelCompile.EnableCompile() + } + if enableCompile { + mlx.EnableCompile() + } else { + mlx.DisableCompile() + } + + for { + if len(s.active) == 0 { + // No active sequences — block waiting for a request. + select { + case <-ctx.Done(): + return nil + case request := <-r.Requests: + s.admitRequest(ctx, request) + } + } else { + // Active sequences decoding — check for new requests non-blocking. + select { + case <-ctx.Done(): + s.finishAll() + return nil + case request := <-r.Requests: + s.admitRequest(ctx, request) + default: + } + + // Run one decode step for all active sequences. + s.decodeStep(ctx) + } + } +} + +// admitRequest prefills a new request and adds it to the decode batch. +func (s *scheduler) admitRequest(ctx context.Context, request Request) { + mlx.ResetPeakMemory() + + seqID := s.allocSeqID() + + seq := &activeSeq{ + seqID: seqID, + request: request, + final: CompletionResponse{ + Done: true, + PromptEvalCount: len(request.Tokens), + EvalCount: request.Options.MaxTokens, + DoneReason: 1, + }, + } + + // Ensure caches exist with all pool slots registered. SetSeqs is + // a no-op after the first call since the slot set never changes. + s.runner.cache.ensureCaches(s.runner.Model) + allSlots := make([]int, maxParallel) + for i := range allSlots { + allSlots[i] = i + } + for _, kv := range s.runner.cache.caches { + if kv != nil { + kv.SetSeqs(allSlots) + } + } + + if err := s.prefill(ctx, seq); err != nil { + slog.Info("Prefill failed", "seq", seqID, "error", err) + seq.cleanup() + s.freeSeqID(seqID) + s.sendError(request, err) + return + } + + // Materialize all cache state so existing sequences' decode steps + // see clean buffer data (not lazy graphs from prefill/restore). + s.materializeCaches() + + s.active = append(s.active, seq) +} + +func (s *scheduler) prefill(ctx context.Context, seq *activeSeq) error { + r := s.runner + inputs := seq.request.Tokens + seq.request.Sampler.ResetHistory(inputs) + + session := r.cache.begin(seq.seqID, r.Model, inputs) + seq.session = session + + caches := session.caches + tokens := session.remaining + + // Schedule periodic snapshots during prefill. + const snapshotInterval = 8192 + for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval { + session.requestSnapshot(offset) + } + const preThinking = 4 + if end := len(inputs) - preThinking; end > 0 { + session.requestSnapshot(end) + } + + prefillChunk := prefillChunkSize() + total, processed := len(tokens), 0 + for total-processed > 1 { + if err := ctx.Err(); err != nil { + return err + } + if err := seq.request.Ctx.Err(); err != nil { + return err + } + + n := min(prefillChunk, total-processed-1) + + if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { + baseOffset := len(session.inputs) - len(tokens) + tokensUntilSnapshot := snapOffset - (baseOffset + processed) + if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n { + n = tokensUntilSnapshot + } + } + + r.Model.Forward(&batch.ForwardBatch{ + InputIDs: mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), + SeqIDs: []int{seq.seqID}, + SeqLens: []int{n}, + }, caches) + mlx.Sweep() + s.materializeCaches() + processed += n + slog.Info("Prompt processing progress", "seq", seq.seqID, "processed", processed, "total", total) + + if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 { + baseOffset := len(session.inputs) - len(tokens) + if baseOffset+processed >= snapOffset { + session.snapshot() + } + } + + mlx.ClearCache() + } + + // First decode step: process final token(s) and get initial sample. + // Eval the sample AND the cache state so everything is materialized + // before any cache transitions (snapshot/restore/rebuild). + seq.sample, seq.logprobs = s.singleStep(seq, mlx.FromValues(tokens[processed:], total-processed)) + evalArrays := []*mlx.Array{seq.sample, seq.logprobs} + for _, c := range caches { + evalArrays = append(evalArrays, c.State()...) + } + mlx.Eval(evalArrays...) + seq.decodeAt = time.Now() + + return nil +} + +// singleStep runs a single-sequence forward+sample (used during prefill's +// final token and as fallback). +func (s *scheduler) singleStep(seq *activeSeq, token *mlx.Array) (*mlx.Array, *mlx.Array) { + r := s.runner + caches := seq.session.caches + + fwd := r.Model.Forward(&batch.ForwardBatch{ + InputIDs: token.ExpandDims(0), + SeqIDs: []int{seq.seqID}, + SeqLens: []int{1}, + }, caches) + logits := r.Model.Unembed(fwd) + logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) + + logprobs := logits.Subtract(logits.Logsumexp(true)) + sample := seq.request.Sampler.Sample(logprobs) + + mlx.Pin(sample, logprobs) + mlx.Sweep() + mlx.AsyncEval(sample, logprobs) + + return sample, logprobs +} + +// decodeStep runs one batched decode iteration for all active sequences. +func (s *scheduler) decodeStep(ctx context.Context) { + r := s.runner + + // Check for cancelled sequences and remove them. + s.reapCancelled(ctx) + if len(s.active) == 0 { + return + } + + // Read token values from previous step's samples. This forces + // evaluation of the lazy computation from the prior step. + inputTokens := make([]int32, len(s.active)) + for i, seq := range s.active { + if seq.generated == 0 { + mlx.Eval(seq.sample) + seq.final.PromptEvalDuration = time.Since(seq.decodeAt) + seq.decodeAt = time.Now() + } + inputTokens[i] = int32(seq.sample.Int()) + } + + // Process previous step's outputs: stream tokens, check EOS. + var completed []*activeSeq + for i, seq := range s.active { + output := inputTokens[i] + seq.session.outputs = append(seq.session.outputs, output) + seq.generated++ + + if r.Tokenizer.IsEOS(output) { + seq.final.DoneReason = 0 + seq.final.EvalCount = seq.generated - 1 + completed = append(completed, seq) + continue + } + + if seq.generated >= seq.request.Options.MaxTokens { + seq.final.EvalCount = seq.generated + completed = append(completed, seq) + continue + } + + // Stream token to client. + select { + case <-seq.request.Ctx.Done(): + completed = append(completed, seq) + case seq.request.Responses <- CompletionResponse{ + Content: r.Decode(output, &seq.buf), + }: + } + } + + // Finish completed sequences and remove from active list. + if len(completed) > 0 { + completedSet := make(map[int]bool, len(completed)) + for _, seq := range completed { + s.finishSeq(seq) + completedSet[seq.seqID] = true + } + alive := s.active[:0] + for _, seq := range s.active { + if !completedSet[seq.seqID] { + alive = append(alive, seq) + } + } + s.active = alive + mlx.ClearCache() + } + + if len(s.active) == 0 { + return + } + + // Batched forward pass: one token per sequence. + seqIDs := make([]int, len(s.active)) + seqLens := make([]int, len(s.active)) + nextTokens := make([]int32, len(s.active)) + for i, seq := range s.active { + seq.request.Sampler.AppendToken(seq.sample) + nextTokens[i] = int32(seq.sample.Int()) + seqIDs[i] = seq.seqID + seqLens[i] = 1 + mlx.Unpin(seq.sample, seq.logprobs) + seq.sample, seq.logprobs = nil, nil + } + + fwd := r.Model.Forward(&batch.ForwardBatch{ + InputIDs: mlx.FromValues(nextTokens, len(nextTokens)).ExpandDims(0), + SeqIDs: seqIDs, + SeqLens: seqLens, + }, r.cache.caches) + logits := r.Model.Unembed(fwd) + + for i, seq := range s.active { + seqLogits := logits.Slice(mlx.Slice(), mlx.Slice(i, i+1), mlx.Slice()).Squeeze(1) + lp := seqLogits.Subtract(seqLogits.Logsumexp(true)) + sample := seq.request.Sampler.Sample(lp) + mlx.Pin(sample, lp) + seq.sample = sample + seq.logprobs = lp + } + + mlx.Sweep() + + evalArrays := make([]*mlx.Array, 0, 2*len(s.active)) + for _, seq := range s.active { + evalArrays = append(evalArrays, seq.sample, seq.logprobs) + } + mlx.AsyncEval(evalArrays...) +} + +// reapCancelled removes sequences whose request context has been cancelled. +func (s *scheduler) reapCancelled(ctx context.Context) { + var alive []*activeSeq + for _, seq := range s.active { + if ctx.Err() != nil || seq.request.Ctx.Err() != nil { + s.finishSeq(seq) + } else { + alive = append(alive, seq) + } + } + if len(alive) != len(s.active) { + s.active = alive + } +} + +// finishSeq sends the final response, saves to trie, and cleans up. +// It does NOT remove from s.active — the caller is responsible for that. +func (s *scheduler) finishSeq(seq *activeSeq) { + seq.final.EvalDuration = time.Since(seq.decodeAt) + + // Send final response. + if seq.request.Ctx.Err() == nil { + select { + case seq.request.Responses <- seq.final: + case <-seq.request.Ctx.Done(): + } + } + + // Save to trie and clean up. + if seq.session != nil && seq.generated > 0 { + seq.session.close() + } + s.freeSeqID(seq.seqID) + seq.cleanup() + close(seq.request.Responses) + + if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { + s.runner.cache.dumpTree() + } + slog.Info("sequence complete", "seq", seq.seqID, "generated", seq.generated, + "peak_memory", mlx.PrettyBytes(mlx.PeakMemory())) +} + +func (s *scheduler) sendError(request Request, err error) { + slog.Info("Request terminated", "error", err) + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + statusErr = api.StatusError{ + StatusCode: http.StatusInternalServerError, + ErrorMessage: err.Error(), + } + } + select { + case request.Responses <- CompletionResponse{Error: &statusErr}: + case <-request.Ctx.Done(): + } + close(request.Responses) +} + +func (s *scheduler) finishAll() { + for _, seq := range s.active { + s.finishSeq(seq) + } + s.active = nil +} + +func (s *scheduler) materializeCaches() { + state := make([]*mlx.Array, 0, 2*len(s.runner.cache.caches)) + for _, c := range s.runner.cache.caches { + state = append(state, c.State()...) + } + if len(state) == 0 { + return + } + mlx.Eval(state...) +} diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index fb4c18773..f4de2cfc0 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -95,7 +95,6 @@ func Execute(args []string) error { request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) - request.Pipeline = runner.TextGenerationPipeline request.Sampler = sample.New( request.Options.Temperature, request.Options.TopP,