mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
feat: allow to set cors (#339)
This commit is contained in:
parent
ed5df1e68e
commit
6f54cab3f0
5 changed files with 199 additions and 65 deletions
69
api/api.go
69
api/api.go
|
@ -1,10 +1,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
|
@ -13,16 +11,18 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
|
||||
func App(opts ...AppOption) *fiber.App {
|
||||
options := newOptions(opts...)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if debug {
|
||||
if options.debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: disableMessage,
|
||||
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.disableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
|
@ -43,24 +43,24 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
|
|||
},
|
||||
})
|
||||
|
||||
if debug {
|
||||
if options.debug {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||
}))
|
||||
}
|
||||
|
||||
cm := NewConfigMerger()
|
||||
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
|
||||
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||
}
|
||||
|
||||
if configFile != "" {
|
||||
if err := cm.LoadConfigFile(configFile); err != nil {
|
||||
if options.configFile != "" {
|
||||
if err := cm.LoadConfigFile(options.configFile); err != nil {
|
||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if debug {
|
||||
if options.debug {
|
||||
for _, v := range cm.ListConfigs() {
|
||||
cfg, _ := cm.GetConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
|
@ -68,46 +68,55 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
|
|||
}
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
app.Use(cors.New())
|
||||
|
||||
if options.cors {
|
||||
if options.corsAllowOrigins == "" {
|
||||
app.Use(cors.New())
|
||||
} else {
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: options.corsAllowOrigins,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// LocalAI API endpoints
|
||||
applier := newGalleryApplier(loader.ModelPath)
|
||||
applier.start(c, cm)
|
||||
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
|
||||
applier := newGalleryApplier(options.loader.ModelPath)
|
||||
applier.start(options.context, cm)
|
||||
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
|
||||
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
|
||||
app.Post("/chat/completions", chatEndpoint(cm, options))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/edits", editEndpoint(cm, options))
|
||||
app.Post("/edits", editEndpoint(cm, options))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/completions", completionEndpoint(cm, options))
|
||||
app.Post("/completions", completionEndpoint(cm, options))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
|
||||
app.Post("/embeddings", embeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
||||
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
|
||||
app.Post("/v1/images/generations", imageEndpoint(cm, options))
|
||||
|
||||
if imageDir != "" {
|
||||
app.Static("/generated-images", imageDir)
|
||||
if options.imageDir != "" {
|
||||
app.Static("/generated-images", options.imageDir)
|
||||
}
|
||||
|
||||
// models
|
||||
app.Get("/v1/models", listModels(loader, cm))
|
||||
app.Get("/models", listModels(loader, cm))
|
||||
app.Get("/v1/models", listModels(options.loader, cm))
|
||||
app.Get("/models", listModels(options.loader, cm))
|
||||
|
||||
return app
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue