feat(gallery): uniform download from CLI (#2559)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-06-13 16:12:46 +02:00 committed by GitHub
parent f183fec232
commit 7b205510f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 89 additions and 51 deletions

View file

@ -2,19 +2,24 @@ package startup
import (
"errors"
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/embedded"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
// PreloadModelsConfigurations will preload models from the given list of URLs
// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error {
// create an error that groups all errors
var err error
for _, url := range models {
// As a best effort, try to resolve the model from the remote library
@ -32,18 +37,20 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model
url = embedded.ModelShortURL(url)
switch {
case embedded.ExistsInModelsLibrary(url):
modelYAML, err := embedded.ResolveContent(url)
modelYAML, e := embedded.ResolveContent(url)
// If we resolve something, just save it to disk and continue
if err != nil {
log.Error().Err(err).Msg("error resolving model content")
if e != nil {
log.Error().Err(e).Msg("error resolving model content")
err = errors.Join(err, e)
continue
}
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
md5Name := utils.MD5(url)
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); err != nil {
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
err = errors.Join(err, e)
}
case downloader.LooksLikeURL(url):
log.Debug().Msgf("[startup] resolved model to download: %s", url)
@ -52,34 +59,70 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model
md5Name := utils.MD5(url)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
err := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if err != nil {
log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
if e != nil {
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
err = errors.Join(err, e)
}
}
default:
if _, err := os.Stat(url); err == nil {
if _, e := os.Stat(url); e == nil {
log.Debug().Msgf("[startup] resolved local model: %s", url)
// copy to modelPath
md5Name := utils.MD5(url)
modelYAML, err := os.ReadFile(url)
if err != nil {
log.Error().Err(err).Str("filepath", url).Msg("error reading model definition")
modelYAML, e := os.ReadFile(url)
if e != nil {
log.Error().Err(e).Str("filepath", url).Msg("error reading model definition")
err = errors.Join(err, e)
continue
}
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil {
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
err = errors.Join(err, e)
}
} else {
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
// Check if it's a model gallery, or print a warning
e, found := installModel(galleries, url, modelPath, downloadStatus)
if e != nil && found {
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
err = errors.Join(err, e)
} else if !found {
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
err = errors.Join(err, fmt.Errorf("failed resolving model '%s'", url))
}
}
}
}
return err
}
func installModel(galleries []gallery.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, modelPath)
if err != nil {
return err, false
}
model := gallery.FindModel(models, modelName, modelPath)
if model == nil {
return err, false
}
if downloadStatus == nil {
downloadStatus = utils.DisplayDownloadFunction
}
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus)
if err != nil {
return err, true
}
return nil, true
}

View file

@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/gallery"
. "github.com/go-skynet/LocalAI/pkg/startup"
"github.com/go-skynet/LocalAI/pkg/utils"
@ -21,7 +22,7 @@ var _ = Describe("Preload test", func() {
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")
InstallModels([]gallery.Gallery{}, libraryURL, tmpdir, nil, "phi-2")
resultFile := filepath.Join(tmpdir, fileName)
@ -37,7 +38,7 @@ var _ = Describe("Preload test", func() {
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations("", tmpdir, url)
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url)
resultFile := filepath.Join(tmpdir, fileName)
@ -51,7 +52,7 @@ var _ = Describe("Preload test", func() {
Expect(err).ToNot(HaveOccurred())
url := "phi-2"
PreloadModelsConfigurations("", tmpdir, url)
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url)
entry, err := os.ReadDir(tmpdir)
Expect(err).ToNot(HaveOccurred())
@ -69,7 +70,7 @@ var _ = Describe("Preload test", func() {
url := "mistral-openorca"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations("", tmpdir, url)
InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url)
resultFile := filepath.Join(tmpdir, fileName)