mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 20:54:12 +02:00
Compare commits
1 Commits
pdevine/sa
...
mxyng/asyn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ef2106b47 |
89
cmd/cmd.go
89
cmd/cmd.go
@@ -43,7 +43,6 @@ import (
|
|||||||
"github.com/ollama/ollama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/types/syncmap"
|
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
xcmd "github.com/ollama/ollama/x/cmd"
|
xcmd "github.com/ollama/ollama/x/cmd"
|
||||||
"github.com/ollama/ollama/x/create"
|
"github.com/ollama/ollama/x/create"
|
||||||
@@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
spinner.Stop()
|
|
||||||
|
|
||||||
req.Model = modelName
|
req.Model = modelName
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
@@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var g errgroup.Group
|
var g errgroup.Group
|
||||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
|
for blob, err := range createBlobs(req.Files, req.Adapters) {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
files := syncmap.NewSyncMap[string, string]()
|
|
||||||
for f, digest := range req.Files {
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
_, err := createBlob(cmd, client, blob.Abs, blob.Digest, p)
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: this is incorrect since the file might be in a subdirectory
|
|
||||||
// instead this should take the path relative to the model directory
|
|
||||||
// but the current implementation does not allow this
|
|
||||||
files.Store(filepath.Base(f), digest)
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
adapters := syncmap.NewSyncMap[string, string]()
|
if _, ok := req.Files[blob.Rel]; ok {
|
||||||
for f, digest := range req.Adapters {
|
req.Files[blob.Rel] = blob.Digest
|
||||||
g.Go(func() error {
|
} else if _, ok := req.Adapters[blob.Rel]; ok {
|
||||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
req.Adapters[blob.Rel] = blob.Digest
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: same here
|
|
||||||
adapters.Store(filepath.Base(f), digest)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Files = files.Items()
|
spinner.Stop()
|
||||||
req.Adapters = adapters.Items()
|
|
||||||
|
|
||||||
bars := make(map[string]*progress.Bar)
|
bars := make(map[string]*progress.Bar)
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
@@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
|
|
||||||
realPath, err := filepath.EvalSymlinks(path)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
bin, err := os.Open(realPath)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer bin.Close()
|
|
||||||
|
|
||||||
// Get file info to retrieve the size
|
|
||||||
fileInfo, err := bin.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
fileSize := fileInfo.Size()
|
|
||||||
|
|
||||||
var pw progressWriter
|
|
||||||
status := fmt.Sprintf("copying file %s 0%%", digest)
|
|
||||||
spinner := progress.NewSpinner(status)
|
|
||||||
p.Add(status, spinner)
|
|
||||||
defer spinner.Stop()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(60 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
|
|
||||||
case <-done:
|
|
||||||
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return digest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type progressWriter struct {
|
type progressWriter struct {
|
||||||
n atomic.Int64
|
n atomic.Int64
|
||||||
}
|
}
|
||||||
|
|||||||
103
cmd/create.go
Normal file
103
cmd/create.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
type blob struct {
|
||||||
|
Rel, Abs, Digest string
|
||||||
|
}
|
||||||
|
|
||||||
|
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
|
||||||
|
realPath, err := filepath.EvalSymlinks(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := os.Open(realPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer bin.Close()
|
||||||
|
|
||||||
|
// Get file info to retrieve the size
|
||||||
|
fileInfo, err := bin.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fileSize := fileInfo.Size()
|
||||||
|
|
||||||
|
var pw progressWriter
|
||||||
|
status := fmt.Sprintf("copying file %s 0%%", digest)
|
||||||
|
spinner := progress.NewSpinner(status)
|
||||||
|
p.Add(status, spinner)
|
||||||
|
defer spinner.Stop()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
|
||||||
|
case <-done:
|
||||||
|
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] {
|
||||||
|
return func(yield func(blob, error) bool) {
|
||||||
|
for _, mapping := range mappings {
|
||||||
|
for rel, abs := range mapping {
|
||||||
|
if abs, ok := strings.CutPrefix(abs, "abs:"); ok {
|
||||||
|
f, err := os.Open(abs)
|
||||||
|
if err != nil {
|
||||||
|
yield(blob{}, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, f); err != nil {
|
||||||
|
yield(blob{}, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
yield(blob{}, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(blob{
|
||||||
|
Rel: rel,
|
||||||
|
Abs: abs,
|
||||||
|
Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)),
|
||||||
|
}, nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
161
parser/parser.go
161
parser/parser.go
@@ -3,22 +3,20 @@ package parser
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"golang.org/x/text/encoding/unicode"
|
"golang.org/x/text/encoding/unicode"
|
||||||
"golang.org/x/text/transform"
|
"golang.org/x/text/transform"
|
||||||
|
|
||||||
@@ -54,7 +52,10 @@ var deprecatedParameters = []string{
|
|||||||
|
|
||||||
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
|
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
|
||||||
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
|
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
|
||||||
req := &api.CreateRequest{}
|
req := &api.CreateRequest{
|
||||||
|
Files: make(map[string]string),
|
||||||
|
Adapters: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
var licenses []string
|
var licenses []string
|
||||||
@@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
for _, c := range f.Commands {
|
for _, c := range f.Commands {
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
files, err := filesMap(c.Args, relativeDir)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
digestMap, err := fileDigestMap(path)
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
req.From = c.Args
|
req.From = c.Args
|
||||||
continue
|
continue
|
||||||
@@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Files == nil {
|
maps.Copy(req.Files, files)
|
||||||
req.Files = digestMap
|
|
||||||
} else {
|
|
||||||
for k, v := range digestMap {
|
|
||||||
req.Files[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "adapter":
|
case "adapter":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
files, err := filesMap(c.Args, relativeDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
digestMap, err := fileDigestMap(path)
|
maps.Copy(req.Adapters, files)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Adapters = digestMap
|
|
||||||
case "template":
|
case "template":
|
||||||
req.Template = c.Args
|
req.Template = c.Args
|
||||||
case "system":
|
case "system":
|
||||||
@@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileDigestMap(path string) (map[string]string, error) {
|
func filesMap(args, base string) (map[string]string, error) {
|
||||||
fl := make(map[string]string)
|
path, err := expandPath(args, base)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(path)
|
fi, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
mapping := make(map[string]string)
|
||||||
if fi.IsDir() {
|
if !fi.IsDir() {
|
||||||
fs, err := filesForModel(path)
|
return map[string]string{
|
||||||
if err != nil {
|
filepath.Base(path): "abs:" + path,
|
||||||
return nil, err
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
for _, f := range fs {
|
|
||||||
f, err := filepath.EvalSymlinks(f)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rel, err := filepath.Rel(path, f)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !filepath.IsLocal(rel) {
|
|
||||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
|
||||||
}
|
|
||||||
|
|
||||||
files = append(files, f)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
files = []string{path}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var mu sync.Mutex
|
root, err := os.OpenRoot(path)
|
||||||
var g errgroup.Group
|
if err != nil {
|
||||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
return nil, err
|
||||||
for _, f := range files {
|
|
||||||
g.Go(func() error {
|
|
||||||
digest, err := digestForFile(f)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
fl[f] = digest
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
files, err := filesForModel(root)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fl, nil
|
for _, file := range files {
|
||||||
|
// create a temporary mapping from relative path to absolute path
|
||||||
|
mapping[file] = "abs:" + filepath.Join(root.Name(), file)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func digestForFile(filename string) (string, error) {
|
func filesForModel(root *os.Root) ([]string, error) {
|
||||||
filepath, err := filepath.EvalSymlinks(filename)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
bin, err := os.Open(filepath)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer bin.Close()
|
|
||||||
|
|
||||||
hash := sha256.New()
|
|
||||||
if _, err := io.Copy(hash, bin); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func filesForModel(path string) ([]string, error) {
|
|
||||||
detectContentType := func(path string) (string, error) {
|
detectContentType := func(path string) (string, error) {
|
||||||
f, err := os.Open(path)
|
f, err := root.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var b bytes.Buffer
|
bts := make([]byte, 512)
|
||||||
b.Grow(512)
|
n, err := io.ReadFull(f, bts)
|
||||||
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
// short read, use what we have
|
||||||
|
bts = bts[:n]
|
||||||
|
} else if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
|
||||||
return contentType, nil
|
return contentType, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
glob := func(pattern, contentType string) ([]string, error) {
|
glob := func(pattern, contentType string) ([]string, error) {
|
||||||
matches, err := filepath.Glob(pattern)
|
matches, err := fs.Glob(root.FS(), pattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
if ct, err := detectContentType(match); err != nil {
|
if ct, err := detectContentType(match); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if len(contentType) > 0 && ct != contentType {
|
} else if len(contentType) > 0 && ct != contentType {
|
||||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
|
if st, _ := glob("model*.safetensors", ""); len(st) > 0 {
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
|
} else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 {
|
||||||
// covers consolidated.safetensors
|
// covers consolidated.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
files = append(files, pt...)
|
files = append(files, pt...)
|
||||||
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
|
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers consolidated.x.pth, consolidated.pth
|
// covers consolidated.x.pth, consolidated.pth
|
||||||
files = append(files, pt...)
|
files = append(files, pt...)
|
||||||
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
|
} else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 {
|
||||||
// covers gguf files ending in .gguf
|
// covers gguf files ending in .gguf
|
||||||
files = append(files, gg...)
|
files = append(files, gg...)
|
||||||
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
|
} else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 {
|
||||||
// covers gguf files ending in .bin
|
// covers gguf files ending in .bin
|
||||||
files = append(files, gg...)
|
files = append(files, gg...)
|
||||||
} else {
|
} else {
|
||||||
@@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add configuration files, json files are detected as text/plain
|
// add configuration files, json files are detected as text/plain
|
||||||
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
js, err := glob("*.json", "text/plain")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
// bert models require a nested config.json
|
// bert models require a nested config.json
|
||||||
// TODO(mxyng): merge this with the glob above
|
// TODO(mxyng): merge this with the glob above
|
||||||
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
|
js, err = glob("**/*.json", "text/plain")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
|
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
|
||||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
|
||||||
files = append(files, tks...)
|
files = append(files, tks...)
|
||||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
|
||||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||||
files = append(files, tks...)
|
files = append(files, tks...)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package parser
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,6 +14,7 @@ import (
|
|||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/text/encoding"
|
"golang.org/x/text/encoding"
|
||||||
@@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you?
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(actual, c.expected); diff != "" {
|
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
|
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string {
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
h := sha256.New()
|
|
||||||
n, err := io.Copy(h, r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
|
||||||
}
|
|
||||||
|
|
||||||
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
|
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
|
||||||
@@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
|||||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// Calculate sha256 of file
|
return f.Name()
|
||||||
if _, err := f.Seek(0, 0); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
digest, _ := getSHA256Digest(t, f)
|
|
||||||
|
|
||||||
return f.Name(), digest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRequestFiles(t *testing.T) {
|
func TestCreateRequestFiles(t *testing.T) {
|
||||||
n1, d1 := createBinFile(t, nil, nil)
|
n1 := createBinFile(t, nil, nil)
|
||||||
n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
|
n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
fmt.Sprintf("FROM %s", n1),
|
fmt.Sprintf("FROM %s", n1),
|
||||||
&api.CreateRequest{Files: map[string]string{n1: d1}},
|
&api.CreateRequest{
|
||||||
|
Files: map[string]string{
|
||||||
|
filepath.Base(n1): "abs:" + n1,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
|
fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
|
||||||
&api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
|
&api.CreateRequest{
|
||||||
|
Files: map[string]string{
|
||||||
|
filepath.Base(n1): "abs:" + n1,
|
||||||
|
filepath.Base(n2): "abs:" + n2,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(actual, c.expected); diff != "" {
|
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -860,15 +850,15 @@ func TestCreateRequestFiles(t *testing.T) {
|
|||||||
|
|
||||||
func TestFilesForModel(t *testing.T) {
|
func TestFilesForModel(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setup func(string) error
|
setup func(*testing.T, *os.Root)
|
||||||
wantFiles []string
|
want []string
|
||||||
wantErr bool
|
wantErr error
|
||||||
expectErrType error
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "safetensors model files",
|
name: "safetensors model files",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
|
t.Helper()
|
||||||
files := []string{
|
files := []string{
|
||||||
"model-00001-of-00002.safetensors",
|
"model-00001-of-00002.safetensors",
|
||||||
"model-00002-of-00002.safetensors",
|
"model-00002-of-00002.safetensors",
|
||||||
@@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"model-00001-of-00002.safetensors",
|
"model-00001-of-00002.safetensors",
|
||||||
"model-00002-of-00002.safetensors",
|
"model-00002-of-00002.safetensors",
|
||||||
"config.json",
|
"config.json",
|
||||||
@@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "safetensors with both tokenizer.json and tokenizer.model",
|
name: "safetensors with both tokenizer.json and tokenizer.model",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
// Create binary content for tokenizer.model (application/octet-stream)
|
// Create binary content for tokenizer.model (application/octet-stream)
|
||||||
binaryContent := make([]byte, 512)
|
binaryContent := make([]byte, 512)
|
||||||
for i := range binaryContent {
|
for i := range binaryContent {
|
||||||
@@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Write tokenizer.model as binary
|
// Write tokenizer.model as binary
|
||||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
|
if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"model-00001-of-00001.safetensors",
|
"model-00001-of-00001.safetensors",
|
||||||
"config.json",
|
"config.json",
|
||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
@@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "safetensors with consolidated files - prefers model files",
|
name: "safetensors with consolidated files - prefers model files",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
files := []string{
|
files := []string{
|
||||||
"model-00001-of-00001.safetensors",
|
"model-00001-of-00001.safetensors",
|
||||||
"consolidated.safetensors",
|
"consolidated.safetensors",
|
||||||
"config.json",
|
"config.json",
|
||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"model-00001-of-00001.safetensors", // consolidated files should be excluded
|
"model-00001-of-00001.safetensors", // consolidated files should be excluded
|
||||||
"config.json",
|
"config.json",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "safetensors without model-.safetensors files - uses consolidated",
|
name: "safetensors without model-.safetensors files - uses consolidated",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
files := []string{
|
files := []string{
|
||||||
"consolidated.safetensors",
|
"consolidated.safetensors",
|
||||||
"config.json",
|
"config.json",
|
||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"consolidated.safetensors",
|
"consolidated.safetensors",
|
||||||
"config.json",
|
"config.json",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "pytorch model files",
|
name: "pytorch model files",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
// Create a file that will be detected as application/zip
|
// Create a file that will be detected as application/zip
|
||||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
|
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
|
||||||
files := []string{
|
files := []string{
|
||||||
@@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
if file == "config.json" {
|
if file == "config.json" {
|
||||||
content = []byte(`{"config": true}`)
|
content = []byte(`{"config": true}`)
|
||||||
}
|
}
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"pytorch_model-00001-of-00002.bin",
|
"pytorch_model-00001-of-00002.bin",
|
||||||
"pytorch_model-00002-of-00002.bin",
|
"pytorch_model-00002-of-00002.bin",
|
||||||
"config.json",
|
"config.json",
|
||||||
@@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "consolidated pth files",
|
name: "consolidated pth files",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
|
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
|
||||||
files := []string{
|
files := []string{
|
||||||
"consolidated.00.pth",
|
"consolidated.00.pth",
|
||||||
@@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
if file == "config.json" {
|
if file == "config.json" {
|
||||||
content = []byte(`{"config": true}`)
|
content = []byte(`{"config": true}`)
|
||||||
}
|
}
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"consolidated.00.pth",
|
"consolidated.00.pth",
|
||||||
"consolidated.01.pth",
|
"consolidated.01.pth",
|
||||||
"config.json",
|
"config.json",
|
||||||
@@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "gguf files",
|
name: "gguf files",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
// Create binary content that will be detected as application/octet-stream
|
// Create binary content that will be detected as application/octet-stream
|
||||||
binaryContent := make([]byte, 512)
|
binaryContent := make([]byte, 512)
|
||||||
for i := range binaryContent {
|
for i := range binaryContent {
|
||||||
@@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
if file == "config.json" {
|
if file == "config.json" {
|
||||||
content = []byte(`{"config": true}`)
|
content = []byte(`{"config": true}`)
|
||||||
}
|
}
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"model.gguf",
|
"model.gguf",
|
||||||
"config.json",
|
"config.json",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "bin files as gguf",
|
name: "bin files as gguf",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
binaryContent := make([]byte, 512)
|
binaryContent := make([]byte, 512)
|
||||||
for i := range binaryContent {
|
for i := range binaryContent {
|
||||||
binaryContent[i] = byte(i % 256)
|
binaryContent[i] = byte(i % 256)
|
||||||
@@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
if file == "config.json" {
|
if file == "config.json" {
|
||||||
content = []byte(`{"config": true}`)
|
content = []byte(`{"config": true}`)
|
||||||
}
|
}
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantFiles: []string{
|
want: []string{
|
||||||
"model.bin",
|
"model.bin",
|
||||||
"config.json",
|
"config.json",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no model files found",
|
name: "no model files found",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
// Only create non-model files
|
// Only create non-model files
|
||||||
files := []string{"README.md", "config.json"}
|
files := []string{"README.md", "config.json"}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
|
if err := root.WriteFile(file, []byte("content"), 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: ErrModelNotFound,
|
||||||
expectErrType: ErrModelNotFound,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid content type for pytorch model",
|
name: "invalid content type for pytorch model",
|
||||||
setup: func(dir string) error {
|
setup: func(t *testing.T, root *os.Root) {
|
||||||
// Create pytorch model file with wrong content type (text instead of zip)
|
// Create pytorch model file with wrong content type (text instead of zip)
|
||||||
files := []string{
|
files := []string{
|
||||||
"pytorch_model.bin",
|
"pytorch_model.bin",
|
||||||
@@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
content := []byte("plain text content")
|
content := []byte("plain text content")
|
||||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||||
return err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: ErrModelNotFound,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testDir := filepath.Join(tmpDir, tt.name)
|
root, err := os.OpenRoot(t.TempDir())
|
||||||
if err := os.MkdirAll(testDir, 0o755); err != nil {
|
|
||||||
t.Fatalf("Failed to create test directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tt.setup(testDir); err != nil {
|
|
||||||
t.Fatalf("Setup failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
files, err := filesForModel(testDir)
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error, but got none")
|
|
||||||
}
|
|
||||||
if tt.expectErrType != nil && err != tt.expectErrType {
|
|
||||||
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Fatalf("Failed to open root: %v", err)
|
||||||
return
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
|
tt.setup(t, root)
|
||||||
|
|
||||||
|
files, err := filesForModel(root)
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("want %v error, got %v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var relativeFiles []string
|
if diff := cmp.Diff(tt.want, files); diff != "" {
|
||||||
for _, file := range files {
|
t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff)
|
||||||
rel, err := filepath.Rel(testDir, file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get relative path: %v", err)
|
|
||||||
}
|
|
||||||
relativeFiles = append(relativeFiles, rel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(relativeFiles) != len(tt.wantFiles) {
|
|
||||||
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileSet := make(map[string]bool)
|
|
||||||
for _, file := range relativeFiles {
|
|
||||||
fileSet[file] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, wantFile := range tt.wantFiles {
|
|
||||||
if !fileSet[wantFile] {
|
|
||||||
t.Errorf("Missing expected file: %s", wantFile)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
package syncmap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"maps"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SyncMap is a simple, generic thread-safe map implementation.
|
|
||||||
type SyncMap[K comparable, V any] struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
m map[K]V
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
|
|
||||||
return &SyncMap[K, V]{
|
|
||||||
m: make(map[K]V),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
val, ok := s.m[key]
|
|
||||||
return val, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SyncMap[K, V]) Store(key K, value V) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.m[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SyncMap[K, V]) Items() map[K]V {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
// shallow copy map items
|
|
||||||
return maps.Clone(s.m)
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user