Compare commits

...

5 Commits

Author SHA1 Message Date
Patrick Devine
7bcdb250b9 fix failing client2 unit tests 2026-04-21 13:56:39 -07:00
Patrick Devine
7bbcd2e6be server: add v2 manifest path
This change adds a new manifest-v2/ path for new models created with the
create/pull/copy commands. Under manifest-v2, manifests are now just blobs which are
content addressable similar to tensors/config files. The named tags instead
will symlink/hard link/contain a copy depending on what the file system supports.

Downgrades to older versions of ollama are still possible, but any create/pull/copy
done with the newer version will potentially have its blobs pruned by the older
version.

manifest-v2 also changes the default registry name to `ollama.com` instead of
`registry.ollama.ai`.
2026-04-21 12:05:54 -07:00
Jesse Gross
22d6c817f8 mlxrunner: fuse top-P and top-K into a single sort pass
When both filters are active, avoid paying for a full sort in top-P
and a partial sort in top-K. Single-filter paths are unchanged.
Improves generation throughput on gemma4:e4b by 1.5%.
2026-04-20 17:43:00 -07:00
Jesse Gross
ca01373b28 mlxrunner: use MaxAxis in the min-P sampler
One reduction op instead of Argmax + TakeAlongAxis.
2026-04-20 17:43:00 -07:00
Jesse Gross
24e038d56a mlxrunner: add logprobs support
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax
with the top-K ranked by probability. Skipped on the GPU when the request
doesn't ask for logprobs so decode doesn't pay for it otherwise.
2026-04-20 17:43:00 -07:00
23 changed files with 1549 additions and 386 deletions

View File

@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
}
func TestAPIGenerateLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
}
func TestAPIChatLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

View File

@@ -1,18 +1,23 @@
package manifest
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/types/model"
)
var blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`)
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
@@ -22,6 +27,7 @@ type Manifest struct {
filepath string
fi os.FileInfo
digest string
name model.Name
}
func (m *Manifest) Size() (size int64) {
@@ -36,6 +42,14 @@ func (m *Manifest) Digest() string {
return m.digest
}
func (m *Manifest) BlobDigest() string {
if m.digest == "" {
return ""
}
return "sha256:" + m.digest
}
func (m *Manifest) FileInfo() os.FileInfo {
return m.fi
}
@@ -59,16 +73,7 @@ func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := Path()
if err != nil {
return err
}
return PruneDirectory(manifests)
return removeNamedManifestPaths(m.name)
}
func (m *Manifest) RemoveLayers() error {
@@ -80,6 +85,9 @@ func (m *Manifest) RemoveLayers() error {
// Build set of digests still in use by other manifests
inUse := make(map[string]struct{})
for _, other := range ms {
if other.BlobDigest() != "" {
inUse[other.BlobDigest()] = struct{}{}
}
for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" {
inUse[layer.Digest] = struct{}{}
@@ -87,20 +95,27 @@ func (m *Manifest) RemoveLayers() error {
}
}
// Remove layers not used by any other manifest
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == "" {
digests := make([]string, 0, len(m.Layers)+2)
digests = append(digests, m.BlobDigest())
for _, layer := range m.Layers {
digests = append(digests, layer.Digest)
}
digests = append(digests, m.Config.Digest)
// Remove manifest and layer blobs not used by any other manifest
for _, digest := range digests {
if digest == "" {
continue
}
if _, used := inUse[layer.Digest]; used {
if _, used := inUse[digest]; used {
continue
}
blob, err := BlobsPath(layer.Digest)
blob, err := BlobsPath(digest)
if err != nil {
return err
}
if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest)
slog.Debug("blob does not exist", "digest", digest)
} else if err != nil {
return err
}
@@ -114,15 +129,36 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
manifests, err := Path()
p, root, err := resolveManifestPath(n)
if err != nil {
return nil, err
}
p := filepath.Join(manifests, n.Filepath())
return parseManifestFile(normalizeLogicalName(n), p, root)
}
func ReadManifestData(n model.Name) ([]byte, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
}
p, root, err := resolveManifestPath(n)
if err != nil {
return nil, err
}
f, _, err := OpenVerifiedManifest(p, root)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func parseManifestFile(name model.Name, path, root string) (*Manifest, error) {
var m Manifest
f, err := os.Open(p)
f, digest, err := OpenVerifiedManifest(path, root)
if err != nil {
return nil, err
}
@@ -133,35 +169,19 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
m.filepath = p
m.filepath = path
m.fi = fi
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
m.digest = digest
m.name = name
return &m, nil
}
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := Path()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@@ -169,33 +189,371 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
Layers: layers,
}
return json.NewEncoder(f).Encode(m)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(m); err != nil {
return err
}
return WriteManifestData(name, b.Bytes())
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
// WriteManifestData stores raw manifest bytes as a content-addressed blob and
// updates the v2 named manifest path to reference that blob. Any legacy named
// manifest for the same model is removed after the v2 write succeeds.
func WriteManifestData(name model.Name, data []byte) error {
if !name.IsFullyQualified() {
return model.Unqualified(name)
}
digest, err := writeManifestBlob(data)
if err != nil {
return err
}
if err := LinkManifest(name, digest); err != nil {
return err
}
return removeLegacyManifestPaths(name)
}
// LinkManifest updates the v2 named manifest path to reference an existing
// manifest blob. It prefers symlinks, then hardlinks, then a byte-for-byte copy
// for filesystems that do not support links.
func LinkManifest(name model.Name, digest string) error {
if !name.IsFullyQualified() {
return model.Unqualified(name)
}
manifestPath, err := V2PathForName(name)
if err != nil {
return err
}
blobPath, err := BlobsPath(digest)
if err != nil {
return err
}
if _, err := os.Stat(blobPath); err != nil {
return err
}
if err := checkBlobDigest(blobPath, digest); err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
return err
}
if rel, err := filepath.Rel(filepath.Dir(manifestPath), blobPath); err == nil {
if err := os.Symlink(rel, manifestPath); err == nil {
return nil
}
}
if err := os.Link(blobPath, manifestPath); err == nil {
return nil
}
return copyManifestFile(blobPath, manifestPath)
}
func writeManifestBlob(data []byte) (string, error) {
sum := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", sum)
blobPath, err := BlobsPath(digest)
if err != nil {
return "", err
}
if existing, err := os.ReadFile(blobPath); err == nil && bytes.Equal(existing, data) {
return digest, nil
}
blobs, err := BlobsPath("")
if err != nil {
return "", err
}
temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil {
return "", err
}
tempName := temp.Name()
defer os.Remove(tempName)
if _, err := temp.Write(data); err != nil {
temp.Close()
return "", err
}
if err := temp.Close(); err != nil {
return "", err
}
if err := os.Chmod(tempName, 0o644); err != nil {
return "", err
}
if err := os.Rename(tempName, blobPath); err != nil {
if err := os.Remove(blobPath); err != nil && !os.IsNotExist(err) {
return "", err
}
if err := os.Rename(tempName, blobPath); err != nil {
return "", err
}
}
return digest, nil
}
func copyManifestFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
temp, err := os.CreateTemp(filepath.Dir(dst), ".manifest-*")
if err != nil {
return err
}
tempName := temp.Name()
defer os.Remove(tempName)
if _, err := io.Copy(temp, in); err != nil {
temp.Close()
return err
}
if err := temp.Close(); err != nil {
return err
}
if err := os.Chmod(tempName, 0o644); err != nil {
return err
}
return os.Rename(tempName, dst)
}
// OpenVerifiedManifest opens a named manifest path rooted under root. Symlinks must resolve to a
// blob whose basename is sha256-<hex> and whose bytes hash to that digest.
// Regular-file manifests are treated as legacy/copy fallback manifests and are
// opened without mutating the local store.
func OpenVerifiedManifest(path, root string) (*os.File, string, error) {
resolvedRoot, err := filepath.EvalSymlinks(root)
if err != nil {
return nil, "", err
}
info, err := os.Lstat(path)
if err != nil {
return nil, "", err
}
target, err := evalAbs(path)
if err != nil {
return nil, "", err
}
if info.Mode()&os.ModeSymlink != 0 {
base := filepath.Base(target)
if !blobFilenamePattern.MatchString(base) {
return nil, "", fmt.Errorf("manifest symlink target %q is not a sha256 blob", target)
}
digest := strings.ToLower(strings.TrimPrefix(base, "sha256-"))
blobPath, err := BlobsPath("sha256:" + digest)
if err != nil {
return nil, "", err
}
if !sameFile(target, blobPath) {
return nil, "", fmt.Errorf("manifest symlink target %q does not match blob %q", target, blobPath)
}
f, err := os.Open(path)
if err != nil {
return nil, "", err
}
if err := checkBlobDigestReader(f, "sha256:"+digest); err != nil {
f.Close()
return nil, "", err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
f.Close()
return nil, "", err
}
return f, digest, nil
}
if !pathWithin(target, resolvedRoot) {
return nil, "", fmt.Errorf("manifest path %q resolves outside manifest directory", path)
}
f, err := os.Open(path)
if err != nil {
return nil, "", err
}
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
f.Close()
return nil, "", err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
f.Close()
return nil, "", err
}
digest := fmt.Sprintf("%x", h.Sum(nil))
return f, digest, nil
}
// MigrateManifestLinks moves legacy named manifests into manifests-v2. This is currently unwired but
// will be added in the future.
func MigrateManifestLinks() (int, error) {
manifests, err := Path()
if err != nil {
return nil, err
return 0, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
return 0, err
}
ms := make(map[model.Name]*Manifest)
var migrated int
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
return migrated, err
}
if fi.IsDir() {
continue
}
rel, err := filepath.Rel(manifests, match)
if err != nil {
return migrated, fmt.Errorf("%s %w", match, err)
}
n := model.ParseNameFromFilepath(rel)
if !n.IsFullyQualified() {
slog.Warn("bad manifest name", "path", rel)
continue
}
data, err := readManifestPath(match, manifests)
if err != nil {
return migrated, err
}
if err := WriteManifestData(normalizeLogicalName(n), data); err != nil {
return migrated, err
}
migrated++
}
return migrated, nil
}
func readManifestPath(path, root string) ([]byte, error) {
f, _, err := OpenVerifiedManifest(path, root)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func pathWithin(path, root string) bool {
rel, err := filepath.Rel(root, path)
return err == nil && rel != "." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}
func evalAbs(path string) (string, error) {
abs, err := filepath.Abs(path)
if err != nil {
return "", err
}
return filepath.EvalSymlinks(abs)
}
func sameFile(a, b string) bool {
ai, err := os.Stat(a)
if err != nil {
return false
}
bi, err := os.Stat(b)
if err != nil {
return false
}
return os.SameFile(ai, bi)
}
func checkBlobDigest(path, digest string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
return checkBlobDigestReader(f, digest)
}
func checkBlobDigestReader(r io.Reader, digest string) error {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
return err
}
got := fmt.Sprintf("sha256:%x", h.Sum(nil))
if got != strings.ToLower(strings.Replace(digest, "-", ":", 1)) {
return errors.New("digest mismatch")
}
return nil
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
ms := make(map[model.Name]*Manifest)
manifestsV2, err := V2Path()
if err != nil {
return nil, err
}
if err := collectManifests(ms, manifestsV2, continueOnError); err != nil {
return nil, err
}
manifests, err := Path()
if err != nil {
return nil, err
}
if err := collectManifests(ms, manifests, continueOnError); err != nil {
return nil, err
}
return ms, nil
}
func collectManifests(ms map[model.Name]*Manifest, root string, continueOnError bool) error {
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(root, "*", "*", "*", "*"))
if err != nil {
return err
}
for _, match := range matches {
fi, err := os.Lstat(match)
if err != nil {
return err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
rel, err := filepath.Rel(root, match)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", match, err)
return fmt.Errorf("%s %w", match, err)
}
slog.Warn("bad filepath", "path", match, "error", err)
continue
@@ -204,16 +562,21 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
if !continueOnError {
return nil, fmt.Errorf("%s %w", rel, err)
return fmt.Errorf("invalid manifest name: %s", rel)
}
slog.Warn("bad manifest name", "path", rel)
continue
}
m, err := ParseNamedManifest(n)
n = normalizeLogicalName(n)
if _, ok := ms[n]; ok {
continue
}
m, err := parseManifestFile(n, match, root)
if err != nil {
if !continueOnError {
return nil, fmt.Errorf("%s %w", n, err)
return fmt.Errorf("%s %w", n, err)
}
slog.Warn("bad manifest", "name", n, "error", err)
continue
@@ -223,5 +586,5 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
}
}
return ms, nil
return nil
}

View File

@@ -1,19 +1,23 @@
package manifest
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
func createManifestAtRoot(t *testing.T, path, root, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
p := filepath.Join(path, root, name)
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatal(err)
}
@@ -29,6 +33,309 @@ func createManifest(t *testing.T, path, name string) {
}
}
func createManifest(t *testing.T, path, name string) {
t.Helper()
createManifestAtRoot(t, path, "manifests", name)
}
func TestWriteManifestStoresManifestAsBlob(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
config := Layer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:" + strings.Repeat("a", 64),
Size: 12,
}
if err := WriteManifest(name, config, nil); err != nil {
t.Fatal(err)
}
manifestPath, err := V2PathForName(name)
if err != nil {
t.Fatal(err)
}
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(manifestData)
digest := fmt.Sprintf("sha256:%x", sum)
blobPath, err := BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
blobData, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(blobData, manifestData) {
t.Fatal("manifest path and blob content differ")
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
if got := m.Digest(); got != fmt.Sprintf("%x", sum) {
t.Fatalf("digest = %q, want %x", got, sum)
}
if got := m.BlobDigest(); got != digest {
t.Fatalf("blob digest = %q, want %q", got, digest)
}
}
func TestParseNamedManifestLeavesLegacyManifestInPlace(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
manifestPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if _, err := ParseNamedManifest(name); err != nil {
t.Fatal(err)
}
fi, err := os.Lstat(manifestPath)
if err != nil {
t.Fatal(err)
}
if fi.Mode()&os.ModeSymlink != 0 {
t.Fatal("legacy manifest was converted to a symlink while reading")
}
data, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(data)
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
t.Fatalf("legacy manifest read created blob: %v", err)
}
}
func TestMigrateManifestLinks(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
migrated, err := MigrateManifestLinks()
if err != nil {
t.Fatal(err)
}
if migrated != 1 {
t.Fatalf("migrated = %d, want 1", migrated)
}
manifestPath, err := V2PathForName(name)
if err != nil {
t.Fatal(err)
}
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(manifestData)
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
if err != nil {
t.Fatal(err)
}
blobData, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(blobData, manifestData) {
t.Fatal("migrated manifest path and blob content differ")
}
legacyPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
t.Fatalf("legacy manifest still exists: %v", err)
}
migrated, err = MigrateManifestLinks()
if err != nil {
t.Fatal(err)
}
if migrated != 0 {
t.Fatalf("migrated on second run = %d, want 0", migrated)
}
if _, err := MigrateManifestLinks(); err != nil {
t.Fatal(err)
}
manifestDataAfter, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(manifestDataAfter, manifestData) {
t.Fatal("second migration changed manifest content")
}
}
func TestRemoveLayersRemovesUnreferencedManifestBlob(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
if err := WriteManifest(name, Layer{}, nil); err != nil {
t.Fatal(err)
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
blobPath, err := BlobsPath(m.BlobDigest())
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); err != nil {
t.Fatal(err)
}
if err := m.Remove(); err != nil {
t.Fatal(err)
}
if err := m.RemoveLayers(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
t.Fatalf("manifest blob still exists: %v", err)
}
}
func TestParseNamedManifestRejectsUnsafeSymlinks(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
manifestPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
t.Fatal(err)
}
t.Run("non blob basename", func(t *testing.T) {
target := filepath.Join(t.TempDir(), "not-a-blob")
if err := os.WriteFile(target, []byte(`{"schemaVersion":2}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
if err := os.Symlink(target, manifestPath); err != nil {
t.Skipf("symlink unavailable: %v", err)
}
_, err := ParseNamedManifest(name)
if err == nil || !strings.Contains(err.Error(), "not a sha256 blob") {
t.Fatalf("err = %v, want not a sha256 blob", err)
}
})
t.Run("blob basename outside blob store", func(t *testing.T) {
data := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
sum := sha256.Sum256(data)
target := filepath.Join(t.TempDir(), fmt.Sprintf("sha256-%x", sum))
if err := os.WriteFile(target, data, 0o644); err != nil {
t.Fatal(err)
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
if err := os.Symlink(target, manifestPath); err != nil {
t.Skipf("symlink unavailable: %v", err)
}
_, err := ParseNamedManifest(name)
if err == nil || !strings.Contains(err.Error(), "does not match blob") {
t.Fatalf("err = %v, want does not match blob", err)
}
})
}
func TestParseNamedManifestPrefersV2(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
legacyPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
t.Fatal(err)
}
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
t.Fatal(err)
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
if m.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
}
}
func TestManifestsV2ShadowsLegacy(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
t.Fatal(err)
}
ms, err := Manifests(true)
if err != nil {
t.Fatal(err)
}
if len(ms) != 1 {
t.Fatalf("manifest count = %d, want 1", len(ms))
}
var m *Manifest
for gotName, gotManifest := range ms {
if gotName.EqualFold(model.ParseName("example")) {
m = gotManifest
break
}
}
if m == nil {
t.Fatalf("missing v2 manifest for %s", name)
}
if m.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string

View File

@@ -14,8 +14,23 @@ import (
var ErrInvalidDigestFormat = errors.New("invalid digest format")
const (
legacyDirName = "manifests"
v2DirName = "manifests-v2"
defaultPublicHost = "registry.ollama.ai"
v2CanonicalHost = "ollama.com"
)
func Path() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
return manifestPath(legacyDirName)
}
func V2Path() (string, error) {
return manifestPath(v2DirName)
}
func manifestPath(dir string) (string, error) {
path := filepath.Join(envconfig.Models(), dir)
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
@@ -25,6 +40,10 @@ func Path() (string, error) {
// PathForName returns the path to the manifest file for a specific model name.
func PathForName(n model.Name) (string, error) {
return LegacyPathForName(n)
}
func LegacyPathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
@@ -37,6 +56,162 @@ func PathForName(n model.Name) (string, error) {
return filepath.Join(manifests, n.Filepath()), nil
}
func V2PathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := V2Path()
if err != nil {
return "", err
}
return filepath.Join(manifests, canonicalV2Name(n).Filepath()), nil
}
func ResolvePathForName(n model.Name) (string, error) {
path, _, err := resolveManifestPath(n)
return path, err
}
func resolveManifestPath(n model.Name) (string, string, error) {
if !n.IsValid() {
return "", "", os.ErrNotExist
}
v2Path, err := V2PathForName(n)
if err != nil {
return "", "", err
}
if _, err := os.Lstat(v2Path); err == nil {
root, err := V2Path()
return v2Path, root, err
} else if !os.IsNotExist(err) {
return "", "", err
}
legacyRoot, err := Path()
if err != nil {
return "", "", err
}
for _, legacyName := range legacyNameCandidates(n) {
legacyPath := filepath.Join(legacyRoot, legacyName.Filepath())
if _, err := os.Lstat(legacyPath); err == nil {
return legacyPath, legacyRoot, nil
} else if !os.IsNotExist(err) {
return "", "", err
}
}
return "", "", os.ErrNotExist
}
func removeNamedManifestPaths(n model.Name) error {
candidates := legacyNameCandidates(n)
paths := make([]string, 0, 1+len(candidates))
v2Path, err := V2PathForName(n)
if err != nil {
return err
}
paths = append(paths, v2Path)
for _, legacyName := range candidates {
legacyPath, err := LegacyPathForName(legacyName)
if err != nil {
return err
}
paths = append(paths, legacyPath)
}
for _, path := range paths {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
}
return pruneManifestRoots()
}
func removeLegacyManifestPaths(n model.Name) error {
for _, legacyName := range legacyNameCandidates(n) {
legacyPath, err := LegacyPathForName(legacyName)
if err != nil {
return err
}
if err := os.Remove(legacyPath); err != nil && !os.IsNotExist(err) {
return err
}
}
legacyRoot, err := Path()
if err != nil {
return err
}
if err := PruneDirectory(legacyRoot); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func pruneManifestRoots() error {
roots := []func() (string, error){Path, V2Path}
for _, rootFn := range roots {
root, err := rootFn()
if err != nil {
return err
}
if err := PruneDirectory(root); err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}
// normalizeLogicalName maps any public host to the legacy default
// so that map keys use a single identity regardless of on-disk host.
func normalizeLogicalName(n model.Name) model.Name {
if isDefaultPublicHost(n.Host) {
n.Host = defaultPublicHost
}
return n
}
// canonicalV2Name maps any public host to the v2 canonical host
// for use in manifests-v2/ on-disk paths.
func canonicalV2Name(n model.Name) model.Name {
if isDefaultPublicHost(n.Host) {
n.Host = v2CanonicalHost
}
return n
}
func legacyNameCandidates(n model.Name) []model.Name {
names := []model.Name{n}
if !isDefaultPublicHost(n.Host) {
return names
}
alt := n
switch {
case strings.EqualFold(n.Host, defaultPublicHost):
alt.Host = v2CanonicalHost
default:
alt.Host = defaultPublicHost
}
return append(names, alt)
}
func isDefaultPublicHost(host string) bool {
return strings.EqualFold(host, defaultPublicHost) || strings.EqualFold(host, v2CanonicalHost)
}
func BlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"

View File

@@ -411,31 +411,12 @@ func CopyModel(src, dst model.Name) error {
return nil
}
manifests, err := manifest.Path()
data, err := manifest.ReadManifestData(src)
if err != nil {
return err
}
dstpath := filepath.Join(manifests, dst.Filepath())
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
srcpath := filepath.Join(manifests, src.Filepath())
srcfile, err := os.Open(srcpath)
if err != nil {
return err
}
defer srcfile.Close()
dstfile, err := os.Create(dstpath)
if err != nil {
return err
}
defer dstfile.Close()
_, err = io.Copy(dstfile, srcfile)
return err
return manifest.WriteManifestData(dst, data)
}
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
@@ -446,6 +427,10 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
}
for _, manifest := range manifests {
if manifest.BlobDigest() != "" {
delete(deleteMap, manifest.BlobDigest())
}
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
@@ -549,11 +534,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := manifest.PathForName(n)
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
manifestJSON, err := manifest.ReadManifestData(n)
if err != nil {
return err
}
@@ -610,6 +591,14 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{}
}
if existingMf.BlobDigest() != "" {
digest := existingMf.BlobDigest()
if blob, err := manifest.BlobsPath(digest); err == nil {
if _, err := os.Stat(blob); err == nil {
deleteMap[digest] = struct{}{}
}
}
}
}
if n.ProtocolScheme == "http" && !regOpts.Insecure {
@@ -679,21 +668,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "writing manifest"})
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
if err := manifest.WriteManifestData(n, manifestData); err != nil {
slog.Info(fmt.Sprintf("couldn't write manifest for %s", n.DisplayShortest()))
return err
}
err = os.WriteFile(fp, manifestData, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
@@ -776,19 +756,11 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
fp, err := manifest.PathForName(n)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
if err := manifest.WriteManifestData(n, manifestData); err != nil {
return err
}
if err := os.WriteFile(fp, manifestData, 0o644); err != nil {
return err
}
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
return nil
}

View File

@@ -116,6 +116,10 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
proxied, err := func() (bool, error) {
switch r.URL.Path {
case "/api/delete":
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
return true, nil
}
return false, s.handleDelete(rec, r)
case "/api/pull":
return false, s.handlePull(rec, r)

View File

@@ -1770,13 +1770,15 @@ func Serve(ln net.Listener) error {
return err
}
manifestsPath, err := manifest.Path()
if err != nil {
return err
}
for _, rootFn := range []func() (string, error){manifest.Path, manifest.V2Path} {
manifestsPath, err := rootFn()
if err != nil {
return err
}
if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
if err := manifest.PruneDirectory(manifestsPath); err != nil && !os.IsNotExist(err) {
return err
}
}
}
}

View File

@@ -109,12 +109,44 @@ func checkFileExists(t *testing.T, p string, expect []string) {
if err != nil {
t.Fatal(err)
}
if strings.HasSuffix(filepath.ToSlash(p), "/blobs/*") {
actual = slices.DeleteFunc(actual, isManifestBlobForTest)
}
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
t.Errorf("file exists mismatch (-want +got):\n%s", diff)
}
}
func checkManifestFiles(t *testing.T, names ...string) {
t.Helper()
expect := make([]string, len(names))
for i, name := range names {
p, err := manifest.V2PathForName(model.ParseName(name))
if err != nil {
t.Fatal(err)
}
expect[i] = p
}
checkFileExists(t, filepath.Join(envconfig.Models(), "manifests-v2", "*", "*", "*", "*"), expect)
}
func isManifestBlobForTest(path string) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
var m manifest.Manifest
if err := json.Unmarshal(data, &m); err != nil {
return false
}
return m.SchemaVersion != 0 && m.MediaType != "" && (m.Config.Digest != "" || len(m.Layers) > 0)
}
func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -136,9 +168,7 @@ func TestCreateFromBin(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -196,9 +226,7 @@ func TestCreateFromModel(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2",
@@ -210,10 +238,7 @@ func TestCreateFromModel(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test", "test2")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -306,9 +331,7 @@ func TestCreateRemovesLayers(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
@@ -327,9 +350,7 @@ func TestCreateRemovesLayers(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -357,9 +378,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
@@ -378,9 +397,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -411,9 +428,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
@@ -436,10 +451,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test", "test2")
// Display contents of each blob in the directory
blobDir := filepath.Join(p, "blobs")
@@ -495,10 +507,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test", "test2")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
@@ -555,9 +564,7 @@ func TestCreateReplacesMessages(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
@@ -589,10 +596,7 @@ func TestCreateReplacesMessages(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test", "test2")
// Old layers will not have been pruned
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
@@ -650,9 +654,7 @@ func TestCreateTemplateSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
@@ -850,9 +852,7 @@ func TestCreateLicenses(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkManifestFiles(t, "test")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),

View File

@@ -42,10 +42,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test", "test2")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -60,9 +57,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkManifestFiles(t, "test2")
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -76,7 +71,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkManifestFiles(t)
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}
@@ -109,7 +104,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Errorf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkManifestFiles(t)
}
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
@@ -129,14 +124,12 @@ func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
})
checkManifestFiles(t, "gpt-oss:20b-cloud")
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkManifestFiles(t)
}

View File

@@ -658,11 +658,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
checkManifestList := func() {
t.Helper()
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
mandir, err := manifest.V2Path()
if err != nil {
t.Fatalf("failed to resolve v2 manifest path: %v", err)
}
var entries []string
t.Logf("dir entries:")
fsys := os.DirFS(mandir)
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
err = fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
@@ -685,7 +688,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
g := entries[0] // raw path
g = filepath.ToSlash(g)
w := model.ParseName(wantStableName).Filepath()
wp, err := manifest.V2PathForName(model.ParseName(wantStableName))
if err != nil {
t.Fatalf("failed to resolve expected manifest path: %v", err)
}
w, err := filepath.Rel(mandir, wp)
if err != nil {
t.Fatalf("failed to make expected manifest path relative: %v", err)
}
w = filepath.ToSlash(w)
if g != w {
t.Errorf("\ngot: %s\nwant: %s", g, w)

View File

@@ -11,6 +11,8 @@ import (
"strings"
"github.com/ollama/ollama/envconfig"
rootmanifest "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
// ManifestLayer represents a layer in the manifest.
@@ -49,9 +51,7 @@ func DefaultManifestDir() string {
// LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
data, err := rootmanifest.ReadManifestData(model.ParseName(modelName))
if err != nil {
return nil, fmt.Errorf("read manifest: %w", err)
}
@@ -67,36 +67,6 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
}, nil
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
// Parse model name into components
// Default: registry.ollama.ai/library/<name>/<tag>
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
// Handle explicit tag
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
// Handle full path like "host/namespace/name"
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest.
func (m *ModelManifest) BlobPath(digest string) string {
// Convert "sha256:abc123" to "sha256-abc123"

View File

@@ -1,8 +1,12 @@
package manifest
import (
"os"
"path/filepath"
"testing"
rootmanifest "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
func TestTotalTensorSize(t *testing.T) {
@@ -55,3 +59,39 @@ func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
}
}
func TestLoadManifestPrefersV2(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
legacyPath, err := rootmanifest.PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
t.Fatal(err)
}
v2Path, err := rootmanifest.V2PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(v2Path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(v2Path, []byte(`{"schemaVersion":2,"mediaType":"v2"}`), 0o644); err != nil {
t.Fatal(err)
}
m, err := LoadManifest(name.String())
if err != nil {
t.Fatal(err)
}
if m.Manifest.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.Manifest.MediaType, "v2")
}
}

View File

@@ -151,22 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
}
}
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Options *completionOpts `json:"options,omitempty"`
}
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
type CompletionRequest struct {
Prompt string
Options api.Options
Logprobs bool
TopLogprobs int
}
type CompletionResponse struct {
@@ -179,6 +168,8 @@ type CompletionResponse struct {
EvalCount int
EvalDuration time.Duration
Logprobs []llm.Logprob
Error *api.StatusError
}
@@ -203,21 +194,13 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{
Prompt: req.Prompt,
creq := CompletionRequest{
Prompt: req.Prompt,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
RepeatPenalty: req.Options.RepeatPenalty,
PresencePenalty: req.Options.PresencePenalty,
FrequencyPenalty: req.Options.FrequencyPenalty,
NumPredict: req.Options.NumPredict,
}
creq.Options = *req.Options
}
body, err := json.Marshal(creq)
@@ -266,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration,
Logprobs: raw.Logprobs,
}
fn(cresp)

View File

@@ -238,6 +238,9 @@ func (t Array) Float() float64 {
}
func (t Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
@@ -246,6 +249,9 @@ func (t Array) Ints() []int {
}
func (t Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)

View File

@@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View File

@@ -7,11 +7,15 @@ import (
"fmt"
"log/slog"
"net/http"
"sort"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
func prefillChunkSize() int {
@@ -25,17 +29,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample *mlx.Array
nextSample *mlx.Array
)
var sample, nextSample sampler.Result
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample)
mlx.Unpin(nextSample)
mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep()
mlx.ClearCache()
@@ -60,10 +61,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Sampler.ResetHistory(inputs)
@@ -135,40 +136,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache()
}
step := func(token *mlx.Array) *mlx.Array {
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
sample := request.Sampler.Sample(logits)
mlx.Pin(sample)
mlx.Pin(sample.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(sample)
mlx.AsyncEval(sample.Arrays()...)
return sample
}
sample = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
request.Sampler.AppendToken(sample)
nextSample = step(sample)
request.Sampler.AppendToken(sample.Token)
nextSample = step(sample.Token)
if i == 0 {
mlx.Eval(sample)
mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
@@ -177,17 +176,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
break
}
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
if resp, ok := dec.decode(sample); ok {
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- resp:
}
}
mlx.Unpin(sample)
sample = nextSample
nextSample = nil
mlx.Unpin(sample.Arrays()...)
sample, nextSample = nextSample, sampler.Result{}
if i%256 == 0 {
mlx.ClearCache()
@@ -203,13 +201,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
// decoder serializes sampled tokens into response chunks, holding bytes
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
}
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
output := int32(res.Token.Int())
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
if sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
}
return flushValidUTF8Prefix(b)
if sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
}

View File

@@ -19,7 +19,7 @@ import (
)
type Request struct {
TextCompletionsRequest
CompletionRequest
Responses chan CompletionResponse
Pipeline func(Request) error
@@ -28,24 +28,6 @@ type Request struct {
Sampler *sample.Sampler
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer

View File

@@ -0,0 +1,249 @@
//go:build mlx
package sample
import (
"math"
"sort"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// logprobEntry is the (token id, logprob) pair returned by the sampler's
// top-K extraction, used after the test-side descending sort.
type logprobEntry struct {
id int
logprob float64
}
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
// and returns the greedily-sampled token id, its logprob, and the top-K
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper()
s := New(Options{Logprobs: true, TopLogprobs: topK})
defer func() {
s.Free()
mlx.Sweep()
}()
tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample(tensor)
mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...)
mlx.Sweep()
mlx.Eval(res.Arrays()...)
selected := res.Token.Int()
selLP := float64(res.Logprob.Floats()[0])
var top []logprobEntry
if topK > 0 && res.TopTokens != nil {
ids := res.TopTokens.Ints()
vals := res.TopLogprobs.Floats()
top = make([]logprobEntry, len(ids))
for i, id := range ids {
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
}
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
}
return selected, selLP, top
}
func TestSampleLogprobsBasic(t *testing.T) {
tests := []struct {
name string
logits []float32
topK int
wantSelectedID int
wantTopLen int
}{
{
name: "single token without top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 0,
wantSelectedID: 0,
wantTopLen: 0,
},
{
name: "single token with top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 3,
wantSelectedID: 0,
wantTopLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
if selected != tt.wantSelectedID {
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
}
if len(top) != tt.wantTopLen {
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
}
})
}
}
func TestSampleLogprobsNumericalStability(t *testing.T) {
logits := []float32{1000.0, 999.0, 998.0}
_, selLP, top := runSampleLogprobs(t, logits, 3)
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
t.Errorf("selected logprob is not finite: %f", selLP)
}
for i, e := range top {
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
}
}
}
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if selLP > 0 {
t.Errorf("selected logprob should be <= 0, got %f", selLP)
}
for i, e := range top {
if e.logprob > 0 {
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
}
}
if tt.name == "uniform" {
want := 1.0 / float64(len(tt.logits))
got := math.Exp(selLP)
if math.Abs(got-want) > 1e-6 {
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending at %d: %f > %f",
i, top[i].logprob, top[i-1].logprob)
}
}
found := false
for _, e := range top {
if e.id == selected {
found = true
if math.Abs(e.logprob-selLP) > 1e-6 {
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
}
break
}
}
if !found {
t.Errorf("selected token %d not present in top-K", selected)
}
})
}
}
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
{"large differences", []float32{10.0, 0.0, -10.0}},
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
{"very large values", []float32{500.0, 499.0, 498.0}},
{"very small values", []float32{-500.0, -499.0, -498.0}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if len(top) != len(tt.logits) {
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
}
var sum float64
for _, e := range top {
p := math.Exp(e.logprob)
if p < 0 || p > 1 {
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
}
sum += p
}
if math.Abs(sum-1.0) > 1e-5 {
t.Errorf("probabilities sum = %f, want 1.0", sum)
}
})
}
}
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
logits := []float32{3.0, 1.0, 2.0, 0.5}
maxIdx := 0
for i, v := range logits[1:] {
if v > logits[maxIdx] {
maxIdx = i + 1
}
}
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
if selected != maxIdx {
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
}
if top[0].id != maxIdx {
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
}
if math.Abs(top[0].logprob-selLP) > 1e-6 {
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) {
// Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
wantOrder := []int{1, 3, 4, 0, 2}
_, _, top := runSampleLogprobs(t, logits, len(logits))
if len(top) != len(wantOrder) {
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
}
for i, e := range top {
if e.id != wantOrder[i] {
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
i, top[i].logprob, i-1, top[i-1].logprob)
}
}
}

