Take on gallery (2)

This commit is contained in:
mudler 2023-06-23 20:04:36 +02:00
parent 5905187fe0
commit 80d30f658c
10 changed files with 193 additions and 68 deletions

View file

@ -106,8 +106,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
app.Post("/models/list", listModelFromGallery(options.galleries))
app.Get("/models/list", listModelFromGallery(options.galleries))
app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint

View file

@ -2,25 +2,23 @@ package api
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"sync"
"time"
json "github.com/json-iterator/go"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []*gallery.Gallery
galleries []gallery.Gallery
galleryName string
}
@ -48,28 +46,28 @@ func newGalleryApplier(modelPath string) *galleryApplier {
}
}
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
url, err := req.DecodeURL()
if err != nil {
return err
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a Config struct
func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config
err = yaml.Unmarshal(body, &config)
err := req.Get(&config)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil {
return err
}
// Reload models
return cm.LoadConfigs(modelPath)
}
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config
err := req.Get(&config)
if err != nil {
return err
}
@ -107,14 +105,17 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
if op.galleryName != "" {
gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req.Name, op.req.Overrides, func(fileName string, current string, total string, percentage float64) {
if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
})
}); err != nil {
updateError(err)
continue
}
} else {
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
@ -209,7 +210,7 @@ type GalleryModel struct {
gallery.GalleryModel
}
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []*gallery.Gallery) func(c *fiber.Ctx) error {
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
@ -234,12 +235,22 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
}
}
func listModelFromGallery(galleries []*gallery.Gallery) func(c *fiber.Ctx) error {
func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
models, err := gallery.AvailableModels(galleries)
if err != nil {
return err
}
return c.JSON(models)
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}

View file

@ -22,7 +22,7 @@ type Option struct {
preloadModelsFromPath string
corsAllowOrigins string
galleries []*gallery.Gallery
galleries []gallery.Gallery
backendAssets embed.FS
assetsDestination string
@ -69,7 +69,7 @@ func WithBackendAssets(f embed.FS) AppOption {
}
}
func WithGalleries(galleries []*gallery.Gallery) AppOption {
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
o.galleries = append(o.galleries, galleries...)
}

8
go.mod
View file

@ -29,6 +29,12 @@ require (
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
)
require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect
@ -52,7 +58,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mudler/go-piper v0.0.0-00010101000000-000000000000 // indirect
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760
github.com/otiai10/mint v1.5.1 // indirect
github.com/philhofer/fwd v1.1.2 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

9
go.sum
View file

@ -62,6 +62,7 @@ github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -75,6 +76,8 @@ github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY=
github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@ -97,6 +100,12 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=

View file

@ -131,8 +131,9 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
galls := ctx.String("galleries")
var galleries []*gallery.Gallery
json.Unmarshal([]byte(galls), galleries)
var galleries []gallery.Gallery
err := json.Unmarshal([]byte(galls), &galleries)
fmt.Println(err)
app, err := api.App(
api.WithConfigFile(ctx.String("config-file")),
api.WithGalleries(galleries),

View file

@ -10,6 +10,7 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
@ -172,6 +173,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
@ -294,47 +296,30 @@ func calculateSHA(filePath string) (string, error) {
}
type Gallery struct {
URL string `json:"url"`
Name string `json:"name"`
URL string `json:"url" yaml:"url"`
Name string `json:"name" yaml:"name"`
}
// Installs a model from the gallery (galleryname@modelname)
func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOverride string, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
func ApplyModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableModels(galleries)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error {
url, err := model.DecodeURL()
if err != nil {
return err
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a Config struct
var config Config
err = yaml.Unmarshal(body, &config)
err := model.Get(&config)
if err != nil {
return err
}
if nameOverride != "" {
model.Name = nameOverride
if req.Name != "" {
model.Name = req.Name
}
// TODO model.Overrides could be merged with user overrides (not defined yet)
if err := mergo.Merge(&model.Overrides, configOverrides, mergo.WithOverride); err != nil {
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
return err
}
@ -357,7 +342,7 @@ func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOver
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) {
func AvailableModels(galleries []Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries
@ -372,9 +357,28 @@ func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) {
return models, nil
}
func getModels(gallery *Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel
func getModels(gallery Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel = []*GalleryModel{}
if strings.HasPrefix(gallery.URL, "file://") {
rawURL := strings.TrimPrefix(gallery.URL, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return models, err
}
// Unmarshal YAML data into a struct
err = yaml.Unmarshal(body, &models)
if err != nil {
return models, err
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
}
return models, nil
}
// Get list of models
resp, err := http.Get(gallery.URL)
if err != nil {

View file

@ -1,6 +1,7 @@
package gallery_test
import (
"io/ioutil"
"os"
"path/filepath"
@ -38,6 +39,45 @@ var _ = Describe("Model test", func() {
Expect(content["context_size"]).To(Equal(1024))
})
It("applies model from gallery correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
gallery := []GalleryModel{{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
}}
out, err := yaml.Marshal(gallery)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tempdir, "gallery_simple.yaml"),
},
}
models, err := AvailableModels(galleries)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
err = ApplyModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
})
It("renames model correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())

View file

@ -2,8 +2,12 @@ package gallery
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"gopkg.in/yaml.v2"
)
// endpoints
@ -13,7 +17,7 @@ type GalleryModel struct {
Name string `json:"name" yaml:"name"`
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
AdditionalFiles []File `json:"files" yaml:"files"`
Gallery *Gallery `json:"gallery" yaml:"gallery"`
Gallery Gallery `json:"gallery" yaml:"gallery"`
}
const (
@ -46,9 +50,48 @@ func (request GalleryModel) DecodeURL() (string, error) {
return "", fmt.Errorf("invalid URL: %w", err)
}
rawURL = u.String()
// check if it's a file path
} else if strings.HasPrefix(input, "file://") {
return input, nil
} else {
return "", fmt.Errorf("invalid URL format")
}
return rawURL, nil
}
// Get fetches a model from a URL and unmarshals it into a struct
func (request GalleryModel) Get(i interface{}) error {
url, err := request.DecodeURL()
if err != nil {
return err
}
if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
}

View file

@ -6,9 +6,21 @@ import (
. "github.com/onsi/gomega"
)
type example struct {
Name string `yaml:"name"`
}
var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
var e example
err := req.Get(&e)
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred())