mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 22:20:43 +00:00
feat(resume downloads): implement resumable downloads for interrupted transfers
- Adds support for resuming partially downloaded files - Uses HTTP Range header to continue from last byte position - Maintains download progress across interruptions - Preserves partial downloads with .partial extension - Validates SHA256 checksum after completion Signed-off-by: Saarthak Verma <saarthakverma739@gmail.com>
This commit is contained in:
parent
44d7869405
commit
a9bec0fc5f
1 changed files with 56 additions and 25 deletions
|
@ -2,7 +2,9 @@ package downloader
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -204,6 +206,15 @@ func removePartialFile(tmpFilePath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func calculateHashForPartialFile(file *os.File) (hash.Hash, error) {
|
||||
hash := sha256.New()
|
||||
_, err := io.Copy(hash, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
url := uri.ResolveURL()
|
||||
if uri.LooksLikeOCI() {
|
||||
|
@ -266,8 +277,32 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
|||
|
||||
log.Info().Msgf("Downloading %q", url)
|
||||
|
||||
// Download file
|
||||
resp, err := http.Get(url)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
|
||||
}
|
||||
|
||||
/* TODO
|
||||
* 1. ~~Mock downloads~~
|
||||
* 2. Check if server supports partial downloads
|
||||
* 3. ~~Resume partial downloads~~
|
||||
* 4. Ensure progressWriter accurately reflects progress if a partial file is present
|
||||
* 5. MAYBE:
|
||||
* a. Delete file if calculatedSHA != sha
|
||||
*/
|
||||
|
||||
// save partial download to dedicated file
|
||||
tmpFilePath := filePath + ".partial"
|
||||
tmpFileInfo, err := os.Stat(tmpFilePath)
|
||||
if err == nil {
|
||||
startPos := tmpFileInfo.Size()
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos))
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
|
||||
}
|
||||
|
||||
// Start the request
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file %q: %v", filePath, err)
|
||||
}
|
||||
|
@ -282,33 +317,29 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
|
||||
}
|
||||
/** Enabling partial downloads
|
||||
* - Do I remove the partial file
|
||||
* -
|
||||
*/
|
||||
// save partial download to dedicated file
|
||||
fmt.Printf("DELETEING PARTIAL FILE")
|
||||
tmpFilePath := filePath + ".partial"
|
||||
|
||||
// remove tmp file
|
||||
err = removePartialFile(tmpFilePath)
|
||||
// Create and write file
|
||||
outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create and write file content
|
||||
outFile, err := os.Create(tmpFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err)
|
||||
return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
outFileInfo, err := outFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file info: %v", err)
|
||||
}
|
||||
fileSize := outFileInfo.Size()
|
||||
hash, err := calculateHashForPartialFile(outFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate hash for partial file")
|
||||
}
|
||||
progress := &progressWriter{
|
||||
fileName: tmpFilePath,
|
||||
total: resp.ContentLength,
|
||||
hash: sha256.New(),
|
||||
hash: hash,
|
||||
fileNo: fileN,
|
||||
totalFiles: total,
|
||||
written: fileSize,
|
||||
downloadStatus: downloadStatus,
|
||||
}
|
||||
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
|
||||
|
@ -316,11 +347,6 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
|||
return fmt.Errorf("failed to write file %q: %v", filePath, err)
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFilePath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
|
||||
}
|
||||
|
||||
if sha != "" {
|
||||
// Verify SHA
|
||||
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
|
||||
|
@ -332,6 +358,11 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
|
|||
log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath)
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFilePath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
|
||||
}
|
||||
|
||||
log.Info().Msgf("File %q downloaded and verified", filePath)
|
||||
if utils.IsArchive(filePath) {
|
||||
basePath := filepath.Dir(filePath)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue