feat: Centralized Request Processing middleware (#3847)

* squash past, centralize request middleware PR

Signed-off-by: Dave Lee <dave@gray101.com>

* migrate bruno request files to examples repo

Signed-off-by: Dave Lee <dave@gray101.com>

* fix

Signed-off-by: Dave Lee <dave@gray101.com>

* Update tests/e2e-aio/e2e_test.go

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Dave 2025-02-10 06:06:16 -05:00 committed by GitHub
parent c330360785
commit 3cddf24747
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 240975 additions and 821 deletions

View file

@ -3,51 +3,50 @@ package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
)
func RegisterOpenAIRoutes(app *fiber.App,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
chatChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/chat/completions", chatChain...)
app.Post("/chat/completions", chatChain...)
// edit
app.Post("/v1/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
editChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.Post("/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// completion
completionChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
// assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
@ -81,45 +80,37 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// completion
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
embeddingChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/transcriptions",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
)
app.Post("/v1/audio/speech",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/images/generations",
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir)