diff --git a/api/types.go b/api/types.go index a0acd6641..b1ce8cc7d 100644 --- a/api/types.go +++ b/api/types.go @@ -64,6 +64,9 @@ type GenerateRequest struct { // the library at https://ollama.com/library Model string `json:"model"` + // Runner selects a runner variant from a manifest list. + Runner string `json:"runner,omitempty"` + // Prompt is the textual prompt to send to the model. Prompt string `json:"prompt"` @@ -148,6 +151,9 @@ type ChatRequest struct { // Model is the model name, as in [GenerateRequest]. Model string `json:"model"` + // Runner selects a runner variant from a manifest list. + Runner string `json:"runner,omitempty"` + // Messages is the messages of the chat - can be used to keep a chat memory. Messages []Message `json:"messages"` @@ -675,6 +681,9 @@ type CreateRequest struct { // From is the name of the model or file to use as the source. From string `json:"from,omitempty"` + // List is the list of local model tags to include in a manifest list. + List []string `json:"list,omitempty"` + // RemoteHost is the URL of the upstream ollama API for the model (if any). RemoteHost string `json:"remote_host,omitempty"` @@ -725,6 +734,7 @@ type DeleteRequest struct { // ShowRequest is the request passed to [Client.Show]. type ShowRequest struct { Model string `json:"model"` + Runner string `json:"runner,omitempty"` System string `json:"system"` // Template is deprecated @@ -829,6 +839,7 @@ type ProcessModelResponse struct { ExpiresAt time.Time `json:"expires_at"` SizeVRAM int64 `json:"size_vram"` ContextLength int `json:"context_length"` + Runner string `json:"runner,omitempty"` } type TokenResponse struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index 43a2e7d3c..f482ea0a6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -98,11 +98,11 @@ func init() { const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n" // ensureThinkingSupport emits a warning if the model does not advertise thinking support -func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { +func ensureThinkingSupport(ctx context.Context, client *api.Client, name, runner string) { if name == "" { return } - resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) + resp, err := client.Show(ctx, &api.ShowRequest{Model: name, Runner: runner}) if err != nil { return } @@ -156,6 +156,45 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid model name: %s", modelName) } + list, _ := cmd.Flags().GetStringSlice("combine") + if len(list) > 0 { + if experimental, _ := cmd.Flags().GetBool("experimental"); experimental { + return errors.New("--combine cannot be used with --experimental") + } + if quantize, _ := cmd.Flags().GetString("quantize"); quantize != "" { + return errors.New("--combine cannot be used with --quantize") + } + if cmd.Flags().Changed("file") { + return errors.New("--combine cannot be used with --file") + } + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + req := &api.CreateRequest{ + Model: modelName, + List: list, + } + + status := "creating manifest list" + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + + fn := func(resp api.ProgressResponse) error { + if status != resp.Status { + spinner.Stop() + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + return nil + } + + return client.Create(cmd.Context(), req, fn) + } + // Check for --experimental flag for safetensors model creation // This gates both safetensors LLM and imagegen model creation experimental, _ := cmd.Flags().GetBool("experimental") @@ -399,7 +438,7 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { requestedCloud := modelref.HasExplicitCloudSource(opts.Model) - if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil { + if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model, Runner: opts.Runner}); err != nil { return err } else if info.RemoteHost != "" || requestedCloud { // Cloud model, no need to load/unload @@ -431,6 +470,7 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { req := &api.GenerateRequest{ Model: opts.Model, + Runner: opts.Runner, KeepAlive: opts.KeepAlive, // pass Think here so we fail before getting to the chat prompt if the model doesn't support it @@ -562,6 +602,14 @@ func RunHandler(cmd *cobra.Command, args []string) error { ShowConnect: true, } + if flag := cmd.Flags().Lookup("runner"); flag != nil { + runner, err := cmd.Flags().GetString("runner") + if err != nil { + return err + } + opts.Runner = runner + } + format, err := cmd.Flags().GetString("format") if err != nil { return err @@ -651,7 +699,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { requestedCloud := modelref.HasExplicitCloudSource(name) info, err := func() (*api.ShowResponse, error) { - showReq := &api.ShowRequest{Name: name} + showReq := &api.ShowRequest{Name: name, Runner: opts.Runner} info, err := client.Show(cmd.Context(), showReq) var se api.StatusError if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { @@ -661,7 +709,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { if err := PullHandler(cmd, []string{name}); err != nil { return nil, err } - return client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + return client.Show(cmd.Context(), &api.ShowRequest{Name: name, Runner: opts.Runner}) } return info, err }() @@ -761,7 +809,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { // Use experimental agent loop with tools if isExperimental { - return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch) + return xcmd.GenerateInteractive(cmd, opts.Model, opts.Runner, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch) } return generateInteractive(cmd, opts) @@ -1000,12 +1048,12 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { until = format.HumanTime(m.ExpiresAt, "Never") } ctxStr := strconv.Itoa(m.ContextLength) - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until}) + data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, m.Runner, until}) } } table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"}) + table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "RUNNER", "UNTIL"}) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderLine(false) @@ -1412,6 +1460,7 @@ type generateContextKey string type runOptions struct { Model string + Runner string ParentModel string LoadedMessages []api.Message Prompt string @@ -1463,6 +1512,7 @@ func (r runOptions) Copy() runOptions { return runOptions{ Model: r.Model, + Runner: r.Runner, ParentModel: r.ParentModel, LoadedMessages: loadedMessages, Prompt: r.Prompt, @@ -1646,6 +1696,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { req := &api.ChatRequest{ Model: opts.Model, + Runner: opts.Runner, Messages: opts.Messages, Format: json.RawMessage(opts.Format), Options: opts.Options, @@ -1778,6 +1829,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { request := api.GenerateRequest{ Model: opts.Model, + Runner: opts.Runner, Prompt: opts.Prompt, Context: generateContext, Images: opts.Images, @@ -2121,6 +2173,7 @@ func NewCLI() *cobra.Command { } createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")") + createCmd.Flags().StringSlice("combine", nil, "Create a manifest list from comma-separated local models") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation") @@ -2152,6 +2205,7 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + runCmd.Flags().String("runner", "", "Runner to use for manifest list selection (mlx, ollama, llamacpp)") runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models") runCmd.Flags().Lookup("think").NoOptDefVal = "true" runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index d9c565630..498c230f2 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -479,6 +479,143 @@ func TestRunEmbeddingModel(t *testing.T) { } } +func TestListRunningHandlerShowsRunner(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/ps" || r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + + if err := json.NewEncoder(w).Encode(api.ProcessResponse{ + Models: []api.ProcessModelResponse{ + { + Name: "test-model:latest", + Model: "test-model:latest", + Digest: "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", + Size: 1024, + SizeVRAM: 1024, + ContextLength: 4096, + Runner: "mlx", + ExpiresAt: time.Now().Add(time.Hour), + }, + }, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + err := ListRunningHandler(cmd, nil) + w.Close() + os.Stdout = oldStdout + if err != nil { + t.Fatal(err) + } + + out, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + got := string(out) + for _, want := range []string{"CONTEXT", "RUNNER", "abcdef123456", "mlx"} { + if !strings.Contains(got, want) { + t.Fatalf("output missing %q:\n%s", want, got) + } + } +} + +func TestRunHandlerRunnerFlag(t *testing.T) { + showReqCh := make(chan api.ShowRequest, 1) + generateReqCh := make(chan api.GenerateRequest, 1) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/show" && r.Method == http.MethodPost: + var req api.ShowRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + showReqCh <- req + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + Capabilities: []model.Capability{model.CapabilityCompletion}, + ModelInfo: map[string]any{}, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: + var req api.GenerateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + generateReqCh <- req + w.Header().Set("Content-Type", "application/x-ndjson") + if err := json.NewEncoder(w).Encode(api.GenerateResponse{ + Model: "test-model", + Done: true, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + default: + http.NotFound(w, r) + } + })) + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + cmd.Flags().String("keepalive", "", "") + cmd.Flags().Bool("verbose", false, "") + cmd.Flags().Bool("insecure", false, "") + cmd.Flags().Bool("nowordwrap", false, "") + cmd.Flags().String("format", "", "") + cmd.Flags().String("runner", "", "") + cmd.Flags().String("think", "", "") + cmd.Flags().Bool("hidethinking", false, "") + if err := cmd.Flags().Set("runner", "llamacpp"); err != nil { + t.Fatal(err) + } + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + err := RunHandler(cmd, []string{"test-model", "hello"}) + w.Close() + os.Stdout = oldStdout + if _, readErr := io.ReadAll(r); readErr != nil { + t.Fatal(readErr) + } + if err != nil { + t.Fatal(err) + } + + select { + case req := <-showReqCh: + if req.Runner != "llamacpp" { + t.Fatalf("show runner = %q, want %q", req.Runner, "llamacpp") + } + default: + t.Fatal("server did not receive show request") + } + select { + case req := <-generateReqCh: + if req.Runner != "llamacpp" { + t.Fatalf("generate runner = %q, want %q", req.Runner, "llamacpp") + } + default: + t.Fatal("server did not receive generate request") + } +} + func TestRunEmbeddingModelWithFlags(t *testing.T) { reqCh := make(chan api.EmbedRequest, 1) mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1524,6 +1661,66 @@ func TestCreateHandler(t *testing.T) { } } +func TestCreateHandlerManifestList(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/create" { + t.Errorf("unexpected request to %s", r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + + var req api.CreateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if req.Model != "parent" { + t.Errorf("model = %q, want %q", req.Model, "parent") + } + if !cmp.Equal(req.List, []string{"gguf", "safetensors"}) { + t.Errorf("list = %#v, want %#v", req.List, []string{"gguf", "safetensors"}) + } + if req.From != "" || len(req.Files) > 0 { + t.Errorf("manifest list create sent normal create fields: from=%q files=%v", req.From, req.Files) + } + + if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.(http.Flusher).Flush() + })) + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.Flags().String("file", "", "") + cmd.Flags().String("quantize", "", "") + cmd.Flags().Bool("experimental", false, "") + cmd.Flags().StringSlice("combine", nil, "") + cmd.SetContext(t.Context()) + if err := cmd.Flags().Set("combine", "gguf,safetensors"); err != nil { + t.Fatal(err) + } + + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + err := CreateHandler(cmd, []string{"parent"}) + w.Close() + os.Stderr = oldStderr + if _, readErr := io.ReadAll(r); readErr != nil { + t.Fatal(readErr) + } + + if err != nil { + t.Fatal(err) + } +} + func TestNewCreateRequest(t *testing.T) { tests := []struct { name string diff --git a/cmd/interactive.go b/cmd/interactive.go index ea1ed9841..6c0904ab3 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -224,7 +224,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Messages = []api.Message{} opts.LoadedMessages = nil fmt.Printf("Loading model '%s'\n", opts.Model) - info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}) + info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model, Runner: opts.Runner}) if err != nil { if strings.Contains(err.Error(), "not found") { fmt.Printf("Couldn't find model '%s'\n", opts.Model) @@ -323,7 +323,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Think = &thinkValue thinkExplicitlySet = true if client, err := api.ClientFromEnvironment(); err == nil { - ensureThinkingSupport(cmd.Context(), client, opts.Model) + ensureThinkingSupport(cmd.Context(), client, opts.Model, opts.Runner) } if maybeLevel != "" { fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel) @@ -334,7 +334,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Think = &api.ThinkValue{Value: false} thinkExplicitlySet = true if client, err := api.ClientFromEnvironment(); err == nil { - ensureThinkingSupport(cmd.Context(), client, opts.Model) + ensureThinkingSupport(cmd.Context(), client, opts.Model, opts.Runner) } fmt.Println("Set 'nothink' mode.") case "format": @@ -414,6 +414,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } req := &api.ShowRequest{ Name: opts.Model, + Runner: opts.Runner, System: opts.System, Options: opts.Options, } diff --git a/cmd/warn_thinking_test.go b/cmd/warn_thinking_test.go index 31dc4156b..4d5eb9b70 100644 --- a/cmd/warn_thinking_test.go +++ b/cmd/warn_thinking_test.go @@ -47,7 +47,7 @@ func TestWarnMissingThinking(t *testing.T) { oldStderr := os.Stderr r, w, _ := os.Pipe() os.Stderr = w - ensureThinkingSupport(t.Context(), client, "m") + ensureThinkingSupport(t.Context(), client, "m", "") w.Close() os.Stderr = oldStderr out, _ := io.ReadAll(r) diff --git a/llm/server.go b/llm/server.go index a8104f79f..362cee22e 100644 --- a/llm/server.go +++ b/llm/server.go @@ -120,6 +120,18 @@ type ollamaServer struct { tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding } +// RunnerName returns the runner implementation name for a LlamaServer. +func RunnerName(s LlamaServer) string { + switch s.(type) { + case *ollamaServer: + return "ollama" + case *llamaServer: + return "llamacpp" + default: + return "" + } +} + // LoadModel will load a model from disk. The model must be in the GGML format. // // It collects array values for arrays with a size less than or equal to diff --git a/manifest/manifest.go b/manifest/manifest.go index dc8026e71..e378867ca 100644 --- a/manifest/manifest.go +++ b/manifest/manifest.go @@ -17,7 +17,13 @@ import ( "github.com/ollama/ollama/types/model" ) -var blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`) +var ( + blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`) + + // ErrNoCompatibleManifest is returned when a manifest list does not contain + // a child manifest for the requested runner. + ErrNoCompatibleManifest = errors.New("no compatible manifest found") +) const ( MediaTypeManifest = "application/vnd.docker.distribution.manifest.v2+json" @@ -40,10 +46,15 @@ type Manifest struct { Format string `json:"format,omitempty"` Manifests []Manifest `json:"manifests,omitempty"` - filepath string - fi os.FileInfo - digest string - name model.Name + filepath string + fi os.FileInfo + digest string + selectedDigest string + name model.Name +} + +func (m Manifest) isReference() bool { + return m.MediaType != MediaTypeManifestList && m.digest != "" && m.Config.Digest == "" && len(m.Layers) == 0 } func (m Manifest) MarshalJSON() ([]byte, error) { @@ -59,6 +70,20 @@ func (m Manifest) MarshalJSON() ([]byte, error) { }) } + if m.isReference() { + return json.Marshal(struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + Runner string `json:"runner,omitempty"` + Format string `json:"format,omitempty"` + }{ + MediaType: m.MediaType, + Digest: m.BlobDigest(), + Runner: m.Runner, + Format: m.Format, + }) + } + return json.Marshal(struct { SchemaVersion int `json:"schemaVersion"` MediaType string `json:"mediaType"` @@ -76,6 +101,41 @@ func (m Manifest) MarshalJSON() ([]byte, error) { }) } +func (m *Manifest) UnmarshalJSON(data []byte) error { + var raw struct { + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config Layer `json:"config"` + Layers []Layer `json:"layers"` + Runner string `json:"runner,omitempty"` + Format string `json:"format,omitempty"` + Manifests []Manifest `json:"manifests,omitempty"` + Digest string `json:"digest,omitempty"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + *m = Manifest{ + SchemaVersion: raw.SchemaVersion, + MediaType: raw.MediaType, + Config: raw.Config, + Layers: raw.Layers, + Runner: raw.Runner, + Format: raw.Format, + Manifests: raw.Manifests, + } + if raw.Digest != "" { + digest, err := canonicalBlobDigest(raw.Digest) + if err != nil { + return err + } + m.digest = strings.TrimPrefix(digest, "sha256:") + } + + return nil +} + func (m *Manifest) Size() (size int64) { for _, layer := range append(m.Layers, m.Config) { size += layer.Size @@ -88,6 +148,16 @@ func (m *Manifest) Digest() string { return m.digest } +// SelectedDigest returns the digest of the runnable manifest selected from a +// manifest list. For non-list manifests, it is the same as Digest. +func (m *Manifest) SelectedDigest() string { + if m.selectedDigest != "" { + return m.selectedDigest + } + + return m.digest +} + func (m *Manifest) BlobDigest() string { if m.digest == "" { return "" @@ -230,6 +300,13 @@ func referencedBlobDigestsForData(manifestDigest string, data []byte) ([]string, if err := add(child.BlobDigest()); err != nil { return nil, err } + if child.isReference() { + resolved, err := parseManifestBlob(child.BlobDigest()) + if err != nil { + return nil, err + } + child = resolved + } if err := addManifest(child); err != nil { return nil, err } @@ -288,6 +365,16 @@ func RemoveUnreferencedBlobs(candidates ...string) (int, error) { } func ParseNamedManifest(n model.Name) (*Manifest, error) { + return parseNamedManifest(n, runnerPreferences()) +} + +// ParseNamedManifestForRunner returns the named manifest selected for runner. +// If the named object is a manifest list, runner must match one child entry. +func ParseNamedManifestForRunner(n model.Name, runner string) (*Manifest, error) { + return parseNamedManifest(n, runnerPreferencesFor(runner)) +} + +func parseNamedManifest(n model.Name, preferences []string) (*Manifest, error) { if !n.IsFullyQualified() { return nil, model.Unqualified(n) } @@ -297,7 +384,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return nil, err } - return parseManifestFile(normalizeLogicalName(n), p, root) + return parseManifestFile(normalizeLogicalName(n), p, root, preferences) } func ReadManifestData(n model.Name) ([]byte, error) { @@ -352,32 +439,38 @@ func ReadSelectedManifestData(n model.Name) ([]byte, error) { return data, nil } -func parseManifestFile(name model.Name, path, root string) (*Manifest, error) { +func parseManifestFile(name model.Name, path, root string, preferences []string) (*Manifest, error) { data, fi, digest, err := readVerifiedManifest(path, root) if err != nil { return nil, err } - return parseManifestData(name, path, fi, digest, data) + return parseManifestData(name, path, fi, digest, data, preferences) } -func parseManifestData(name model.Name, path string, fi os.FileInfo, digest string, data []byte) (*Manifest, error) { +func parseManifestData(name model.Name, path string, fi os.FileInfo, digest string, data []byte, preferences []string) (*Manifest, error) { m, err := parseManifest(data) if err != nil { return nil, err } if m.MediaType == MediaTypeManifestList { - child, err := selectManifest(m.Manifests) + child, err := selectManifestWithPreferences(m.Manifests, preferences) if err != nil { return nil, err } + selectedDigest := child.digest child.filepath = path child.fi = fi child.digest = digest + child.selectedDigest = selectedDigest child.name = name return child, nil } + if len(preferences) == 1 && m.Runner != "" && !strings.EqualFold(m.Runner, preferences[0]) { + return nil, fmt.Errorf("%w for runners: %s", ErrNoCompatibleManifest, preferences[0]) + } + m.filepath = path m.fi = fi m.digest = digest @@ -398,12 +491,29 @@ func selectManifestWithPreferences(manifests []Manifest, preferences []string) ( } if strings.EqualFold(manifests[i].Runner, runner) { child := manifests[i] + if child.isReference() { + childDigest := child.digest + resolved, err := parseManifestBlob(child.BlobDigest()) + if err != nil { + return nil, err + } + if resolved.Runner == "" { + resolved.Runner = child.Runner + } + if resolved.Format == "" { + resolved.Format = child.Format + } + if resolved.digest == "" { + resolved.digest = childDigest + } + child = *resolved + } return &child, nil } } } - return nil, fmt.Errorf("no compatible manifest found for runners: %s", strings.Join(preferences, ", ")) + return nil, fmt.Errorf("%w for runners: %s", ErrNoCompatibleManifest, strings.Join(preferences, ", ")) } func runnerPreferences() []string { @@ -414,6 +524,15 @@ func runnerPreferences() []string { return []string{RunnerOllama, RunnerLlamaCPP, RunnerMLX} } +func runnerPreferencesFor(runner string) []string { + runner = strings.ToLower(strings.TrimSpace(runner)) + if runner == "" { + return runnerPreferences() + } + + return []string{runner} +} + func parseManifest(data []byte) (*Manifest, error) { var m Manifest if err := json.Unmarshal(data, &m); err != nil { @@ -424,6 +543,9 @@ func parseManifest(data []byte) (*Manifest, error) { if m.Manifests[i].MediaType == MediaTypeManifestList { return nil, errors.New("nested manifest lists are not supported") } + if m.Manifests[i].isReference() { + continue + } canonical, err := json.Marshal(m.Manifests[i]) if err != nil { @@ -437,6 +559,37 @@ func parseManifest(data []byte) (*Manifest, error) { return &m, nil } +func readManifestBlob(digest string) ([]byte, error) { + digest, err := canonicalBlobDigest(digest) + if err != nil { + return nil, err + } + + blobPath, err := BlobsPath(digest) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(blobPath) + if err != nil { + return nil, err + } + if err := checkBlobDigestReader(bytes.NewReader(data), digest); err != nil { + return nil, err + } + + return data, nil +} + +func parseManifestBlob(digest string) (*Manifest, error) { + data, err := readManifestBlob(digest) + if err != nil { + return nil, err + } + + return parseManifest(data) +} + func canonicalBlobDigest(digest string) (string, error) { if _, err := BlobsPath(digest); err != nil { return "", err @@ -460,11 +613,19 @@ func blobDigest(digest string) string { } func WriteManifest(name model.Name, config Layer, layers []Layer) error { + return WriteManifestWithMetadata(name, config, layers, "", "") +} + +// WriteManifestWithMetadata stores a single runnable manifest with optional +// runner and weight format metadata. +func WriteManifestWithMetadata(name model.Name, config Layer, layers []Layer, runner, format string) error { m := Manifest{ SchemaVersion: 2, MediaType: MediaTypeManifest, Config: config, Layers: layers, + Runner: runner, + Format: format, } var b bytes.Buffer @@ -475,6 +636,27 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error { return WriteManifestData(name, b.Bytes()) } +// NewManifestReference returns a manifest-list entry that points at an existing +// child manifest blob. +func NewManifestReference(digest, runner, format string) (Manifest, error) { + digest, err := canonicalBlobDigest(digest) + if err != nil { + return Manifest{}, err + } + + return Manifest{ + MediaType: MediaTypeManifest, + Runner: runner, + Format: format, + digest: strings.TrimPrefix(digest, "sha256:"), + }, nil +} + +// WriteManifestBlob stores raw manifest bytes as a content-addressed blob. +func WriteManifestBlob(data []byte) (string, error) { + return writeManifestBlob(data) +} + // WriteManifestData stores raw manifest bytes as a content-addressed blob and // updates the v2 named manifest path to reference that blob. Any legacy named // manifest for the same model is removed after the v2 write succeeds. @@ -815,7 +997,7 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) { ms := make(map[model.Name]*Manifest, len(refs)) for n, ref := range refs { - m, err := parseManifestFile(n, ref.path, ref.root) + m, err := parseManifestFile(n, ref.path, ref.root, runnerPreferences()) if err != nil { if !continueOnError { return nil, fmt.Errorf("%s %w", n, err) diff --git a/manifest/manifest_test.go b/manifest/manifest_test.go index e0b76b156..658008b81 100644 --- a/manifest/manifest_test.go +++ b/manifest/manifest_test.go @@ -192,6 +192,9 @@ func TestParseNamedManifestResolvesManifestList(t *testing.T) { if got := m.BlobDigest(); got != fmt.Sprintf("sha256:%x", parentSum) { t.Fatalf("blob digest = %q, want sha256:%x", got, parentSum) } + if got := m.SelectedDigest(); got != strings.TrimPrefix(ollamaDigest, "sha256:") { + t.Fatalf("selected digest = %q, want %q", got, strings.TrimPrefix(ollamaDigest, "sha256:")) + } if got := m.Runner; got != RunnerOllama { t.Fatalf("runner = %q, want %q", got, RunnerOllama) } @@ -201,6 +204,21 @@ func TestParseNamedManifestResolvesManifestList(t *testing.T) { if got := m.Config.Digest; got != "sha256:"+strings.Repeat("a", 64) { t.Fatalf("config digest = %q, want selected child config", got) } + + m, err = ParseNamedManifestForRunner(name, RunnerLlamaCPP) + if err != nil { + t.Fatal(err) + } + if got := m.Runner; got != RunnerLlamaCPP { + t.Fatalf("runner = %q, want %q", got, RunnerLlamaCPP) + } + if got := m.SelectedDigest(); got != strings.TrimPrefix(llamacppDigest, "sha256:") { + t.Fatalf("selected digest = %q, want %q", got, strings.TrimPrefix(llamacppDigest, "sha256:")) + } + if got := m.Config.Digest; got != "sha256:"+strings.Repeat("c", 64) { + t.Fatalf("config digest = %q, want selected child config", got) + } + referenced, err := ReferencedBlobDigestsForName(name) if err != nil { t.Fatal(err) diff --git a/server/create.go b/server/create.go index 4a2b943b6..42b7d97c4 100644 --- a/server/create.go +++ b/server/create.go @@ -104,6 +104,23 @@ func (s *Server) CreateHandler(c *gin.Context) { oldManifestDigests, _ := manifest.ReferencedBlobDigestsForName(name) + if len(r.List) > 0 { + if err := createManifestList(r, name, fn); err != nil { + ch <- gin.H{"error": err.Error()} + return + } + + if !envconfig.NoPrune() && len(oldManifestDigests) > 0 { + if _, err := manifest.RemoveUnreferencedBlobs(oldManifestDigests...); err != nil { + ch <- gin.H{"error": err.Error()} + return + } + } + + ch <- api.ProgressResponse{Status: "success"} + return + } + var baseLayers []*layerGGML var err error var remote bool @@ -599,13 +616,162 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, } fn(api.ProgressResponse{Status: "writing manifest"}) - if err := manifest.WriteManifest(name, *configLayer, layers); err != nil { + runner, format := manifestMetadataForConfig(*config) + if err := manifest.WriteManifestWithMetadata(name, *configLayer, layers, runner, format); err != nil { return err } return nil } +func createManifestList(r api.CreateRequest, name model.Name, fn func(resp api.ProgressResponse)) error { + if err := validateCreateManifestListRequest(r); err != nil { + return err + } + + manifests := make([]manifest.Manifest, 0, len(r.List)) + for _, ref := range r.List { + ref = strings.TrimSpace(ref) + if ref == "" { + return errors.New("manifest list contains an empty model") + } + + fn(api.ProgressResponse{Status: fmt.Sprintf("reading manifest %s", ref)}) + + modelRef, err := parseAndValidateModelRef(ref) + if err != nil { + return err + } + if modelRef.Source == modelSourceCloud { + return fmt.Errorf("manifest list entries must be local models: %s", ref) + } + + childName, err := getExistingName(modelRef.Name) + if err != nil { + return err + } + + data, err := manifest.ReadManifestData(childName) + if err != nil { + return fmt.Errorf("read manifest %s: %w", ref, err) + } + + var child manifest.Manifest + if err := json.Unmarshal(data, &child); err != nil { + return err + } + if child.MediaType == manifest.MediaTypeManifestList { + return fmt.Errorf("manifest list entry %s is already a manifest list", ref) + } + + if err := fillManifestMetadata(&child); err != nil { + return fmt.Errorf("manifest list entry %s: %w", ref, err) + } + + childData, err := json.Marshal(child) + if err != nil { + return err + } + childDigest, err := manifest.WriteManifestBlob(childData) + if err != nil { + return err + } + + childRef, err := manifest.NewManifestReference(childDigest, child.Runner, child.Format) + if err != nil { + return err + } + + manifests = append(manifests, childRef) + } + + parent := manifest.Manifest{ + SchemaVersion: 2, + MediaType: manifest.MediaTypeManifestList, + Manifests: manifests, + } + data, err := json.Marshal(parent) + if err != nil { + return err + } + + fn(api.ProgressResponse{Status: "writing manifest list"}) + return manifest.WriteManifestData(name, data) +} + +func validateCreateManifestListRequest(r api.CreateRequest) error { + if len(r.List) == 0 { + return errors.New("manifest list must contain at least one model") + } + + switch { + case r.From != "", r.RemoteHost != "", len(r.Files) > 0, len(r.Adapters) > 0: + return errors.New("manifest list creation cannot be combined with model creation options") + case r.Template != "", r.System != "", r.License != nil, len(r.Parameters) > 0, len(r.Messages) > 0: + return errors.New("manifest list creation cannot be combined with model creation options") + case r.Renderer != "", r.Parser != "", r.Requires != "", len(r.Info) > 0: + return errors.New("manifest list creation cannot be combined with model creation options") + case r.Quantize != "", r.Quantization != "": + return errors.New("manifest list creation cannot be combined with model creation options") + default: + return nil + } +} + +func fillManifestMetadata(m *manifest.Manifest) error { + if m.Runner != "" && m.Format != "" { + return nil + } + + config, err := readManifestConfig(m.Config.Digest) + if err != nil { + return err + } + + runner, format := manifestMetadataForConfig(config) + if m.Runner == "" { + m.Runner = runner + } + if m.Format == "" { + m.Format = format + } + if m.Runner == "" || m.Format == "" { + return errors.New("manifest is missing runner or format metadata") + } + + return nil +} + +func readManifestConfig(digest string) (model.ConfigV2, error) { + var config model.ConfigV2 + if digest == "" { + return config, errors.New("manifest is missing config digest") + } + + configPath, err := manifest.BlobsPath(digest) + if err != nil { + return config, err + } + configFile, err := os.Open(configPath) + if err != nil { + return config, err + } + defer configFile.Close() + + return config, json.NewDecoder(configFile).Decode(&config) +} + +func manifestMetadataForConfig(config model.ConfigV2) (runner, format string) { + switch strings.ToLower(config.ModelFormat) { + case manifest.FormatSafetensors: + return manifest.RunnerMLX, manifest.FormatSafetensors + case manifest.FormatGGUF, "ggml": + return manifest.RunnerOllama, manifest.FormatGGUF + default: + return "", strings.ToLower(config.ModelFormat) + } +} + func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) { ft := layer.GGML.KV().FileType() var doneBytes atomic.Uint64 diff --git a/server/images.go b/server/images.go index 3036b9be5..19cd8ea4a 100644 --- a/server/images.go +++ b/server/images.go @@ -71,6 +71,8 @@ type Model struct { System string License []string Digest string + ManifestDigest string + Runner string Options map[string]any Messages []api.Message @@ -300,17 +302,30 @@ func (m *Model) String() string { } func GetModel(name string) (*Model, error) { + return GetModelForRunner(name, "") +} + +// GetModelForRunner returns model metadata for name, selecting runner from a +// manifest list when one is specified. +func GetModelForRunner(name, runner string) (*Model, error) { n := model.ParseName(name) - mf, err := manifest.ParseNamedManifest(n) + mf, err := manifest.ParseNamedManifestForRunner(n, runner) if err != nil { return nil, err } + manifestDigest := mf.SelectedDigest() + if manifestDigest == "" { + manifestDigest = mf.Digest() + } + m := &Model{ - Name: n.String(), - ShortName: n.DisplayShortest(), - Digest: mf.Digest(), - Template: template.DefaultTemplate, + Name: n.String(), + ShortName: n.DisplayShortest(), + Digest: mf.Digest(), + ManifestDigest: manifestDigest, + Runner: mf.Runner, + Template: template.DefaultTemplate, } if mf.Config.Digest != "" { diff --git a/server/routes.go b/server/routes.go index 6620180b0..f71c7cab5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -141,14 +141,29 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt return opts, nil } +func normalizeRunner(runner string) (string, error) { + switch strings.ToLower(strings.TrimSpace(runner)) { + case "": + return "", nil + case manifest.RunnerMLX, "mlxrunner": + return manifest.RunnerMLX, nil + case manifest.RunnerOllama, "ggml": + return manifest.RunnerOllama, nil + case manifest.RunnerLlamaCPP, "llama.cpp", "llama-cpp", "llama_cpp": + return manifest.RunnerLlamaCPP, nil + default: + return "", fmt.Errorf("unknown runner %q", runner) + } +} + // scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { +func (s *Server) scheduleRunner(ctx context.Context, name, selectedRunner string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { return nil, nil, nil, fmt.Errorf("model %w", errRequired) } - model, err := GetModel(name) + model, err := GetModelForRunner(name, selectedRunner) if err != nil { return nil, nil, nil, err } @@ -207,6 +222,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if runner, err := normalizeRunner(req.Runner); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } else { + req.Runner = runner + } + modelRef, err := parseAndValidateModelRef(req.Model) if err != nil { writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model)) @@ -231,11 +253,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - m, err := GetModel(name.String()) + m, err := GetModelForRunner(name.String(), req.Runner) if err != nil { switch { case errors.Is(err, fs.ErrNotExist): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case errors.Is(err, manifest.ErrNoCompatibleManifest): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case err.Error() == errtypes.InvalidModelNameErrMsg: c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: @@ -405,7 +429,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), req.Runner, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return @@ -727,7 +751,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), "", []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -882,7 +906,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { name := modelRef.Name - r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), "", []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -1113,6 +1137,11 @@ func (s *Server) ShowHandler(c *gin.Context) { return } + if req.Runner, err = normalizeRunner(req.Runner); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + modelRef, err := parseAndValidateModelRef(req.Model) if err != nil { writeModelRefParseError(c, err, http.StatusBadRequest, err.Error()) @@ -1133,6 +1162,8 @@ func (s *Server) ShowHandler(c *gin.Context) { switch { case os.IsNotExist(err): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case errors.Is(err, manifest.ErrNoCompatibleManifest): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.As(err, &statusErr): c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage}) case err.Error() == errtypes.InvalidModelNameErrMsg: @@ -1163,16 +1194,25 @@ func (s *Server) ShowHandler(c *gin.Context) { } func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { + runner, err := normalizeRunner(req.Runner) + if err != nil { + return nil, api.StatusError{ + StatusCode: http.StatusBadRequest, + ErrorMessage: err.Error(), + } + } + req.Runner = runner + name := model.ParseName(req.Model) if !name.IsValid() { return nil, model.Unqualified(name) } - name, err := getExistingName(name) + name, err = getExistingName(name) if err != nil { return nil, err } - m, err := GetModel(name.String()) + m, err := GetModelForRunner(name.String(), req.Runner) if err != nil { return nil, err } @@ -1231,7 +1271,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} } - mf, err := manifest.ParseNamedManifest(name) + mf, err := manifest.ParseNamedManifestForRunner(name, req.Runner) if err != nil { return nil, err } @@ -2038,6 +2078,14 @@ func (s *Server) PsHandler(c *gin.Context) { for _, v := range s.sched.loaded { model := v.model + digest := model.ManifestDigest + if digest == "" { + digest = model.Digest + } + runner := v.runner + if runner == "" { + runner = model.Runner + } modelDetails := api.ModelDetails{ Format: model.Config.ModelFormat, Family: model.Config.ModelFamily, @@ -2051,9 +2099,10 @@ func (s *Server) PsHandler(c *gin.Context) { Name: model.ShortName, Size: int64(v.totalSize), SizeVRAM: int64(v.vramSize), - Digest: model.Digest, + Digest: digest, Details: modelDetails, ExpiresAt: v.expiresAt, + Runner: runner, } if v.llama != nil { mr.ContextLength = v.llama.ContextLength() @@ -2106,6 +2155,13 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if runner, err := normalizeRunner(req.Runner); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } else { + req.Runner = runner + } + modelRef, err := parseAndValidateModelRef(req.Model) if err != nil { writeModelRefParseError(c, err, http.StatusBadRequest, "model is required") @@ -2130,11 +2186,13 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - m, err := GetModel(name.String()) + m, err := GetModelForRunner(name.String(), req.Runner) if err != nil { switch { case os.IsNotExist(err): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case errors.Is(err, manifest.ErrNoCompatibleManifest): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case err.Error() == errtypes.InvalidModelNameErrMsg: c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: @@ -2283,7 +2341,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), req.Runner, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return @@ -2622,6 +2680,8 @@ func handleScheduleError(c *gin.Context, name string, err error) { switch { case errors.Is(err, errCapabilities), errors.Is(err, errRequired): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + case errors.Is(err, manifest.ErrNoCompatibleManifest): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) case errors.Is(err, ErrMaxQueue): @@ -2672,7 +2732,7 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo } // Schedule the runner for image generation - runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive) + runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, req.Runner, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 6655b88d1..d2bf7ac17 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -359,6 +359,113 @@ func TestCreateRemovesLayers(t *testing.T) { }) } +func writeManifestListVariant(t *testing.T, name, modelFormat string) { + t.Helper() + + configData, err := json.Marshal(model.ConfigV2{ + ModelFormat: modelFormat, + Capabilities: []string{"completion"}, + }) + if err != nil { + t.Fatal(err) + } + configLayer, err := manifest.NewLayer(bytes.NewReader(configData), "application/vnd.docker.container.image.v1+json") + if err != nil { + t.Fatal(err) + } + modelLayer, err := manifest.NewLayer(strings.NewReader(name+" layer"), "application/vnd.ollama.image.license") + if err != nil { + t.Fatal(err) + } + + if err := manifest.WriteManifest(model.ParseName(name), configLayer, []manifest.Layer{modelLayer}); err != nil { + t.Fatal(err) + } +} + +func TestCreateManifestList(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Setenv("OLLAMA_MODELS", t.TempDir()) + var s Server + + writeManifestListVariant(t, "test-gguf", manifest.FormatGGUF) + writeManifestListVariant(t, "test-safetensors", manifest.FormatSafetensors) + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-list", + List: []string{"test-gguf", "test-safetensors"}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d: %s", w.Code, w.Body.String()) + } + + data, err := manifest.ReadManifestData(model.ParseName("test-list")) + if err != nil { + t.Fatal(err) + } + + var parent manifest.Manifest + if err := json.Unmarshal(data, &parent); err != nil { + t.Fatal(err) + } + if parent.MediaType != manifest.MediaTypeManifestList { + t.Fatalf("mediaType = %q, want %q", parent.MediaType, manifest.MediaTypeManifestList) + } + if len(parent.Manifests) != 2 { + t.Fatalf("manifest count = %d, want 2", len(parent.Manifests)) + } + + selected, err := manifest.ParseNamedManifest(model.ParseName("test-list")) + if err != nil { + t.Fatal(err) + } + if selected.Config.Digest == "" { + t.Fatal("selected manifest is missing config") + } + + mlxInfo, err := GetModelInfo(api.ShowRequest{Model: "test-list", Runner: manifest.RunnerMLX}) + if err != nil { + t.Fatal(err) + } + if mlxInfo.Details.Format != manifest.FormatSafetensors { + t.Fatalf("mlx show format = %q, want %q", mlxInfo.Details.Format, manifest.FormatSafetensors) + } + + want := map[string]string{ + manifest.RunnerOllama: manifest.FormatGGUF, + manifest.RunnerMLX: manifest.FormatSafetensors, + } + for _, child := range parent.Manifests { + if got := want[child.Runner]; got != child.Format { + t.Fatalf("child runner/format = %q/%q, want one of %v", child.Runner, child.Format, want) + } + if child.BlobDigest() == "" { + t.Fatal("child manifest reference is missing digest") + } + if child.Config.Digest != "" || len(child.Layers) != 0 { + t.Fatalf("child manifest reference embedded config/layers: config=%q layers=%d", child.Config.Digest, len(child.Layers)) + } + + childBlob, err := manifest.BlobsPath(child.BlobDigest()) + if err != nil { + t.Fatal(err) + } + childData, err := os.ReadFile(childBlob) + if err != nil { + t.Fatalf("child manifest blob missing: %v", err) + } + var resolved manifest.Manifest + if err := json.Unmarshal(childData, &resolved); err != nil { + t.Fatal(err) + } + if resolved.Config.Digest == "" || len(resolved.Layers) == 0 { + t.Fatalf("resolved child manifest missing config/layers: config=%q layers=%d", resolved.Config.Digest, len(resolved.Layers)) + } + } +} + func TestCreateUnsetsSystem(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_test.go b/server/routes_test.go index 77abec986..a70cb42b0 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -20,6 +20,7 @@ import ( "sort" "strings" "testing" + "time" "unicode" "github.com/gin-gonic/gin" @@ -33,6 +34,58 @@ import ( "github.com/ollama/ollama/version" ) +func TestPsHandlerUsesRunningManifestAndRunner(t *testing.T) { + gin.SetMode(gin.TestMode) + + childDigest := strings.Repeat("a", 64) + s := Server{ + sched: &Scheduler{ + loaded: map[string]*runnerRef{ + "test": { + model: &Model{ + ShortName: "test-model:latest", + Digest: strings.Repeat("b", 64), + ManifestDigest: childDigest, + Runner: manifest.RunnerMLX, + Config: model.ConfigV2{ + ModelFormat: manifest.FormatSafetensors, + }, + }, + runner: manifest.RunnerMLX, + totalSize: 1024, + vramSize: 1024, + expiresAt: time.Now().Add(time.Hour), + sessionDuration: time.Hour, + }, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/ps", nil) + + s.PsHandler(c) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want %d: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp api.ProcessResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if len(resp.Models) != 1 { + t.Fatalf("model count = %d, want 1", len(resp.Models)) + } + if resp.Models[0].Digest != childDigest { + t.Fatalf("digest = %q, want child digest %q", resp.Models[0].Digest, childDigest) + } + if resp.Models[0].Runner != manifest.RunnerMLX { + t.Fatalf("runner = %q, want %q", resp.Models[0].Runner, manifest.RunnerMLX) + } +} + func createTestFile(t *testing.T, name string) (string, string) { t.Helper() diff --git a/server/sched.go b/server/sched.go index f040e34f3..9fd0e3528 100644 --- a/server/sched.go +++ b/server/sched.go @@ -84,7 +84,8 @@ func InitScheduler(ctx context.Context) *Scheduler { // schedulerModelKey returns the scheduler map key for a model. // GGUF-backed models use ModelPath; safetensors/image models without a -// ModelPath use manifest digest so distinct models don't collide. +// ModelPath use the selected manifest digest so distinct child manifests don't +// collide. func schedulerModelKey(m *Model) string { if m == nil { return "" @@ -92,6 +93,9 @@ func schedulerModelKey(m *Model) string { if m.ModelPath != "" { return m.ModelPath } + if m.ManifestDigest != "" { + return "manifest:" + m.ManifestDigest + } if m.Digest != "" { return "digest:" + m.Digest } @@ -530,6 +534,12 @@ iGPUScan: } totalSize, vramSize := llama.MemorySize() + runnerName := req.model.Runner + if req.model.IsMLX() && runnerName == "" { + runnerName = "mlx" + } else if name := llm.RunnerName(llama); name != "" { + runnerName = name + } runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, @@ -540,6 +550,7 @@ iGPUScan: gpus: gpuIDs, discreteGPUs: discreteGPUs, isImagegen: slices.Contains(req.model.Config.Capabilities, "image"), + runner: runnerName, totalSize: totalSize, vramSize: vramSize, loading: true, @@ -640,6 +651,7 @@ type runnerRef struct { gpus []ml.DeviceID // Recorded at time of provisioning discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs isImagegen bool // True if loaded via imagegen runner (vs mlxrunner) + runner string vramSize uint64 totalSize uint64 diff --git a/server/sched_test.go b/server/sched_test.go index f40dc117f..f533c51e4 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -499,6 +499,35 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) { require.Len(t, s.pendingReqCh, 1) } +func TestSchedGetRunnerUsesManifestDigestKeyWhenModelPathEmpty(t *testing.T) { + ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer done() + + s := InitScheduler(ctx) + opts := api.DefaultOptions() + opts.NumCtx = 4 + + loadedModel := &Model{Name: "list", Digest: "parent", ManifestDigest: "child-a"} + loadedRunner := &runnerRef{ + model: loadedModel, + modelKey: schedulerModelKey(loadedModel), + llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}, + Options: &opts, + numParallel: 1, + } + + s.loadedMu.Lock() + s.loaded[loadedRunner.modelKey] = loadedRunner + s.loadedMu.Unlock() + + reqModel := &Model{Name: "list", Digest: "parent", ManifestDigest: "child-b"} + successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil) + + require.Empty(t, successCh) + require.Empty(t, errCh) + require.Len(t, s.pendingReqCh, 1) +} + func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() diff --git a/x/cmd/run.go b/x/cmd/run.go index e96c8385e..eecedbff0 100644 --- a/x/cmd/run.go +++ b/x/cmd/run.go @@ -142,6 +142,7 @@ func waitForOllamaSignin(ctx context.Context) error { // RunOptions contains options for running an interactive agent session. type RunOptions struct { Model string + Runner string Messages []api.Message WordWrap bool Format string @@ -260,6 +261,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) { for { req := &api.ChatRequest{ Model: opts.Model, + Runner: opts.Runner, Messages: messages, Format: json.RawMessage(opts.Format), Options: opts.Options, @@ -638,13 +640,13 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string { } // checkModelCapabilities checks if the model supports tools. -func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) { +func checkModelCapabilities(ctx context.Context, modelName, runner string) (supportsTools bool, err error) { client, err := api.ClientFromEnvironment() if err != nil { return false, err } - resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName}) + resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName, Runner: runner}) if err != nil { return false, err } @@ -662,7 +664,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool // This is called from cmd.go when --experimental flag is set. // If yoloMode is true, all tool approvals are skipped. // If enableWebsearch is true, the web search tool is registered. -func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error { +func GenerateInteractive(cmd *cobra.Command, modelName, runner string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error { scanner, err := readline.New(readline.Prompt{ Prompt: ">>> ", AltPrompt: "... ", @@ -677,7 +679,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op defer fmt.Printf(readline.EndBracketedPaste) // Check if model supports tools - supportsTools, err := checkModelCapabilities(cmd.Context(), modelName) + supportsTools, err := checkModelCapabilities(cmd.Context(), modelName, runner) if err != nil { fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err) supportsTools = false @@ -807,7 +809,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op think = &thinkValue // Check if model supports thinking if client, err := api.ClientFromEnvironment(); err == nil { - if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil { + if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName, Runner: runner}); err == nil { if !slices.Contains(resp.Capabilities, model.CapabilityThinking) { fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName) } @@ -822,7 +824,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op think = &api.ThinkValue{Value: false} // Check if model supports thinking if client, err := api.ClientFromEnvironment(); err == nil { - if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil { + if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName, Runner: runner}); err == nil { if !slices.Contains(resp.Capabilities, model.CapabilityThinking) { fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName) } @@ -884,6 +886,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op } req := &api.ShowRequest{ Name: modelName, + Runner: runner, Options: options, } resp, err := client.Show(cmd.Context(), req) @@ -981,7 +984,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op } // Check if model exists and get its info - info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName}) + info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName, Runner: runner}) if err != nil { p.StopAndClear() if strings.Contains(err.Error(), "not found") { @@ -996,8 +999,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op if info.RemoteHost == "" { // Preload the model by sending an empty generate request req := &api.GenerateRequest{ - Model: newModelName, - Think: think, + Model: newModelName, + Runner: runner, + Think: think, } err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { return nil @@ -1059,6 +1063,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op verbose, _ := cmd.Flags().GetBool("verbose") opts := RunOptions{ Model: modelName, + Runner: runner, Messages: messages, WordWrap: wordWrap, Format: format, diff --git a/x/create/client/create.go b/x/create/client/create.go index c5962fdd7..85c6bf9d4 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -389,7 +389,7 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re manifestLayers = append(manifestLayers, modelfileLayers...) } - return manifest.WriteManifest(name, configLayer, manifestLayers) + return manifest.WriteManifestWithMetadata(name, configLayer, manifestLayers, manifest.RunnerMLX, manifest.FormatSafetensors) } }