feat: Centralized Request Processing middleware (#3847)

* squash past, centralize request middleware PR

Signed-off-by: Dave Lee <dave@gray101.com>

* migrate bruno request files to examples repo

Signed-off-by: Dave Lee <dave@gray101.com>

* fix

Signed-off-by: Dave Lee <dave@gray101.com>

* Update tests/e2e-aio/e2e_test.go

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Dave 2025-02-10 06:06:16 -05:00 committed by GitHub
parent c330360785
commit 3cddf24747
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 240975 additions and 821 deletions

View file

@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
@ -17,45 +17,21 @@ import (
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsSoundGenerationRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
if input.Duration != nil {
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
}
if input.Temperature != nil {
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
}
// TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -3,7 +3,7 @@ package elevenlabs
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -20,39 +20,21 @@ import (
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -3,9 +3,9 @@ package jina
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@ -19,58 +19,32 @@ import (
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Cannot parse JSON",
})
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
request := &proto.RerankRequest{
Query: req.Query,
TopN: int32(req.TopN),
Documents: req.Documents,
Query: input.Query,
TopN: int32(input.TopN),
Documents: input.Documents,
}
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
results, err := backend.Rerank(request, ml, appConfig, *cfg)
if err != nil {
return err
}
response := &schema.JINARerankResponse{
Model: req.Model,
Model: input.Model,
}
for _, r := range results.Results {

View file

@ -4,13 +4,15 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model"
)
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
//
// @Summary Get TokenMetrics for Active Slot.
@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
if err != nil {
log.Err(err)

View file

@ -4,10 +4,9 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// TokenizeEndpoint exposes a REST API to tokenize the content
@ -16,42 +15,21 @@ import (
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenizeRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return c.JSON(tokenResponse)
return ctx.JSON(tokenResponse)
}
}

View file

@ -3,7 +3,7 @@ package localai
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -24,37 +24,24 @@ import (
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
if cfg.Backend == "" {
if input.Backend != "" {
cfg.Backend = input.Backend
} else {
cfg.Backend = model.PiperBackend
}
}
if input.Language != "" {
@ -65,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Voice = input.Voice
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -4,9 +4,8 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@ -19,45 +18,20 @@ import (
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.VADRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}

View file

@ -5,18 +5,19 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@ -174,26 +175,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
@ -543,7 +538,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
audios = append(audios, m.StringAudios...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View file

@ -10,12 +10,13 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
@ -26,11 +27,11 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@ -63,22 +64,18 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
}
return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
// Opt-in extra usage flag
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
if config.ResponseFormatMap != nil {
@ -122,7 +119,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses, extraUsage)
go process(id, predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

View file

@ -2,16 +2,17 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
@ -25,20 +26,21 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}

View file

@ -2,11 +2,11 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/google/uuid"
@ -23,14 +23,14 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
log.Debug().Msgf("Parameter Config: %+v", config)

View file

@ -15,6 +15,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
}
if m == "" {
m = "stablediffusion"
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest
}
src := ""
if input.File != "" {
fileData := []byte{}
var err error
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {

View file

@ -37,7 +37,7 @@ func ComputeChoices(
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback)
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}

View file

@ -1,326 +0,0 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())
ctx, cancel := context.WithCancel(o.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
if !cfg.Validate() {
return nil, nil, fmt.Errorf("failed to validate config")
}
return cfg, input, err
}

View file

@ -1,7 +1,6 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
@ -10,6 +9,8 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -25,15 +26,16 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request: %w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {