Take on gallery (2)

This commit is contained in:
mudler 2023-06-23 20:04:36 +02:00
parent 5905187fe0
commit 80d30f658c
10 changed files with 193 additions and 68 deletions

View file

@ -106,8 +106,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
applier.start(options.context, cm) applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
app.Post("/models/list", listModelFromGallery(options.galleries)) app.Get("/models/list", listModelFromGallery(options.galleries))
app.Get("/models/jobs/:uuid", getOpStatus(applier)) app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint // openAI compatible API endpoint

View file

@ -2,25 +2,23 @@ package api
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http"
"os" "os"
"sync" "sync"
"time" "time"
json "github.com/json-iterator/go"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
) )
type galleryOp struct { type galleryOp struct {
req gallery.GalleryModel req gallery.GalleryModel
id string id string
galleries []*gallery.Gallery galleries []gallery.Gallery
galleryName string galleryName string
} }
@ -48,28 +46,28 @@ func newGalleryApplier(modelPath string) *galleryApplier {
} }
} }
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error {
url, err := req.DecodeURL()
if err != nil {
return err
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a Config struct
var config gallery.Config var config gallery.Config
err = yaml.Unmarshal(body, &config)
err := req.Get(&config)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil {
return err
}
// Reload models
return cm.LoadConfigs(modelPath)
}
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
var config gallery.Config
err := req.Get(&config)
if err != nil { if err != nil {
return err return err
} }
@ -107,14 +105,17 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
updateError := func(e error) { updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
} }
if op.galleryName != "" { if op.galleryName != "" {
gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req.Name, op.req.Overrides, func(fileName string, current string, total string, percentage float64) { if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage) displayDownload(fileName, current, total, percentage)
}) }); err != nil {
updateError(err)
continue
}
} else { } else {
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
@ -209,7 +210,7 @@ type GalleryModel struct {
gallery.GalleryModel gallery.GalleryModel
} }
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []*gallery.Gallery) func(c *fiber.Ctx) error { func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(GalleryModel) input := new(GalleryModel)
// Get input data from the request body // Get input data from the request body
@ -234,12 +235,22 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
} }
} }
func listModelFromGallery(galleries []*gallery.Gallery) func(c *fiber.Ctx) error { func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
models, err := gallery.AvailableModels(galleries) models, err := gallery.AvailableModels(galleries)
if err != nil { if err != nil {
return err return err
} }
return c.JSON(models) log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
} }
} }

View file

@ -22,7 +22,7 @@ type Option struct {
preloadModelsFromPath string preloadModelsFromPath string
corsAllowOrigins string corsAllowOrigins string
galleries []*gallery.Gallery galleries []gallery.Gallery
backendAssets embed.FS backendAssets embed.FS
assetsDestination string assetsDestination string
@ -69,7 +69,7 @@ func WithBackendAssets(f embed.FS) AppOption {
} }
} }
func WithGalleries(galleries []*gallery.Gallery) AppOption { func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) { return func(o *Option) {
o.galleries = append(o.galleries, galleries...) o.galleries = append(o.galleries, galleries...)
} }

8
go.mod
View file

@ -29,6 +29,12 @@ require (
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
)
require ( require (
github.com/KyleBanks/depth v1.2.1 // indirect github.com/KyleBanks/depth v1.2.1 // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect
@ -52,7 +58,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mudler/go-piper v0.0.0-00010101000000-000000000000 // indirect github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760
github.com/otiai10/mint v1.5.1 // indirect github.com/otiai10/mint v1.5.1 // indirect
github.com/philhofer/fwd v1.1.2 // indirect github.com/philhofer/fwd v1.1.2 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect

9
go.sum
View file

@ -62,6 +62,7 @@ github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -75,6 +76,8 @@ github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY= github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY=
github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@ -97,6 +100,12 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=

View file

@ -131,8 +131,9 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
galls := ctx.String("galleries") galls := ctx.String("galleries")
var galleries []*gallery.Gallery var galleries []gallery.Gallery
json.Unmarshal([]byte(galls), galleries) err := json.Unmarshal([]byte(galls), &galleries)
fmt.Println(err)
app, err := api.App( app, err := api.App(
api.WithConfigFile(ctx.String("config-file")), api.WithConfigFile(ctx.String("config-file")),
api.WithGalleries(galleries), api.WithGalleries(galleries),

View file

@ -10,6 +10,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo" "github.com/imdario/mergo"
@ -172,6 +173,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
// Verify SHA // Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 { if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
} }
} else { } else {
@ -294,47 +296,30 @@ func calculateSHA(filePath string) (string, error) {
} }
type Gallery struct { type Gallery struct {
URL string `json:"url"` URL string `json:"url" yaml:"url"`
Name string `json:"name"` Name string `json:"name" yaml:"name"`
} }
// Installs a model from the gallery (galleryname@modelname) // Installs a model from the gallery (galleryname@modelname)
func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOverride string, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { func ApplyModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableModels(galleries) models, err := AvailableModels(galleries)
if err != nil { if err != nil {
return err return err
} }
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
url, err := model.DecodeURL()
if err != nil {
return err
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a Config struct
var config Config var config Config
err = yaml.Unmarshal(body, &config)
err := model.Get(&config)
if err != nil { if err != nil {
return err return err
} }
if nameOverride != "" { if req.Name != "" {
model.Name = nameOverride model.Name = req.Name
} }
// TODO model.Overrides could be merged with user overrides (not defined yet) // TODO model.Overrides could be merged with user overrides (not defined yet)
if err := mergo.Merge(&model.Overrides, configOverrides, mergo.WithOverride); err != nil { if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
return err return err
} }
@ -357,7 +342,7 @@ func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOver
// List available models // List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github). // Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. // Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) { func AvailableModels(galleries []Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel var models []*GalleryModel
// Get models from galleries // Get models from galleries
@ -372,9 +357,28 @@ func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) {
return models, nil return models, nil
} }
func getModels(gallery *Gallery) ([]*GalleryModel, error) { func getModels(gallery Gallery) ([]*GalleryModel, error) {
var models []*GalleryModel var models []*GalleryModel = []*GalleryModel{}
if strings.HasPrefix(gallery.URL, "file://") {
rawURL := strings.TrimPrefix(gallery.URL, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return models, err
}
// Unmarshal YAML data into a struct
err = yaml.Unmarshal(body, &models)
if err != nil {
return models, err
}
// Add gallery to models
for _, model := range models {
model.Gallery = gallery
}
return models, nil
}
// Get list of models // Get list of models
resp, err := http.Get(gallery.URL) resp, err := http.Get(gallery.URL)
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package gallery_test package gallery_test
import ( import (
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -38,6 +39,45 @@ var _ = Describe("Model test", func() {
Expect(content["context_size"]).To(Equal(1024)) Expect(content["context_size"]).To(Equal(1024))
}) })
It("applies model from gallery correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
gallery := []GalleryModel{{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
}}
out, err := yaml.Marshal(gallery)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tempdir, "gallery_simple.yaml"),
},
}
models, err := AvailableModels(galleries)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Name).To(Equal("bert"))
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
err = ApplyModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
})
It("renames model correctly", func() { It("renames model correctly", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -2,8 +2,12 @@ package gallery
import ( import (
"fmt" "fmt"
"io/ioutil"
"net/http"
"net/url" "net/url"
"strings" "strings"
"gopkg.in/yaml.v2"
) )
// endpoints // endpoints
@ -13,7 +17,7 @@ type GalleryModel struct {
Name string `json:"name" yaml:"name"` Name string `json:"name" yaml:"name"`
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"` Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
AdditionalFiles []File `json:"files" yaml:"files"` AdditionalFiles []File `json:"files" yaml:"files"`
Gallery *Gallery `json:"gallery" yaml:"gallery"` Gallery Gallery `json:"gallery" yaml:"gallery"`
} }
const ( const (
@ -46,9 +50,48 @@ func (request GalleryModel) DecodeURL() (string, error) {
return "", fmt.Errorf("invalid URL: %w", err) return "", fmt.Errorf("invalid URL: %w", err)
} }
rawURL = u.String() rawURL = u.String()
// check if it's a file path
} else if strings.HasPrefix(input, "file://") {
return input, nil
} else { } else {
return "", fmt.Errorf("invalid URL format") return "", fmt.Errorf("invalid URL format")
} }
return rawURL, nil return rawURL, nil
} }
// Get fetches a model from a URL and unmarshals it into a struct
func (request GalleryModel) Get(i interface{}) error {
url, err := request.DecodeURL()
if err != nil {
return err
}
if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// Read the response body
body, err := ioutil.ReadFile(rawURL)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
}
// Send a GET request to the URL
response, err := http.Get(url)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return yaml.Unmarshal(body, i)
}

View file

@ -6,9 +6,21 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type example struct {
Name string `yaml:"name"`
}
var _ = Describe("Gallery API tests", func() { var _ = Describe("Gallery API tests", func() {
Context("requests", func() { Context("requests", func() {
It("parses github with a branch", func() { It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
var e example
err := req.Get(&e)
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
It("parses github without a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
str, err := req.DecodeURL() str, err := req.DecodeURL()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())