diff --git a/api/api.go b/api/api.go index 543e7566..a2702208 100644 --- a/api/api.go +++ b/api/api.go @@ -2,6 +2,7 @@ package api import ( "errors" + "strings" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" @@ -83,6 +84,26 @@ func App(opts ...AppOption) (*fiber.App, error) { // Default middleware config app.Use(recover.New()) + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. + auth := func(c *fiber.Ctx) error { + if options.apiKey != "" { + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + if apiKey != options.apiKey { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + } + } + return c.Next() + } + if options.preloadJSONModels != "" { if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil { return nil, err @@ -109,42 +130,42 @@ func App(opts ...AppOption) (*fiber.App, error) { applier := newGalleryApplier(options.loader.ModelPath) applier.start(options.context, cm) - app.Get("/version", func(c *fiber.Ctx) error { + app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) - app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath)) - app.Get("/models/jobs/:uuid", getOpStatus(applier)) + app.Post("/models/apply", auth, applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) + app.Get("/models/available", auth, listModelFromGallery(options.galleries, options.loader.ModelPath)) + app.Get("/models/jobs/:uuid", auth, getOpStatus(applier)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", chatEndpoint(cm, options)) - app.Post("/chat/completions", chatEndpoint(cm, options)) + app.Post("/v1/chat/completions", auth, chatEndpoint(cm, options)) + app.Post("/chat/completions", auth, chatEndpoint(cm, options)) // edit - app.Post("/v1/edits", editEndpoint(cm, options)) - app.Post("/edits", editEndpoint(cm, options)) + app.Post("/v1/edits", auth, editEndpoint(cm, options)) + app.Post("/edits", auth, editEndpoint(cm, options)) // completion - app.Post("/v1/completions", completionEndpoint(cm, options)) - app.Post("/completions", completionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options)) + app.Post("/v1/completions", auth, completionEndpoint(cm, options)) + app.Post("/completions", auth, completionEndpoint(cm, options)) + app.Post("/v1/engines/:model/completions", auth, completionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", auth, embeddingsEndpoint(cm, options)) + app.Post("/embeddings", auth, embeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", auth, embeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options)) - app.Post("/tts", ttsEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", auth, transcriptEndpoint(cm, options)) + app.Post("/tts", auth, ttsEndpoint(cm, options)) // images - app.Post("/v1/images/generations", imageEndpoint(cm, options)) + app.Post("/v1/images/generations", auth, imageEndpoint(cm, options)) if options.imageDir != "" { app.Static("/generated-images", options.imageDir) @@ -163,8 +184,8 @@ func App(opts ...AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // models - app.Get("/v1/models", listModels(options.loader, cm)) - app.Get("/models", listModels(options.loader, cm)) + app.Get("/v1/models", auth, listModels(options.loader, cm)) + app.Get("/models", auth, listModels(options.loader, cm)) return app, nil } diff --git a/api/options.go b/api/options.go index 923288ac..41ef62b8 100644 --- a/api/options.go +++ b/api/options.go @@ -23,6 +23,7 @@ type Option struct { preloadJSONModels string preloadModelsFromPath string corsAllowOrigins string + apiKey string galleries []gallery.Gallery @@ -167,3 +168,9 @@ func WithImageDir(imageDir string) AppOption { o.imageDir = imageDir } } + +func WithApiKey(apiKey string) AppOption { + return func(o *Option) { + o.apiKey = apiKey + } +} diff --git a/main.go b/main.go index fc1dea09..118364c3 100644 --- a/main.go +++ b/main.go @@ -110,6 +110,11 @@ func main() { EnvVars: []string{"UPLOAD_LIMIT"}, Value: 15, }, + &cli.StringFlag{ + Name: "api-key", + Usage: "API Key to enable API authentication. When this is set, all the requests must be authenticated with this API key.", + EnvVars: []string{"API_KEY"}, + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -145,7 +150,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit api.WithThreads(ctx.Int("threads")), api.WithBackendAssets(backendAssets), api.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - api.WithUploadLimitMB(ctx.Int("upload-limit"))) + api.WithUploadLimitMB(ctx.Int("upload-limit")), + api.WithApiKey(ctx.String("api-key")), + ) if err != nil { return err }