View File

@@ -8,7 +8,7 @@ import (
type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Sampler struct {
type Options struct {
Temperature float32
TopP float32
MinP float32
@@ -18,45 +18,66 @@ type Sampler struct {
PresencePenalty float32
FrequencyPenalty float32
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
Logprobs bool
TopLogprobs int
}
type Sampler struct {
Options
history *mlx.Array
historyLen int
transforms []Transform
}
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) *Sampler {
if repeatPenalty <= 0 {
repeatPenalty = 1
// Result bundles the outputs of one decode step. The logprob tensors are
// populated only when the sampler is configured to report them.
type Result struct {
Token *mlx.Array // sampled token id, shape [B]
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
}
// Arrays returns the tensor fields as a slice so callers can drive the mlx
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
// fields stay nil; the mlx helpers skip them.
func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
func New(opts Options) *Sampler {
if opts.RepeatPenalty <= 0 {
opts.RepeatPenalty = 1
}
s := &Sampler{
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
}
s := &Sampler{Options: opts}
var transforms []Transform
if s.usesHistory() {
transforms = append(transforms, penalty)
}
if top_p > 0 && top_p < 1 {
transforms = append(transforms, topP)
}
if min_p != 0 {
transforms = append(transforms, minP)
}
if top_k > 0 {
hasTopP := opts.TopP > 0 && opts.TopP < 1
hasTopK := opts.TopK > 0
switch {
case hasTopP:
// topKTopP always does a full descending sort for the top-P
// cumulative mask and opportunistically masks top-K during the
// same pass when it is also configured.
transforms = append(transforms, topKTopP)
case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
transforms = append(transforms, topK)
}
if temp == 0 {
if opts.MinP != 0 {
transforms = append(transforms, minP)
}
if opts.Temperature == 0 {
transforms = append(transforms, greedy)
} else {
transforms = append(transforms, temperature)
@@ -123,76 +144,121 @@ func (s *Sampler) Free() {
s.setHistory(nil, 0)
}
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
// Sample runs the configured transform chain on the raw per-token logits
// and returns the sampled token id plus, when configured, the reported
// log-probability tensors for the selected token and the top-K tokens.
func (s *Sampler) Sample(logits *mlx.Array) Result {
scores := logits
for _, transform := range s.transforms {
logits = transform(s, logits)
scores = transform(s, scores)
}
return logits
}
res := Result{Token: scores}
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
}
func topP(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.TopP <= 0 || s.TopP >= 1 {
return logits
if s.Logprobs {
// Compute log_softmax in fp32 and subtract the max before
// logsumexp so the final subtraction stays on small values.
// Otherwise it cancels two large numbers and loses precision.
lp := logits.AsType(mlx.DTypeFloat32)
lp = lp.Subtract(lp.MaxAxis(-1, true))
lp = lp.Subtract(lp.Logsumexp(true))
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
if k := s.TopLogprobs; k > 0 {
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
k = vocab
}
// Argpartition on the negated values places the K largest
// (unsorted) in positions [0:K].
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
res.TopTokens = idx.AsType(mlx.DTypeInt32)
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
}
}
return res
}
order := logits.Negative().ArgsortAxis(-1)
sortedLogits := logits.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedLogits, -1, true)
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
return scores.Argmax(-1, false)
}
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
// case uses a cheaper partial sort via the topK transform.
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := s.TopK > 0 && s.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below s.TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogits, mlx.FromValue(float32(math.Inf(-1))))
return logits.PutAlongAxis(order, filtered, -1)
}
sorted = mlx.Where(keep, sorted, negInf)
func minP(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return logits
out := scores.PutAlongAxis(order, sorted, -1)
// Top-K: sorted is already in descending order, so positions [K, V)
// are the ones to drop. Scatter -inf through their original-layout
// indices (order[K:]). Positional (not value-based) so exactly K
// tokens survive — ties at the K-th logit get broken by the sort
// order rather than promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
}
maxLogits := logits.TakeAlongAxis(logits.Argmax(-1, true), -1)
minLogits := mlx.AddScalar(maxLogits, float32(math.Log(float64(s.MinP))))
return out
}
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return scores
}
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
return mlx.Where(
logits.Less(minLogits),
scores.Less(threshold),
mlx.FromValue(float32(math.Inf(-1))),
logits,
scores,
)
}
func topK(s *Sampler, logits *mlx.Array) *mlx.Array {
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.TopK <= 0 {
return logits
return scores
}
vocab := logits.Dim(logits.NumDims() - 1)
vocab := scores.Dim(scores.NumDims() - 1)
if s.TopK >= vocab {
return logits
return scores
}
mask := logits.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logits.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
func penalty(s *Sampler, logits *mlx.Array) *mlx.Array {
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.historyLen == 0 {
return logits
return scores
}
tokenIndices := s.history
if logits.NumDims() > 1 {
if scores.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0)
}
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
adjusted := logits.TakeAlongAxis(tokenIndices, -1)
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
if s.RepeatPenalty != 1 {
factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))),
@@ -204,12 +270,12 @@ func penalty(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
}
logits = logits.PutAlongAxis(tokenIndices, adjusted, -1)
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
}
if s.FrequencyPenalty != 0 {
logits = logits.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
}
return logits
return scores
}

