mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-30 06:30:43 +00:00
Take on gallery (2)
This commit is contained in:
parent
5905187fe0
commit
80d30f658c
10 changed files with 193 additions and 68 deletions
|
@ -106,8 +106,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||||
applier.start(options.context, cm)
|
applier.start(options.context, cm)
|
||||||
|
|
||||||
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
|
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
|
||||||
app.Post("/models/list", listModelFromGallery(options.galleries))
|
app.Get("/models/list", listModelFromGallery(options.galleries))
|
||||||
|
|
||||||
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
|
@ -2,25 +2,23 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
json "github.com/json-iterator/go"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type galleryOp struct {
|
type galleryOp struct {
|
||||||
req gallery.GalleryModel
|
req gallery.GalleryModel
|
||||||
id string
|
id string
|
||||||
galleries []*gallery.Gallery
|
galleries []gallery.Gallery
|
||||||
galleryName string
|
galleryName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,28 +46,28 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
func applyModelFromGallery(modelPath string, name string, basePath string, req gallery.GalleryModel, cm *ConfigMerger, galleries []gallery.Gallery, downloadStatus func(string, string, string, float64)) error {
|
||||||
url, err := req.DecodeURL()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send a GET request to the URL
|
|
||||||
response, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
// Read the response body
|
|
||||||
body, err := ioutil.ReadAll(response.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal YAML data into a Config struct
|
|
||||||
var config gallery.Config
|
var config gallery.Config
|
||||||
err = yaml.Unmarshal(body, &config)
|
|
||||||
|
err := req.Get(&config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||||
|
|
||||||
|
if err := gallery.ApplyModelFromGallery(galleries, name, modelPath, req, downloadStatus); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload models
|
||||||
|
return cm.LoadConfigs(modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyGallery(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
var config gallery.Config
|
||||||
|
|
||||||
|
err := req.Get(&config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -107,14 +105,17 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||||
|
|
||||||
updateError := func(e error) {
|
updateError := func(e error) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
|
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
if op.galleryName != "" {
|
if op.galleryName != "" {
|
||||||
gallery.ApplyModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req.Name, op.req.Overrides, func(fileName string, current string, total string, percentage float64) {
|
if err := applyModelFromGallery(g.modelPath, op.galleryName, g.modelPath, op.req, cm, op.galleries, func(fileName string, current string, total string, percentage float64) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
displayDownload(fileName, current, total, percentage)
|
displayDownload(fileName, current, total, percentage)
|
||||||
})
|
}); err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
|
@ -209,7 +210,7 @@ type GalleryModel struct {
|
||||||
gallery.GalleryModel
|
gallery.GalleryModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []*gallery.Gallery) func(c *fiber.Ctx) error {
|
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
input := new(GalleryModel)
|
input := new(GalleryModel)
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
|
@ -234,12 +235,22 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listModelFromGallery(galleries []*gallery.Gallery) func(c *fiber.Ctx) error {
|
func listModelFromGallery(galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
|
||||||
|
|
||||||
models, err := gallery.AvailableModels(galleries)
|
models, err := gallery.AvailableModels(galleries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(models)
|
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||||
|
for _, m := range models {
|
||||||
|
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ type Option struct {
|
||||||
preloadModelsFromPath string
|
preloadModelsFromPath string
|
||||||
corsAllowOrigins string
|
corsAllowOrigins string
|
||||||
|
|
||||||
galleries []*gallery.Gallery
|
galleries []gallery.Gallery
|
||||||
|
|
||||||
backendAssets embed.FS
|
backendAssets embed.FS
|
||||||
assetsDestination string
|
assetsDestination string
|
||||||
|
@ -69,7 +69,7 @@ func WithBackendAssets(f embed.FS) AppOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithGalleries(galleries []*gallery.Gallery) AppOption {
|
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.galleries = append(o.galleries, galleries...)
|
o.galleries = append(o.galleries, galleries...)
|
||||||
}
|
}
|
||||||
|
|
8
go.mod
8
go.mod
|
@ -29,6 +29,12 @@ require (
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||||
github.com/PuerkitoBio/purell v1.1.1 // indirect
|
github.com/PuerkitoBio/purell v1.1.1 // indirect
|
||||||
|
@ -52,7 +58,7 @@ require (
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||||
github.com/mudler/go-piper v0.0.0-00010101000000-000000000000 // indirect
|
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760
|
||||||
github.com/otiai10/mint v1.5.1 // indirect
|
github.com/otiai10/mint v1.5.1 // indirect
|
||||||
github.com/philhofer/fwd v1.1.2 // indirect
|
github.com/philhofer/fwd v1.1.2 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
|
|
9
go.sum
9
go.sum
|
@ -62,6 +62,7 @@ github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5
|
||||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
@ -75,6 +76,8 @@ github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
|
||||||
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY=
|
github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY=
|
||||||
github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
||||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
@ -97,6 +100,12 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
|
||||||
|
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
||||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
|
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
|
||||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
|
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
|
||||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||||
|
|
5
main.go
5
main.go
|
@ -131,8 +131,9 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
||||||
galls := ctx.String("galleries")
|
galls := ctx.String("galleries")
|
||||||
var galleries []*gallery.Gallery
|
var galleries []gallery.Gallery
|
||||||
json.Unmarshal([]byte(galls), galleries)
|
err := json.Unmarshal([]byte(galls), &galleries)
|
||||||
|
fmt.Println(err)
|
||||||
app, err := api.App(
|
app, err := api.App(
|
||||||
api.WithConfigFile(ctx.String("config-file")),
|
api.WithConfigFile(ctx.String("config-file")),
|
||||||
api.WithGalleries(galleries),
|
api.WithGalleries(galleries),
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/imdario/mergo"
|
"github.com/imdario/mergo"
|
||||||
|
@ -172,6 +173,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
|
||||||
// Verify SHA
|
// Verify SHA
|
||||||
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
|
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
|
||||||
if calculatedSHA != file.SHA256 {
|
if calculatedSHA != file.SHA256 {
|
||||||
|
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
||||||
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -294,47 +296,30 @@ func calculateSHA(filePath string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Gallery struct {
|
type Gallery struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url" yaml:"url"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name" yaml:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Installs a model from the gallery (galleryname@modelname)
|
// Installs a model from the gallery (galleryname@modelname)
|
||||||
func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOverride string, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
func ApplyModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||||
models, err := AvailableModels(galleries)
|
models, err := AvailableModels(galleries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
applyModel := func(model *GalleryModel) error {
|
applyModel := func(model *GalleryModel) error {
|
||||||
url, err := model.DecodeURL()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Send a GET request to the URL
|
|
||||||
response, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
// Read the response body
|
|
||||||
body, err := ioutil.ReadAll(response.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal YAML data into a Config struct
|
|
||||||
var config Config
|
var config Config
|
||||||
err = yaml.Unmarshal(body, &config)
|
|
||||||
|
err := model.Get(&config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if nameOverride != "" {
|
if req.Name != "" {
|
||||||
model.Name = nameOverride
|
model.Name = req.Name
|
||||||
}
|
}
|
||||||
// TODO model.Overrides could be merged with user overrides (not defined yet)
|
// TODO model.Overrides could be merged with user overrides (not defined yet)
|
||||||
if err := mergo.Merge(&model.Overrides, configOverrides, mergo.WithOverride); err != nil {
|
if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,7 +342,7 @@ func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOver
|
||||||
// List available models
|
// List available models
|
||||||
// Models galleries are a list of json files that are hosted on a remote server (for example github).
|
// Models galleries are a list of json files that are hosted on a remote server (for example github).
|
||||||
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
|
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
|
||||||
func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) {
|
func AvailableModels(galleries []Gallery) ([]*GalleryModel, error) {
|
||||||
var models []*GalleryModel
|
var models []*GalleryModel
|
||||||
|
|
||||||
// Get models from galleries
|
// Get models from galleries
|
||||||
|
@ -372,9 +357,28 @@ func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) {
|
||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getModels(gallery *Gallery) ([]*GalleryModel, error) {
|
func getModels(gallery Gallery) ([]*GalleryModel, error) {
|
||||||
var models []*GalleryModel
|
var models []*GalleryModel = []*GalleryModel{}
|
||||||
|
if strings.HasPrefix(gallery.URL, "file://") {
|
||||||
|
rawURL := strings.TrimPrefix(gallery.URL, "file://")
|
||||||
|
// Read the response body
|
||||||
|
body, err := ioutil.ReadFile(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal YAML data into a struct
|
||||||
|
err = yaml.Unmarshal(body, &models)
|
||||||
|
if err != nil {
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add gallery to models
|
||||||
|
for _, model := range models {
|
||||||
|
model.Gallery = gallery
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
// Get list of models
|
// Get list of models
|
||||||
resp, err := http.Get(gallery.URL)
|
resp, err := http.Get(gallery.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gallery_test
|
package gallery_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
@ -38,6 +39,45 @@ var _ = Describe("Model test", func() {
|
||||||
Expect(content["context_size"]).To(Equal(1024))
|
Expect(content["context_size"]).To(Equal(1024))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("applies model from gallery correctly", func() {
|
||||||
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer os.RemoveAll(tempdir)
|
||||||
|
|
||||||
|
gallery := []GalleryModel{{
|
||||||
|
Name: "bert",
|
||||||
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||||
|
}}
|
||||||
|
out, err := yaml.Marshal(gallery)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
err = ioutil.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
galleries := []Gallery{
|
||||||
|
{
|
||||||
|
Name: "test",
|
||||||
|
URL: "file://" + filepath.Join(tempdir, "gallery_simple.yaml"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := AvailableModels(galleries)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(models)).To(Equal(1))
|
||||||
|
Expect(models[0].Name).To(Equal("bert"))
|
||||||
|
Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
|
||||||
|
|
||||||
|
err = ApplyModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
content := map[string]interface{}{}
|
||||||
|
err = yaml.Unmarshal(dat, &content)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(content["backend"]).To(Equal("bert-embeddings"))
|
||||||
|
})
|
||||||
|
|
||||||
It("renames model correctly", func() {
|
It("renames model correctly", func() {
|
||||||
tempdir, err := os.MkdirTemp("", "test")
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -2,8 +2,12 @@ package gallery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// endpoints
|
// endpoints
|
||||||
|
@ -13,7 +17,7 @@ type GalleryModel struct {
|
||||||
Name string `json:"name" yaml:"name"`
|
Name string `json:"name" yaml:"name"`
|
||||||
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
|
Overrides map[string]interface{} `json:"overrides" yaml:"overrides"`
|
||||||
AdditionalFiles []File `json:"files" yaml:"files"`
|
AdditionalFiles []File `json:"files" yaml:"files"`
|
||||||
Gallery *Gallery `json:"gallery" yaml:"gallery"`
|
Gallery Gallery `json:"gallery" yaml:"gallery"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -46,9 +50,48 @@ func (request GalleryModel) DecodeURL() (string, error) {
|
||||||
return "", fmt.Errorf("invalid URL: %w", err)
|
return "", fmt.Errorf("invalid URL: %w", err)
|
||||||
}
|
}
|
||||||
rawURL = u.String()
|
rawURL = u.String()
|
||||||
|
// check if it's a file path
|
||||||
|
} else if strings.HasPrefix(input, "file://") {
|
||||||
|
return input, nil
|
||||||
} else {
|
} else {
|
||||||
return "", fmt.Errorf("invalid URL format")
|
return "", fmt.Errorf("invalid URL format")
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawURL, nil
|
return rawURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get fetches a model from a URL and unmarshals it into a struct
|
||||||
|
func (request GalleryModel) Get(i interface{}) error {
|
||||||
|
url, err := request.DecodeURL()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(url, "file://") {
|
||||||
|
rawURL := strings.TrimPrefix(url, "file://")
|
||||||
|
// Read the response body
|
||||||
|
body, err := ioutil.ReadFile(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal YAML data into a struct
|
||||||
|
return yaml.Unmarshal(body, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a GET request to the URL
|
||||||
|
response, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
body, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal YAML data into a struct
|
||||||
|
return yaml.Unmarshal(body, i)
|
||||||
|
}
|
||||||
|
|
|
@ -6,9 +6,21 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type example struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Gallery API tests", func() {
|
var _ = Describe("Gallery API tests", func() {
|
||||||
|
|
||||||
Context("requests", func() {
|
Context("requests", func() {
|
||||||
It("parses github with a branch", func() {
|
It("parses github with a branch", func() {
|
||||||
|
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||||
|
var e example
|
||||||
|
err := req.Get(&e)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||||
|
})
|
||||||
|
It("parses github without a branch", func() {
|
||||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||||
str, err := req.DecodeURL()
|
str, err := req.DecodeURL()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue