mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
pull/push: refine safetensors (#14946)
* pull: refine safetensors pull - Body drain in resolve() — drain response body before close so Go's HTTP client can reuse TCP connections instead of opening a new one per blob (1,075 extra TCP+TLS handshakes eliminated) - Skip speed recording for tiny blobs (<100KB) — prevents HTTP-overhead-dominated transfer times from poisoning the median, which the stall detector uses to cancel "too slow" downloads - Resume support for large blobs (>=64MB) — on failure, preserves partial .tmp files; on retry, re-hashes existing datak and sends Range header to download only remaining bytes; gracefully falls back to full download if server returns 200 instead of 206; SHA256 verification catches corrupt partials * harden push - Prevents killing TCP connections after every request. - Stronger backoff to handle server back-pressure and rate limiting - Larger buffered reads for improve safetensor upload performance - Better error message handling from server - Handle 201 if server says blob exists - Fix progress reporting on already uploaded blobs - Trace logging to help troubleshoot and tune going forward * review comments * review comments
This commit is contained in:
@@ -117,14 +117,25 @@ func (d *downloader) download(ctx context.Context, blob Blob) error {
|
||||
start := time.Now()
|
||||
n, err := d.downloadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
if s := time.Since(start).Seconds(); s > 0 {
|
||||
d.speeds.record(float64(blob.Size) / s)
|
||||
// Skip speed recording for tiny blobs — their transfer time is
|
||||
// dominated by HTTP overhead, not throughput, and would pollute
|
||||
// the median used for stall detection.
|
||||
if blob.Size >= smallBlobSpeedThreshold {
|
||||
if s := time.Since(start).Seconds(); s > 0 {
|
||||
d.speeds.record(float64(blob.Size) / s)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
d.progress.add(-n) // rollback
|
||||
|
||||
// Preserve partial .tmp files for large blobs to enable resume
|
||||
if blob.Size < resumeThreshold {
|
||||
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
||||
os.Remove(dest + ".tmp")
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return err
|
||||
@@ -153,55 +164,106 @@ func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Check for existing partial .tmp file for resume
|
||||
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
||||
tmp := dest + ".tmp"
|
||||
var existingSize int64
|
||||
if blob.Size >= resumeThreshold {
|
||||
if fi, statErr := os.Stat(tmp); statErr == nil {
|
||||
if fi.Size() < blob.Size {
|
||||
existingSize = fi.Size()
|
||||
} else if fi.Size() > blob.Size {
|
||||
// .tmp larger than expected — discard
|
||||
os.Remove(tmp)
|
||||
}
|
||||
// fi.Size() == blob.Size handled in save (hash check + rename)
|
||||
}
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
// Add auth only for same-host (not CDN)
|
||||
if u.Host == baseURL.Host && *d.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*d.token)
|
||||
}
|
||||
if existingSize > 0 {
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize))
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// Full response — reset any partial state
|
||||
existingSize = 0
|
||||
case http.StatusPartialContent:
|
||||
// Resume succeeded
|
||||
default:
|
||||
return 0, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return d.save(ctx, blob, resp.Body)
|
||||
return d.save(ctx, blob, resp.Body, existingSize)
|
||||
}
|
||||
|
||||
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
|
||||
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader, existingSize int64) (int64, error) {
|
||||
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
||||
tmp := dest + ".tmp"
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
h := sha256.New()
|
||||
|
||||
var f *os.File
|
||||
var err error
|
||||
|
||||
if existingSize > 0 {
|
||||
// Resume — re-hash existing partial data, then append
|
||||
f, err = os.OpenFile(tmp, os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
// Can't open partial file, start fresh
|
||||
existingSize = 0
|
||||
} else {
|
||||
// Hash the existing data
|
||||
if _, hashErr := io.CopyN(h, f, existingSize); hashErr != nil {
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
existingSize = 0
|
||||
} else {
|
||||
// Report resumed bytes as progress
|
||||
d.progress.add(existingSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if existingSize == 0 {
|
||||
f, err = os.Create(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
setSparse(f)
|
||||
}
|
||||
defer f.Close()
|
||||
setSparse(f)
|
||||
|
||||
h := sha256.New()
|
||||
n, err := d.copy(ctx, f, r, h)
|
||||
if err != nil {
|
||||
os.Remove(tmp)
|
||||
return n, err
|
||||
// Don't remove .tmp here — download() handles cleanup based on blob size
|
||||
return existingSize + n, err
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("digest mismatch")
|
||||
return existingSize + n, fmt.Errorf("digest mismatch")
|
||||
}
|
||||
if n != blob.Size {
|
||||
totalWritten := existingSize + n
|
||||
if totalWritten != blob.Size {
|
||||
os.Remove(tmp)
|
||||
return n, fmt.Errorf("size mismatch")
|
||||
return totalWritten, fmt.Errorf("size mismatch")
|
||||
}
|
||||
return n, os.Rename(tmp, dest)
|
||||
return totalWritten, os.Rename(tmp, dest)
|
||||
}
|
||||
|
||||
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
|
||||
@@ -261,6 +323,9 @@ func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h i
|
||||
}
|
||||
}
|
||||
|
||||
// resolve follows redirects to find the final download URL.
|
||||
// Uses GET (not HEAD) because registries may return 200 for HEAD without
|
||||
// redirecting to CDN, while GET triggers the actual CDN redirect.
|
||||
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
|
||||
u, _ := url.Parse(rawURL)
|
||||
for range 10 {
|
||||
@@ -274,6 +339,8 @@ func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Drain body before close to enable HTTP connection reuse
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
// Key simplifications for many-small-blob workloads:
|
||||
//
|
||||
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
|
||||
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
|
||||
// - Resume for large blobs: Blobs >= 64MB preserve partial .tmp files on failure
|
||||
// and use HTTP Range requests on retry. Small blobs restart from scratch.
|
||||
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
|
||||
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
|
||||
//
|
||||
@@ -99,6 +100,14 @@ const (
|
||||
DefaultUploadConcurrency = 32
|
||||
maxRetries = 6
|
||||
defaultUserAgent = "ollama-transfer/1.0"
|
||||
|
||||
// resumeThreshold is the minimum blob size for resume support.
|
||||
// Only blobs above this size keep partial .tmp files on failure.
|
||||
resumeThreshold = 64 << 20 // 64 MB
|
||||
|
||||
// smallBlobSpeedThreshold is the size below which speed samples are skipped,
|
||||
// since their transfer time is dominated by HTTP overhead, not throughput.
|
||||
smallBlobSpeedThreshold = 100 << 10 // 100 KB
|
||||
)
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
@@ -134,6 +143,11 @@ func (p *progressTracker) add(n int64) {
|
||||
return
|
||||
}
|
||||
completed := p.completed.Add(n)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Debug("progress callback panic (likely closed channel)", "recovered", r)
|
||||
}
|
||||
}()
|
||||
p.callback(completed, p.total)
|
||||
}
|
||||
|
||||
|
||||
@@ -1775,3 +1775,389 @@ func TestThroughput(t *testing.T) {
|
||||
t.Logf("Warning: total time %v exceeds 500ms target", dlElapsed+ulElapsed)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Resume Tests ====================
|
||||
|
||||
func TestResumeFromPartialFile(t *testing.T) {
|
||||
// Create a blob large enough for resume (>= resumeThreshold)
|
||||
blobSize := resumeThreshold + 1024
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte((i * 13) % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
var rangeHeader string
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
rangeHeader = r.Header.Get("Range")
|
||||
mu.Unlock()
|
||||
|
||||
rng := r.Header.Get("Range")
|
||||
if rng != "" {
|
||||
// Parse "bytes=N-"
|
||||
var start int64
|
||||
fmt.Sscanf(rng, "bytes=%d-", &start)
|
||||
if start > 0 && start < int64(blobSize) {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
w.Write(data[start:])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create a partial .tmp file (first half)
|
||||
partialSize := blobSize / 2
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", data[:partialSize], 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Resume download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify Range header was sent
|
||||
mu.Lock()
|
||||
if rangeHeader == "" {
|
||||
t.Error("Expected Range header for resume, got none")
|
||||
} else {
|
||||
expected := fmt.Sprintf("bytes=%d-", partialSize)
|
||||
if rangeHeader != expected {
|
||||
t.Errorf("Range header = %q, want %q", rangeHeader, expected)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify final file is correct
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
if len(finalData) != blobSize {
|
||||
t.Errorf("Final file size = %d, want %d", len(finalData), blobSize)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumeCorruptPartialFile(t *testing.T) {
|
||||
blobSize := resumeThreshold + 1024
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte((i * 13) % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
rng := r.Header.Get("Range")
|
||||
if rng != "" {
|
||||
var start int64
|
||||
fmt.Sscanf(rng, "bytes=%d-", &start)
|
||||
if start > 0 && start < int64(blobSize) {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
w.Write(data[start:])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create a partial .tmp file with CORRUPT data
|
||||
partialSize := blobSize / 2
|
||||
corruptData := make([]byte, partialSize)
|
||||
for i := range corruptData {
|
||||
corruptData[i] = 0xFF // All 0xFF — definitely wrong
|
||||
}
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", corruptData, 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
// First attempt resumes with corrupt data → hash mismatch → retry.
|
||||
// Retry should clean up .tmp and re-download fully.
|
||||
if err != nil {
|
||||
t.Fatalf("Download with corrupt partial file failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify final file is correct
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch after corrupt resume recovery")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumePartialFileLargerThanBlob(t *testing.T) {
|
||||
blobSize := resumeThreshold + 1024
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte((i * 13) % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create .tmp file LARGER than expected blob
|
||||
oversizedData := make([]byte, blobSize+1000)
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", oversizedData, 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download with oversized .tmp failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify final file is correct
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumeBelowThreshold(t *testing.T) {
|
||||
// Blob below resume threshold should NOT attempt resume
|
||||
blobSize := 1024 // Well below resumeThreshold
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
var gotRange atomic.Bool
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Range") != "" {
|
||||
gotRange.Store(true)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create a partial .tmp file
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
if gotRange.Load() {
|
||||
t.Error("Range header sent for blob below resume threshold — should not attempt resume")
|
||||
}
|
||||
|
||||
// Verify final file
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumeServerDoesNotSupportRange(t *testing.T) {
|
||||
blobSize := resumeThreshold + 1024
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte((i * 13) % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Ignore Range header — always return full content with 200
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create partial .tmp file
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed when server doesn't support Range: %v", err)
|
||||
}
|
||||
|
||||
// Verify final file is correct
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumePartialFileExactSize(t *testing.T) {
|
||||
blobSize := resumeThreshold + 1024
|
||||
data := make([]byte, blobSize)
|
||||
for i := range data {
|
||||
data[i] = byte((i * 13) % 256)
|
||||
}
|
||||
h := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", h)
|
||||
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
||||
|
||||
var requestCount atomic.Int32
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
requestCount.Add(1)
|
||||
|
||||
rng := r.Header.Get("Range")
|
||||
if rng != "" {
|
||||
var start int64
|
||||
fmt.Sscanf(rng, "bytes=%d-", &start)
|
||||
if start >= int64(blobSize) {
|
||||
// Nothing to send
|
||||
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
if start > 0 {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
w.Write(data[start:])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
clientDir := t.TempDir()
|
||||
|
||||
// Pre-create .tmp file with exact correct content (full size)
|
||||
// This simulates a download that completed but wasn't renamed
|
||||
dest := filepath.Join(clientDir, digestToPath(digest))
|
||||
os.MkdirAll(filepath.Dir(dest), 0o755)
|
||||
os.WriteFile(dest+".tmp", data, 0o644)
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify final file is correct
|
||||
finalData, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read final file: %v", err)
|
||||
}
|
||||
finalHash := sha256.Sum256(finalData)
|
||||
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
||||
t.Error("Final file hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
@@ -76,23 +79,30 @@ func upload(ctx context.Context, opts UploadOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Filter to only blobs that need uploading
|
||||
// Filter to only blobs that need uploading, but track total across all blobs
|
||||
var toUpload []Blob
|
||||
var total int64
|
||||
var totalSize, alreadyExists int64
|
||||
for i, blob := range opts.Blobs {
|
||||
totalSize += blob.Size
|
||||
if needsUpload[i] {
|
||||
toUpload = append(toUpload, blob)
|
||||
total += blob.Size
|
||||
} else {
|
||||
alreadyExists += blob.Size
|
||||
}
|
||||
}
|
||||
|
||||
// Progress includes all blobs — already-existing ones start as completed
|
||||
u.progress = newProgressTracker(totalSize, opts.Progress)
|
||||
u.progress.add(alreadyExists)
|
||||
|
||||
logutil.Trace("upload plan", "total_blobs", len(opts.Blobs), "need_upload", len(toUpload), "already_exist", len(opts.Blobs)-len(toUpload), "total_bytes", totalSize, "existing_bytes", alreadyExists)
|
||||
|
||||
if len(toUpload) == 0 {
|
||||
if u.logger != nil {
|
||||
u.logger.Debug("all blobs exist, nothing to upload")
|
||||
}
|
||||
} else {
|
||||
// Phase 2: Upload blobs that don't exist
|
||||
u.progress = newProgressTracker(total, opts.Progress)
|
||||
concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency)
|
||||
sem := semaphore.NewWeighted(int64(concurrency))
|
||||
|
||||
@@ -113,7 +123,12 @@ func upload(ctx context.Context, opts UploadOptions) error {
|
||||
}
|
||||
|
||||
if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" {
|
||||
return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest)
|
||||
logutil.Trace("pushing manifest", "repo", opts.Repository, "ref", opts.ManifestRef, "size", len(opts.Manifest))
|
||||
if err := u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest); err != nil {
|
||||
logutil.Trace("manifest push failed", "error", err)
|
||||
return err
|
||||
}
|
||||
logutil.Trace("manifest push succeeded", "repo", opts.Repository, "ref", opts.ManifestRef)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -124,7 +139,10 @@ func (u *uploader) upload(ctx context.Context, blob Blob) error {
|
||||
|
||||
for attempt := range maxRetries {
|
||||
if attempt > 0 {
|
||||
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
||||
// Use longer backoff for uploads — server-side rate limiting
|
||||
// and S3 upload session creation need real breathing room.
|
||||
// attempt 1: up to 2s, attempt 2: up to 4s, attempt 3: up to 8s, etc.
|
||||
if err := backoff(ctx, attempt, 2*time.Second<<uint(attempt-1)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -132,13 +150,16 @@ func (u *uploader) upload(ctx context.Context, blob Blob) error {
|
||||
var err error
|
||||
n, err = u.uploadOnce(ctx, blob)
|
||||
if err == nil {
|
||||
logutil.Trace("blob upload complete", "digest", blob.Digest, "bytes", n, "attempt", attempt+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logutil.Trace("blob upload cancelled", "digest", blob.Digest, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logutil.Trace("blob upload failed, retrying", "digest", blob.Digest, "attempt", attempt+1, "bytes", n, "error", err)
|
||||
u.progress.add(-n)
|
||||
lastErr = err
|
||||
}
|
||||
@@ -178,6 +199,7 @@ func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
@@ -191,58 +213,91 @@ func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
const maxInitRetries = 12
|
||||
|
||||
func (u *uploader) initUpload(ctx context.Context, blob Blob) (string, error) {
|
||||
endpoint, _ := url.Parse(fmt.Sprintf("%s/v2/%s/blobs/uploads/", u.baseURL, u.repository))
|
||||
q := endpoint.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
endpoint.RawQuery = q.Encode()
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return "", err
|
||||
var lastErr error
|
||||
for attempt := range maxInitRetries {
|
||||
if attempt > 0 {
|
||||
// Start at 5s and cap at 30s — the server needs real breathing
|
||||
// room when it's dropping Location headers under load.
|
||||
if err := backoff(ctx, attempt, min(5*time.Second<<uint(attempt-1), 30*time.Second)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
logutil.Trace("retrying init upload", "digest", blob.Digest, "attempt", attempt+1, "error", lastErr)
|
||||
}
|
||||
return u.initUpload(ctx, blob)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
return "", fmt.Errorf("init: status %d", resp.StatusCode)
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
|
||||
req.Header.Set("User-Agent", u.userAgent)
|
||||
if *u.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+*u.token)
|
||||
}
|
||||
|
||||
loc := resp.Header.Get("Docker-Upload-Location")
|
||||
if loc == "" {
|
||||
loc = resp.Header.Get("Location")
|
||||
}
|
||||
if loc == "" {
|
||||
return "", fmt.Errorf("no upload location")
|
||||
}
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
locURL, _ := url.Parse(loc)
|
||||
if !locURL.IsAbs() {
|
||||
base, _ := url.Parse(u.baseURL)
|
||||
locURL = base.ResolveReference(locURL)
|
||||
}
|
||||
q = locURL.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
locURL.RawQuery = q.Encode()
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
if *u.token, err = u.getToken(ctx, ch); err != nil {
|
||||
return "", err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return locURL.String(), nil
|
||||
if resp.StatusCode == http.StatusCreated {
|
||||
// Blob was mounted or already exists — no upload needed
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
lastErr = fmt.Errorf("init upload: status %d", resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
loc := resp.Header.Get("Docker-Upload-Location")
|
||||
if loc == "" {
|
||||
loc = resp.Header.Get("Location")
|
||||
}
|
||||
if loc == "" {
|
||||
// Server returned 202 but no Location — retry, the server may
|
||||
// be under load and dropping headers.
|
||||
lastErr = fmt.Errorf("no upload location (server returned 202 without Location header)")
|
||||
continue
|
||||
}
|
||||
|
||||
locURL, _ := url.Parse(loc)
|
||||
if !locURL.IsAbs() {
|
||||
base, _ := url.Parse(u.baseURL)
|
||||
locURL = base.ResolveReference(locURL)
|
||||
}
|
||||
q = locURL.Query()
|
||||
q.Set("digest", blob.Digest)
|
||||
locURL.RawQuery = q.Encode()
|
||||
|
||||
return locURL.String(), nil
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size int64) (int64, error) {
|
||||
pr := &progressReader{reader: f, tracker: u.progress}
|
||||
// uploadURL is empty when initUpload determined the blob already exists (201 Created)
|
||||
if uploadURL == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Buffer reads for better throughput — 256KB reads instead of default 4KB
|
||||
br := bufio.NewReaderSize(f, 256*1024)
|
||||
pr := &progressReader{reader: br, tracker: u.progress}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, pr)
|
||||
req.ContentLength = size
|
||||
@@ -254,9 +309,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return pr.n, err
|
||||
return pr.n, fmt.Errorf("put request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
|
||||
|
||||
// Handle auth retry
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
@@ -274,7 +329,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
|
||||
loc, _ := resp.Location()
|
||||
f.Seek(0, 0)
|
||||
u.progress.add(-pr.n)
|
||||
pr2 := &progressReader{reader: f, tracker: u.progress}
|
||||
|
||||
br2 := bufio.NewReaderSize(f, 256*1024)
|
||||
pr2 := &progressReader{reader: br2, tracker: u.progress}
|
||||
|
||||
req2, _ := http.NewRequestWithContext(ctx, http.MethodPut, loc.String(), pr2)
|
||||
req2.ContentLength = size
|
||||
@@ -283,9 +340,9 @@ func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size i
|
||||
|
||||
resp2, err := u.client.Do(req2)
|
||||
if err != nil {
|
||||
return pr2.n, err
|
||||
return pr2.n, fmt.Errorf("cdn put request: %w", err)
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
defer func() { io.Copy(io.Discard, resp2.Body); resp2.Body.Close() }()
|
||||
|
||||
if resp2.StatusCode != http.StatusCreated && resp2.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp2.Body)
|
||||
@@ -313,7 +370,7 @@ func (u *uploader) pushManifest(ctx context.Context, repo, ref string, manifest
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { io.Copy(io.Discard, resp.Body); resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
|
||||
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
||||
|
||||
Reference in New Issue
Block a user