pull/push: refine safetensors (#14946)

* pull: refine safetensors pull

 - Body drain in resolve() — drain response body before close so Go's HTTP
   client can reuse TCP connections instead of opening a new one per blob
   (1,075 extra TCP+TLS handshakes eliminated)
 - Skip speed recording for tiny blobs (<100KB) — prevents
   HTTP-overhead-dominated transfer times from poisoning the median, which the
   stall detector uses to cancel "too slow" downloads
 - Resume support for large blobs (>=64MB) — on failure, preserves partial .tmp
   files; on retry, re-hashes existing datak and sends Range header to download
   only remaining bytes; gracefully falls back to full download if server returns
   200 instead of 206; SHA256 verification catches corrupt partials

* harden push

- Prevents killing TCP connections after every request.
- Stronger backoff to handle server back-pressure and rate limiting
- Larger buffered reads for improve safetensor upload performance
- Better error message handling from server
- Handle 201 if server says blob exists
- Fix progress reporting on already uploaded blobs
- Trace logging to help troubleshoot and tune going forward

* review comments

* review comments
This commit is contained in:
Daniel Hiltgen
2026-04-08 14:15:39 -07:00
committed by GitHub
parent d17f482d50
commit b5918f9785
4 changed files with 592 additions and 68 deletions

View File

@@ -117,14 +117,25 @@ func (d *downloader) download(ctx context.Context, blob Blob) error {
start := time.Now()
n, err := d.downloadOnce(ctx, blob)
if err == nil {
if s := time.Since(start).Seconds(); s > 0 {
d.speeds.record(float64(blob.Size) / s)
// Skip speed recording for tiny blobs — their transfer time is
// dominated by HTTP overhead, not throughput, and would pollute
// the median used for stall detection.
if blob.Size >= smallBlobSpeedThreshold {
if s := time.Since(start).Seconds(); s > 0 {
d.speeds.record(float64(blob.Size) / s)
}
}
return nil
}
d.progress.add(-n) // rollback
// Preserve partial .tmp files for large blobs to enable resume
if blob.Size < resumeThreshold {
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
os.Remove(dest + ".tmp")
}
switch {
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return err
@@ -153,55 +164,106 @@ func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error)
return 0, err
}
// Check for existing partial .tmp file for resume
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
tmp := dest + ".tmp"
var existingSize int64
if blob.Size >= resumeThreshold {
if fi, statErr := os.Stat(tmp); statErr == nil {
if fi.Size() < blob.Size {
existingSize = fi.Size()
} else if fi.Size() > blob.Size {
// .tmp larger than expected — discard
os.Remove(tmp)
}
// fi.Size() == blob.Size handled in save (hash check + rename)
}
}
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", d.userAgent)
// Add auth only for same-host (not CDN)
if u.Host == baseURL.Host && *d.token != "" {
req.Header.Set("Authorization", "Bearer "+*d.token)
}
if existingSize > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize))
}
resp, err := d.client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusOK:
// Full response — reset any partial state
existingSize = 0
case http.StatusPartialContent:
// Resume succeeded
default:
return 0, fmt.Errorf("status %d", resp.StatusCode)
}
return d.save(ctx, blob, resp.Body)
return d.save(ctx, blob, resp.Body, existingSize)
}
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader, existingSize int64) (int64, error) {
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
tmp := dest + ".tmp"
os.MkdirAll(filepath.Dir(dest), 0o755)
f, err := os.Create(tmp)
if err != nil {
return 0, err
h := sha256.New()
var f *os.File
var err error
if existingSize > 0 {
// Resume — re-hash existing partial data, then append
f, err = os.OpenFile(tmp, os.O_RDWR, 0o644)
if err != nil {
// Can't open partial file, start fresh
existingSize = 0
} else {
// Hash the existing data
if _, hashErr := io.CopyN(h, f, existingSize); hashErr != nil {
f.Close()
os.Remove(tmp)
existingSize = 0
} else {
// Report resumed bytes as progress
d.progress.add(existingSize)
}
}
}
if existingSize == 0 {
f, err = os.Create(tmp)
if err != nil {
return 0, err
}
setSparse(f)
}
defer f.Close()
setSparse(f)
h := sha256.New()
n, err := d.copy(ctx, f, r, h)
if err != nil {
os.Remove(tmp)
return n, err
// Don't remove .tmp here — download() handles cleanup based on blob size
return existingSize + n, err
}
f.Close()
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
os.Remove(tmp)
return n, fmt.Errorf("digest mismatch")
return existingSize + n, fmt.Errorf("digest mismatch")
}
if n != blob.Size {
totalWritten := existingSize + n
if totalWritten != blob.Size {
os.Remove(tmp)
return n, fmt.Errorf("size mismatch")
return totalWritten, fmt.Errorf("size mismatch")
}
return n, os.Rename(tmp, dest)
return totalWritten, os.Rename(tmp, dest)
}
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
@@ -261,6 +323,9 @@ func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h i
}
}
// resolve follows redirects to find the final download URL.
// Uses GET (not HEAD) because registries may return 200 for HEAD without
// redirecting to CDN, while GET triggers the actual CDN redirect.
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
u, _ := url.Parse(rawURL)
for range 10 {
@@ -274,6 +339,8 @@ func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, erro
if err != nil {
return nil, err
}
// Drain body before close to enable HTTP connection reuse
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
switch resp.StatusCode {

View File

@@ -11,7 +11,8 @@
// Key simplifications for many-small-blob workloads:
//
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
// - Resume for large blobs: Blobs >= 64MB preserve partial .tmp files on failure
// and use HTTP Range requests on retry. Small blobs restart from scratch.
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
//
@@ -99,6 +100,14 @@ const (
DefaultUploadConcurrency = 32
maxRetries = 6
defaultUserAgent = "ollama-transfer/1.0"
// resumeThreshold is the minimum blob size for resume support.
// Only blobs above this size keep partial .tmp files on failure.
resumeThreshold = 64 << 20 // 64 MB
// smallBlobSpeedThreshold is the size below which speed samples are skipped,
// since their transfer time is dominated by HTTP overhead, not throughput.
smallBlobSpeedThreshold = 100 << 10 // 100 KB
)
var errMaxRetriesExceeded = errors.New("max retries exceeded")
@@ -134,6 +143,11 @@ func (p *progressTracker) add(n int64) {
return
}
completed := p.completed.Add(n)
defer func() {
if r := recover(); r != nil {
slog.Debug("progress callback panic (likely closed channel)", "recovered", r)
}
}()
p.callback(completed, p.total)
}

View File

@@ -1775,3 +1775,389 @@ func TestThroughput(t *testing.T) {
t.Logf("Warning: total time %v exceeds 500ms target", dlElapsed+ulElapsed)
}
}
// ==================== Resume Tests ====================
func TestResumeFromPartialFile(t *testing.T) {
// Create a blob large enough for resume (>= resumeThreshold)
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var rangeHeader string
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
mu.Lock()
rangeHeader = r.Header.Get("Range")
mu.Unlock()
rng := r.Header.Get("Range")
if rng != "" {
// Parse "bytes=N-"
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start > 0 && start < int64(blobSize) {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file (first half)
partialSize := blobSize / 2
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:partialSize], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Resume download failed: %v", err)
}
// Verify Range header was sent
mu.Lock()
if rangeHeader == "" {
t.Error("Expected Range header for resume, got none")
} else {
expected := fmt.Sprintf("bytes=%d-", partialSize)
if rangeHeader != expected {
t.Errorf("Range header = %q, want %q", rangeHeader, expected)
}
}
mu.Unlock()
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
if len(finalData) != blobSize {
t.Errorf("Final file size = %d, want %d", len(finalData), blobSize)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeCorruptPartialFile(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
rng := r.Header.Get("Range")
if rng != "" {
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start > 0 && start < int64(blobSize) {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file with CORRUPT data
partialSize := blobSize / 2
corruptData := make([]byte, partialSize)
for i := range corruptData {
corruptData[i] = 0xFF // All 0xFF — definitely wrong
}
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", corruptData, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
// First attempt resumes with corrupt data → hash mismatch → retry.
// Retry should clean up .tmp and re-download fully.
if err != nil {
t.Fatalf("Download with corrupt partial file failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch after corrupt resume recovery")
}
}
func TestResumePartialFileLargerThanBlob(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create .tmp file LARGER than expected blob
oversizedData := make([]byte, blobSize+1000)
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", oversizedData, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download with oversized .tmp failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeBelowThreshold(t *testing.T) {
// Blob below resume threshold should NOT attempt resume
blobSize := 1024 // Well below resumeThreshold
data := make([]byte, blobSize)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var gotRange atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
if r.Header.Get("Range") != "" {
gotRange.Store(true)
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
if gotRange.Load() {
t.Error("Range header sent for blob below resume threshold — should not attempt resume")
}
// Verify final file
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeServerDoesNotSupportRange(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
// Ignore Range header — always return full content with 200
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create partial .tmp file
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed when server doesn't support Range: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumePartialFileExactSize(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
requestCount.Add(1)
rng := r.Header.Get("Range")
if rng != "" {
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start >= int64(blobSize) {
// Nothing to send
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return
}
if start > 0 {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create .tmp file with exact correct content (full size)
// This simulates a download that completed but wasn't renamed
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}

View File

@@ -1,6 +1,7 @@
package transfer
import (
"bufio"
"bytes"
"cmp"
"context"
@@ -14,6 +15,8 @@ import (
"path/filepath"
"time"
"github.com/ollama/ollama/logutil"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
@@ -76,23 +79,30 @@ func upload(ctx context.Context, opts UploadOptions) error {
}
}
// Filter to only blobs that need uploading
// Filter to only blobs that need uploading, but track total across all blobs
var toUpload []Blob
var total int64
var totalSize, alreadyExists int64
for i, blob := range opts.Blobs {
totalSize += blob.Size
if needsUpload[i] {
toUpload = append(toUpload, blob)
total += blob.Size
} else {
alreadyExists += blob.Size
}
}
// Progress includes all blobs — already-existing ones start as completed
u.progress = newProgressTracker(totalSize, opts.Progress)
u.progress.add(alreadyExists)
logutil.Trace("upload plan", "total_blobs", len(opts.Blobs), "need_upload", len(toUpload), "already_exist", len(opts.Blobs)-len(toUpload), "total_bytes", totalSize, "existing_bytes", alreadyExists)
if len(toUpload) == 0 {
if u.logger != nil {
u.logger.Debug("all blobs exist, nothing to upload")
}
} else {
// Phase 2: Upload blobs that don't exist
u.progress = newProgressTracker(total, opts.Progress)
concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency)
sem := semaphore.NewWeighted(int64(concurrency))
@@ -113,7 +123,12 @@ func upload(ctx context.Context, opts UploadOptions) error {
}
if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" {
return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest)
logutil.Trace("pushing manifest", "repo", opts.Repository, "ref", opts.ManifestRef, "size", len(opts.Manifest))
if err := u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest); err != nil {
logutil.Trace("manifest push failed", "error", err)
return err
}
logutil.Trace("manifest push succeeded", "repo", opts.Repository, "ref", opts.ManifestRef)
}
return nil
}
@@ -124,7 +139,10 @@ func (u *uploader) upload(ctx context.Context, blob Blob) error {
for attempt := range maxRetries {
if attempt > 0 {
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
// Use longer backoff for uploads — server-side rate limiting
// and S3 upload session creation need real breathing room.
// attempt 1: up to 2s, attempt 2: up to 4s, attempt 3: up to 8s, etc.
if err := backoff(ctx, attempt, 2*time.Second<<uint(attempt-1)); err != nil {
return err
}
}
@@ -132,13 +150,16 @@ func (u *uploader) upload(ctx context.Context, blob Blob) error {
var err error
n, err = u.uploadOnce(ctx, blob)
if err == nil {
logutil.Trace("blob upload complete", "digest", blob.Digest, "bytes", n, "attempt", attempt+1)
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logutil.Trace("blob upload cancelled", "digest", blob.Digest, "error", err)
return err
}
logutil.Trace("blob upload failed, retrying", "digest", blob.Digest, "attempt", attempt+1, "bytes", n, "error", err)
u.progress.add(-n)
lastErr = err
}
@@ -178,6 +199,7 @@ func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
if err != nil {
return false, err
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
@@ -191,58 +213,91 @@ func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
return resp.StatusCode == http.StatusOK, nil
}
const maxInitRetries = 12
func (u *uploader) initUpload(ctx context.Context, blob Blob) (string, error) {
endpoint, _ := url.Parse(fmt.Sprintf("%s/v2/%s/blobs/uploads/", u.baseURL, u.repository))
q := endpoint.Query()
q.Set("digest", blob.Digest)
endpoint.RawQuery = q.Encode()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
resp, err := u.client.Do(req)
if err != nil {
return "", err
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return "", err
var lastErr error
for attempt := range maxInitRetries {
if attempt > 0 {
// Start at 5s and cap at 30s — the server needs real breathing
// room when it's dropping Location headers under load.
if err := backoff(ctx, attempt, min(5*time.Second<<uint(attempt-1), 30*time.Second)); err != nil {
return "", err
}
logutil.Trace("retrying init upload", "digest", blob.Digest, "attempt", attempt+1, "error", lastErr)
}
return u.initUpload(ctx, blob)
}
if resp.StatusCode != http.StatusAccepted {
return "", fmt.Errorf("init: status %d", resp.StatusCode)
}
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
loc := resp.Header.Get("Docker-Upload-Location")
if loc == "" {
loc = resp.Header.Get("Location")
}
if loc == "" {
return "", fmt.Errorf("no upload location")
}
resp, err := u.client.Do(req)
if err != nil {
lastErr = err
continue
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
locURL, _ := url.Parse(loc)
if !locURL.IsAbs() {
base, _ := url.Parse(u.baseURL)
locURL = base.ResolveReference(locURL)
}
q = locURL.Query()
q.Set("digest", blob.Digest)
locURL.RawQuery = q.Encode()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return "", err
}
continue
}
return locURL.String(), nil
if resp.StatusCode == http.StatusCreated {
// Blob was mounted or already exists — no upload needed
return "", nil
}
if resp.StatusCode != http.StatusAccepted {
lastErr = fmt.Errorf("init upload: status %d", resp.StatusCode)
continue
}
loc := resp.Header.Get("Docker-Upload-Location")
if loc == "" {
loc = resp.Header.Get("Location")
}
if loc == "" {
// Server returned 202 but no Location — retry, the server may
// be under load and dropping headers.
lastErr = fmt.Errorf("no upload location (server returned 202 without Location header)")
continue
}
locURL, _ := url.Parse(loc)
if !locURL.IsAbs() {
base, _ := url.Parse(u.baseURL)
locURL = base.ResolveReference(locURL)
}
q = locURL.Query()
q.Set("digest", blob.Digest)
locURL.RawQuery = q.Encode()
return locURL.String(), nil
}
return "", lastErr
}
func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size int64) (int64, error) {
pr := &progressReader{reader: f, tracker: u.progress}
// uploadURL is empty when initUpload determined the blob already exists (201 Created)
if uploadURL == "" {
return 0, nil
}
// Buffer reads for better throughput — 256KB reads instead of default 4KB
br := bufio.NewReaderSize(f, 256*1024)
pr := &progressReader{reader: br, tracker: u.progress}
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, pr)
req.ContentLength = size
@@ -254,9 +309,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
resp, err := u.client.Do(req)
if err != nil {
return pr.n, err
return pr.n, fmt.Errorf("put request: %w", err)
}
defer resp.Body.Close()
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
// Handle auth retry
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
@@ -274,7 +329,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
loc, _ := resp.Location()
f.Seek(0, 0)
u.progress.add(-pr.n)
pr2 := &progressReader{reader: f, tracker: u.progress}
br2 := bufio.NewReaderSize(f, 256*1024)
pr2 := &progressReader{reader: br2, tracker: u.progress}
req2, _ := http.NewRequestWithContext(ctx, http.MethodPut, loc.String(), pr2)
req2.ContentLength = size
@@ -283,9 +340,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
resp2, err := u.client.Do(req2)
if err != nil {
return pr2.n, err
return pr2.n, fmt.Errorf("cdn put request: %w", err)
}
defer resp2.Body.Close()
defer func() { io.Copy(io.Discard, resp2.Body); resp2.Body.Close() }()
if resp2.StatusCode != http.StatusCreated && resp2.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp2.Body)
@@ -313,7 +370,7 @@ func (u *uploader) pushManifest(ctx context.Context, repo, ref string, manifest
if err != nil {
return err
}
defer resp.Body.Close()
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))