Extend tests

This commit is contained in:
mudler 2023-06-24 00:23:37 +02:00
parent 813ed1c0f1
commit d68f6e65df

View file

@ -149,10 +149,18 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir) modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{{ g := []gallery.GalleryModel{
{
Name: "bert", Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
}} },
{
Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
},
}
out, err := yaml.Marshal(g) out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644) err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
@ -195,10 +203,9 @@ var _ = Describe("API test", func() {
It("applies models from a gallery", func() { It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/list") models := getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(1)) Expect(len(models)).To(Equal(2), fmt.Sprint(models))
fmt.Println(models)
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert", ID: "test@bert2",
}) })
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@ -213,14 +220,17 @@ var _ = Describe("API test", func() {
}, "360s").Should(Equal(true)) }, "360s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error")) Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content) err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
}) })
It("overrides models", func() { It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{