feat: fiber CSRF (#2482)

new config option - enables or disables the fiber csrf middleware

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-06-04 15:43:46 -04:00 committed by GitHub
parent 2fc6fe806b
commit 4e1463fec2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 28 additions and 0 deletions

View file

@ -43,6 +43,7 @@ type RunCMD struct {
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
@ -77,6 +78,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithModelLibraryURL(r.RemoteLibrary),
config.WithCors(r.CORS),
config.WithCorsAllowOrigins(r.CORSAllowOrigins),
config.WithCsrf(r.CSRF),
config.WithThreads(r.Threads),
config.WithBackendAssets(ctx.BackendAssets),
config.WithBackendAssetsOutput(r.BackendAssetsPath),

View file

@ -26,6 +26,7 @@ type ApplicationConfig struct {
DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration
CORS bool
CSRF bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
@ -87,6 +88,12 @@ func WithCors(b bool) AppOption {
}
}
func WithCsrf(b bool) AppOption {
return func(o *ApplicationConfig) {
o.CSRF = b
}
}
func WithModelLibraryURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.ModelLibraryURL = url

View file

@ -20,6 +20,7 @@ import (
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
@ -167,6 +168,11 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Use(c)
}
if appConfig.CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
app.Use(csrf.New())
}
// Load config jsons
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)