mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 14:35:00 +00:00
feat: embedded model configurations, add popular model examples, refactoring (#1532)
* move downloader out * separate startup functions for preloading configuration files * docs: add popular model examples Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * shorteners * Add llava * Add mistral-openorca * Better link to build section * docs: update * fixup * Drop code dups * Minor fixups * Apply suggestions from code review Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * ci: try to cache gRPC build during tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * ci: do not build all images for tests, just necessary * ci: cache gRPC also in release pipeline * fixes * Update model_preload_test.go Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
db926896bd
commit
09e5d9007b
26 changed files with 586 additions and 150 deletions
22
pkg/assets/list.go
Normal file
22
pkg/assets/list.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package assets
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
func ListFiles(content embed.FS) (files []string) {
|
||||
fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
files = append(files, path)
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
26
pkg/downloader/progress.go
Normal file
26
pkg/downloader/progress.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package downloader
|
||||
|
||||
import "hash"
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = pw.hash.Write(p)
|
||||
pw.written += int64(n)
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
} else {
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,10 +1,9 @@
|
|||
package utils
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -12,9 +11,18 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
url = ConvertURL(url)
|
||||
|
||||
|
@ -52,14 +60,6 @@ func GetURI(url string, f func(url string, i []byte) error) error {
|
|||
return f(url, body)
|
||||
}
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func LooksLikeURL(s string) bool {
|
||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
||||
|
@ -229,10 +229,10 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
|
|||
}
|
||||
|
||||
log.Info().Msgf("File %q downloaded and verified", filePath)
|
||||
if IsArchive(filePath) {
|
||||
if utils.IsArchive(filePath) {
|
||||
basePath := filepath.Dir(filePath)
|
||||
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath)
|
||||
if err := ExtractArchive(filePath, basePath); err != nil {
|
||||
if err := utils.ExtractArchive(filePath, basePath); err != nil {
|
||||
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error())
|
||||
return err
|
||||
}
|
||||
|
@ -241,32 +241,35 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
|
|||
return nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
}
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func GetBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = pw.hash.Write(p)
|
||||
pw.written += int64(n)
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
} else {
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MD5 of a string
|
||||
func MD5(s string) string {
|
||||
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
|
@ -1,7 +1,7 @@
|
|||
package utils_test
|
||||
package downloader_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/utils"
|
||||
. "github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
|
@ -6,7 +6,7 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
@ -140,7 +140,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
|
|||
|
||||
func findGalleryURLFromReferenceURL(url string) (string, error) {
|
||||
var refFile string
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
err := downloader.GetURI(url, func(url string, d []byte) error {
|
||||
refFile = string(d)
|
||||
if len(refFile) == 0 {
|
||||
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
|
||||
|
@ -163,7 +163,7 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error)
|
|||
}
|
||||
}
|
||||
|
||||
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
|
||||
err := downloader.GetURI(gallery.URL, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &models)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
package gallery
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -66,7 +63,7 @@ type PromptTemplate struct {
|
|||
|
||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
var config Config
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
err := downloader.GetURI(url, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -114,7 +111,7 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
|
|||
// Create file path
|
||||
filePath := filepath.Join(basePath, file.Filename)
|
||||
|
||||
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
|
||||
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -183,54 +180,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = pw.hash.Write(p)
|
||||
pw.written += int64(n)
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
} else {
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return strconv.FormatInt(bytes, 10) + " B"
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func calculateSHA(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
|
54
pkg/startup/model_preload.go
Normal file
54
pkg/startup/model_preload.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package startup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/embedded"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// PreloadModelsConfigurations will preload models from the given list of URLs
|
||||
// It will download the model if it is not already present in the model path
|
||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||
func PreloadModelsConfigurations(modelPath string, models ...string) {
|
||||
for _, url := range models {
|
||||
url = embedded.ModelShortURL(url)
|
||||
|
||||
switch {
|
||||
case embedded.ExistsInModelsLibrary(url):
|
||||
modelYAML, err := embedded.ResolveContent(url)
|
||||
// If we resolve something, just save it to disk and continue
|
||||
if err != nil {
|
||||
log.Error().Msgf("error loading model: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
|
||||
md5Name := utils.MD5(url)
|
||||
if err := os.WriteFile(filepath.Join(modelPath, md5Name)+".yaml", modelYAML, os.ModePerm); err != nil {
|
||||
log.Error().Msgf("error loading model: %s", err.Error())
|
||||
}
|
||||
case downloader.LooksLikeURL(url):
|
||||
log.Debug().Msgf("[startup] resolved model to download: %s", url)
|
||||
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
err := downloader.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Msgf("error loading model: %s", err.Error())
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
|
||||
}
|
||||
}
|
||||
}
|
66
pkg/startup/model_preload_test.go
Normal file
66
pkg/startup/model_preload_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package startup_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/startup"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Preload test", func() {
|
||||
|
||||
Context("Preloading from strings", func() {
|
||||
It("loads from embedded full-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations(tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded short-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "phi-2"
|
||||
|
||||
PreloadModelsConfigurations(tmpdir, url)
|
||||
|
||||
entry, err := os.ReadDir(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).To(HaveLen(1))
|
||||
resultFile := entry[0].Name()
|
||||
|
||||
content, err := os.ReadFile(filepath.Join(tmpdir, resultFile))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded models", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "mistral-openorca"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations(tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
|
||||
})
|
||||
})
|
||||
})
|
13
pkg/startup/startup_suite_test.go
Normal file
13
pkg/startup/startup_suite_test.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package startup_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestStartup(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LocalAI startup test")
|
||||
}
|
10
pkg/utils/hash.go
Normal file
10
pkg/utils/hash.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func MD5(s string) string {
|
||||
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue