Compare commits

...

21 Commits

Author SHA1 Message Date
Michael Yang
dc474f9b83 handle intermediate blobs 2024-05-02 17:05:49 -07:00
Michael Yang
41ae232e10 split model layer into metadata and data layers 2024-05-02 17:05:49 -07:00
Michael Yang
122b35c784 s/DisplayLongest/String/ 2024-05-02 17:05:26 -07:00
Michael Yang
3244a25c79 only quantize language models 2024-05-02 17:05:26 -07:00
Michael Yang
b535afe35c no iterator 2024-05-02 17:05:26 -07:00
Michael Yang
fd071eab8b rebase 2024-05-02 17:05:26 -07:00
Michael Yang
da0bb5d772 comments 2024-05-02 17:05:26 -07:00
Michael Yang
1909e624ce update tests 2024-05-02 17:05:26 -07:00
Michael Yang
1d8c850f38 quantize any fp16/fp32 model
- FROM /path/to/{safetensors,pytorch}
- FROM /path/to/fp{16,32}.bin
- FROM model:fp{16,32}
2024-05-02 17:05:26 -07:00
Michael Yang
e9ae607ece Merge pull request #3892 from ollama/mxyng/parser
refactor modelfile parser
2024-05-02 17:04:47 -07:00
Michael Yang
93707fa3f2 Merge pull request #4108 from ollama/mxyng/lf
fix line ending
2024-05-02 14:55:15 -07:00
Michael Yang
94c369095f fix line ending
replace CRLF with LF
2024-05-02 14:53:13 -07:00
Michael Yang
5ea844964e cmd: import regexp 2024-05-01 09:53:45 -07:00
Michael Yang
bd8eed57fc fix parser name 2024-05-01 09:52:54 -07:00
Michael Yang
9cf0f2e973 use parser.Format instead of templating modelfile 2024-05-01 09:52:54 -07:00
Michael Yang
176ad3aa6e parser: add commands format 2024-05-01 09:52:54 -07:00
Michael Yang
4d08363580 comments 2024-05-01 09:52:54 -07:00
Michael Yang
8907bf51d2 fix multiline 2024-05-01 09:52:54 -07:00
Michael Yang
abe614c705 tests 2024-05-01 09:52:54 -07:00
Michael Yang
238715037d linting 2024-05-01 09:52:54 -07:00
Michael Yang
c0a00f68ae refactor modelfile parser 2024-05-01 09:52:54 -07:00
20 changed files with 1649 additions and 873 deletions

View File

@@ -57,12 +57,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
modelfile, err := os.ReadFile(filename) modelfile, err := os.Open(filename)
if err != nil { if err != nil {
return err return err
} }
defer modelfile.Close()
commands, err := parser.Parse(bytes.NewReader(modelfile)) commands, err := parser.Parse(modelfile)
if err != nil { if err != nil {
return err return err
} }
@@ -76,10 +77,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
for _, c := range commands { for i := range commands {
switch c.Name { switch commands[i].Name {
case "model", "adapter": case "model", "adapter":
path := c.Args path := commands[i].Args
if path == "~" { if path == "~" {
path = home path = home
} else if strings.HasPrefix(path, "~/") { } else if strings.HasPrefix(path, "~/") {
@@ -91,7 +92,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" { if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
continue continue
} else if err != nil { } else if err != nil {
return err return err
@@ -114,13 +115,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
name := c.Name commands[i].Args = "@"+digest
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
} }
} }
@@ -150,7 +145,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization") quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil { if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err return err
} }

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
@@ -47,7 +48,7 @@ type ByteOrder interface {
type ModelArch interface { type ModelArch interface {
GetTensors() error GetTensors() error
LoadVocab() error LoadVocab() error
WriteGGUF() (string, error) WriteGGUF(io.WriteSeeker) error
} }
type ModelFormat interface { type ModelFormat interface {

View File

@@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
return nil return nil
} }
func (m *GemmaModel) WriteGGUF() (string, error) { func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "gemma", "general.architecture": "gemma",
"general.name": m.Name, "general.name": m.Name,
@@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View File

@@ -132,7 +132,7 @@ func (m *LlamaModel) LoadVocab() error {
return nil return nil
} }
func (m *LlamaModel) WriteGGUF() (string, error) { func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@@ -161,16 +161,9 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
f, err := os.CreateTemp("", "ollama-gguf") f, err := os.CreateTemp("", "ollama-gguf")
if err != nil { if err != nil {
return "", err return err
} }
defer f.Close() defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(f, kv, m.Tensors)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
return f.Name(), nil
} }

View File

@@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
return nil return nil
} }
func (m *MistralModel) WriteGGUF() (string, error) { func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View File

@@ -1,7 +1,7 @@
package convert package convert
import ( import (
"os" "io"
"regexp" "regexp"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error {
return nil return nil
} }
func (m *MixtralModel) WriteGGUF() (string, error) { func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View File

@@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error {
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
slog.Info("setting env", "OLLAMA_HOST", ollamaHost) slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
os.Setenv("OLLAMA_HOST", ollamaHost) t.Setenv("OLLAMA_HOST", ollamaHost)
} }
slog.Info("starting server", "url", ollamaHost) slog.Info("starting server", "url", ollamaHost)

140
llm/filetype.go Normal file
View File

@@ -0,0 +1,140 @@
package llm
import "fmt"
type fileType uint32
const (
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
fileTypeUnknown
)
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return fileTypeF32, nil
case "F16":
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "Q3_K_XS":
return fileTypeQ3_K_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
default:
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t fileType) String() string {
switch t {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
func (t fileType) Value() uint32 {
return uint32(t)
}

View File

@@ -33,6 +33,7 @@ func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
type ggla struct { type ggla struct {
*containerGGLA *containerGGLA
offset int64
kv KV kv KV
tensors []*Tensor tensors []*Tensor
@@ -53,6 +54,10 @@ func (llm *ggla) Tensors() Tensors {
return llm.tensors return llm.tensors
} }
func (llm *ggla) Offset() int64 {
return llm.offset
}
func (llm *ggla) decode(rs io.ReadSeeker) error { func (llm *ggla) decode(rs io.ReadSeeker) error {
var r uint32 var r uint32
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil { if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
@@ -66,6 +71,13 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
} }
llm.kv["alpha"] = alpha llm.kv["alpha"] = alpha
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
llm.offset = offset
for { for {
var dims uint32 var dims uint32
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil { if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {

View File

@@ -13,85 +13,10 @@ type GGML struct {
model model
} }
const (
fileTypeF32 uint32 = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ8_0 uint32 = iota + 2
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
)
func fileType(fileType uint32) string {
switch fileType {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
type model interface { type model interface {
KV() KV KV() KV
Tensors() Tensors Tensors() Tensors
Offset() int64
} }
type KV map[string]any type KV map[string]any
@@ -123,7 +48,7 @@ func (kv KV) ParameterCount() uint64 {
func (kv KV) FileType() string { func (kv KV) FileType() string {
if u64 := kv.u64("general.file_type"); u64 > 0 { if u64 := kv.u64("general.file_type"); u64 > 0 {
return fileType(uint32(u64)) return fileType(uint32(u64)).String()
} }
return "unknown" return "unknown"
@@ -286,6 +211,23 @@ const (
var ErrUnsupportedFormat = errors.New("unsupported model format") var ErrUnsupportedFormat = errors.New("unsupported model format")
func DetectGGMLType(b []byte) string {
switch binary.LittleEndian.Uint32(b[:4]) {
case FILE_MAGIC_GGML:
return "ggml"
case FILE_MAGIC_GGMF:
return "ggmf"
case FILE_MAGIC_GGJT:
return "ggjt"
case FILE_MAGIC_GGLA:
return "ggla"
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
return "gguf"
default:
return ""
}
}
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {

View File

@@ -55,7 +55,7 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
model := newGGUF(c) model := newGGUF(c)
slog.Debug(fmt.Sprintf("model = %#v", model)) slog.Debug(fmt.Sprintf("model = %#v", model))
if err := model.Decode(rs); err != nil { if err := model.decode(rs); err != nil {
return nil, err return nil, err
} }
@@ -90,6 +90,7 @@ const (
type gguf struct { type gguf struct {
*containerGGUF *containerGGUF
offset int64
kv KV kv KV
tensors []*Tensor tensors []*Tensor
@@ -116,6 +117,10 @@ func (llm *gguf) Tensors() Tensors {
return llm.tensors return llm.tensors
} }
func (llm *gguf) Offset() int64 {
return llm.offset
}
func (llm *gguf) numTensor() uint64 { func (llm *gguf) numTensor() uint64 {
switch llm.Version { switch llm.Version {
case 1: case 1:
@@ -138,7 +143,7 @@ func (llm *gguf) numKV() uint64 {
} }
} }
func (llm *gguf) Decode(rs io.ReadSeeker) error { func (llm *gguf) decode(rs io.ReadSeeker) error {
// decode key-values // decode key-values
for i := 0; uint64(i) < llm.numKV(); i++ { for i := 0; uint64(i) < llm.numKV(); i++ {
k, err := readGGUFString(llm, rs) k, err := readGGUFString(llm, rs)
@@ -250,6 +255,8 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
return err return err
} }
llm.offset = offset + padding
for _, tensor := range llm.tensors { for _, tensor := range llm.tensors {
if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil { if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil {
return err return err

View File

@@ -20,7 +20,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile, filetype string) error { func Quantize(infile, outfile string, ftype fileType) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
@@ -29,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value()
switch filetype { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
case "F32": return fmt.Errorf("llama_model_quantize: %d", rc)
params.ftype = fileTypeF32
case "F16":
params.ftype = fileTypeF16
case "Q4_0":
params.ftype = fileTypeQ4_0
case "Q4_1":
params.ftype = fileTypeQ4_1
case "Q4_1_F16":
params.ftype = fileTypeQ4_1_F16
case "Q8_0":
params.ftype = fileTypeQ8_0
case "Q5_0":
params.ftype = fileTypeQ5_0
case "Q5_1":
params.ftype = fileTypeQ5_1
case "Q2_K":
params.ftype = fileTypeQ2_K
case "Q3_K_S":
params.ftype = fileTypeQ3_K_S
case "Q3_K_M":
params.ftype = fileTypeQ3_K_M
case "Q3_K_L":
params.ftype = fileTypeQ3_K_L
case "Q4_K_S":
params.ftype = fileTypeQ4_K_S
case "Q4_K_M":
params.ftype = fileTypeQ4_K_M
case "Q5_K_S":
params.ftype = fileTypeQ5_K_S
case "Q5_K_M":
params.ftype = fileTypeQ5_K_M
case "Q6_K":
params.ftype = fileTypeQ6_K
case "IQ2_XXS":
params.ftype = fileTypeIQ2_XXS
case "IQ2_XS":
params.ftype = fileTypeIQ2_XS
case "Q2_K_S":
params.ftype = fileTypeQ2_K_S
case "Q3_K_XS":
params.ftype = fileTypeQ3_K_XS
case "IQ3_XXS":
params.ftype = fileTypeIQ3_XXS
default:
return fmt.Errorf("unknown filetype: %s", filetype)
}
if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
return fmt.Errorf("llama_model_quantize: %d", retval)
} }
return nil return nil

View File

@@ -6,8 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog" "strconv"
"slices" "strings"
) )
type Command struct { type Command struct {
@@ -15,118 +15,283 @@ type Command struct {
Args string Args string
} }
func (c *Command) Reset() { type state int
c.Name = ""
c.Args = ""
}
func Parse(reader io.Reader) ([]Command, error) { const (
var commands []Command stateNil state = iota
var command, modelCommand Command stateName
stateValue
stateParameter
stateMessage
stateComment
)
scanner := bufio.NewScanner(reader) var (
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) errMissingFrom = errors.New("no FROM line")
scanner.Split(scanModelfile) errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
for scanner.Scan() { errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
line := scanner.Bytes() )
fields := bytes.SplitN(line, []byte(" "), 2) func Format(cmds []Command) string {
if len(fields) == 0 || len(fields[0]) == 0 { var sb strings.Builder
continue for _, cmd := range cmds {
} name := cmd.Name
args := cmd.Args
switch string(bytes.ToUpper(fields[0])) { switch cmd.Name {
case "FROM": case "model":
command.Name = "model" name = "from"
command.Args = string(bytes.TrimSpace(fields[1])) args = cmd.Args
// copy command for validation case "license", "template", "system", "adapter":
modelCommand = command args = quote(args)
case "ADAPTER": case "message":
command.Name = string(bytes.ToLower(fields[0])) role, message, _ := strings.Cut(cmd.Args, ": ")
command.Args = string(bytes.TrimSpace(fields[1])) args = role + " " + quote(message)
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("missing value for %s", fields)
}
command.Name = string(fields[0])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { name = "parameter"
// log a warning for unknown commands args = cmd.Name + " " + quote(cmd.Args)
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
} }
fmt.Fprintln(&sb, strings.ToUpper(name), args)
}
return sb.String()
}
func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command
var curr state
var b bytes.Buffer
var role string
br := bufio.NewReader(r)
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String())
} else if err != nil {
return nil, err
}
// process the state transition, some transitions need to be intercepted and redirected
if next != curr {
switch curr {
case stateName:
if !isValidCommand(b.String()) {
return nil, errInvalidCommand
}
// next state sometimes depends on the current buffer value
switch s := strings.ToLower(b.String()); s {
case "from":
cmd.Name = "model"
case "parameter":
// transition to stateParameter which sets command name
next = stateParameter
case "message":
// transition to stateMessage which validates the message role
next = stateMessage
fallthrough
default:
cmd.Name = s
}
case stateParameter:
cmd.Name = b.String()
case stateMessage:
if !isValidMessageRole(b.String()) {
return nil, errInvalidMessageRole
}
role = b.String()
case stateComment, stateNil:
// pass
case stateValue:
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue continue
} }
commands = append(commands, command) if role != "" {
command.Reset() s = role + ": " + s
role = ""
} }
if modelCommand.Args == "" { cmd.Args = s
return nil, errors.New("no FROM line for the model was specified") cmds = append(cmds, cmd)
} }
return commands, scanner.Err() b.Reset()
curr = next
}
if strconv.IsPrint(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
}
}
// flush the buffer
switch curr {
case stateComment, stateNil:
// pass; nothing to flush
case stateValue:
s, ok := unquote(b.String())
if !ok {
return nil, io.ErrUnexpectedEOF
}
if role != "" {
s = role + ": " + s
}
cmd.Args = s
cmds = append(cmds, cmd)
default:
return nil, io.ErrUnexpectedEOF
}
for _, cmd := range cmds {
if cmd.Name == "model" {
return cmds, nil
}
}
return nil, errMissingFrom
} }
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { func parseRuneForState(r rune, cs state) (state, rune, error) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) switch cs {
if err != nil { case stateNil:
return 0, nil, err switch {
case r == '#':
return stateComment, 0, nil
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
} }
case stateName:
if advance > 0 && token != nil { switch {
return advance, token, nil case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, errInvalidCommand
} }
case stateValue:
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) switch {
if err != nil { case isNewline(r):
return 0, nil, err return stateNil, r, nil
case isSpace(r):
return stateNil, r, nil
default:
return stateValue, r, nil
} }
case stateParameter:
if advance > 0 && token != nil { switch {
return advance, token, nil case isAlpha(r), isNumber(r), r == '_':
return stateParameter, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateMessage:
switch {
case isAlpha(r):
return stateMessage, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateComment:
switch {
case isNewline(r):
return stateNil, 0, nil
default:
return stateComment, 0, nil
}
default:
return stateNil, 0, errors.New("")
} }
return bufio.ScanLines(data, atEOF)
} }
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { func quote(s string) string {
newline := bytes.IndexByte(data, '\n') if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
if strings.Contains(s, "\"") {
if start := bytes.Index(data, openBytes); start >= 0 && start < newline { return `"""` + s + `"""`
end := bytes.Index(data[start+len(openBytes):], closeBytes)
if end < 0 {
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
}
} }
n := start + len(openBytes) + end + len(closeBytes) return `"` + s + `"`
newData := data[:start]
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
return n, newData, nil
} }
return 0, nil, nil return s
}
func unquote(s string) (string, bool) {
if len(s) == 0 {
return "", false
}
// TODO: single quotes
if len(s) >= 3 && s[:3] == `"""` {
if len(s) >= 6 && s[len(s)-3:] == `"""` {
return s[3 : len(s)-3], true
}
return "", false
}
if len(s) >= 1 && s[0] == '"' {
if len(s) >= 2 && s[len(s)-1] == '"' {
return s[1 : len(s)-1], true
}
return "", false
}
return s, true
}
func isAlpha(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
}
func isNumber(r rune) bool {
return r >= '0' && r <= '9'
}
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
}
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "parameter", "message":
return true
default:
return false
}
} }

View File

@@ -1,14 +1,16 @@
package parser package parser
import ( import (
"bytes"
"fmt"
"io"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Parser(t *testing.T) { func TestParser(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
ADAPTER adapter1 ADAPTER adapter1
@@ -35,21 +37,62 @@ TEMPLATE template1
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, commands)
} }
func Test_Parser_NoFromLine(t *testing.T) { func TestParserFrom(t *testing.T) {
var cases = []struct {
input string
expected []Command
err error
}{
{
"FROM foo",
[]Command{{Name: "model", Args: "foo"}},
nil,
},
{
"FROM /path/to/model",
[]Command{{Name: "model", Args: "/path/to/model"}},
nil,
},
{
"FROM /path/to/model/fp16.bin",
[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
nil,
},
{
"FROM llama3:latest",
[]Command{{Name: "model", Args: "llama3:latest"}},
nil,
},
{
"FROM llama3:7b-instruct-q4_K_M",
[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
nil,
},
{
"", nil, errMissingFrom,
},
{
"PARAMETER param1 value1",
nil,
errMissingFrom,
},
{
"PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
nil,
},
}
input := ` for _, c := range cases {
PARAMETER param1 value1 t.Run("", func(t *testing.T) {
PARAMETER param2 value2 commands, err := Parse(strings.NewReader(c.input))
` assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
reader := strings.NewReader(input) })
}
_, err := Parse(reader)
assert.ErrorContains(t, err, "no FROM line")
} }
func Test_Parser_MissingValue(t *testing.T) { func TestParserParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
PARAMETER param1 PARAMETER param1
@@ -58,41 +101,401 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := Parse(reader) _, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
func TestParserBadCommand(t *testing.T) {
input := `
FROM foo
BADCOMMAND param1 value1
`
_, err := Parse(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand)
} }
func Test_Parser_Messages(t *testing.T) { func TestParserMessages(t *testing.T) {
var cases = []struct {
input := ` input string
expected []Command
err error
}{
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
` `,
[]Command{
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, {Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
} },
nil,
assert.Equal(t, expectedCommands, commands) },
} {
`
func Test_Parser_Messages_BadRole(t *testing.T) { FROM foo
MESSAGE system """
input := ` You are a multiline Parser. Always Parse things.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE badguy I'm a bad guy! MESSAGE badguy I'm a bad guy!
` `,
nil,
errInvalidMessageRole,
},
{
`
FROM foo
MESSAGE system
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
MESSAGE system`,
nil,
io.ErrUnexpectedEOF,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserQuoted(t *testing.T) {
var cases = []struct {
multiline string
expected []Command
err error
}{
{
`
FROM foo
SYSTEM """
This is a
multiline system.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a\nmultiline system.\n"},
},
nil,
},
{
`
FROM foo
SYSTEM """
This is a
multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a\nmultiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a
multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "This is a\nmultiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "This is a multiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a multiline system.""
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
SYSTEM "
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
SYSTEM """
This is a multiline system with "quotes".
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
},
nil,
},
{
`
FROM foo
SYSTEM """"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: ""},
},
nil,
},
{
`
FROM foo
SYSTEM ""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: ""},
},
nil,
},
{
`
FROM foo
SYSTEM "'"
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "'"},
},
nil,
},
{
`
FROM foo
SYSTEM """''"'""'""'"'''''""'""'"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: `''"'""'""'"'''''""'""'`},
},
nil,
},
{
`
FROM foo
TEMPLATE """
{{ .Prompt }}
"""`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\n{{ .Prompt }}\n"},
},
nil,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserParameters(t *testing.T) {
var cases = map[string]struct {
name, value string
}{
"numa true": {"numa", "true"},
"num_ctx 1": {"num_ctx", "1"},
"num_batch 1": {"num_batch", "1"},
"num_gqa 1": {"num_gqa", "1"},
"num_gpu 1": {"num_gpu", "1"},
"main_gpu 1": {"main_gpu", "1"},
"low_vram true": {"low_vram", "true"},
"f16_kv true": {"f16_kv", "true"},
"logits_all true": {"logits_all", "true"},
"vocab_only true": {"vocab_only", "true"},
"use_mmap true": {"use_mmap", "true"},
"use_mlock true": {"use_mlock", "true"},
"num_thread 1": {"num_thread", "1"},
"num_keep 1": {"num_keep", "1"},
"seed 1": {"seed", "1"},
"num_predict 1": {"num_predict", "1"},
"top_k 1": {"top_k", "1"},
"top_p 1.0": {"top_p", "1.0"},
"tfs_z 1.0": {"tfs_z", "1.0"},
"typical_p 1.0": {"typical_p", "1.0"},
"repeat_last_n 1": {"repeat_last_n", "1"},
"temperature 1.0": {"temperature", "1.0"},
"repeat_penalty 1.0": {"repeat_penalty", "1.0"},
"presence_penalty 1.0": {"presence_penalty", "1.0"},
"frequency_penalty 1.0": {"frequency_penalty", "1.0"},
"mirostat 1": {"mirostat", "1"},
"mirostat_tau 1.0": {"mirostat_tau", "1.0"},
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
"penalize_newline true": {"penalize_newline", "true"},
"stop ### User:": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User: "},
"stop \"### User:\"": {"stop", "### User:"},
"stop \"### User: \"": {"stop", "### User: "},
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
"stop <|endoftext|>": {"stop", "<|endoftext|>"},
"stop <|eot_id|>": {"stop", "<|eot_id|>"},
"stop </s>": {"stop", "</s>"},
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k)
commands, err := Parse(&b)
assert.Nil(t, err)
assert.Equal(t, []Command{
{Name: "model", Args: "foo"},
{Name: v.name, Args: v.value},
}, commands)
})
}
}
func TestParserComments(t *testing.T) {
var cases = []struct {
input string
expected []Command
}{
{
`
# comment
FROM foo
`,
[]Command{
{Name: "model", Args: "foo"},
},
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.Nil(t, err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParseFormatParse(t *testing.T) {
var cases = []string{
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE """
Very long and boring legal text.
Blah blah blah.
"Oh look, a quote!"
"""
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c))
assert.NoError(t, err)
commands2, err := Parse(strings.NewReader(Format(commands)))
assert.NoError(t, err)
assert.Equal(t, commands, commands2)
})
}
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
} }

View File

@@ -1,8 +1,8 @@
package server package server
import ( import (
"archive/zip"
"bytes" "bytes"
"cmp"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
@@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log" "log"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -21,13 +20,11 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"text/template"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@@ -64,6 +61,48 @@ func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
} }
func (m *Model) Commands() (cmds []parser.Command) {
cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath})
if m.Template != "" {
cmds = append(cmds, parser.Command{Name: "template", Args: m.Template})
}
if m.System != "" {
cmds = append(cmds, parser.Command{Name: "system", Args: m.System})
}
for _, adapter := range m.AdapterPaths {
cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter})
}
for _, projector := range m.ProjectorPaths {
cmds = append(cmds, parser.Command{Name: "projector", Args: projector})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
for _, s := range v {
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)})
}
}
for _, license := range m.License {
cmds = append(cmds, parser.Command{Name: "license", Args: license})
}
for _, msg := range m.Messages {
cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)})
}
return cmds
}
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
@@ -89,36 +128,6 @@ type ConfigV2 struct {
RootFS RootFS `json:"rootfs"` RootFS RootFS `json:"rootfs"`
} }
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct { type RootFS struct {
Type string `json:"type"` Type string `json:"type"`
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
@@ -199,6 +208,14 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.ParentModel = layer.From
case "application/vnd.ollama.image.model+metadata", "application/vnd.ollama.image.model+data":
filename, err = GetBlobsPath(layer.MergeBase)
if err != nil {
return nil, err
}
model.ModelPath = filename model.ModelPath = filename
model.ParentModel = layer.From model.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
@@ -263,7 +280,7 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func realpath(mfDir, from string) string { func realpath(rel, from string) string {
abspath, err := filepath.Abs(from) abspath, err := filepath.Abs(from)
if err != nil { if err != nil {
return from return from
@@ -280,22 +297,15 @@ func realpath(mfDir, from string) string {
return filepath.Join(home, from[2:]) return filepath.Join(home, from[2:])
} }
if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
// this is a file relative to the Modelfile // this is a file relative to the Modelfile
return filepath.Join(mfDir, from) return filepath.Join(rel, from)
} }
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) (err error) {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@@ -304,250 +314,222 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
}, },
} }
var layers Layers var messages []*api.Message
messages := []string{} parameters := make(map[string]any)
params := make(map[string][]string)
fromParams := make(map[string]any)
var layers []*Layer
for _, c := range commands { for _, c := range commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model", "adapter":
if strings.HasPrefix(c.Args, "@") { var baseLayers []*layerWithGGML
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
return err
}
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil { if err != nil {
return err return err
} }
c.Args = blobPath blob, err := os.Open(blobpath)
}
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertModel(name, pathName, fn)
if err != nil {
var pathErr *fs.PathError
switch {
case errors.Is(err, zip.ErrFormat):
// it's not a safetensor archive
case errors.As(err, &pathErr):
// it's not a file on disk, could be a model reference
default:
return err
}
}
if ggufName != "" {
pathName = ggufName
defer os.RemoveAll(ggufName)
if quantization != "" {
quantization = strings.ToUpper(quantization)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(tempfile.Name()) defer blob.Close()
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil { baseLayers, err = parseFromFile(ctx, blob, fn)
return err
}
if err := tempfile.Close(); err != nil {
return err
}
pathName = tempfile.Name()
}
}
bin, err := os.Open(pathName)
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
return err
}
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return err return err
} }
case err != nil: } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
defer file.Close()
baseLayers, err = parseFromFile(ctx, file, fn)
if err != nil {
return err return err
} }
} else {
return fmt.Errorf("invalid model reference: %s", c.Args)
}
fn(api.ProgressResponse{Status: "reading model metadata"}) for _, baseLayer := range baseLayers {
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) if quantization != "" && baseLayer.MediaType == "application/vnd.ollama.image.model" {
ftype, err := llm.ParseFileType(quantization)
if err != nil { if err != nil {
return err return err
} }
fromConfigFile, err := os.Open(fromConfigPath) filetype := baseLayer.GGML.KV().FileType()
if err != nil { if !slices.Contains([]string{"F16", "F32"}, filetype) {
return err return errors.New("quantization is only supported for F16 and F32 models")
}
defer fromConfigFile.Close()
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err
} }
// if the model is still not in gguf format, error out fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)})
if fromConfig.ModelFormat != "gguf" {
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
}
config.SetModelFormat(fromConfig.ModelFormat) blob, err := GetBlobsPath(baseLayer.Digest)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
return err return err
} }
fromParamsFile, err := os.Open(fromParamsPath) temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
if err != nil { if err != nil {
return err return err
} }
defer fromParamsFile.Close() defer temp.Close()
defer os.Remove(temp.Name())
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { if err := llm.Quantize(blob, temp.Name(), ftype); err != nil {
return err
}
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
return err return err
} }
} }
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) if baseLayer.GGML != nil {
config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
f, err := baseLayer.Layer.Open()
if err != nil {
return err
}
defer f.Close()
metadata := io.NewSectionReader(f, 0, baseLayer.GGML.Offset())
metadataLayer, err := NewLayer(metadata, "application/vnd.ollama.image.model+metadata")
if err != nil {
return err
}
metadataLayer.Intermediate = true
metadataLayer.MergeBase = baseLayer.Digest
layers = append(layers, metadataLayer)
metadataPath, err := GetBlobsPath(metadataLayer.Digest)
if err != nil {
return err
}
defer os.Remove(metadataPath)
stat, err := f.Stat()
if err != nil { if err != nil {
return err return err
} }
layers.Add(layer) data := io.NewSectionReader(f, baseLayer.GGML.Offset(), stat.Size())
dataLayer, err := NewLayer(data, "application/vnd.ollama.image.model+data")
if err != nil {
return err
} }
dataLayer.Intermediate = true
dataLayer.MergeBase = baseLayer.Digest
deleteMap[manifest.Config.Digest] = struct{}{} layers = append(layers, dataLayer)
dataPath, err := GetBlobsPath(dataLayer.Digest)
if err != nil {
return err
}
defer os.Remove(dataPath)
continue continue
} }
defer bin.Close()
var offset int64 layers = append(layers, baseLayer.Layer)
for {
fn(api.ProgressResponse{Status: "creating model layer"})
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
return err
} }
case "license", "template", "system":
ggml, size, err := llm.DecodeGGML(bin) blob := strings.NewReader(c.Args)
if errors.Is(err, io.EOF) { layer, err := NewLayer(blob, mediatype)
break
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.KV().Architecture())
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
config.SetFileType(ggml.KV().FileType())
mediatype := mediatype
if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, size)
layer, err := NewLayer(sr, mediatype)
if err != nil { if err != nil {
return err return err
} }
layers.Add(layer) if c.Name != "license" {
// replace
offset += size layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
} return layer.MediaType == mediatype
case "adapter": })
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
} }
c.Args = blobPath layers = append(layers, layer)
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
return err
}
defer bin.Close()
_, size, err := llm.DecodeGGML(bin)
if err != nil {
return err
}
sr := io.NewSectionReader(bin, 0, size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Replace(layer)
case "message": case "message":
messages = append(messages, c.Args) role, content, ok := strings.Cut(c.Args, ": ")
default: if !ok {
params[c.Name] = append(params[c.Name], c.Args) return fmt.Errorf("invalid message: %s", c.Args)
} }
messages = append(messages, &api.Message{Role: role, Content: content})
default:
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil {
return err
}
for k, v := range ps {
if ks, ok := parameters[k].([]string); ok {
parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
parameters[k] = vs
} else {
parameters[k] = v
}
}
}
}
var err2 error
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
switch layer.MediaType {
case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones
if len(messages) > 0 {
return true
}
return false
case "application/vnd.ollama.image.params":
// merge inherited parameters with new ones
r, err := layer.Open()
if err != nil {
err2 = err
return false
}
defer r.Close()
var ps map[string]any
if err := json.NewDecoder(r).Decode(&ps); err != nil {
err2 = err
return false
}
for k, v := range ps {
if _, ok := parameters[k]; !ok {
parameters[k] = v
}
}
return true
default:
return false
}
})
if err2 != nil {
return err2
} }
if len(messages) > 0 { if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil { if err := json.NewEncoder(&b).Encode(messages); err != nil {
return err return err
} }
@@ -556,39 +538,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err return err
} }
layers.Replace(layer) layers = append(layers, layer)
}
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v
}
} }
if len(parameters) > 0 {
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { if err := json.NewEncoder(&b).Encode(parameters); err != nil {
return err return err
} }
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := NewLayer(&b, "application/vnd.ollama.image.params") layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil { if err != nil {
return err return err
} }
layers.Replace(layer) layers = append(layers, layer)
} }
digests := make([]string, len(layers.items)) digests := make([]string, len(layers))
for i, layer := range layers.items { for i, layer := range layers {
digests[i] = layer.Digest digests[i] = layer.Digest
} }
@@ -599,36 +567,38 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err return err
} }
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil { if err != nil {
return err return err
} }
delete(deleteMap, configLayer.Digest) for _, layer := range append(layers, layer) {
if layer.message != "" {
for _, layer := range append(layers.items, configLayer) { fn(api.ProgressResponse{Status: layer.message})
committed, err := layer.Commit() }
if err != nil {
return err
} }
status := "writing layer" unref := make(map[string]struct{})
if !committed { if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
status = "using already created layer" for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
} }
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)}) if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
delete(deleteMap, layer.Digest) }
} }
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, configLayer, layers.items); err != nil { if err := WriteManifest(name, layer, layers); err != nil {
return err return err
} }
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if os.Getenv("OLLAMA_NOPRUNE") == "" && len(unref) > 0 {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { fn(api.ProgressResponse{Status: "removing unused layers"})
if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err return err
} }
} }
@@ -637,74 +607,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return nil return nil
} }
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
r, err := zip.OpenReader(path)
if err != nil {
return "", err
}
defer r.Close()
tempDir, err := os.MkdirTemp("", "ollama-convert")
if err != nil {
return "", err
}
defer os.RemoveAll(tempDir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
fpath := filepath.Join(tempDir, f.Name)
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return "", err
}
rc, err := f.Open()
if err != nil {
return "", err
}
_, err = io.Copy(outFile, rc)
if err != nil {
return "", err
}
outFile.Close()
rc.Close()
}
mf, err := convert.GetModelFormat(tempDir)
if err != nil {
return "", err
}
params, err := mf.GetParams(tempDir)
if err != nil {
return "", err
}
mArch, err := mf.GetModelArch(name, tempDir, params)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return "", err
}
if err := mArch.LoadVocab(); err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "converting model"})
path, err = mArch.WriteGGUF()
if err != nil {
return "", err
}
return path, nil
}
func CopyModel(src, dst model.Name) error { func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() { if !dst.IsFullyQualified() {
return model.Unqualified(dst) return model.Unqualified(dst)
@@ -774,6 +676,9 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{},
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest) delete(deleteMap, layer.Digest)
if layer.MergeBase != "" {
delete(deleteMap, layer.MergeBase)
}
} }
delete(deleteMap, manifest.Config.Digest) delete(deleteMap, manifest.Config.Digest)
@@ -880,6 +785,9 @@ func DeleteModel(name string) error {
deleteMap := make(map[string]struct{}) deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{} deleteMap[layer.Digest] = struct{}{}
if layer.MergeBase != "" {
deleteMap[layer.MergeBase] = struct{}{}
}
} }
deleteMap[manifest.Config.Digest] = struct{}{} deleteMap[manifest.Config.Digest] = struct{}{}
@@ -901,67 +809,6 @@ func DeleteModel(name string) error {
return nil return nil
} }
func ShowModelfile(model *Model) (string, error) {
var mt struct {
*Model
From string
Parameters map[string][]any
}
mt.Parameters = make(map[string][]any)
for k, v := range model.Options {
if s, ok := v.([]any); ok {
mt.Parameters[k] = s
continue
}
mt.Parameters[k] = []any{v}
}
mt.Model = model
mt.From = model.ModelPath
if model.ParentModel != "" {
mt.From = model.ParentModel
}
modelFile := `# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM {{ .ShortName }}
FROM {{ .From }}
TEMPLATE """{{ .Template }}"""
{{- if .System }}
SYSTEM """{{ .System }}"""
{{- end }}
{{- range $adapter := .AdapterPaths }}
ADAPTER {{ $adapter }}
{{- end }}
{{- range $k, $v := .Parameters }}
{{- range $parameter := $v }}
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
{{- end }}
{{- end }}`
tmpl, err := template.New("").Parse(modelFile)
if err != nil {
slog.Info(fmt.Sprintf("error parsing template: %q", err))
return "", err
}
var buf bytes.Buffer
if err = tmpl.Execute(&buf, mt); err != nil {
slog.Info(fmt.Sprintf("error executing template: %q", err))
return "", err
}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -980,6 +827,49 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
for _, layer := range layers {
if !layer.Intermediate {
continue
}
switch layer.MediaType {
case "application/vnd.ollama.image.model+metadata", "application/vnd.ollama.image.model+data":
if _, err := GetBlobsPath(layer.MergeBase); errors.Is(err, os.ErrNotExist) {
filename, err := GetBlobsPath(layer.MergeBase)
if err != nil {
return err
}
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
ggml, size, err := llm.DecodeGGML(f)
if err != nil {
return err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
return err
}
metadata := io.NewSectionReader(f, 0, ggml.Offset())
if _, err := NewLayer(metadata, "application/vnd.ollama.image.model+metadata"); err != nil {
return err
}
data := io.NewSectionReader(f, ggml.Offset(), size)
if _, err := NewLayer(data, "application/vnd.ollama.image.model+metadata"); err != nil {
return err
}
} else if err != nil {
return err
}
}
}
for _, layer := range layers { for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err)) slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -1049,6 +939,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
for _, layer := range layers { for _, layer := range layers {
if layer.Intermediate {
filename, err := GetBlobsPath(layer.MergeBase)
if err != nil {
return err
}
if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
// pass
} else if err != nil {
return err
} else {
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", layer.Digest[7:19]),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
continue
}
}
if err := downloadBlob( if err := downloadBlob(
ctx, ctx,
downloadOpts{ downloadOpts{
@@ -1063,9 +974,59 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
} }
delete(deleteMap, manifest.Config.Digest) delete(deleteMap, manifest.Config.Digest)
type mergedLayer struct {
Metadata, Data *Layer
}
mergedLayers := make(map[string]mergedLayer)
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.MergeBase)
if err != nil {
return err
}
if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
merged := mergedLayers[layer.MergeBase]
if layer.MediaType == "application/vnd.ollama.image.model+metadata" {
merged.Metadata = layer
} else if layer.MediaType == "application/vnd.ollama.image.model+data" {
merged.Data = layer
} else {
continue
}
mergedLayers[layer.MergeBase] = merged
} else if err != nil {
return err
} else {
continue
}
}
for _, mergedLayer := range mergedLayers {
fn(api.ProgressResponse{Status: "merging layers"})
metadata, err := mergedLayer.Metadata.Open()
if err != nil {
return err
}
defer metadata.Close()
data, err := mergedLayer.Data.Open()
if err != nil {
return err
}
defer data.Close()
if _, err := NewLayer(io.MultiReader(metadata, data), "application/vnd.ollama.image.model"); err != nil {
return err
}
}
fn(api.ProgressResponse{Status: "verifying sha256 digest"}) fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers { for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil { if err := verifyBlob(layer.Digest); errors.Is(err, os.ErrNotExist) && layer.Intermediate {
// pass
} else if err != nil {
if errors.Is(err, errDigestMismatch) { if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob // something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest) fp, err := GetBlobsPath(layer.Digest)

View File

@@ -5,39 +5,18 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"golang.org/x/exp/slices"
) )
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct { type Layer struct {
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Digest string `json:"digest"` Digest string `json:"digest"`
Size int64 `json:"size"` Size int64 `json:"size"`
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
tempFileName string Intermediate bool `json:"intermediate,omitempty"`
MergeBase string `json:"merge_base,omitempty"`
message string
} }
func NewLayer(r io.Reader, mediatype string) (*Layer, error) { func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
@@ -46,14 +25,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err return nil, err
} }
const delimiter = "-" temp, err := os.CreateTemp(blobs, "sha256-")
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name())
sha256sum := sha256.New() sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
@@ -61,11 +38,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err return nil, err
} }
if err := temp.Close(); err != nil {
return nil, err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
status := "using existing layer"
if _, err := os.Stat(blob); err != nil {
status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil {
return nil, err
}
}
return &Layer{ return &Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), Digest: digest,
Size: n, Size: n,
tempFileName: temp.Name(), message: fmt.Sprintf("%s %s", status, digest),
}, nil }, nil
} }
@@ -85,21 +80,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
Digest: digest, Digest: digest,
Size: fi.Size(), Size: fi.Size(),
From: from, From: from,
message: fmt.Sprintf("using existing layer %s", digest),
}, nil }, nil
} }
func (l *Layer) Commit() (bool, error) { func (l *Layer) Open() (*os.File, error) {
// always remove temp
defer os.Remove(l.tempFileName)
blob, err := GetBlobsPath(l.Digest) blob, err := GetBlobsPath(l.Digest)
if err != nil { if err != nil {
return false, err return nil, err
} }
if _, err := os.Stat(blob); err != nil { return os.Open(blob)
return true, os.Rename(l.tempFileName, blob)
}
return false, nil
} }

259
server/model.go Normal file
View File

@@ -0,0 +1,259 @@
package server
import (
"archive/zip"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)
type layerWithGGML struct {
*Layer
*llm.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
blob, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
default:
layers = append(layers, &layerWithGGML{layer, nil})
}
}
return layers, nil
}
func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
}
r, err := zip.NewReader(file, stat.Size())
if err != nil {
return nil, err
}
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempdir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
// TODO(mxyng): this should not write out all files to disk
outfile, err := os.Create(filepath.Join(tempdir, f.Name))
if err != nil {
return nil, err
}
infile, err := f.Open()
if err != nil {
return nil, err
}
if _, err = io.Copy(outfile, infile); err != nil {
return nil, err
}
if err := outfile.Close(); err != nil {
return nil, err
}
if err := infile.Close(); err != nil {
return nil, err
}
}
mf, err := convert.GetModelFormat(tempdir)
if err != nil {
return nil, err
}
params, err := mf.GetParams(tempdir)
if err != nil {
return nil, err
}
mArch, err := mf.GetModelArch("", tempdir, params)
if err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return nil, err
}
if err := mArch.LoadVocab(); err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "converting model"})
// TODO(mxyng): this should write directly into a layer
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
temp, err := os.CreateTemp(tempdir, "fp16")
if err != nil {
return nil, err
}
defer temp.Close()
defer os.Remove(temp.Name())
if err = mArch.WriteGGUF(temp); err != nil {
return nil, err
}
if _, err := temp.Seek(0, io.SeekStart); err != nil {
return nil, err
}
layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
if err != nil {
return nil, fmt.Errorf("aaa: %w", err)
}
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
bin, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer bin.Close()
ggml, _, err := llm.DecodeGGML(bin)
if err != nil {
return nil, err
}
layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
return layers, nil
}
func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
return nil, err
}
switch contentType {
case "gguf", "ggla":
// noop
case "application/zip":
return parseFromZipFile(ctx, file, fn)
default:
return nil, fmt.Errorf("unsupported content type: %s", contentType)
}
stat, err := file.Stat()
if err != nil {
return nil, err
}
var offset int64
for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if ggml.Name() == "ggla" {
mediatype = "application/vnd.ollama.image.adapter"
} else if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
offset = n
}
return layers, nil
}
func detectContentType(r io.Reader) (string, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return "", err
}
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
return contentType, nil
}
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
return contentType, nil
}
return "unknown", nil
}

View File

@@ -580,7 +580,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil { if err := CreateModel(ctx, model, filepath.Dir(req.Path), strings.ToUpper(req.Quantization), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
} }
mf, err := ShowModelfile(model) var sb strings.Builder
if err != nil { fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
return nil, err fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
} fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprint(&sb, parser.Format(model.Commands()))
resp.Modelfile = mf resp.Modelfile = sb.String()
return resp, nil return resp, nil
} }
@@ -872,11 +872,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return return
} }
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }

View File

@@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/create", Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
f, err := os.CreateTemp(t.TempDir(), "ollama-model") fname := createTestFile(t, "ollama-model")
assert.Nil(t, err)
defer f.Close()
stream := false stream := false
createReq := api.CreateRequest{ createReq := api.CreateRequest{
Name: "t-bone", Name: "t-bone",
Modelfile: fmt.Sprintf("FROM %s", f.Name()), Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream, Stream: &stream,
} }
jsonData, err := json.Marshal(createReq) jsonData, err := json.Marshal(createReq)
@@ -216,13 +214,10 @@ func Test_Routes(t *testing.T) {
httpSrv := httptest.NewServer(router) httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close) t.Cleanup(httpSrv.Close)
workDir, err := os.MkdirTemp("", "ollama-test") t.Setenv("OLLAMA_MODELS", t.TempDir())
assert.Nil(t, err)
defer os.RemoveAll(workDir)
os.Setenv("OLLAMA_MODELS", workDir)
for _, tc := range testCases { for _, tc := range testCases {
t.Logf("Running Test: [%s]", tc.Name) t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err) assert.Nil(t, err)
@@ -238,6 +233,6 @@ func Test_Routes(t *testing.T) {
if tc.Expected != nil { if tc.Expected != nil {
tc.Expected(t, resp) tc.Expected(t, resp)
} }
})
} }
} }