feat(silero): add Silero-vad backend (#4204)

* feat(vad): add silero-vad backend (WIP)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(vad): add API endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(vad): correctly place the onnxruntime libs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(vad): hook silero-vad to binary and container builds

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(gRPC): register VAD Server

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(Makefile): consume ONNX_OS consistently

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(Makefile): handle macOS

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Ettore Di Giacinto 2024-11-20 14:48:40 +01:00 committed by GitHub
parent 9892d7d584
commit b1ea9318e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 255 additions and 1 deletions

View file

@ -0,0 +1,68 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// VADEndpoint is Voice-Activation-Detection endpoint
// @Summary Detect voice fragments in an audio stream
// @Accept json
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.VADRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}
return c.JSON(resp)
}
}

View file

@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
}
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")

View file

@ -30,10 +30,16 @@ type TTSRequest struct {
Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"`
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
}
// @Description VAD request body
type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
}
type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`