View File

@@ -10,8 +10,7 @@ import (
)
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
// RepeatLastN = 1, PresencePenalty = 6
s := New(0, 0, 0, 0, 1, 1, 6, 0)
s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
defer func() {
s.Free()
mlx.Sweep()
@@ -21,7 +20,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
got := s.Sample(logits).Token
mlx.Eval(got)
// logits will be [0, -1, 4] after the penalty
@@ -33,7 +32,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
}
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s := New(0, 0, 0, 0, 1, 2, 0, 0)
s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
defer func() {
s.Free()
mlx.Sweep()
@@ -42,7 +41,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s.ResetHistory([]int32{1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
got := s.Sample(logits).Token
mlx.Eval(got)
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
@@ -53,7 +52,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
}
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s := New(0, 0, 0, 0, 4, 1, 0, 2)
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
defer func() {
s.Free()
mlx.Sweep()
@@ -62,7 +61,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s.ResetHistory([]int32{1, 1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
got := s.Sample(logits).Token
mlx.Eval(got)
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
@@ -73,7 +72,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
}
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 1, 0, 0)
s := New(Options{MinP: 0.5})
defer func() {
s.Free()
mlx.Sweep()

View File

@@ -2,7 +2,6 @@ package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
@@ -87,25 +86,25 @@ func Execute(args []string) error {
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
request.Options.RepeatLastN,
request.Options.RepeatPenalty,
request.Options.PresencePenalty,
request.Options.FrequencyPenalty,
)
request.Sampler = sample.New(sample.Options{
Temperature: request.Options.Temperature,
TopP: request.Options.TopP,
MinP: request.Options.MinP,
TopK: request.Options.TopK,
RepeatLastN: request.Options.RepeatLastN,
RepeatPenalty: request.Options.RepeatPenalty,
PresencePenalty: request.Options.PresencePenalty,
FrequencyPenalty: request.Options.FrequencyPenalty,
Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
})
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())

View File

@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
gotInds = gotInds.AsType(mlx.DTypeInt32)
wantInds = wantInds.AsType(mlx.DTypeInt32)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {

View File

@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input)
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut)
got := qOut.Floats()