diff --git a/api/api_test.go b/api/api_test.go index c20b91b0..1273007b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -149,10 +149,18 @@ var _ = Describe("API test", func() { modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) - g := []gallery.GalleryModel{{ - Name: "bert", - URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", - }} + g := []gallery.GalleryModel{ + { + Name: "bert", + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + }, + { + Name: "bert2", + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + Overrides: map[string]interface{}{"foo": "bar"}, + AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, + }, + } out, err := yaml.Marshal(g) Expect(err).ToNot(HaveOccurred()) err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644) @@ -195,10 +203,9 @@ var _ = Describe("API test", func() { It("applies models from a gallery", func() { models := getModels("http://127.0.0.1:9090/models/list") - Expect(len(models)).To(Equal(1)) - fmt.Println(models) + Expect(len(models)).To(Equal(2), fmt.Sprint(models)) response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ - ID: "test@bert", + ID: "test@bert2", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) @@ -213,14 +220,17 @@ var _ = Describe("API test", func() { }, "360s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) - dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) + Expect(err).ToNot(HaveOccurred()) + + _, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("bert-embeddings")) - + Expect(content["foo"]).To(Equal("bar")) }) It("overrides models", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{