modelfiles: fix /save command and add shortname for safetensors based models (#15413)

This change fixes two issues with Modelfiles:

  1. If a user uses `ollama show --modelfile` to show a safetensors based
     model, the Model would leave the "FROM" field blank which won't allow
     a user to recreate the model. This change adds the model's current
     canonical short name to the FROM field.
  2. If a user uses the `/save` command in the CLI any messages which were
     saved in a previous model wouldn't get saved (only the set of messages
     from the current session).
This commit is contained in:
Patrick Devine
2026-04-08 21:05:39 -07:00
committed by GitHub
parent 6b5db12aa2
commit eb97274e5c
5 changed files with 179 additions and 45 deletions

View File

@@ -695,7 +695,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
opts.ParentModel = info.Details.ParentModel
applyShowResponseToRunOptions(&opts, info)
// Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -1411,23 +1411,30 @@ func PullHandler(cmd *cobra.Command, args []string) error {
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
Model string
ParentModel string
LoadedMessages []api.Message
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Images []api.ImageData
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
}
func (r runOptions) Copy() runOptions {
var loadedMessages []api.Message
if r.LoadedMessages != nil {
loadedMessages = make([]api.Message, len(r.LoadedMessages))
copy(loadedMessages, r.LoadedMessages)
}
var messages []api.Message
if r.Messages != nil {
messages = make([]api.Message, len(r.Messages))
@@ -1455,23 +1462,29 @@ func (r runOptions) Copy() runOptions {
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
Prompt: r.Prompt,
Messages: messages,
WordWrap: r.WordWrap,
Format: r.Format,
System: r.System,
Images: images,
Options: opts,
MultiModal: r.MultiModal,
KeepAlive: r.KeepAlive,
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
Model: r.Model,
ParentModel: r.ParentModel,
LoadedMessages: loadedMessages,
Prompt: r.Prompt,
Messages: messages,
WordWrap: r.WordWrap,
Format: r.Format,
System: r.System,
Images: images,
Options: opts,
MultiModal: r.MultiModal,
KeepAlive: r.KeepAlive,
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
}
}
func applyShowResponseToRunOptions(opts *runOptions, info *api.ShowResponse) {
opts.ParentModel = info.Details.ParentModel
opts.LoadedMessages = slices.Clone(info.Messages)
}
type displayResponseState struct {
lineLength int
wordBuffer string

View File

@@ -1655,6 +1655,24 @@ func TestNewCreateRequest(t *testing.T) {
},
},
},
{
"loaded messages are preserved when saving",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}},
Messages: []api.Message{{Role: "user", Content: "new"}},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Messages: []api.Message{
{Role: "assistant", Content: "loaded"},
{Role: "user", Content: "new"},
},
},
},
}
for _, tt := range tests {
@@ -1667,15 +1685,43 @@ func TestNewCreateRequest(t *testing.T) {
}
}
func TestApplyShowResponseToRunOptions(t *testing.T) {
opts := runOptions{}
info := &api.ShowResponse{
Details: api.ModelDetails{
ParentModel: "parentmodel",
},
Messages: []api.Message{
{Role: "assistant", Content: "loaded"},
},
}
applyShowResponseToRunOptions(&opts, info)
if opts.ParentModel != "parentmodel" {
t.Fatalf("ParentModel = %q, want %q", opts.ParentModel, "parentmodel")
}
if !cmp.Equal(opts.LoadedMessages, info.Messages) {
t.Fatalf("LoadedMessages = %#v, want %#v", opts.LoadedMessages, info.Messages)
}
info.Messages[0].Content = "modified"
if opts.LoadedMessages[0].Content == "modified" {
t.Fatal("LoadedMessages should be copied independently from ShowResponse")
}
}
func TestRunOptions_Copy(t *testing.T) {
// Setup test data
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
originalThink := &api.ThinkValue{Value: "test reasoning"}
original := runOptions{
Model: "test-model",
ParentModel: "parent-model",
Prompt: "test prompt",
Model: "test-model",
ParentModel: "parent-model",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded hello"}},
Prompt: "test prompt",
Messages: []api.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi there"},
@@ -1715,6 +1761,7 @@ func TestRunOptions_Copy(t *testing.T) {
}{
{"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel},
{"LoadedMessages", copied.LoadedMessages, original.LoadedMessages},
{"Prompt", copied.Prompt, original.Prompt},
{"WordWrap", copied.WordWrap, original.WordWrap},
{"Format", copied.Format, original.Format},
@@ -1819,13 +1866,18 @@ func TestRunOptions_Copy(t *testing.T) {
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
// Test with empty slices and maps
original := runOptions{
Messages: []api.Message{},
Images: []api.ImageData{},
Options: map[string]any{},
LoadedMessages: []api.Message{},
Messages: []api.Message{},
Images: []api.ImageData{},
Options: map[string]any{},
}
copied := original.Copy()
if copied.LoadedMessages == nil {
t.Error("Empty LoadedMessages slice should remain empty, not nil")
}
if copied.Messages == nil {
t.Error("Empty Messages slice should remain empty, not nil")
}
@@ -1842,6 +1894,10 @@ func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
t.Error("Empty Messages slice should remain empty")
}
if len(copied.LoadedMessages) != 0 {
t.Error("Empty LoadedMessages slice should remain empty")
}
if len(copied.Images) != 0 {
t.Error("Empty Images slice should remain empty")
}
@@ -1987,16 +2043,20 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}
original := runOptions{
Model: "original-model",
Messages: []api.Message{{Role: "user", Content: "original"}},
Options: map[string]any{"key": "value"},
Think: originalThink,
Model: "original-model",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}},
Messages: []api.Message{{Role: "user", Content: "original"}},
Options: map[string]any{"key": "value"},
Think: originalThink,
}
copied := original.Copy()
// Modify original
original.Model = "modified-model"
if len(original.LoadedMessages) > 0 {
original.LoadedMessages[0].Content = "modified loaded"
}
if len(original.Messages) > 0 {
original.Messages[0].Content = "modified"
}
@@ -2010,6 +2070,10 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Model should not be affected by original modification")
}
if len(copied.LoadedMessages) > 0 && copied.LoadedMessages[0].Content == "modified loaded" {
t.Error("Copy LoadedMessages should not be affected by original modification")
}
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Copy Messages should not be affected by original modification")
}

