feat(tts): add Elevenlabs and OpenAI TTS compatibility layer (#1834)

* feat(elevenlabs): map elevenlabs API support to TTS

This allows elevenlabs Clients to work automatically with LocalAI by
supporting the elevenlabs API.

The elevenlabs server endpoint is implemented such as it is wired to the
TTS endpoints.

Fixes: https://github.com/mudler/LocalAI/issues/1809

* feat(openai/tts): compat layer with openai tts

Fixes: #1276

* fix: adapt tts CLI
This commit is contained in:
Ettore Di Giacinto 2024-03-14 23:08:34 +01:00 committed by GitHub
parent 45d520f913
commit 20136ca8b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 454 additions and 338 deletions

View file

@ -29,7 +29,7 @@ func generateUniqueFileName(dir, baseName, ext string) string {
}
}
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
@ -44,12 +44,12 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appCon
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
piperModel, err := loader.BackendLoader(opts...)
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
@ -57,25 +57,31 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appCon
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav")
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
}
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
})

View file

@ -6,6 +6,7 @@ import (
"os"
"strings"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
@ -21,6 +22,24 @@ import (
"github.com/gofiber/fiber/v2/middleware/recover"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
@ -94,10 +113,12 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
return c.Next()
}
authHeader := c.Get("Authorization")
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
@ -111,7 +132,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if appConfig.CORS {
@ -147,6 +167,11 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
// openAI compatible API endpoint
// chat
@ -181,7 +206,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))

View file

@ -0,0 +1,55 @@
package elevenlabs
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View file

@ -46,7 +46,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Backend = input.Backend
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -0,0 +1,6 @@
package schema
type ElevenLabsTTSRequest struct {
Text string `json:"text" yaml:"text"`
ModelID string `json:"model_id" yaml:"model_id"`
}

View file

@ -17,5 +17,6 @@ type BackendMonitorResponse struct {
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Voice string `json:"voice" yaml:"voice"`
Backend string `json:"backend" yaml:"backend"`
}