feat(elevenlabs): map elevenlabs API support to TTS

This allows elevenlabs Clients to work automatically with LocalAI by
supporting the elevenlabs API.

The elevenlabs server endpoint is implemented such as it is wired to the
TTS endpoints.

Fixes: https://github.com/mudler/LocalAI/issues/1809
This commit is contained in:
Ettore Di Giacinto 2024-03-14 18:06:23 +01:00
parent a6b540737f
commit d2934dd69f
24 changed files with 444 additions and 336 deletions

View file

@ -6,6 +6,7 @@ import (
"os"
"strings"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
@ -21,6 +22,24 @@ import (
"github.com/gofiber/fiber/v2/middleware/recover"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
@ -94,10 +113,12 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
return c.Next()
}
authHeader := c.Get("Authorization")
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
@ -111,7 +132,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if appConfig.CORS {
@ -147,6 +167,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
// openAI compatible API endpoint
// chat

View file

@ -0,0 +1,55 @@
package elevenlabs
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View file

@ -46,7 +46,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Backend = input.Backend
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
if err != nil {
return err
}