diff --git a/api/api.go b/api/api.go index 0bc1130b..527258b8 100644 --- a/api/api.go +++ b/api/api.go @@ -106,7 +106,7 @@ func App(opts ...AppOption) (*fiber.App, error) { applier.start(options.context, cm) app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) - app.Get("/models/list", listModelFromGallery(options.galleries)) + app.Get("/models/list", listModelFromGallery(options.galleries, options.loader.ModelPath)) app.Get("/models/jobs/:uuid", getOpStatus(applier)) // openAI compatible API endpoint diff --git a/api/api_test.go b/api/api_test.go index 1273007b..05d5e7be 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -204,6 +204,9 @@ var _ = Describe("API test", func() { models := getModels("http://127.0.0.1:9090/models/list") Expect(len(models)).To(Equal(2), fmt.Sprint(models)) + Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) + Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "test@bert2", }) @@ -231,6 +234,18 @@ var _ = Describe("API test", func() { Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["foo"]).To(Equal("bar")) + + models = getModels("http://127.0.0.1:9090/models/list") + Expect(len(models)).To(Equal(2), fmt.Sprint(models)) + Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) + Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) + for _, m := range models { + if m.Name == "bert2" { + Expect(m.Installed).To(BeTrue()) + } else { + Expect(m.Installed).To(BeFalse()) + } + } }) It("overrides models", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ diff --git a/api/gallery.go b/api/gallery.go index dfdfa1c6..46d92f9f 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -222,11 +222,11 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal } } -func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error { +func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", galleries) - models, err := gallery.AvailableGalleryModels(galleries) + models, err := gallery.AvailableGalleryModels(galleries, basePath) if err != nil { return err } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index b10f2826..d4440340 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -2,6 +2,8 @@ package gallery import ( "fmt" + "os" + "path/filepath" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" @@ -15,7 +17,7 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - models, err := AvailableGalleryModels(galleries) + models, err := AvailableGalleryModels(galleries, basePath) if err != nil { return err } @@ -59,12 +61,12 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, // List available models // Models galleries are a list of json files that are hosted on a remote server (for example github). // Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. -func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) { +func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) { var models []*GalleryModel // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryModels(gallery) + galleryModels, err := getGalleryModels(gallery, basePath) if err != nil { return nil, err } @@ -74,7 +76,7 @@ func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) { return models, nil } -func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) { +func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) { var models []*GalleryModel = []*GalleryModel{} err := utils.GetURI(gallery.URL, func(d []byte) error { @@ -87,6 +89,11 @@ func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) { // Add gallery to models for _, model := range models { model.Gallery = gallery + // we check if the model was already installed by checking if the config file exists + // TODO: (what to do if the model doesn't install a config file?) + if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.Name))); err == nil { + model.Installed = true + } } return models, nil } diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index c0612c03..9ea87a70 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -60,11 +60,12 @@ var _ = Describe("Model test", func() { }, } - models, err := AvailableGalleryModels(galleries) + models, err := AvailableGalleryModels(galleries, tempdir) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Name).To(Equal("bert")) Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml")) + Expect(models[0].Installed).To(BeFalse()) err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) @@ -76,6 +77,11 @@ var _ = Describe("Model test", func() { err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("bert-embeddings")) + + models, err = AvailableGalleryModels(galleries, tempdir) + Expect(err).ToNot(HaveOccurred()) + Expect(len(models)).To(Equal(1)) + Expect(models[0].Installed).To(BeTrue()) }) It("renames model correctly", func() { diff --git a/pkg/gallery/request.go b/pkg/gallery/request.go index e6fde737..2e2da3e8 100644 --- a/pkg/gallery/request.go +++ b/pkg/gallery/request.go @@ -17,6 +17,7 @@ type GalleryModel struct { Overrides map[string]interface{} `json:"overrides" yaml:"overrides"` AdditionalFiles []File `json:"files" yaml:"files"` Gallery Gallery `json:"gallery" yaml:"gallery"` + Installed bool `json:"installed" yaml:"installed"` } const (