mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 00:03:27 +02:00
The KV cache previously used a tree structure which could store multiple divergent sequences, which is good for cache reuse. However, this is typically used in conjunction with paged attention so each node in the tree can store just a chunk of the KV cache and they can be stitched together later. We don't currently do this, so the cache was storing copies of the full cache for each past sequence. This redundancy plus the lack of resource limits, caused significant memory use as a conversation grew. Instead, this changes to store a single entry for the cache, which can be prefix matched. Although it is less ideal for multiple users, it largely matches Ollama's current behavior. It can be improved as additional pieces are fleshed out.
175 lines
4.5 KiB
Go
175 lines
4.5 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
type Request struct {
|
|
TextCompletionsRequest
|
|
Responses chan Response
|
|
Pipeline func(Request) error
|
|
|
|
sample.Sampler
|
|
caches []cache.Cache
|
|
}
|
|
|
|
type TextCompletionsRequest struct {
|
|
Prompt string `json:"prompt"`
|
|
Options struct {
|
|
Temperature float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
MinP float32 `json:"min_p"`
|
|
TopK int `json:"top_k"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
|
|
// Deprecated: use MaxTokens instead
|
|
NumPredict int `json:"num_predict"`
|
|
} `json:"options"`
|
|
}
|
|
|
|
type Response struct {
|
|
Text string `json:"content,omitempty"`
|
|
Token int `json:"token,omitempty"`
|
|
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
Done bool `json:"done,omitempty"`
|
|
DoneReason int `json:"done_reason,omitempty"`
|
|
|
|
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
CompletionTokens int `json:"eval_count,omitempty"`
|
|
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
TotalTokens int `json:"total_tokens,omitempty"`
|
|
}
|
|
|
|
type Runner struct {
|
|
Model base.Model
|
|
Tokenizer *tokenizer.Tokenizer
|
|
Requests chan Request
|
|
cache *CacheEntry
|
|
}
|
|
|
|
func (r *Runner) Load(modelName string) error {
|
|
root, err := model.Open(modelName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer root.Close()
|
|
|
|
m, err := base.New(root)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Load all tensor blobs from manifest
|
|
tensors, err := loadTensorsFromManifest(root)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Assign weights to model (model-specific logic)
|
|
loadWeights := base.Weights(m)
|
|
if err := loadWeights(tensors); err != nil {
|
|
return err
|
|
}
|
|
|
|
r.Model = m
|
|
r.Tokenizer = m.Tokenizer()
|
|
return nil
|
|
}
|
|
|
|
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
|
|
// flat map, deduplicating by digest and remapping safetensors key suffixes.
|
|
//
|
|
// Uses a two-phase approach: first loads all raw tensors, then remaps
|
|
// .bias → _qbias with complete knowledge of which base names have .scale
|
|
// entries. This avoids a race condition where Go map iteration order could
|
|
// cause .bias to be processed before .scale within the same blob.
|
|
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
|
|
// Phase 1: Load all tensors raw from all blobs
|
|
rawTensors := make(map[string]*mlx.Array)
|
|
seen := make(map[string]bool)
|
|
for _, layer := range root.Manifest.GetTensorLayers("") {
|
|
if seen[layer.Digest] {
|
|
continue
|
|
}
|
|
seen[layer.Digest] = true
|
|
blobPath := root.Manifest.BlobPath(layer.Digest)
|
|
for name, arr := range mlx.Load(blobPath) {
|
|
rawTensors[name] = arr
|
|
}
|
|
}
|
|
|
|
// Phase 2: Identify all base names that have .scale tensors and remap them
|
|
scaleBaseNames := make(map[string]bool)
|
|
allTensors := make(map[string]*mlx.Array, len(rawTensors))
|
|
for name, arr := range rawTensors {
|
|
if strings.HasSuffix(name, ".scale") {
|
|
baseName := strings.TrimSuffix(name, ".scale")
|
|
allTensors[baseName+"_scale"] = arr
|
|
scaleBaseNames[baseName] = true
|
|
}
|
|
}
|
|
|
|
// Phase 3: Process remaining tensors with complete scale knowledge
|
|
for name, arr := range rawTensors {
|
|
if strings.HasSuffix(name, ".scale") {
|
|
continue // already handled
|
|
}
|
|
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
|
baseName := strings.TrimSuffix(name, ".bias")
|
|
if scaleBaseNames[baseName] {
|
|
allTensors[baseName+"_qbias"] = arr
|
|
} else {
|
|
allTensors[name] = arr
|
|
}
|
|
} else {
|
|
allTensors[name] = arr
|
|
}
|
|
}
|
|
|
|
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
|
|
return allTensors, nil
|
|
}
|
|
|
|
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|
g, ctx := errgroup.WithContext(context.Background())
|
|
|
|
g.Go(func() error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case request := <-r.Requests:
|
|
if err := request.Pipeline(request); err != nil {
|
|
break
|
|
}
|
|
|
|
close(request.Responses)
|
|
}
|
|
}
|
|
})
|
|
|
|
g.Go(func() error {
|
|
slog.Info("Starting HTTP server", "host", host, "port", port)
|
|
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
|
})
|
|
|
|
return g.Wait()
|
|
}
|