diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index bcf13a32..6976c9b4 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -48,42 +48,77 @@ var _ = Describe("Gallery API tests", func() { }) }) +type RangeHeaderError struct { + msg string +} + +func (e *RangeHeaderError) Error() string { return e.msg } + var _ = Describe("Download Test", func() { var mockData []byte var mockDataSha string var filePath string - var getMockServer = func() *httptest.Server { + extractRangeHeader := func(rangeString string) (int, int, error) { + regex := regexp.MustCompile(`^bytes=(\d+)-(\d+|)$`) + matches := regex.FindStringSubmatch(rangeString) + rangeErr := RangeHeaderError{msg: "invalid / ill-formatted range"} + if matches == nil { + return -1, -1, &rangeErr + } + startPos, err := strconv.Atoi(matches[1]) + if err != nil { + return -1, -1, err + } + + endPos := -1 + if matches[2] != "" { + endPos, err = strconv.Atoi(matches[2]) + if err != nil { + return -1, -1, err + } + endPos += 1 // because range is inclusive in rangeString + } + return startPos, endPos, nil + } + + getMockServer := func(supportsRangeHeader bool) *httptest.Server { mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" && r.Method != "GET" { + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method == "HEAD" { + if supportsRangeHeader { + w.Header().Add("Accept-Ranges", "bytes") + } + w.WriteHeader(http.StatusOK) + return + } + // GET method + startPos := 0 + endPos := len(mockData) + var err error var respData []byte rangeString := r.Header.Get("Range") if rangeString != "" { - regex := regexp.MustCompile(`^bytes=(\d+)-(\d+|)$`) - matches := regex.FindStringSubmatch(rangeString) - if matches == nil { - w.WriteHeader(http.StatusBadRequest) - return - } - startPos := 0 - endPos := len(mockData) - var err error - if matches[1] != "" { - startPos, err = strconv.Atoi(matches[1]) + startPos, endPos, err = extractRangeHeader(rangeString) + if err != nil { + if _, ok := err.(*RangeHeaderError); ok { + w.WriteHeader(http.StatusBadRequest) + return + } Expect(err).ToNot(HaveOccurred()) } - if matches[2] != "" { - endPos, err = strconv.Atoi(matches[2]) - Expect(err).ToNot(HaveOccurred()) - endPos += 1 + if endPos == -1 { + endPos = len(mockData) } if startPos < 0 || startPos >= len(mockData) || endPos < 0 || endPos > len(mockData) || startPos > endPos { w.WriteHeader(http.StatusBadRequest) return } - respData = mockData[startPos:endPos] - } else { - respData = mockData } + respData = mockData[startPos:endPos] w.WriteHeader(http.StatusOK) w.Write(respData) })) @@ -107,17 +142,15 @@ var _ = Describe("Download Test", func() { Context("URI DownloadFile", func() { It("fetches files from mock server", func() { - mockServer := getMockServer() + mockServer := getMockServer(true) defer mockServer.Close() uri := URI(mockServer.URL) err := uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) - err = os.Remove(filePath) // cleanup, also checks existance of filePath` - Expect(err).ToNot(HaveOccurred()) }) It("resumes partially downloaded files", func() { - mockServer := getMockServer() + mockServer := getMockServer(true) defer mockServer.Close() uri := URI(mockServer.URL) // Create a partial file @@ -128,9 +161,25 @@ var _ = Describe("Download Test", func() { Expect(err).ToNot(HaveOccurred()) err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) - err = os.Remove(filePath) // cleanup, also checks existance of filePath` + }) + + It("restarts download from 0 if server doesn't support Range header", func() { + mockServer := getMockServer(false) + defer mockServer.Close() + uri := URI(mockServer.URL) + // Create a partial file + tmpFilePath := filePath + ".partial" + file, err := os.OpenFile(tmpFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + Expect(err).ToNot(HaveOccurred()) + _, err = file.Write(mockData[0:10000]) + Expect(err).ToNot(HaveOccurred()) + err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) }) - // It("deletes partial file if after completion hash of downloaded file doesn't match hash of the file in the server") + }) + + AfterEach(func() { + os.Remove(filePath) // cleanup, also checks existance of filePath` + os.Remove(filePath + ".partial") }) })