Gallery repository (#663)

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-06-24 08:18:17 +02:00 committed by GitHub
parent 2a45a99737
commit 60db5957d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 644 additions and 194 deletions

View file

@ -104,7 +104,9 @@ func App(opts ...AppOption) (*fiber.App, error) {
// LocalAI API endpoints
applier := newGalleryApplier(options.loader.ModelPath)
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
app.Get("/models/list", listModelFromGallery(options.galleries, options.loader.ModelPath))
app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint
@ -120,6 +122,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
// completion
app.Post("/v1/completions", completionEndpoint(cm, options))
app.Post("/completions", completionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options))
// embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))

View file

@ -13,6 +13,7 @@ import (
"runtime"
. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
@ -24,6 +25,7 @@ import (
)
type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"`
Name string `json:"name"`
Overrides map[string]string `json:"overrides"`
@ -52,6 +54,35 @@ func getModelStatus(url string) (response map[string]interface{}) {
}
return
}
func getModels(url string) (response []gallery.GalleryModel) {
//url := "http://localhost:AI/models/apply"
// Create the request payload
// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
return nil
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
//url := "http://localhost:AI/models/apply"
@ -118,7 +149,33 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
g := []gallery.GalleryModel{
{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
},
{
Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
},
}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
},
}
app, err = App(WithContext(c),
WithGalleries(galleries),
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -143,6 +200,53 @@ var _ = Describe("API test", func() {
})
Context("Applying models", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
resp := map[string]interface{}{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
}, "360s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
models = getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
for _, m := range models {
if m.Name == "bert2" {
Expect(m.Installed).To(BeTrue())
} else {
Expect(m.Installed).To(BeFalse())
}
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",

View file

@ -97,7 +97,7 @@ func ReadConfig(file string) (*Config, error) {
return c, nil
}
func (cm ConfigMerger) LoadConfigFile(file string) error {
func (cm *ConfigMerger) LoadConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
@ -111,7 +111,7 @@ func (cm ConfigMerger) LoadConfigFile(file string) error {
return nil
}
func (cm ConfigMerger) LoadConfig(file string) error {
func (cm *ConfigMerger) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
@ -123,14 +123,14 @@ func (cm ConfigMerger) LoadConfig(file string) error {
return nil
}
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm ConfigMerger) ListConfigs() []string {
func (cm *ConfigMerger) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
@ -140,7 +140,7 @@ func (cm ConfigMerger) ListConfigs() []string {
return res
}
func (cm ConfigMerger) LoadConfigs(path string) error {
func (cm *ConfigMerger) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
@ -316,20 +316,32 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
}
var config *Config
cfg, exists := cm.GetConfig(modelFile)
if !exists {
defaults := func() {
config = defaultConfig(modelFile)
config.ContextSize = ctx
config.Threads = threads
config.F16 = f16
config.Debug = debug
}
cfg, exists := cm.GetConfig(modelFile)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfg, exists = cm.GetConfig(modelFile)
if exists {
config = &cfg
} else {
defaults()
}
} else {
defaults()
}
} else {
config = &cfg
}

View file

@ -2,26 +2,24 @@ package api
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
json "github.com/json-iterator/go"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type galleryOp struct {
req ApplyGalleryModelRequest
id string
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
@ -48,49 +46,27 @@ func newGalleryApplier(modelPath string) *galleryApplier {
}
}
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
url, err := req.DecodeURL()
if err != nil {
return err
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a Config struct
// prepareModel applies a
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config
err = yaml.Unmarshal(body, &config)
err := req.Get(&config)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
return err
}
// Reload models
return cm.LoadConfigs(modelPath)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
@ -104,21 +80,40 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
case <-c.Done():
return
case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}); err != nil {
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
@ -154,14 +149,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
if err != nil {
return err
}
var requests []ApplyGalleryModelRequest
var requests []gallery.GalleryModel
err = json.Unmarshal(dat, &requests)
if err != nil {
return err
}
for _, r := range requests {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err
}
}
@ -170,14 +165,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
}
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
var requests []ApplyGalleryModelRequest
var requests []gallery.GalleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
for _, r := range requests {
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
if err := prepareModel(modelPath, r, cm, displayDownload); err != nil {
return err
}
}
@ -185,56 +180,10 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
return nil
}
// endpoints
type ApplyGalleryModelRequest struct {
URL string `json:"url"`
Name string `json:"name"`
Overrides map[string]interface{} `json:"overrides"`
AdditionalFiles []gallery.File `json:"files"`
}
const (
githubURI = "github:"
)
func (request ApplyGalleryModelRequest) DecodeURL() (string, error) {
input := request.URL
var rawURL string
if strings.HasPrefix(input, githubURI) {
parts := strings.Split(input, ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
// Handle regular URLs
u, err := url.Parse(input)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
rawURL = u.String()
} else {
return "", fmt.Errorf("invalid URL format")
}
return rawURL, nil
}
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := g.getstatus(c.Params("uuid"))
status := g.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
@ -243,9 +192,14 @@ func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
}
}
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(ApplyGalleryModelRequest)
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
@ -256,8 +210,10 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) fun
return err
}
g <- galleryOp{
req: *input,
id: uuid.String(),
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
@ -265,3 +221,23 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) fun
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
models, err := gallery.AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}

View file

@ -1,30 +0,0 @@
package api_test
import (
. "github.com/go-skynet/LocalAI/api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses github without a branch", func() {
req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
It("parses URLS", func() {
req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
})
})
})

View file

@ -4,6 +4,7 @@ import (
"context"
"embed"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
)
@ -21,6 +22,8 @@ type Option struct {
preloadModelsFromPath string
corsAllowOrigins string
galleries []gallery.Gallery
backendAssets embed.FS
assetsDestination string
}
@ -66,6 +69,12 @@ func WithBackendAssets(f embed.FS) AppOption {
}
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
o.galleries = append(o.galleries, galleries...)
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
o.context = ctx