diff --git a/api/api_test.go b/api/api_test.go index af2193f3..c20b91b0 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "runtime" . "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" @@ -24,6 +25,7 @@ import ( ) type modelApplyRequest struct { + ID string `json:"id"` URL string `json:"url"` Name string `json:"name"` Overrides map[string]string `json:"overrides"` @@ -52,6 +54,35 @@ func getModelStatus(url string) (response map[string]interface{}) { } return } + +func getModels(url string) (response []gallery.GalleryModel) { + + //url := "http://localhost:AI/models/apply" + + // Create the request payload + + // Create the HTTP request + resp, err := http.Get(url) + if err != nil { + return nil + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + return + } + + // Unmarshal the response into a map[string]interface{} + err = json.Unmarshal(body, &response) + if err != nil { + fmt.Println("Error unmarshaling JSON response:", err) + return + } + return +} + func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { //url := "http://localhost:AI/models/apply" @@ -118,7 +149,25 @@ var _ = Describe("API test", func() { modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) - app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) + g := []gallery.GalleryModel{{ + Name: "bert", + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + }} + out, err := yaml.Marshal(g) + Expect(err).ToNot(HaveOccurred()) + err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644) + Expect(err).ToNot(HaveOccurred()) + + galleries := []gallery.Gallery{ + { + Name: "test", + URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"), + }, + } + + app, err = App(WithContext(c), + WithGalleries(galleries), + WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -143,6 +192,36 @@ var _ = Describe("API test", func() { }) Context("Applying models", func() { + It("applies models from a gallery", func() { + + models := getModels("http://127.0.0.1:9090/models/list") + Expect(len(models)).To(Equal(1)) + fmt.Println(models) + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ID: "test@bert", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + resp := map[string]interface{}{} + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + resp = response + return response["processed"].(bool) + }, "360s").Should(Equal(true)) + Expect(resp["message"]).ToNot(ContainSubstring("error")) + + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["backend"]).To(Equal("bert-embeddings")) + + }) It("overrides models", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", diff --git a/api/gallery.go b/api/gallery.go index ce5640b6..bb74c669 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -46,24 +46,6 @@ func newGalleryApplier(modelPath string) *galleryApplier { } } -func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error { - var config gallery.Config - - err := req.Get(&config) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil { - return err - } - - // Reload models - return cm.LoadConfigs(modelPath) -} - func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { var config gallery.Config @@ -109,13 +91,21 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { } if op.galleryName != "" { - if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) { + if err := gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, func(fileName string, current string, total string, percentage float64) { g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) displayDownload(fileName, current, total, percentage) }); err != nil { updateError(err) continue } + + // Reload models + err := cm.LoadConfigs(g.modelPath) + if err != nil { + updateError(err) + continue + } + } else { if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) diff --git a/pkg/gallery/request.go b/pkg/gallery/request.go index 014ab52f..9f4a6595 100644 --- a/pkg/gallery/request.go +++ b/pkg/gallery/request.go @@ -54,7 +54,8 @@ func (request GalleryModel) DecodeURL() (string, error) { } else if strings.HasPrefix(input, "file://") { return input, nil } else { - return "", fmt.Errorf("invalid URL format") + + return "", fmt.Errorf("invalid URL format: %s", input) } return rawURL, nil