From d9a1fafffebb87b1876d57d8df5d8dd095ee8170 Mon Sep 17 00:00:00 2001 From: mudler Date: Sat, 24 Jun 2023 00:27:52 +0200 Subject: [PATCH] Refactoring Signed-off-by: mudler --- api/gallery.go | 73 +++++++++++------------ pkg/gallery/gallery.go | 92 +++++++++++++++++++++++++++++ pkg/gallery/models.go | 116 +------------------------------------ pkg/gallery/models_test.go | 12 ++-- pkg/gallery/request.go | 33 ++--------- pkg/utils/uri.go | 37 ++++++++++++ 6 files changed, 175 insertions(+), 188 deletions(-) create mode 100644 pkg/gallery/gallery.go create mode 100644 pkg/utils/uri.go diff --git a/api/gallery.go b/api/gallery.go index bb74c669..dfdfa1c6 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -46,7 +46,8 @@ func newGalleryApplier(modelPath string) *galleryApplier { } } -func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { +// prepareModel applies a +func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { var config gallery.Config err := req.Get(&config) @@ -56,21 +57,16 @@ func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, config.Files = append(config.Files, req.AdditionalFiles...) - if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil { - return err - } - - // Reload models - return cm.LoadConfigs(modelPath) + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) } -func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { +func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) { g.Lock() defer g.Unlock() g.statuses[s] = op } -func (g *galleryApplier) getstatus(s string) *galleryOpStatus { +func (g *galleryApplier) getStatus(s string) *galleryOpStatus { g.Lock() defer g.Unlock() @@ -84,39 +80,40 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { case <-c.Done(): return case op := <-g.C: - g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) + // updates the status with an error updateError := func(e error) { - g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) + g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) } + // displayDownload displays the download progress + progressCallback := func(fileName string, current string, total string, percentage float64) { + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + displayDownload(fileName, current, total, percentage) + } + + var err error + // if the request contains a gallery name, we apply the gallery from the gallery list if op.galleryName != "" { - if err := gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, func(fileName string, current string, total string, percentage float64) { - g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) - }); err != nil { - updateError(err) - continue - } - - // Reload models - err := cm.LoadConfigs(g.modelPath) - if err != nil { - updateError(err) - continue - } - + err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) } else { - if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { - g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) - }); err != nil { - updateError(err) - continue - } + err = prepareModel(g.modelPath, op.req, cm, progressCallback) } - g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) + if err != nil { + updateError(err) + continue + } + + // Reload models + err = cm.LoadConfigs(g.modelPath) + if err != nil { + updateError(err) + continue + } + + g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) } } }() @@ -159,7 +156,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { + if err := prepareModel(modelPath, r, cm, displayDownload); err != nil { return err } } @@ -175,7 +172,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { + if err := prepareModel(modelPath, r, cm, displayDownload); err != nil { return err } } @@ -186,7 +183,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - status := g.getstatus(c.Params("uuid")) + status := g.getStatus(c.Params("uuid")) if status == nil { return fmt.Errorf("could not find any status for ID") } @@ -229,7 +226,7 @@ func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", galleries) - models, err := gallery.AvailableModels(galleries) + models, err := gallery.AvailableGalleryModels(galleries) if err != nil { return err } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go new file mode 100644 index 00000000..b10f2826 --- /dev/null +++ b/pkg/gallery/gallery.go @@ -0,0 +1,92 @@ +package gallery + +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/imdario/mergo" + "gopkg.in/yaml.v2" +) + +type Gallery struct { + URL string `json:"url" yaml:"url"` + Name string `json:"name" yaml:"name"` +} + +// Installs a model from the gallery (galleryname@modelname) +func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { + models, err := AvailableGalleryModels(galleries) + if err != nil { + return err + } + + applyModel := func(model *GalleryModel) error { + var config Config + + err := model.Get(&config) + if err != nil { + return err + } + + if req.Name != "" { + model.Name = req.Name + } + + config.Files = append(config.Files, req.AdditionalFiles...) + config.Files = append(config.Files, model.AdditionalFiles...) + + // TODO model.Overrides could be merged with user overrides (not defined yet) + if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil { + return err + } + + if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { + return err + } + + return nil + } + + for _, model := range models { + if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { + return applyModel(model) + } + } + + return fmt.Errorf("no model found with name %q", name) +} + +// List available models +// Models galleries are a list of json files that are hosted on a remote server (for example github). +// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. +func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) { + var models []*GalleryModel + + // Get models from galleries + for _, gallery := range galleries { + galleryModels, err := getGalleryModels(gallery) + if err != nil { + return nil, err + } + models = append(models, galleryModels...) + } + + return models, nil +} + +func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) { + var models []*GalleryModel = []*GalleryModel{} + + err := utils.GetURI(gallery.URL, func(d []byte) error { + return yaml.Unmarshal(d, &models) + }) + if err != nil { + return models, err + } + + // Add gallery to models + for _, model := range models { + model.Gallery = gallery + } + return models, nil +} diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 7f1e6d16..ee28f03a 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -5,12 +5,10 @@ import ( "fmt" "hash" "io" - "io/ioutil" "net/http" "os" "path/filepath" "strconv" - "strings" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" @@ -83,7 +81,7 @@ func ReadConfigFile(filePath string) (*Config, error) { return &config, nil } -func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { +func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { @@ -301,115 +299,3 @@ func calculateSHA(filePath string) (string, error) { return fmt.Sprintf("%x", hash.Sum(nil)), nil } - -type Gallery struct { - URL string `json:"url" yaml:"url"` - Name string `json:"name" yaml:"name"` -} - -// Installs a model from the gallery (galleryname@modelname) -func ApplyModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - models, err := AvailableModels(galleries) - if err != nil { - return err - } - - applyModel := func(model *GalleryModel) error { - var config Config - - err := model.Get(&config) - if err != nil { - return err - } - - if req.Name != "" { - model.Name = req.Name - } - - config.Files = append(config.Files, req.AdditionalFiles...) - config.Files = append(config.Files, model.AdditionalFiles...) - - // TODO model.Overrides could be merged with user overrides (not defined yet) - if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil { - return err - } - - if err := Apply(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { - return err - } - - return nil - } - - for _, model := range models { - if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { - return applyModel(model) - } - } - - return fmt.Errorf("no model found with name %q", name) -} - -// List available models -// Models galleries are a list of json files that are hosted on a remote server (for example github). -// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. -func AvailableModels(galleries []Gallery) ([]*GalleryModel, error) { - var models []*GalleryModel - - // Get models from galleries - for _, gallery := range galleries { - galleryModels, err := getModels(gallery) - if err != nil { - return nil, err - } - models = append(models, galleryModels...) - } - - return models, nil -} - -func getModels(gallery Gallery) ([]*GalleryModel, error) { - var models []*GalleryModel = []*GalleryModel{} - if strings.HasPrefix(gallery.URL, "file://") { - rawURL := strings.TrimPrefix(gallery.URL, "file://") - // Read the response body - body, err := ioutil.ReadFile(rawURL) - if err != nil { - return models, err - } - - // Unmarshal YAML data into a struct - err = yaml.Unmarshal(body, &models) - if err != nil { - return models, err - } - - // Add gallery to models - for _, model := range models { - model.Gallery = gallery - } - return models, nil - } - // Get list of models - resp, err := http.Get(gallery.URL) - if err != nil { - return nil, fmt.Errorf("failed to get models: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to get models: %s", resp.Status) - } - - err = yaml.NewDecoder(resp.Body).Decode(&models) - if err != nil { - return nil, fmt.Errorf("failed to decode models: %v", err) - } - - // Add gallery to models - for _, model := range models { - model.Gallery = gallery - } - - return models, nil -} diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index 1a4c4ad6..c0612c03 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -20,7 +20,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -60,13 +60,13 @@ var _ = Describe("Model test", func() { }, } - models, err := AvailableModels(galleries) + models, err := AvailableGalleryModels(galleries) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Name).To(Equal("bert")) Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml")) - err = ApplyModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}) + err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) @@ -85,7 +85,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -101,7 +101,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -127,7 +127,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) + err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).To(HaveOccurred()) }) }) diff --git a/pkg/gallery/request.go b/pkg/gallery/request.go index 9f4a6595..e6fde737 100644 --- a/pkg/gallery/request.go +++ b/pkg/gallery/request.go @@ -2,11 +2,10 @@ package gallery import ( "fmt" - "io/ioutil" - "net/http" "net/url" "strings" + "github.com/go-skynet/LocalAI/pkg/utils" "gopkg.in/yaml.v2" ) @@ -68,31 +67,7 @@ func (request GalleryModel) Get(i interface{}) error { return err } - if strings.HasPrefix(url, "file://") { - rawURL := strings.TrimPrefix(url, "file://") - // Read the response body - body, err := ioutil.ReadFile(rawURL) - if err != nil { - return err - } - - // Unmarshal YAML data into a struct - return yaml.Unmarshal(body, i) - } - - // Send a GET request to the URL - response, err := http.Get(url) - if err != nil { - return err - } - defer response.Body.Close() - - // Read the response body - body, err := ioutil.ReadAll(response.Body) - if err != nil { - return err - } - - // Unmarshal YAML data into a struct - return yaml.Unmarshal(body, i) + return utils.GetURI(url, func(d []byte) error { + return yaml.Unmarshal(d, i) + }) } diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go new file mode 100644 index 00000000..753a2831 --- /dev/null +++ b/pkg/utils/uri.go @@ -0,0 +1,37 @@ +package utils + +import ( + "io/ioutil" + "net/http" + "strings" +) + +func GetURI(url string, f func(i []byte) error) error { + if strings.HasPrefix(url, "file://") { + rawURL := strings.TrimPrefix(url, "file://") + // Read the response body + body, err := ioutil.ReadFile(rawURL) + if err != nil { + return err + } + + // Unmarshal YAML data into a struct + return f(body) + } + + // Send a GET request to the URL + response, err := http.Get(url) + if err != nil { + return err + } + defer response.Body.Close() + + // Read the response body + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return err + } + + // Unmarshal YAML data into a struct + return f(body) +}