Identify installed models

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
mudler 2023-06-24 01:06:12 +02:00
parent e103fa4a87
commit ece70a2268
6 changed files with 37 additions and 8 deletions

View file

@ -106,7 +106,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
app.Get("/models/list", listModelFromGallery(options.galleries))
app.Get("/models/list", listModelFromGallery(options.galleries, options.loader.ModelPath))
app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint

View file

@ -204,6 +204,9 @@ var _ = Describe("API test", func() {
models := getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
@ -231,6 +234,18 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
models = getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
for _, m := range models {
if m.Name == "bert2" {
Expect(m.Installed).To(BeTrue())
} else {
Expect(m.Installed).To(BeFalse())
}
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{

View file

@ -222,11 +222,11 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
}
}
func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error {
func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
models, err := gallery.AvailableGalleryModels(galleries)
models, err := gallery.AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}

View file

@ -2,6 +2,8 @@ package gallery
import (
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
@ -15,7 +17,7 @@ type Gallery struct {
// Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableGalleryModels(galleries)
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
@ -59,12 +61,12 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) {
func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryModels(gallery)
galleryModels, err := getGalleryModels(gallery, basePath)
if err != nil {
return nil, err
}
@ -74,7 +76,7 @@ func AvailableGalleryModels(galleries []Gallery) ([]*GalleryModel, error) {
return models, nil
}
func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) {
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}
err := utils.GetURI(gallery.URL, func(d []byte) error {
@ -87,6 +89,11 @@ func getGalleryModels(gallery Gallery) ([]*GalleryModel, error) {
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
// we check if the model was already installed by checking if the config file exists
// TODO: (what to do if the model doesn't install a config file?)
if _, err := os.Stat(filepath.Join(basePath, fmt.Sprintf("%s.yaml", model.Name))); err == nil {
model.Installed = true
}
}
return models, nil
}

View file

@ -60,11 +60,12 @@ var _ = Describe("Model test", func() {
},
}
models, err := AvailableGalleryModels(galleries)
models, err := AvailableGalleryModels(galleries, tempdir)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
@ -76,6 +77,11 @@ var _ = Describe("Model test", func() {
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
models, err = AvailableGalleryModels(galleries, tempdir)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeTrue())
})
It("renames model correctly", func() {

View file

@ -17,6 +17,7 @@ type GalleryModel struct {
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
AdditionalFiles []File `json:"files" yaml:"files"`
Gallery Gallery `json:"gallery" yaml:"gallery"`
Installed bool `json:"installed" yaml:"installed"`
}
const (