diff --git a/x/imagegen/transfer/download.go b/x/imagegen/transfer/download.go index 97738a870..64a6f8891 100644 --- a/x/imagegen/transfer/download.go +++ b/x/imagegen/transfer/download.go @@ -117,14 +117,25 @@ func (d *downloader) download(ctx context.Context, blob Blob) error { start := time.Now() n, err := d.downloadOnce(ctx, blob) if err == nil { - if s := time.Since(start).Seconds(); s > 0 { - d.speeds.record(float64(blob.Size) / s) + // Skip speed recording for tiny blobs — their transfer time is + // dominated by HTTP overhead, not throughput, and would pollute + // the median used for stall detection. + if blob.Size >= smallBlobSpeedThreshold { + if s := time.Since(start).Seconds(); s > 0 { + d.speeds.record(float64(blob.Size) / s) + } } return nil } d.progress.add(-n) // rollback + // Preserve partial .tmp files for large blobs to enable resume + if blob.Size < resumeThreshold { + dest := filepath.Join(d.destDir, digestToPath(blob.Digest)) + os.Remove(dest + ".tmp") + } + switch { case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): return err @@ -153,55 +164,106 @@ func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) return 0, err } + // Check for existing partial .tmp file for resume + dest := filepath.Join(d.destDir, digestToPath(blob.Digest)) + tmp := dest + ".tmp" + var existingSize int64 + if blob.Size >= resumeThreshold { + if fi, statErr := os.Stat(tmp); statErr == nil { + if fi.Size() < blob.Size { + existingSize = fi.Size() + } else if fi.Size() > blob.Size { + // .tmp larger than expected — discard + os.Remove(tmp) + } + // fi.Size() == blob.Size handled in save (hash check + rename) + } + } + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) req.Header.Set("User-Agent", d.userAgent) // Add auth only for same-host (not CDN) if u.Host == baseURL.Host && *d.token != "" { req.Header.Set("Authorization", "Bearer "+*d.token) } + if existingSize > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize)) + } resp, err := d.client.Do(req) if err != nil { return 0, err } - defer resp.Body.Close() + defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusOK: + // Full response — reset any partial state + existingSize = 0 + case http.StatusPartialContent: + // Resume succeeded + default: return 0, fmt.Errorf("status %d", resp.StatusCode) } - return d.save(ctx, blob, resp.Body) + return d.save(ctx, blob, resp.Body, existingSize) } -func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) { +func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader, existingSize int64) (int64, error) { dest := filepath.Join(d.destDir, digestToPath(blob.Digest)) tmp := dest + ".tmp" os.MkdirAll(filepath.Dir(dest), 0o755) - f, err := os.Create(tmp) - if err != nil { - return 0, err + h := sha256.New() + + var f *os.File + var err error + + if existingSize > 0 { + // Resume — re-hash existing partial data, then append + f, err = os.OpenFile(tmp, os.O_RDWR, 0o644) + if err != nil { + // Can't open partial file, start fresh + existingSize = 0 + } else { + // Hash the existing data + if _, hashErr := io.CopyN(h, f, existingSize); hashErr != nil { + f.Close() + os.Remove(tmp) + existingSize = 0 + } else { + // Report resumed bytes as progress + d.progress.add(existingSize) + } + } + } + + if existingSize == 0 { + f, err = os.Create(tmp) + if err != nil { + return 0, err + } + setSparse(f) } defer f.Close() - setSparse(f) - h := sha256.New() n, err := d.copy(ctx, f, r, h) if err != nil { - os.Remove(tmp) - return n, err + // Don't remove .tmp here — download() handles cleanup based on blob size + return existingSize + n, err } f.Close() if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest { os.Remove(tmp) - return n, fmt.Errorf("digest mismatch") + return existingSize + n, fmt.Errorf("digest mismatch") } - if n != blob.Size { + totalWritten := existingSize + n + if totalWritten != blob.Size { os.Remove(tmp) - return n, fmt.Errorf("size mismatch") + return totalWritten, fmt.Errorf("size mismatch") } - return n, os.Rename(tmp, dest) + return totalWritten, os.Rename(tmp, dest) } func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) { @@ -261,6 +323,9 @@ func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h i } } +// resolve follows redirects to find the final download URL. +// Uses GET (not HEAD) because registries may return 200 for HEAD without +// redirecting to CDN, while GET triggers the actual CDN redirect. func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) { u, _ := url.Parse(rawURL) for range 10 { @@ -274,6 +339,8 @@ func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, erro if err != nil { return nil, err } + // Drain body before close to enable HTTP connection reuse + io.Copy(io.Discard, resp.Body) resp.Body.Close() switch resp.StatusCode { diff --git a/x/imagegen/transfer/transfer.go b/x/imagegen/transfer/transfer.go index 05842f065..33e6d589e 100644 --- a/x/imagegen/transfer/transfer.go +++ b/x/imagegen/transfer/transfer.go @@ -11,7 +11,8 @@ // Key simplifications for many-small-blob workloads: // // - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit. -// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs). +// - Resume for large blobs: Blobs >= 64MB preserve partial .tmp files on failure +// and use HTTP Range requests on retry. Small blobs restart from scratch. // - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete. // - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median. // @@ -99,6 +100,14 @@ const ( DefaultUploadConcurrency = 32 maxRetries = 6 defaultUserAgent = "ollama-transfer/1.0" + + // resumeThreshold is the minimum blob size for resume support. + // Only blobs above this size keep partial .tmp files on failure. + resumeThreshold = 64 << 20 // 64 MB + + // smallBlobSpeedThreshold is the size below which speed samples are skipped, + // since their transfer time is dominated by HTTP overhead, not throughput. + smallBlobSpeedThreshold = 100 << 10 // 100 KB ) var errMaxRetriesExceeded = errors.New("max retries exceeded") @@ -134,6 +143,11 @@ func (p *progressTracker) add(n int64) { return } completed := p.completed.Add(n) + defer func() { + if r := recover(); r != nil { + slog.Debug("progress callback panic (likely closed channel)", "recovered", r) + } + }() p.callback(completed, p.total) } diff --git a/x/imagegen/transfer/transfer_test.go b/x/imagegen/transfer/transfer_test.go index 78e386ed9..589148125 100644 --- a/x/imagegen/transfer/transfer_test.go +++ b/x/imagegen/transfer/transfer_test.go @@ -1775,3 +1775,389 @@ func TestThroughput(t *testing.T) { t.Logf("Warning: total time %v exceeds 500ms target", dlElapsed+ulElapsed) } } + +// ==================== Resume Tests ==================== + +func TestResumeFromPartialFile(t *testing.T) { + // Create a blob large enough for resume (>= resumeThreshold) + blobSize := resumeThreshold + 1024 + data := make([]byte, blobSize) + for i := range data { + data[i] = byte((i * 13) % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + var rangeHeader string + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + + mu.Lock() + rangeHeader = r.Header.Get("Range") + mu.Unlock() + + rng := r.Header.Get("Range") + if rng != "" { + // Parse "bytes=N-" + var start int64 + fmt.Sscanf(rng, "bytes=%d-", &start) + if start > 0 && start < int64(blobSize) { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize)) + w.WriteHeader(http.StatusPartialContent) + w.Write(data[start:]) + return + } + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create a partial .tmp file (first half) + partialSize := blobSize / 2 + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", data[:partialSize], 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + if err != nil { + t.Fatalf("Resume download failed: %v", err) + } + + // Verify Range header was sent + mu.Lock() + if rangeHeader == "" { + t.Error("Expected Range header for resume, got none") + } else { + expected := fmt.Sprintf("bytes=%d-", partialSize) + if rangeHeader != expected { + t.Errorf("Range header = %q, want %q", rangeHeader, expected) + } + } + mu.Unlock() + + // Verify final file is correct + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + if len(finalData) != blobSize { + t.Errorf("Final file size = %d, want %d", len(finalData), blobSize) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch") + } +} + +func TestResumeCorruptPartialFile(t *testing.T) { + blobSize := resumeThreshold + 1024 + data := make([]byte, blobSize) + for i := range data { + data[i] = byte((i * 13) % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + + rng := r.Header.Get("Range") + if rng != "" { + var start int64 + fmt.Sscanf(rng, "bytes=%d-", &start) + if start > 0 && start < int64(blobSize) { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize)) + w.WriteHeader(http.StatusPartialContent) + w.Write(data[start:]) + return + } + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create a partial .tmp file with CORRUPT data + partialSize := blobSize / 2 + corruptData := make([]byte, partialSize) + for i := range corruptData { + corruptData[i] = 0xFF // All 0xFF — definitely wrong + } + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", corruptData, 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + // First attempt resumes with corrupt data → hash mismatch → retry. + // Retry should clean up .tmp and re-download fully. + if err != nil { + t.Fatalf("Download with corrupt partial file failed: %v", err) + } + + // Verify final file is correct + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch after corrupt resume recovery") + } +} + +func TestResumePartialFileLargerThanBlob(t *testing.T) { + blobSize := resumeThreshold + 1024 + data := make([]byte, blobSize) + for i := range data { + data[i] = byte((i * 13) % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create .tmp file LARGER than expected blob + oversizedData := make([]byte, blobSize+1000) + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", oversizedData, 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + if err != nil { + t.Fatalf("Download with oversized .tmp failed: %v", err) + } + + // Verify final file is correct + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch") + } +} + +func TestResumeBelowThreshold(t *testing.T) { + // Blob below resume threshold should NOT attempt resume + blobSize := 1024 // Well below resumeThreshold + data := make([]byte, blobSize) + for i := range data { + data[i] = byte(i % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + var gotRange atomic.Bool + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + if r.Header.Get("Range") != "" { + gotRange.Store(true) + } + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create a partial .tmp file + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + if gotRange.Load() { + t.Error("Range header sent for blob below resume threshold — should not attempt resume") + } + + // Verify final file + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch") + } +} + +func TestResumeServerDoesNotSupportRange(t *testing.T) { + blobSize := resumeThreshold + 1024 + data := make([]byte, blobSize) + for i := range data { + data[i] = byte((i * 13) % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + // Ignore Range header — always return full content with 200 + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create partial .tmp file + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + if err != nil { + t.Fatalf("Download failed when server doesn't support Range: %v", err) + } + + // Verify final file is correct + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch") + } +} + +func TestResumePartialFileExactSize(t *testing.T) { + blobSize := resumeThreshold + 1024 + data := make([]byte, blobSize) + for i := range data { + data[i] = byte((i * 13) % 256) + } + h := sha256.Sum256(data) + digest := fmt.Sprintf("sha256:%x", h) + blob := Blob{Digest: digest, Size: int64(blobSize)} + + var requestCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + return + } + requestCount.Add(1) + + rng := r.Header.Get("Range") + if rng != "" { + var start int64 + fmt.Sscanf(rng, "bytes=%d-", &start) + if start >= int64(blobSize) { + // Nothing to send + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + if start > 0 { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize)) + w.WriteHeader(http.StatusPartialContent) + w.Write(data[start:]) + return + } + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize)) + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + clientDir := t.TempDir() + + // Pre-create .tmp file with exact correct content (full size) + // This simulates a download that completed but wasn't renamed + dest := filepath.Join(clientDir, digestToPath(digest)) + os.MkdirAll(filepath.Dir(dest), 0o755) + os.WriteFile(dest+".tmp", data, 0o644) + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob}, + BaseURL: server.URL, + DestDir: clientDir, + }) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify final file is correct + finalData, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + finalHash := sha256.Sum256(finalData) + if fmt.Sprintf("sha256:%x", finalHash) != digest { + t.Error("Final file hash mismatch") + } +} diff --git a/x/imagegen/transfer/upload.go b/x/imagegen/transfer/upload.go index bb3ff928c..c75ee3e66 100644 --- a/x/imagegen/transfer/upload.go +++ b/x/imagegen/transfer/upload.go @@ -1,6 +1,7 @@ package transfer import ( + "bufio" "bytes" "cmp" "context" @@ -14,6 +15,8 @@ import ( "path/filepath" "time" + "github.com/ollama/ollama/logutil" + "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" ) @@ -76,23 +79,30 @@ func upload(ctx context.Context, opts UploadOptions) error { } } - // Filter to only blobs that need uploading + // Filter to only blobs that need uploading, but track total across all blobs var toUpload []Blob - var total int64 + var totalSize, alreadyExists int64 for i, blob := range opts.Blobs { + totalSize += blob.Size if needsUpload[i] { toUpload = append(toUpload, blob) - total += blob.Size + } else { + alreadyExists += blob.Size } } + // Progress includes all blobs — already-existing ones start as completed + u.progress = newProgressTracker(totalSize, opts.Progress) + u.progress.add(alreadyExists) + + logutil.Trace("upload plan", "total_blobs", len(opts.Blobs), "need_upload", len(toUpload), "already_exist", len(opts.Blobs)-len(toUpload), "total_bytes", totalSize, "existing_bytes", alreadyExists) + if len(toUpload) == 0 { if u.logger != nil { u.logger.Debug("all blobs exist, nothing to upload") } } else { // Phase 2: Upload blobs that don't exist - u.progress = newProgressTracker(total, opts.Progress) concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency) sem := semaphore.NewWeighted(int64(concurrency)) @@ -113,7 +123,12 @@ func upload(ctx context.Context, opts UploadOptions) error { } if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" { - return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest) + logutil.Trace("pushing manifest", "repo", opts.Repository, "ref", opts.ManifestRef, "size", len(opts.Manifest)) + if err := u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest); err != nil { + logutil.Trace("manifest push failed", "error", err) + return err + } + logutil.Trace("manifest push succeeded", "repo", opts.Repository, "ref", opts.ManifestRef) } return nil } @@ -124,7 +139,10 @@ func (u *uploader) upload(ctx context.Context, blob Blob) error { for attempt := range maxRetries { if attempt > 0 { - if err := backoff(ctx, attempt, time.Second< 0 { + // Start at 5s and cap at 30s — the server needs real breathing + // room when it's dropping Location headers under load. + if err := backoff(ctx, attempt, min(5*time.Second<