feat: embedded model configurations, add popular model examples, refactoring (#1532)

* move downloader out

* separate startup functions for preloading configuration files

* docs: add popular model examples

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* shorteners

* Add llava

* Add mistral-openorca

* Better link to build section

* docs: update

* fixup

* Drop code dups

* Minor fixups

* Apply suggestions from code review

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

* ci: try to cache gRPC build during tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* ci: do not build all images for tests, just necessary

* ci: cache gRPC also in release pipeline

* fixes

* Update model_preload_test.go

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Ettore Di Giacinto 2024-01-05 17:16:33 -05:00 committed by GitHub
parent db926896bd
commit 09e5d9007b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 586 additions and 150 deletions

22
pkg/assets/list.go Normal file
View file

@ -0,0 +1,22 @@
package assets
import (
"embed"
"io/fs"
)
func ListFiles(content embed.FS) (files []string) {
fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
files = append(files, path)
return nil
})
return
}

View file

@ -0,0 +1,26 @@
package downloader
import "hash"
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}

View file

@ -1,10 +1,9 @@
package utils
package downloader
import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash"
"io"
"net/http"
"os"
@ -12,9 +11,18 @@ import (
"strconv"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func GetURI(url string, f func(url string, i []byte) error) error {
url = ConvertURL(url)
@ -52,14 +60,6 @@ func GetURI(url string, f func(url string, i []byte) error) error {
return f(url, body)
}
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func LooksLikeURL(s string) bool {
return strings.HasPrefix(s, HTTPPrefix) ||
strings.HasPrefix(s, HTTPSPrefix) ||
@ -229,10 +229,10 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
}
log.Info().Msgf("File %q downloaded and verified", filePath)
if IsArchive(filePath) {
if utils.IsArchive(filePath) {
basePath := filepath.Dir(filePath)
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath)
if err := ExtractArchive(filePath, basePath); err != nil {
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error())
return err
}
@ -241,32 +241,35 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return nil
}
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func GetBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
return
}
// MD5 of a string
func MD5(s string) string {
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func formatBytes(bytes int64) string {

View file

@ -1,7 +1,7 @@
package utils_test
package downloader_test
import (
. "github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/go-skynet/LocalAI/pkg/downloader"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

View file

@ -6,7 +6,7 @@ import (
"path/filepath"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
@ -140,7 +140,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
func findGalleryURLFromReferenceURL(url string) (string, error) {
var refFile string
err := utils.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
@ -163,7 +163,7 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error)
}
}
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
err := downloader.GetURI(gallery.URL, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {

View file

@ -1,14 +1,11 @@
package gallery
import (
"crypto/sha256"
"fmt"
"hash"
"io"
"os"
"path/filepath"
"strconv"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
@ -66,7 +63,7 @@ type PromptTemplate struct {
func GetGalleryConfigFromURL(url string) (Config, error) {
var config Config
err := utils.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
@ -114,7 +111,7 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
// Create file path
filePath := filepath.Join(basePath, file.Filename)
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
return err
}
}
@ -183,54 +180,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
return nil
}
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

View file

@ -0,0 +1,54 @@
package startup
import (
"errors"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/embedded"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
// PreloadModelsConfigurations will preload models from the given list of URLs
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func PreloadModelsConfigurations(modelPath string, models ...string) {
for _, url := range models {
url = embedded.ModelShortURL(url)
switch {
case embedded.ExistsInModelsLibrary(url):
modelYAML, err := embedded.ResolveContent(url)
// If we resolve something, just save it to disk and continue
if err != nil {
log.Error().Msgf("error loading model: %s", err.Error())
continue
}
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
md5Name := utils.MD5(url)
if err := os.WriteFile(filepath.Join(modelPath, md5Name)+".yaml", modelYAML, os.ModePerm); err != nil {
log.Error().Msgf("error loading model: %s", err.Error())
}
case downloader.LooksLikeURL(url):
log.Debug().Msgf("[startup] resolved model to download: %s", url)
// md5 of model name
md5Name := utils.MD5(url)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if err != nil {
log.Error().Msgf("error loading model: %s", err.Error())
}
}
default:
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
}
}
}

View file

@ -0,0 +1,66 @@
package startup_test
import (
"fmt"
"os"
"path/filepath"
. "github.com/go-skynet/LocalAI/pkg/startup"
"github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Preload test", func() {
Context("Preloading from strings", func() {
It("loads from embedded full-urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations(tmpdir, url)
resultFile := filepath.Join(tmpdir, fileName)
content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: phi-2"))
})
It("loads from embedded short-urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "phi-2"
PreloadModelsConfigurations(tmpdir, url)
entry, err := os.ReadDir(tmpdir)
Expect(err).ToNot(HaveOccurred())
Expect(entry).To(HaveLen(1))
resultFile := entry[0].Name()
content, err := os.ReadFile(filepath.Join(tmpdir, resultFile))
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: phi-2"))
})
It("loads from embedded models", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "mistral-openorca"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations(tmpdir, url)
resultFile := filepath.Join(tmpdir, fileName)
content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
})
})
})

View file

@ -0,0 +1,13 @@
package startup_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestStartup(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI startup test")
}

10
pkg/utils/hash.go Normal file
View file

@ -0,0 +1,10 @@
package utils
import (
"crypto/md5"
"fmt"
)
func MD5(s string) string {
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
}