Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Yang
9ef2106b47 cmd: create blob in parallel with checksum
a simple optimisation where once a blob has been checksumed, immediately
upload it; don't wait for all files to be checksumed before starting
upload.
2026-01-20 12:09:02 -08:00
5 changed files with 246 additions and 355 deletions

View File

@@ -43,7 +43,6 @@ import (
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd" xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/create"
@@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if err != nil { if err != nil {
return err return err
} }
spinner.Stop()
req.Model = modelName req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
@@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
var g errgroup.Group var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) g.SetLimit(runtime.GOMAXPROCS(0))
for blob, err := range createBlobs(req.Files, req.Adapters) {
files := syncmap.NewSyncMap[string, string]() if err != nil {
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error { g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { _, err := createBlob(cmd, client, blob.Abs, blob.Digest, p)
return err return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
}) })
if _, ok := req.Files[blob.Rel]; ok {
req.Files[blob.Rel] = blob.Digest
} else if _, ok := req.Adapters[blob.Rel]; ok {
req.Adapters[blob.Rel] = blob.Digest
}
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err return err
} }
req.Files = files.Items() spinner.Stop()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
@@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct { type progressWriter struct {
n atomic.Int64 n atomic.Int64
} }

103
cmd/create.go Normal file
View File

@@ -0,0 +1,103 @@
package cmd
import (
"crypto/sha256"
"fmt"
"io"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
type blob struct {
Rel, Abs, Digest string
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] {
return func(yield func(blob, error) bool) {
for _, mapping := range mappings {
for rel, abs := range mapping {
if abs, ok := strings.CutPrefix(abs, "abs:"); ok {
f, err := os.Open(abs)
if err != nil {
yield(blob{}, err)
return
}
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
yield(blob{}, err)
return
}
if err := f.Close(); err != nil {
yield(blob{}, err)
return
}
if !yield(blob{
Rel: rel,
Abs: abs,
Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)),
}, nil) {
return
}
}
}
}
}
}

View File

@@ -3,22 +3,20 @@ package parser
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/sha256"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"maps"
"net/http" "net/http"
"os" "os"
"os/user" "os/user"
"path/filepath" "path/filepath"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode" "golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform" "golang.org/x/text/transform"
@@ -54,7 +52,10 @@ var deprecatedParameters = []string{
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile // CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) { func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
req := &api.CreateRequest{} req := &api.CreateRequest{
Files: make(map[string]string),
Adapters: make(map[string]string),
}
var messages []api.Message var messages []api.Message
var licenses []string var licenses []string
@@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands { for _, c := range f.Commands {
switch c.Name { switch c.Name {
case "model": case "model":
path, err := expandPath(c.Args, relativeDir) files, err := filesMap(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
req.From = c.Args req.From = c.Args
continue continue
@@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return nil, err return nil, err
} }
if req.Files == nil { maps.Copy(req.Files, files)
req.Files = digestMap
} else {
for k, v := range digestMap {
req.Files[k] = v
}
}
case "adapter": case "adapter":
path, err := expandPath(c.Args, relativeDir) files, err := filesMap(c.Args, relativeDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
digestMap, err := fileDigestMap(path) maps.Copy(req.Adapters, files)
if err != nil {
return nil, err
}
req.Adapters = digestMap
case "template": case "template":
req.Template = c.Args req.Template = c.Args
case "system": case "system":
@@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return req, nil return req, nil
} }
func fileDigestMap(path string) (map[string]string, error) { func filesMap(args, base string) (map[string]string, error) {
fl := make(map[string]string) path, err := expandPath(args, base)
if err != nil {
return nil, err
}
fi, err := os.Stat(path) fi, err := os.Stat(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var files []string mapping := make(map[string]string)
if fi.IsDir() { if !fi.IsDir() {
fs, err := filesForModel(path) return map[string]string{
filepath.Base(path): "abs:" + path,
}, nil
}
root, err := os.OpenRoot(path)
if err != nil {
return nil, err
}
defer root.Close()
files, err := filesForModel(root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, f := range fs { for _, file := range files {
f, err := filepath.EvalSymlinks(f) // create a temporary mapping from relative path to absolute path
if err != nil { mapping[file] = "abs:" + filepath.Join(root.Name(), file)
return nil, err
} }
rel, err := filepath.Rel(path, f) return mapping, nil
if err != nil {
return nil, err
} }
if !filepath.IsLocal(rel) { func filesForModel(root *os.Root) ([]string, error) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else {
files = []string{path}
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return fl, nil
}
func digestForFile(filename string) (string, error) {
filepath, err := filepath.EvalSymlinks(filename)
if err != nil {
return "", err
}
bin, err := os.Open(filepath)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
detectContentType := func(path string) (string, error) { detectContentType := func(path string) (string, error) {
f, err := os.Open(path) f, err := root.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer f.Close() defer f.Close()
var b bytes.Buffer bts := make([]byte, 512)
b.Grow(512) n, err := io.ReadFull(f, bts)
if errors.Is(err, io.ErrUnexpectedEOF) {
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { // short read, use what we have
bts = bts[:n]
} else if err != nil {
return "", err return "", err
} }
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil return contentType, nil
} }
glob := func(pattern, contentType string) ([]string, error) { glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern) matches, err := fs.Glob(root.FS(), pattern)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) {
if ct, err := detectContentType(match); err != nil { if ct, err := detectContentType(match); err != nil {
return nil, err return nil, err
} else if len(contentType) > 0 && ct != contentType { } else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType)
} }
} }
@@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) {
var files []string var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 { if st, _ := glob("model*.safetensors", ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are // safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...) files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 { } else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 {
// covers consolidated.safetensors // covers consolidated.safetensors
files = append(files, st...) files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { } else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...) files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { } else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth // covers consolidated.x.pth, consolidated.pth
files = append(files, pt...) files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 { } else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf // covers gguf files ending in .gguf
files = append(files, gg...) files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 { } else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin // covers gguf files ending in .bin
files = append(files, gg...) files = append(files, gg...)
} else { } else {
@@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) {
} }
// add configuration files, json files are detected as text/plain // add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain") js, err := glob("*.json", "text/plain")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
// bert models require a nested config.json // bert models require a nested config.json
// TODO(mxyng): merge this with the glob above // TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") js, err = glob("**/*.json", "text/plain")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) {
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob) // add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
// tokenizer.model might be a unresolved git lfs reference; error if it is // tokenizer.model might be a unresolved git lfs reference; error if it is
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
files = append(files, tks...) files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { } else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...) files = append(files, tks...)
} }

View File

@@ -2,7 +2,6 @@ package parser
import ( import (
"bytes" "bytes"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -15,6 +14,7 @@ import (
"unicode/utf16" "unicode/utf16"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/text/encoding" "golang.org/x/text/encoding"
@@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you?
t.Error(err) t.Error(err)
} }
if diff := cmp.Diff(actual, c.expected); diff != "" { if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
} }
} }
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) { func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string {
t.Helper()
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf") f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
@@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
if err := ggml.WriteGGUF(f, base, ti); err != nil { if err := ggml.WriteGGUF(f, base, ti); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Calculate sha256 of file return f.Name()
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := getSHA256Digest(t, f)
return f.Name(), digest
} }
func TestCreateRequestFiles(t *testing.T) { func TestCreateRequestFiles(t *testing.T) {
n1, d1 := createBinFile(t, nil, nil) n1 := createBinFile(t, nil, nil)
n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil) n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
cases := []struct { cases := []struct {
input string input string
@@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) {
}{ }{
{ {
fmt.Sprintf("FROM %s", n1), fmt.Sprintf("FROM %s", n1),
&api.CreateRequest{Files: map[string]string{n1: d1}}, &api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
},
},
}, },
{ {
fmt.Sprintf("FROM %s\nFROM %s", n1, n2), fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
&api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}}, &api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
filepath.Base(n2): "abs:" + n2,
},
},
}, },
} }
@@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) {
t.Error(err) t.Error(err)
} }
if diff := cmp.Diff(actual, c.expected); diff != "" { if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
} }
@@ -861,14 +851,14 @@ func TestCreateRequestFiles(t *testing.T) {
func TestFilesForModel(t *testing.T) { func TestFilesForModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setup func(string) error setup func(*testing.T, *os.Root)
wantFiles []string want []string
wantErr bool wantErr error
expectErrType error
}{ }{
{ {
name: "safetensors model files", name: "safetensors model files",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
t.Helper()
files := []string{ files := []string{
"model-00001-of-00002.safetensors", "model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors", "model-00002-of-00002.safetensors",
@@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json", "tokenizer.json",
} }
for _, file := range files { for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"model-00001-of-00002.safetensors", "model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors", "model-00002-of-00002.safetensors",
"config.json", "config.json",
@@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) {
}, },
{ {
name: "safetensors with both tokenizer.json and tokenizer.model", name: "safetensors with both tokenizer.json and tokenizer.model",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
// Create binary content for tokenizer.model (application/octet-stream) // Create binary content for tokenizer.model (application/octet-stream)
binaryContent := make([]byte, 512) binaryContent := make([]byte, 512)
for i := range binaryContent { for i := range binaryContent {
@@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json", "tokenizer.json",
} }
for _, file := range files { for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
return err t.Fatal(err)
} }
} }
// Write tokenizer.model as binary // Write tokenizer.model as binary
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil { if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil {
return err t.Fatal(err)
} }
return nil
}, },
wantFiles: []string{ want: []string{
"model-00001-of-00001.safetensors", "model-00001-of-00001.safetensors",
"config.json", "config.json",
"tokenizer.json", "tokenizer.json",
@@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) {
}, },
{ {
name: "safetensors with consolidated files - prefers model files", name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
files := []string{ files := []string{
"model-00001-of-00001.safetensors", "model-00001-of-00001.safetensors",
"consolidated.safetensors", "consolidated.safetensors",
"config.json", "config.json",
} }
for _, file := range files { for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"model-00001-of-00001.safetensors", // consolidated files should be excluded "model-00001-of-00001.safetensors", // consolidated files should be excluded
"config.json", "config.json",
}, },
}, },
{ {
name: "safetensors without model-.safetensors files - uses consolidated", name: "safetensors without model-.safetensors files - uses consolidated",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
files := []string{ files := []string{
"consolidated.safetensors", "consolidated.safetensors",
"config.json", "config.json",
} }
for _, file := range files { for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"consolidated.safetensors", "consolidated.safetensors",
"config.json", "config.json",
}, },
}, },
{ {
name: "pytorch model files", name: "pytorch model files",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
// Create a file that will be detected as application/zip // Create a file that will be detected as application/zip
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
files := []string{ files := []string{
@@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" { if file == "config.json" {
content = []byte(`{"config": true}`) content = []byte(`{"config": true}`)
} }
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { if err := root.WriteFile(file, content, 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"pytorch_model-00001-of-00002.bin", "pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin", "pytorch_model-00002-of-00002.bin",
"config.json", "config.json",
@@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) {
}, },
{ {
name: "consolidated pth files", name: "consolidated pth files",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
files := []string{ files := []string{
"consolidated.00.pth", "consolidated.00.pth",
@@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" { if file == "config.json" {
content = []byte(`{"config": true}`) content = []byte(`{"config": true}`)
} }
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { if err := root.WriteFile(file, content, 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"consolidated.00.pth", "consolidated.00.pth",
"consolidated.01.pth", "consolidated.01.pth",
"config.json", "config.json",
@@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) {
}, },
{ {
name: "gguf files", name: "gguf files",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
// Create binary content that will be detected as application/octet-stream // Create binary content that will be detected as application/octet-stream
binaryContent := make([]byte, 512) binaryContent := make([]byte, 512)
for i := range binaryContent { for i := range binaryContent {
@@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" { if file == "config.json" {
content = []byte(`{"config": true}`) content = []byte(`{"config": true}`)
} }
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { if err := root.WriteFile(file, content, 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"model.gguf", "model.gguf",
"config.json", "config.json",
}, },
}, },
{ {
name: "bin files as gguf", name: "bin files as gguf",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
binaryContent := make([]byte, 512) binaryContent := make([]byte, 512)
for i := range binaryContent { for i := range binaryContent {
binaryContent[i] = byte(i % 256) binaryContent[i] = byte(i % 256)
@@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" { if file == "config.json" {
content = []byte(`{"config": true}`) content = []byte(`{"config": true}`)
} }
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { if err := root.WriteFile(file, content, 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantFiles: []string{ want: []string{
"model.bin", "model.bin",
"config.json", "config.json",
}, },
}, },
{ {
name: "no model files found", name: "no model files found",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
// Only create non-model files // Only create non-model files
files := []string{"README.md", "config.json"} files := []string{"README.md", "config.json"}
for _, file := range files { for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil { if err := root.WriteFile(file, []byte("content"), 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantErr: true, wantErr: ErrModelNotFound,
expectErrType: ErrModelNotFound,
}, },
{ {
name: "invalid content type for pytorch model", name: "invalid content type for pytorch model",
setup: func(dir string) error { setup: func(t *testing.T, root *os.Root) {
// Create pytorch model file with wrong content type (text instead of zip) // Create pytorch model file with wrong content type (text instead of zip)
files := []string{ files := []string{
"pytorch_model.bin", "pytorch_model.bin",
@@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) {
} }
for _, file := range files { for _, file := range files {
content := []byte("plain text content") content := []byte("plain text content")
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { if err := root.WriteFile(file, content, 0o644); err != nil {
return err t.Fatal(err)
} }
} }
return nil
}, },
wantErr: true, wantErr: ErrModelNotFound,
}, },
} }
tmpDir := t.TempDir()
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
testDir := filepath.Join(tmpDir, tt.name) root, err := os.OpenRoot(t.TempDir())
if err := os.MkdirAll(testDir, 0o755); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
if err := tt.setup(testDir); err != nil {
t.Fatalf("Setup failed: %v", err)
}
files, err := filesForModel(testDir)
if tt.wantErr {
if err == nil {
t.Error("Expected error, but got none")
}
if tt.expectErrType != nil && err != tt.expectErrType {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
}
return
}
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Fatalf("Failed to open root: %v", err)
return }
defer root.Close()
tt.setup(t, root)
files, err := filesForModel(root)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("want %v error, got %v", tt.wantErr, err)
} }
var relativeFiles []string if diff := cmp.Diff(tt.want, files); diff != "" {
for _, file := range files { t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff)
rel, err := filepath.Rel(testDir, file)
if err != nil {
t.Fatalf("Failed to get relative path: %v", err)
}
relativeFiles = append(relativeFiles, rel)
}
if len(relativeFiles) != len(tt.wantFiles) {
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
}
fileSet := make(map[string]bool)
for _, file := range relativeFiles {
fileSet[file] = true
}
for _, wantFile := range tt.wantFiles {
if !fileSet[wantFile] {
t.Errorf("Missing expected file: %s", wantFile)
}
} }
}) })
} }

View File

@@ -1,38 +0,0 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}