Add tests, simplify

This commit is contained in:
mudler 2023-06-23 22:37:32 +02:00
parent 80d30f658c
commit 3fb1648976
3 changed files with 91 additions and 21 deletions

View file

@ -13,6 +13,7 @@ import (
"runtime" "runtime"
. "github.com/go-skynet/LocalAI/api" . "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -24,6 +25,7 @@ import (
) )
type modelApplyRequest struct { type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"` URL string `json:"url"`
Name string `json:"name"` Name string `json:"name"`
Overrides map[string]string `json:"overrides"` Overrides map[string]string `json:"overrides"`
@ -52,6 +54,35 @@ func getModelStatus(url string) (response map[string]interface{}) {
} }
return return
} }
func getModels(url string) (response []gallery.GalleryModel) {
//url := "http://localhost:AI/models/apply"
// Create the request payload
// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
return nil
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
//url := "http://localhost:AI/models/apply" //url := "http://localhost:AI/models/apply"
@ -118,7 +149,25 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir) modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) g := []gallery.GalleryModel{{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
}}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
},
}
app, err = App(WithContext(c),
WithGalleries(galleries),
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -143,6 +192,36 @@ var _ = Describe("API test", func() {
}) })
Context("Applying models", func() { Context("Applying models", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(1))
fmt.Println(models)
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
resp := map[string]interface{}{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
}, "360s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
})
It("overrides models", func() { It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",

View file

@ -46,24 +46,6 @@ func newGalleryApplier(modelPath string) *galleryApplier {
} }
} }
func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config
err := req.Get(&config)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil {
return err
}
// Reload models
return cm.LoadConfigs(modelPath)
}
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config var config gallery.Config
@ -109,13 +91,21 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
} }
if op.galleryName != "" { if op.galleryName != "" {
if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) { if err := gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage) displayDownload(fileName, current, total, percentage)
}); err != nil { }); err != nil {
updateError(err) updateError(err)
continue continue
} }
// Reload models
err := cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
} else { } else {
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})

View file

@ -54,7 +54,8 @@ func (request GalleryModel) DecodeURL() (string, error) {
} else if strings.HasPrefix(input, "file://") { } else if strings.HasPrefix(input, "file://") {
return input, nil return input, nil
} else { } else {
return "", fmt.Errorf("invalid URL format")
return "", fmt.Errorf("invalid URL format: %s", input)
} }
return rawURL, nil return rawURL, nil