diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 1d226330..2e0363c8 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -215,6 +215,16 @@ func calculateHashForPartialFile(file *os.File) (hash.Hash, error) { return hash, nil } +func (uri URI) checkSeverSupportsRangeHeader() (bool, error) { + url := uri.ResolveURL() + resp, err := http.Head(url) + if err != nil { + return false, err + } + defer resp.Body.Close() + return resp.Header.Get("Accept-Ranges") == "bytes", nil +} + func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { url := uri.ResolveURL() if uri.LooksLikeOCI() { @@ -282,21 +292,23 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat return fmt.Errorf("failed to create request for %q: %v", filePath, err) } - /* TODO - * 1. ~~Mock downloads~~ - * 2. Check if server supports partial downloads - * 3. ~~Resume partial downloads~~ - * 4. Ensure progressWriter accurately reflects progress if a partial file is present - * 5. MAYBE: - * a. Delete file if calculatedSHA != sha - */ - // save partial download to dedicated file tmpFilePath := filePath + ".partial" tmpFileInfo, err := os.Stat(tmpFilePath) if err == nil { - startPos := tmpFileInfo.Size() - req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos)) + support, err := uri.checkSeverSupportsRangeHeader() + if err != nil { + return fmt.Errorf("failed to check if uri server supports range header: %v", err) + } + if support { + startPos := tmpFileInfo.Size() + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos)) + } else { + err := removePartialFile(tmpFilePath) + if err != nil { + return err + } + } } else if !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("failed to check file %q existence: %v", filePath, err) } @@ -341,6 +353,11 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat return fmt.Errorf("failed to write file %q: %v", filePath, err) } + err = os.Rename(tmpFilePath, filePath) + if err != nil { + return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) + } + if sha != "" { // Verify SHA calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) @@ -352,11 +369,6 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) } - err = os.Rename(tmpFilePath, filePath) - if err != nil { - return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) - } - log.Info().Msgf("File %q downloaded and verified", filePath) if utils.IsArchive(filePath) { basePath := filepath.Dir(filePath)