mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 22:20:43 +00:00
Add tests, simplify
This commit is contained in:
parent
80d30f658c
commit
3fb1648976
3 changed files with 91 additions and 21 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api"
|
. "github.com/go-skynet/LocalAI/api"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
@ -24,6 +25,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type modelApplyRequest struct {
|
type modelApplyRequest struct {
|
||||||
|
ID string `json:"id"`
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Overrides map[string]string `json:"overrides"`
|
Overrides map[string]string `json:"overrides"`
|
||||||
|
@ -52,6 +54,35 @@ func getModelStatus(url string) (response map[string]interface{}) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getModels(url string) (response []gallery.GalleryModel) {
|
||||||
|
|
||||||
|
//url := "http://localhost:AI/models/apply"
|
||||||
|
|
||||||
|
// Create the request payload
|
||||||
|
|
||||||
|
// Create the HTTP request
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error reading response body:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal the response into a map[string]interface{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error unmarshaling JSON response:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
|
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
|
||||||
|
|
||||||
//url := "http://localhost:AI/models/apply"
|
//url := "http://localhost:AI/models/apply"
|
||||||
|
@ -118,7 +149,25 @@ var _ = Describe("API test", func() {
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
modelLoader = model.NewModelLoader(tmpdir)
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
|
g := []gallery.GalleryModel{{
|
||||||
|
Name: "bert",
|
||||||
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||||
|
}}
|
||||||
|
out, err := yaml.Marshal(g)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
galleries := []gallery.Gallery{
|
||||||
|
{
|
||||||
|
Name: "test",
|
||||||
|
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app, err = App(WithContext(c),
|
||||||
|
WithGalleries(galleries),
|
||||||
|
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
|
@ -143,6 +192,36 @@ var _ = Describe("API test", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Applying models", func() {
|
Context("Applying models", func() {
|
||||||
|
It("applies models from a gallery", func() {
|
||||||
|
|
||||||
|
models := getModels("http://127.0.0.1:9090/models/list")
|
||||||
|
Expect(len(models)).To(Equal(1))
|
||||||
|
fmt.Println(models)
|
||||||
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
|
ID: "test@bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||||
|
|
||||||
|
uuid := response["uuid"].(string)
|
||||||
|
resp := map[string]interface{}{}
|
||||||
|
Eventually(func() bool {
|
||||||
|
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||||
|
fmt.Println(response)
|
||||||
|
resp = response
|
||||||
|
return response["processed"].(bool)
|
||||||
|
}, "360s").Should(Equal(true))
|
||||||
|
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||||
|
|
||||||
|
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
content := map[string]interface{}{}
|
||||||
|
err = yaml.Unmarshal(dat, &content)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(content["backend"]).To(Equal("bert-embeddings"))
|
||||||
|
|
||||||
|
})
|
||||||
It("overrides models", func() {
|
It("overrides models", func() {
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||||
|
|
|
@ -46,24 +46,6 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error {
|
|
||||||
var config gallery.Config
|
|
||||||
|
|
||||||
err := req.Get(&config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
||||||
|
|
||||||
if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload models
|
|
||||||
return cm.LoadConfigs(modelPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||||
var config gallery.Config
|
var config gallery.Config
|
||||||
|
|
||||||
|
@ -109,13 +91,21 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if op.galleryName != "" {
|
if op.galleryName != "" {
|
||||||
if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) {
|
if err := gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, func(fileName string, current string, total string, percentage float64) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
displayDownload(fileName, current, total, percentage)
|
displayDownload(fileName, current, total, percentage)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
updateError(err)
|
updateError(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reload models
|
||||||
|
err := cm.LoadConfigs(g.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
|
|
|
@ -54,7 +54,8 @@ func (request GalleryModel) DecodeURL() (string, error) {
|
||||||
} else if strings.HasPrefix(input, "file://") {
|
} else if strings.HasPrefix(input, "file://") {
|
||||||
return input, nil
|
return input, nil
|
||||||
} else {
|
} else {
|
||||||
return "", fmt.Errorf("invalid URL format")
|
|
||||||
|
return "", fmt.Errorf("invalid URL format: %s", input)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawURL, nil
|
return rawURL, nil
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue