feat: add /models/apply endpoint to prepare models (#286)

This commit is contained in:
Ettore Di Giacinto 2023-05-18 15:59:03 +02:00 committed by GitHub
parent 5617e50ebc
commit cc9aa9eb3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 556 additions and 33 deletions

View file

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -12,7 +13,7 @@ import (
"github.com/rs/zerolog/log"
)
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}))
}
cm := make(ConfigMerger)
cm := NewConfigMerger()
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}
if debug {
for k, v := range cm {
log.Debug().Msgf("Model: %s (config: %+v)", k, v)
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
// LocalAI API endpoints
applier := newGalleryApplier(loader.ModelPath)
applier.start(c, cm)
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
app.Get("/models/jobs/:uid", getOpStatus(applier))
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
// edit
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
// completion
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
// embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
// /v1/engines/{engine_id}/embeddings
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
// audio
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
// images
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
if imageDir != "" {
app.Static("/generated-images", imageDir)
}
// models
app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader, cm))

View file

@ -22,10 +22,14 @@ var _ = Describe("API test", func() {
var modelLoader *model.ModelLoader
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
var cancel context.CancelFunc
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App("", modelLoader, 15, 1, 512, false, true, true, "")
c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -42,6 +46,7 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("returns the models list", func() {
@ -140,7 +145,9 @@ var _ = Describe("API test", func() {
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
c, cancel = context.WithCancel(context.Background())
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -155,10 +162,10 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("can generate chat completions from config file", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(12))

View file

@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -43,8 +44,16 @@ type TemplateConfig struct {
Edit string `yaml:"edit"`
}
type ConfigMerger map[string]Config
type ConfigMerger struct {
configs map[string]Config
sync.Mutex
}
func NewConfigMerger() *ConfigMerger {
return &ConfigMerger{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) {
}
func (cm ConfigMerger) LoadConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm[cc.Name] = *cc
cm.configs[cc.Name] = *cc
}
return nil
}
func (cm ConfigMerger) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm[c.Name] = *c
cm.configs[c.Name] = *c
return nil
}
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm ConfigMerger) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
func (cm ConfigMerger) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
files, err := ioutil.ReadDir(path)
if err != nil {
return err
@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error {
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
if err == nil {
cm[c.Name] = *c
cm.configs[c.Name] = *c
}
}
@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
return modelFile, input, nil
}
func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader
}
var config *Config
cfg, exists := cm[modelFile]
cfg, exists := cm.GetConfig(modelFile)
if !exists {
config = &Config{
OpenAIRequest: defaultRequest(modelFile),

146
api/gallery.go Normal file
View file

@ -0,0 +1,146 @@
package api
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"sync"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gopkg.in/yaml.v3"
)
type galleryOp struct {
req ApplyGalleryModelRequest
id string
}
type galleryOpStatus struct {
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func newGalleryApplier(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
}
// Send a GET request to the URL
response, err := http.Get(op.req.URL)
if err != nil {
updateError(err)
continue
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
updateError(err)
continue
}
// Unmarshal YAML data into a Config struct
var config gallery.Config
err = yaml.Unmarshal(body, &config)
if err != nil {
updateError(fmt.Errorf("failed to unmarshal YAML: %v", err))
continue
}
if err := gallery.Apply(g.modelPath, op.req.Name, &config); err != nil {
updateError(err)
continue
}
// Reload models
if err := cm.LoadConfigs(g.modelPath); err != nil {
updateError(err)
continue
}
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
}
}
}()
}
// endpoints
type ApplyGalleryModelRequest struct {
URL string `json:"url"`
Name string `json:"name"`
}
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := g.getstatus(c.Params("uid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(ApplyGalleryModelRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
g <- galleryOp{
req: *input,
id: uuid.String(),
}
return c.JSON(struct {
ID string `json:"uid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}

View file

@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest {
}
// https://platform.openai.com/docs/api-reference/completions
func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
// https://platform.openai.com/docs/api-reference/embeddings
func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
}
}
func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
*
*/
func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image
}
// https://platform.openai.com/docs/api-reference/audio/create
func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx)
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
for k := range cm {
for _, k := range cm.ListConfigs() {
if _, exists := mm[k]; !exists {
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
}