mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
feat: Centralized Request Processing middleware (#3847)
* squash past, centralize request middleware PR Signed-off-by: Dave Lee <dave@gray101.com> * migrate bruno request files to examples repo Signed-off-by: Dave Lee <dave@gray101.com> * fix Signed-off-by: Dave Lee <dave@gray101.com> * Update tests/e2e-aio/e2e_test.go Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Dave Lee <dave@gray101.com> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
c330360785
commit
3cddf24747
53 changed files with 240975 additions and 821 deletions
|
@ -130,7 +130,6 @@ func API(application *application.Application) (*fiber.App, error) {
|
|||
return metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(router)
|
||||
|
@ -167,13 +166,15 @@ func API(application *application.Application) (*fiber.App, error) {
|
|||
galleryService := services.NewGalleryService(application.ApplicationConfig())
|
||||
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||
routes.RegisterOpenAIRoutes(router, application)
|
||||
requestExtractor := middleware.NewRequestExtractor(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||
}
|
||||
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
package fiberContext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ModelFromContext returns the model from the context
|
||||
// If no model is specified, it will take the first available
|
||||
// Takes a model string as input which should be the one received from the user request.
|
||||
// It returns the model name resolved from the context and an error if any.
|
||||
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
|
||||
if ctx.Params("model") != "" {
|
||||
modelInput = ctx.Params("model")
|
||||
}
|
||||
if ctx.Query("model") != "" {
|
||||
modelInput = ctx.Query("model")
|
||||
}
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelInput == "" && !bearerExists && firstModel {
|
||||
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
if len(models) > 0 {
|
||||
modelInput = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelInput)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelInput = bearer
|
||||
}
|
||||
return modelInput, nil
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -17,45 +17,21 @@ import (
|
|||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.ElevenLabsSoundGenerationRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
|
||||
} else {
|
||||
if input.ModelID != "" {
|
||||
modelFile = input.ModelID
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
}
|
||||
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||
|
||||
if input.Duration != nil {
|
||||
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
|
||||
}
|
||||
if input.Temperature != nil {
|
||||
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
|
||||
}
|
||||
|
||||
// TODO: Support uploading files?
|
||||
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package elevenlabs
|
|||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -20,39 +20,21 @@ import (
|
|||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.ElevenLabsTTSRequest)
|
||||
voiceID := c.Params("voice-id")
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||
} else {
|
||||
if input.ModelID != "" {
|
||||
modelFile = input.ModelID
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
|
||||
|
||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@ package jina
|
|||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
@ -19,58 +19,32 @@ import (
|
|||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
req := new(schema.JINARerankRequest)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "Cannot parse JSON",
|
||||
})
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
input := new(schema.TTSRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
|
||||
if input.Backend != "" {
|
||||
cfg.Backend = input.Backend
|
||||
}
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
|
||||
|
||||
request := &proto.RerankRequest{
|
||||
Query: req.Query,
|
||||
TopN: int32(req.TopN),
|
||||
Documents: req.Documents,
|
||||
Query: input.Query,
|
||||
TopN: int32(input.TopN),
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
|
||||
results, err := backend.Rerank(request, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := &schema.JINARerankResponse{
|
||||
Model: req.Model,
|
||||
Model: input.Model,
|
||||
}
|
||||
|
||||
for _, r := range results.Results {
|
||||
|
|
|
@ -4,13 +4,15 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
|
||||
|
||||
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
|
||||
//
|
||||
// @Summary Get TokenMetrics for Active Slot.
|
||||
|
@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
|
|||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || modelFile != "" {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
|
|
|
@ -4,10 +4,9 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// TokenizeEndpoint exposes a REST API to tokenize the content
|
||||
|
@ -16,42 +15,21 @@ import (
|
|||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TokenizeRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
|
||||
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(tokenResponse)
|
||||
return ctx.JSON(tokenResponse)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package localai
|
|||
import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -24,37 +24,24 @@ import (
|
|||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.TTSRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
|
||||
if input.Backend != "" {
|
||||
cfg.Backend = input.Backend
|
||||
if cfg.Backend == "" {
|
||||
if input.Backend != "" {
|
||||
cfg.Backend = input.Backend
|
||||
} else {
|
||||
cfg.Backend = model.PiperBackend
|
||||
}
|
||||
}
|
||||
|
||||
if input.Language != "" {
|
||||
|
@ -65,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
|||
cfg.Voice = input.Voice
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,9 +4,8 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -19,45 +18,20 @@ import (
|
|||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.VADRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
|
||||
|
||||
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
|
||||
|
||||
vadModel, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req := proto.VADRequest{
|
||||
Audio: input.Audio,
|
||||
}
|
||||
resp, err := vadModel.VAD(c.Context(), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,18 +5,19 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -174,26 +175,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
textContentToReturn = ""
|
||||
id = uuid.New().String()
|
||||
created = int(time.Now().Unix())
|
||||
// Set CorrelationID
|
||||
correlationID := c.Get("X-Correlation-ID")
|
||||
if len(strings.TrimSpace(correlationID)) == 0 {
|
||||
correlationID = id
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
c.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
|
||||
|
||||
funcs := input.Functions
|
||||
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
|
@ -543,7 +538,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
|
|||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("model inference failed")
|
||||
return "", err
|
||||
|
|
|
@ -10,12 +10,13 @@ import (
|
|||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
@ -26,11 +27,11 @@ import (
|
|||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
|
@ -63,22 +64,18 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Add Correlation
|
||||
c.Set("X-Correlation-ID", id)
|
||||
|
||||
// Opt-in extra usage flag
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.ResponseFormatMap != nil {
|
||||
|
@ -122,7 +119,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, ml, responses, extraUsage)
|
||||
go process(id, predInput, input, config, ml, responses, extraUsage)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
|
|
|
@ -2,16 +2,17 @@ package openai
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -25,20 +26,21 @@ import (
|
|||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
|
||||
log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
|
||||
|
||||
var result []schema.Choice
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
|
|
@ -2,11 +2,11 @@ package openai
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -23,14 +23,14 @@ import (
|
|||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) {
|
|||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Image Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
if m == "" {
|
||||
m = "stablediffusion"
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", m)
|
||||
|
||||
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
|
||||
fileData := []byte{}
|
||||
var err error
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
|
|
|
@ -37,7 +37,7 @@ func ComputeChoices(
|
|||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback)
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -10,6 +9,8 @@ import (
|
|||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -25,15 +26,16 @@ import (
|
|||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request: %w", err)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
|
|
|
@ -1,326 +1,450 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
input := new(schema.OpenAIRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
|
||||
|
||||
return modelFile, input, err
|
||||
}
|
||||
|
||||
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
|
||||
config.LoadOptionDebug(debug),
|
||||
config.LoadOptionThreads(threads),
|
||||
config.LoadOptionContextSize(ctx),
|
||||
config.LoadOptionF16(f16),
|
||||
)
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateRequestConfig(cfg, input)
|
||||
|
||||
if !cfg.Validate() {
|
||||
return nil, nil, fmt.Errorf("failed to validate config")
|
||||
}
|
||||
|
||||
return cfg, input, err
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
backendConfigLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
backendConfigLoader: backendConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = ctx.Params("model")
|
||||
|
||||
if (model == "") && ctx.Query("model") != "" {
|
||||
model = ctx.Query("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return fmt.Errorf("unable to initialize body")
|
||||
}
|
||||
if err := ctx.BodyParser(input); err != nil {
|
||||
return fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
|
||||
ctx.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
|
@ -4,17 +4,26 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/text-to-speech/:voice-id",
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
|
||||
elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
|
||||
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/sound-generation",
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
|
||||
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||
|
||||
}
|
||||
|
|
|
@ -3,16 +3,22 @@ package routes
|
|||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/jina"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterJINARoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// POST endpoint to mimic the reranking
|
||||
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/rerank",
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
|
||||
jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||
}
|
||||
|
|
|
@ -5,13 +5,16 @@ import (
|
|||
"github.com/gofiber/swagger"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
requestExtractor *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
|
@ -33,8 +36,18 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
|||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
}
|
||||
|
||||
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||
router.Post("/tts",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(cl, ml, appConfig))
|
||||
|
||||
vadChain := []fiber.Handler{
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
|
||||
localai.VADEndpoint(cl, ml, appConfig),
|
||||
}
|
||||
router.Post("/vad", vadChain...)
|
||||
router.Post("/v1/vad", vadChain...)
|
||||
|
||||
// Stores
|
||||
sl := model.NewModelLoader("")
|
||||
|
@ -47,10 +60,14 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
|||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
}
|
||||
|
||||
// Experimental Backend Statistics Module
|
||||
// Backend Statistics Module
|
||||
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
|
||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
|
||||
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
|
||||
// p2p
|
||||
if p2p.IsP2PEnabled() {
|
||||
|
@ -67,6 +84,9 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
|||
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||
|
||||
// misc
|
||||
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
router.Post("/v1/tokenize",
|
||||
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
|
||||
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
|
||||
localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
|
||||
}
|
||||
|
|
|
@ -3,51 +3,50 @@ package routes
|
|||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
func RegisterOpenAIRoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions",
|
||||
openai.ChatEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/chat/completions",
|
||||
openai.ChatEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
chatChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
}
|
||||
app.Post("/v1/chat/completions", chatChain...)
|
||||
app.Post("/chat/completions", chatChain...)
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits",
|
||||
openai.EditEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
editChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
}
|
||||
app.Post("/v1/edits", editChain...)
|
||||
app.Post("/edits", editChain...)
|
||||
|
||||
app.Post("/edits",
|
||||
openai.EditEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
// completion
|
||||
completionChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
|
||||
}
|
||||
app.Post("/v1/completions", completionChain...)
|
||||
app.Post("/completions", completionChain...)
|
||||
app.Post("/v1/engines/:model/completions", completionChain...)
|
||||
|
||||
// assistant
|
||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
@ -81,45 +80,37 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
|||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/v1/engines/:model/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
embeddingChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
}
|
||||
app.Post("/v1/embeddings", embeddingChain...)
|
||||
app.Post("/embeddings", embeddingChain...)
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/audio/transcriptions",
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
|
||||
)
|
||||
|
||||
app.Post("/v1/audio/speech",
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
|
||||
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/images/generations",
|
||||
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
|
||||
re.SetOpenAIRequest,
|
||||
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
if application.ApplicationConfig().ImageDir != "" {
|
||||
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue