From 1c30c2446c07a55252f4f0039aa3b82a7791a022 Mon Sep 17 00:00:00 2001 From: mudler Date: Fri, 23 Jun 2023 00:23:50 +0200 Subject: [PATCH] feat: add gallery repositories Signed-off-by: mudler --- api/api.go | 5 +- api/gallery.go | 104 +++++++---------- api/options.go | 9 ++ main.go | 11 ++ pkg/gallery/models.go | 107 ++++++++++++++++++ pkg/gallery/request.go | 54 +++++++++ .../gallery/request_test.go | 10 +- 7 files changed, 231 insertions(+), 69 deletions(-) create mode 100644 pkg/gallery/request.go rename api/gallery_test.go => pkg/gallery/request_test.go (69%) diff --git a/api/api.go b/api/api.go index 6f2ac143..4cb35627 100644 --- a/api/api.go +++ b/api/api.go @@ -104,7 +104,10 @@ func App(opts ...AppOption) (*fiber.App, error) { // LocalAI API endpoints applier := newGalleryApplier(options.loader.ModelPath) applier.start(options.context, cm) - app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C)) + + app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) + app.Post("/models/list", listModelFromGallery(options.galleries)) + app.Get("/models/jobs/:uuid", getOpStatus(applier)) // openAI compatible API endpoint diff --git a/api/gallery.go b/api/gallery.go index a9a87220..2206a45c 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -6,9 +6,7 @@ import ( "fmt" "io/ioutil" "net/http" - "net/url" "os" - "strings" "sync" "time" @@ -20,8 +18,10 @@ import ( ) type galleryOp struct { - req ApplyGalleryModelRequest - id string + req gallery.GalleryModel + id string + galleries []*gallery.Gallery + galleryName string } type galleryOpStatus struct { @@ -48,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } } -func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { +func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { url, err := req.DecodeURL() if err != nil { return err @@ -110,12 +110,19 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) } - if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { - g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) - }); err != nil { - updateError(err) - continue + if op.galleryName != "" { + gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req.Name, op.req.Overrides, func(fileName string, current string, total string, percentage float64) { + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + displayDownload(fileName, current, total, percentage) + }) + } else { + if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + displayDownload(fileName, current, total, percentage) + }); err != nil { + updateError(err) + continue + } } g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) @@ -154,7 +161,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { if err != nil { return err } - var requests []ApplyGalleryModelRequest + var requests []gallery.GalleryModel err = json.Unmarshal(dat, &requests) if err != nil { return err @@ -170,7 +177,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { } func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { - var requests []ApplyGalleryModelRequest + var requests []gallery.GalleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { return err @@ -185,52 +192,6 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { return nil } -// endpoints - -type ApplyGalleryModelRequest struct { - URL string `json:"url"` - Name string `json:"name"` - Overrides map[string]interface{} `json:"overrides"` - AdditionalFiles []gallery.File `json:"files"` -} - -const ( - githubURI = "github:" -) - -func (request ApplyGalleryModelRequest) DecodeURL() (string, error) { - input := request.URL - var rawURL string - - if strings.HasPrefix(input, githubURI) { - parts := strings.Split(input, ":") - repoParts := strings.Split(parts[1], "@") - branch := "main" - - if len(repoParts) > 1 { - branch = repoParts[1] - } - - repoPath := strings.Split(repoParts[0], "/") - org := repoPath[0] - project := repoPath[1] - projectPath := strings.Join(repoPath[2:], "/") - - rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) - } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { - // Handle regular URLs - u, err := url.Parse(input) - if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) - } - rawURL = u.String() - } else { - return "", fmt.Errorf("invalid URL format") - } - - return rawURL, nil -} - func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -243,9 +204,14 @@ func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { } } -func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error { +type GalleryModel struct { + ID string `json:"id"` + gallery.GalleryModel +} + +func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []*gallery.Gallery) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(ApplyGalleryModelRequest) + input := new(GalleryModel) // Get input data from the request body if err := c.BodyParser(input); err != nil { return err @@ -256,8 +222,10 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) fun return err } g <- galleryOp{ - req: *input, - id: uuid.String(), + req: input.GalleryModel, + id: uuid.String(), + galleryName: input.ID, + galleries: galleries, } return c.JSON(struct { ID string `json:"uuid"` @@ -265,3 +233,13 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) fun }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) } } + +func listModelFromGallery(galleries []*gallery.Gallery) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + models, err := gallery.AvailableModels(galleries) + if err != nil { + return err + } + return c.JSON(models) + } +} diff --git a/api/options.go b/api/options.go index 3d94eaa8..59b0d3c9 100644 --- a/api/options.go +++ b/api/options.go @@ -4,6 +4,7 @@ import ( "context" "embed" + "github.com/go-skynet/LocalAI/pkg/gallery" model "github.com/go-skynet/LocalAI/pkg/model" ) @@ -21,6 +22,8 @@ type Option struct { preloadModelsFromPath string corsAllowOrigins string + galleries []*gallery.Gallery + backendAssets embed.FS assetsDestination string } @@ -66,6 +69,12 @@ func WithBackendAssets(f embed.FS) AppOption { } } +func WithGalleries(galleries []*gallery.Gallery) AppOption { + return func(o *Option) { + o.galleries = append(o.galleries, galleries...) + } +} + func WithContext(ctx context.Context) AppOption { return func(o *Option) { o.context = ctx diff --git a/main.go b/main.go index dc6968ad..a547d1cc 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,13 @@ package main import ( + "encoding/json" "fmt" "os" "path/filepath" api "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/pkg/gallery" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -53,6 +55,11 @@ func main() { EnvVars: []string{"MODELS_PATH"}, Value: filepath.Join(path, "models"), }, + &cli.StringFlag{ + Name: "galleries", + Usage: "JSON list of galleries", + EnvVars: []string{"GALLERIES"}, + }, &cli.StringFlag{ Name: "preload-models", Usage: "A List of models to apply in JSON at start", @@ -123,8 +130,12 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) + galls := ctx.String("galleries") + var galleries []*gallery.Gallery + json.Unmarshal([]byte(galls), galleries) app, err := api.App( api.WithConfigFile(ctx.String("config-file")), + api.WithGalleries(galleries), api.WithJSONStringPreload(ctx.String("preload-models")), api.WithYAMLConfigPreload(ctx.String("preload-models-config")), api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 8d4cd296..acc1d2f5 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -5,6 +5,7 @@ import ( "fmt" "hash" "io" + "io/ioutil" "net/http" "os" "path/filepath" @@ -291,3 +292,109 @@ func calculateSHA(filePath string) (string, error) { return fmt.Sprintf("%x", hash.Sum(nil)), nil } + +type Gallery struct { + URL string `json:"url"` + Name string `json:"name"` +} + +// Installs a model from the gallery (galleryname@modelname) +func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOverride string, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { + models, err := AvailableModels(galleries) + if err != nil { + return err + } + + applyModel := func(model *GalleryModel) error { + url, err := model.DecodeURL() + if err != nil { + return err + } + // Send a GET request to the URL + response, err := http.Get(url) + if err != nil { + return err + } + defer response.Body.Close() + + // Read the response body + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return err + } + + // Unmarshal YAML data into a Config struct + var config Config + err = yaml.Unmarshal(body, &config) + if err != nil { + return err + } + + if nameOverride != "" { + model.Name = nameOverride + } + // TODO model.Overrides could be merged with user overrides (not defined yet) + if err := mergo.Merge(&model.Overrides, configOverrides, mergo.WithOverride); err != nil { + return err + } + + if err := Apply(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { + return err + } + + return nil + } + + for _, model := range models { + if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { + return applyModel(model) + } + } + + return fmt.Errorf("no model found with name %q", name) +} + +// List available models +// Models galleries are a list of json files that are hosted on a remote server (for example github). +// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. +func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) { + var models []*GalleryModel + + // Get models from galleries + for _, gallery := range galleries { + galleryModels, err := getModels(gallery) + if err != nil { + return nil, err + } + models = append(models, galleryModels...) + } + + return models, nil +} + +func getModels(gallery *Gallery) ([]*GalleryModel, error) { + var models []*GalleryModel + + // Get list of models + resp, err := http.Get(gallery.URL) + if err != nil { + return nil, fmt.Errorf("failed to get models: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get models: %s", resp.Status) + } + + err = yaml.NewDecoder(resp.Body).Decode(&models) + if err != nil { + return nil, fmt.Errorf("failed to decode models: %v", err) + } + + // Add gallery to models + for _, model := range models { + model.Gallery = gallery + } + + return models, nil +} diff --git a/pkg/gallery/request.go b/pkg/gallery/request.go new file mode 100644 index 00000000..b824c173 --- /dev/null +++ b/pkg/gallery/request.go @@ -0,0 +1,54 @@ +package gallery + +import ( + "fmt" + "net/url" + "strings" +) + +// endpoints + +type GalleryModel struct { + URL string `json:"url" yaml:"url"` + Name string `json:"name" yaml:"name"` + Overrides map[string]interface{} `json:"overrides" yaml:"overrides"` + AdditionalFiles []File `json:"files" yaml:"files"` + Gallery *Gallery `json:"gallery" yaml:"gallery"` +} + +const ( + githubURI = "github:" +) + +func (request GalleryModel) DecodeURL() (string, error) { + input := request.URL + var rawURL string + + if strings.HasPrefix(input, githubURI) { + parts := strings.Split(input, ":") + repoParts := strings.Split(parts[1], "@") + branch := "main" + + if len(repoParts) > 1 { + branch = repoParts[1] + } + + repoPath := strings.Split(repoParts[0], "/") + org := repoPath[0] + project := repoPath[1] + projectPath := strings.Join(repoPath[2:], "/") + + rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) + } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { + // Handle regular URLs + u, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + rawURL = u.String() + } else { + return "", fmt.Errorf("invalid URL format") + } + + return rawURL, nil +} diff --git a/api/gallery_test.go b/pkg/gallery/request_test.go similarity index 69% rename from api/gallery_test.go rename to pkg/gallery/request_test.go index 1c92c0d5..494168fd 100644 --- a/api/gallery_test.go +++ b/pkg/gallery/request_test.go @@ -1,7 +1,7 @@ -package api_test +package gallery_test import ( - . "github.com/go-skynet/LocalAI/api" + . "github.com/go-skynet/LocalAI/pkg/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -9,19 +9,19 @@ import ( var _ = Describe("Gallery API tests", func() { Context("requests", func() { It("parses github with a branch", func() { - req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} + req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} str, err := req.DecodeURL() Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) }) It("parses github without a branch", func() { - req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} + req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} str, err := req.DecodeURL() Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) }) It("parses URLS", func() { - req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} + req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} str, err := req.DecodeURL() Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))