Gallery repository (#663)

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-06-24 08:18:17 +02:00 committed by GitHub
parent 2a45a99737
commit 60db5957d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 644 additions and 194 deletions

99
pkg/gallery/gallery.go Normal file
View file

@ -0,0 +1,99 @@
package gallery
import (
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
"gopkg.in/yaml.v2"
)
type Gallery struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}
// Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error {
var config Config
err := model.Get(&config)
if err != nil {
return err
}
if req.Name != "" {
model.Name = req.Name
}
config.Files = append(config.Files, req.AdditionalFiles...)
config.Files = append(config.Files, model.AdditionalFiles...)
// TODO model.Overrides could be merged with user overrides (not defined yet)
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
return err
}
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
return err
}
return nil
}
for _, model := range models {
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
return applyModel(model)
}
}
return fmt.Errorf("no model found with name %q", name)
}
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryModels(gallery, basePath)
if err != nil {
return nil, err
}
models = append(models, galleryModels...)
}
return models, nil
}
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}
err := utils.GetURI(gallery.URL, func(d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
return models, err
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
// we check if the model was already installed by checking if the config file exists
// TODO: (what to do if the model doesn't install a config file?)
if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.Name))); err == nil {
model.Installed = true
}
}
return models, nil
}

View file

@ -81,7 +81,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
return &config, nil
}
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
if err != nil {
@ -171,6 +171,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
@ -178,6 +179,13 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
}
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
if utils.IsArchive(filePath) {
log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error())
return err
}
}
}
// Write prompt template contents to separate files
@ -211,33 +219,37 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
return err
}
configFilePath := filepath.Join(basePath, name+".yaml")
// write config file
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
configFilePath := filepath.Join(basePath, name+".yaml")
// Read and update config file as map[string]interface{}
configMap := make(map[string]interface{})
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil {
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
// Read and update config file as map[string]interface{}
configMap := make(map[string]interface{})
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil {
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
}
configMap["name"] = name
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
return err
}
// Write updated config file
updatedConfigYAML, err := yaml.Marshal(configMap)
if err != nil {
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
}
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
if err != nil {
return fmt.Errorf("failed to write updated config file: %v", err)
}
log.Debug().Msgf("Written config file %s", configFilePath)
}
configMap["name"] = name
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
return err
}
// Write updated config file
updatedConfigYAML, err := yaml.Marshal(configMap)
if err != nil {
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
}
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
if err != nil {
return fmt.Errorf("failed to write updated config file: %v", err)
}
log.Debug().Msgf("Written config file %s", configFilePath)
return nil
}

View file

@ -1,6 +1,7 @@
package gallery_test
import (
"io/ioutil"
"os"
"path/filepath"
@ -19,7 +20,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@ -38,6 +39,51 @@ var _ = Describe("Model test", func() {
Expect(content["context_size"]).To(Equal(1024))
})
It("applies model from gallery correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
gallery := []GalleryModel{{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
}}
out, err := yaml.Marshal(gallery)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tempdir, "gallery_simple.yaml"),
},
}
models, err := AvailableGalleryModels(galleries, tempdir)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
models, err = AvailableGalleryModels(galleries, tempdir)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeTrue())
})
It("renames model correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
@ -45,7 +91,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -61,7 +107,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -87,7 +133,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).To(HaveOccurred())
})
})

74
pkg/gallery/request.go Normal file
View file

@ -0,0 +1,74 @@
package gallery
import (
"fmt"
"net/url"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)
// endpoints
type GalleryModel struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
AdditionalFiles []File `json:"files" yaml:"files"`
Gallery Gallery `json:"gallery" yaml:"gallery"`
Installed bool `json:"installed" yaml:"installed"`
}
const (
githubURI = "github:"
)
func (request GalleryModel) DecodeURL() (string, error) {
input := request.URL
var rawURL string
if strings.HasPrefix(input, githubURI) {
parts := strings.Split(input, ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
// Handle regular URLs
u, err := url.Parse(input)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
rawURL = u.String()
// check if it's a file path
} else if strings.HasPrefix(input, "file://") {
return input, nil
} else {
return "", fmt.Errorf("invalid URL format: %s", input)
}
return rawURL, nil
}
// Get fetches a model from a URL and unmarshals it into a struct
func (request GalleryModel) Get(i interface{}) error {
url, err := request.DecodeURL()
if err != nil {
return err
}
return utils.GetURI(url, func(d []byte) error {
return yaml.Unmarshal(d, i)
})
}

View file

@ -0,0 +1,42 @@
package gallery_test
import (
. "github.com/go-skynet/LocalAI/pkg/gallery"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type example struct {
Name string `yaml:"name"`
}
var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
var e example
err := req.Get(&e)
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses URLS", func() {
req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
})
})

56
pkg/utils/untar.go Normal file
View file

@ -0,0 +1,56 @@
package utils
import (
"fmt"
"github.com/mholt/archiver/v3"
)
func IsArchive(file string) bool {
uaIface, err := archiver.ByExtension(file)
if err != nil {
return false
}
_, ok := uaIface.(archiver.Unarchiver)
return ok
}
func ExtractArchive(archive, dst string) error {
uaIface, err := archiver.ByExtension(archive)
if err != nil {
return err
}
un, ok := uaIface.(archiver.Unarchiver)
if !ok {
return fmt.Errorf("format specified by source filename is not an archive format: %s (%T)", archive, uaIface)
}
mytar := &archiver.Tar{
OverwriteExisting: true,
MkdirAll: true,
ImplicitTopLevelFolder: false,
ContinueOnError: true,
}
switch v := uaIface.(type) {
case *archiver.Tar:
uaIface = mytar
case *archiver.TarBrotli:
v.Tar = mytar
case *archiver.TarBz2:
v.Tar = mytar
case *archiver.TarGz:
v.Tar = mytar
case *archiver.TarLz4:
v.Tar = mytar
case *archiver.TarSz:
v.Tar = mytar
case *archiver.TarXz:
v.Tar = mytar
case *archiver.TarZstd:
v.Tar = mytar
}
return un.Unarchive(archive, dst)
}

37
pkg/utils/uri.go Normal file
View file

@ -0,0 +1,37 @@
package utils
import (
"io/ioutil"
"net/http"
"strings"
)
func GetURI(url string, f func(i []byte) error) error {
if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(body)
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(body)
}