Refactoring

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
mudler 2023-06-24 00:27:52 +02:00
parent 30535d9832
commit d9a1fafffe
6 changed files with 175 additions and 188 deletions

View file

@ -46,7 +46,8 @@ func newGalleryApplier(modelPath string) *galleryApplier {
} }
} }
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { // prepareModel applies a
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config var config gallery.Config
err := req.Get(&config) err := req.Get(&config)
@ -56,21 +57,16 @@ func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger,
config.Files = append(config.Files, req.AdditionalFiles...) config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil { return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
return err
}
// Reload models
return cm.LoadConfigs(modelPath)
} }
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock() g.Lock()
defer g.Unlock() defer g.Unlock()
g.statuses[s] = op g.statuses[s] = op
} }
func (g *galleryApplier) getstatus(s string) *galleryOpStatus { func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock() g.Lock()
defer g.Unlock() defer g.Unlock()
@ -84,39 +80,40 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
case <-c.Done(): case <-c.Done():
return return
case op := <-g.C: case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) { updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
} }
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" { if op.galleryName != "" {
if err := gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, func(fileName string, current string, total string, percentage float64) { err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}); err != nil {
updateError(err)
continue
}
// Reload models
err := cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
} else { } else {
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { err = prepareModel(g.modelPath, op.req, cm, progressCallback)
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}); err != nil {
updateError(err)
continue
}
} }
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
} }
} }
}() }()
@ -159,7 +156,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
} }
for _, r := range requests { for _, r := range requests {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err return err
} }
} }
@ -175,7 +172,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
} }
for _, r := range requests { for _, r := range requests {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err return err
} }
} }
@ -186,7 +183,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
status := g.getstatus(c.Params("uuid")) status := g.getStatus(c.Params("uuid"))
if status == nil { if status == nil {
return fmt.Errorf("could not find any status for ID") return fmt.Errorf("could not find any status for ID")
} }
@ -229,7 +226,7 @@ func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries) log.Debug().Msgf("Listing models from galleries: %+v", galleries)
models, err := gallery.AvailableModels(galleries) models, err := gallery.AvailableGalleryModels(galleries)
if err != nil { if err != nil {
return err return err
} }

92
pkg/gallery/gallery.go Normal file
View file

@ -0,0 +1,92 @@
package gallery
import (
"fmt"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
"gopkg.in/yaml.v2"
)
type Gallery struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}
// Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableGalleryModels(galleries)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error {
var config Config
err := model.Get(&config)
if err != nil {
return err
}
if req.Name != "" {
model.Name = req.Name
}
config.Files = append(config.Files, req.AdditionalFiles...)
config.Files = append(config.Files, model.AdditionalFiles...)
// TODO model.Overrides could be merged with user overrides (not defined yet)
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
return err
}
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
return err
}
return nil
}
for _, model := range models {
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
return applyModel(model)
}
}
return fmt.Errorf("no model found with name %q", name)
}
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryModels(gallery)
if err != nil {
return nil, err
}
models = append(models, galleryModels...)
}
return models, nil
}
func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}
err := utils.GetURI(gallery.URL, func(d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
return models, err
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
}
return models, nil
}

View file

@ -5,12 +5,10 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo" "github.com/imdario/mergo"
@ -83,7 +81,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
return &config, nil return &config, nil
} }
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755) err := os.MkdirAll(basePath, 0755)
if err != nil { if err != nil {
@ -301,115 +299,3 @@ func calculateSHA(filePath string) (string, error) {
return fmt.Sprintf("%x", hash.Sum(nil)), nil return fmt.Sprintf("%x", hash.Sum(nil)), nil
} }
type Gallery struct {
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}
// Installs a model from the gallery (galleryname@modelname)
func ApplyModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableModels(galleries)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error {
var config Config
err := model.Get(&config)
if err != nil {
return err
}
if req.Name != "" {
model.Name = req.Name
}
config.Files = append(config.Files, req.AdditionalFiles...)
config.Files = append(config.Files, model.AdditionalFiles...)
// TODO model.Overrides could be merged with user overrides (not defined yet)
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
return err
}
if err := Apply(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
return err
}
return nil
}
for _, model := range models {
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
return applyModel(model)
}
}
return fmt.Errorf("no model found with name %q", name)
}
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableModels(galleries []Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getModels(gallery)
if err != nil {
return nil, err
}
models = append(models, galleryModels...)
}
return models, nil
}
func getModels(gallery Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}
if strings.HasPrefix(gallery.URL, "file://") {
rawURL := strings.TrimPrefix(gallery.URL, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return models, err
}
// Unmarshal YAML data into a struct
err = yaml.Unmarshal(body, &models)
if err != nil {
return models, err
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
}
return models, nil
}
// Get list of models
resp, err := http.Get(gallery.URL)
if err != nil {
return nil, fmt.Errorf("failed to get models: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get models: %s", resp.Status)
}
err = yaml.NewDecoder(resp.Body).Decode(&models)
if err != nil {
return nil, fmt.Errorf("failed to decode models: %v", err)
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
}
return models, nil
}

View file

@ -20,7 +20,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@ -60,13 +60,13 @@ var _ = Describe("Model test", func() {
}, },
} }
models, err := AvailableModels(galleries) models, err := AvailableGalleryModels(galleries)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1)) Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert")) Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml")) Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
err = ApplyModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}) err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@ -85,7 +85,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -101,7 +101,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -127,7 +127,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View file

@ -2,11 +2,10 @@ package gallery
import ( import (
"fmt" "fmt"
"io/ioutil"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -68,31 +67,7 @@ func (request GalleryModel) Get(i interface{}) error {
return err return err
} }
if strings.HasPrefix(url, "file://") { return utils.GetURI(url, func(d []byte) error {
rawURL := strings.TrimPrefix(url, "file://") return yaml.Unmarshal(d, i)
// Read the response body })
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
} }

37
pkg/utils/uri.go Normal file
View file

@ -0,0 +1,37 @@
package utils
import (
"io/ioutil"
"net/http"
"strings"
)
func GetURI(url string, f func(i []byte) error) error {
if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(body)
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(body)
}