feat: display download progress when installing models (#543)

This commit is contained in:
Ettore Di Giacinto 2023-06-08 21:33:18 +02:00 committed by GitHub
parent c9bbba4872
commit 84946e9275
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 29 deletions

View file

@ -10,10 +10,12 @@ import (
"os"
"strings"
"sync"
"time"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
@ -23,9 +25,12 @@ type galleryOp struct {
}
type galleryOpStatus struct {
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
}
}
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error {
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
url, err := req.DecodeURL()
if err != nil {
return err
@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg
config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil {
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
return err
}
@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
case <-c.Done():
return
case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
}
if err := applyGallery(g.modelPath, op.req, cm); err != nil {
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}); err != nil {
updateError(err)
continue
}
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func displayDownload(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
dat, err := os.ReadFile(s)
if err != nil {
@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
}
for _, r := range requests {
if err := applyGallery(modelPath, r, cm); err != nil {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
return err
}
}
return nil
}
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
var requests []ApplyGalleryModelRequest
err := json.Unmarshal([]byte(s), &requests)
@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
}
for _, r := range requests {
if err := applyGallery(modelPath, r, cm); err != nil {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
return err
}
}