From b6b8ab6c219415576431649bcb9999e1f2e0ef22 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 11 Jul 2024 15:04:05 +0200 Subject: [PATCH] feat(models): pull models from urls (#2750) * feat(models): pull models from urls When using `run` now we can point directly to hf models via URL, for instance: ```bash local-ai run huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf ``` Will pull the gguf model and place it in the models folder - of course this depends on the fact that the gguf file should be automatically detected by our guesser mechanism in order to this to make effective. Similarly now galleries can refer to single files in the API requests. This also changes the download code and `yaml` files now are treated in the same way, so now config files are saved with the appropriate name (and not hashed anymore). Signed-off-by: Ettore Di Giacinto * Adapt tests Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- pkg/startup/model_preload.go | 48 ++++++++++++++++++++++++++----- pkg/startup/model_preload_test.go | 18 ++++++++++-- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index 74a10e9e..9fa890b0 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -3,6 +3,7 @@ package startup import ( "errors" "fmt" + "net/url" "os" "path/filepath" "strings" @@ -77,19 +78,35 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) case downloader.LooksLikeURL(url): - log.Debug().Msgf("[startup] resolved model to download: %s", url) + log.Debug().Msgf("[startup] downloading %s", url) - // md5 of model name - md5Name := utils.MD5(url) + // Extract filename from URL + fileName, e := filenameFromUrl(url) + if e != nil || fileName == "" { + fileName = utils.MD5(url) + if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") { + fileName = fileName + ".yaml" + } + log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL") + //err = errors.Join(err, e) + //continue + } + + modelPath := filepath.Join(modelPath, fileName) + + if e := utils.VerifyPath(fileName, modelPath); e != nil { + log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path") + err = errors.Join(err, e) + continue + } // check if file exists - if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { + if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) { + e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) if e != nil { - log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model") err = errors.Join(err, e) } } @@ -150,3 +167,20 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl return nil, true } + +func filenameFromUrl(urlstr string) (string, error) { + // strip anything after @ + if strings.Contains(urlstr, "@") { + urlstr = strings.Split(urlstr, "@")[0] + } + + u, err := url.Parse(urlstr) + if err != nil { + return "", fmt.Errorf("error due to parsing url: %w", err) + } + x, err := url.QueryUnescape(u.EscapedPath()) + if err != nil { + return "", fmt.Errorf("error due to escaping: %w", err) + } + return filepath.Base(x), nil +} diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index 939ad1a2..869fcd3e 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -20,7 +20,7 @@ var _ = Describe("Preload test", func() { tmpdir, err := os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" - fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") + fileName := fmt.Sprintf("%s.yaml", "phi-2") InstallModels([]config.Gallery{}, libraryURL, tmpdir, true, nil, "phi-2") @@ -36,7 +36,7 @@ var _ = Describe("Preload test", func() { tmpdir, err := os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" - fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) + fileName := fmt.Sprintf("%s.yaml", "phi-2") InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url) @@ -79,5 +79,19 @@ var _ = Describe("Preload test", func() { Expect(string(content)).To(ContainSubstring("name: mistral-openorca")) }) + It("downloads from urls", func() { + tmpdir, err := os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" + fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K") + + err = InstallModels([]config.Gallery{}, "", tmpdir, false, nil, url) + Expect(err).ToNot(HaveOccurred()) + + resultFile := filepath.Join(tmpdir, fileName) + + _, err = os.Stat(resultFile) + Expect(err).ToNot(HaveOccurred()) + }) }) })