View File

@@ -214,10 +214,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
origOpts := opts.Copy()
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
opts.Model = args[1]
opts.Messages = []api.Message{}
opts.LoadedMessages = nil
fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model})
if err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
@@ -226,6 +233,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
return err
}
applyShowResponseToRunOptions(&opts, info)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
@@ -561,8 +573,10 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
req.Parameters = opts.Options
}
if len(opts.Messages) > 0 {
req.Messages = opts.Messages
messages := slices.Clone(opts.LoadedMessages)
messages = append(messages, opts.Messages...)
if len(messages) > 0 {
req.Messages = messages
}
return req

View File

@@ -1308,9 +1308,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
fmt.Fprint(&sb, m.String())
modelfile := m.String()
if m.IsMLX() {
fmt.Fprintf(&sb, "FROM %s\n", m.ShortName)
if _, rest, ok := strings.Cut(modelfile, "\n"); ok {
fmt.Fprint(&sb, rest)
}
} else {
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
fmt.Fprint(&sb, modelfile)
}
resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model

View File

@@ -580,6 +580,41 @@ func TestGetModelInfo_SafetensorsUsesStoredFileType(t *testing.T) {
}
}
func TestGetModelInfo_SafetensorsModelfileUsesShortName(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cfgData, err := json.Marshal(model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"completion"},
})
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(cfgData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName("show-safetensors")
if err := manifest.WriteManifest(name, configLayer, nil); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
resp, err := GetModelInfo(api.ShowRequest{Model: name.String()})
if err != nil {
t.Fatalf("GetModelInfo() error = %v", err)
}
if !strings.Contains(resp.Modelfile, "FROM show-safetensors:latest\n") {
t.Fatalf("Modelfile = %q, want FROM show-safetensors:latest", resp.Modelfile)
}
if strings.Contains(resp.Modelfile, "# To build a new Modelfile based on this, replace FROM with:") {
t.Fatalf("Modelfile should not include replacement hint: %q", resp.Modelfile)
}
}
func casingShuffle(s string) string {
rr := []rune(s)
for i := range rr {