feat: add tts with go-piper (#649)

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-06-22 17:53:10 +02:00 committed by GitHub
parent cc31c58235
commit a7bb029d23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 237 additions and 29 deletions

View file

@ -10,6 +10,7 @@ import (
"path/filepath"
"strconv"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
@ -80,21 +81,6 @@ func ReadConfigFile(filePath string) (*Config, error) {
return &config, nil
}
func inTrustedRoot(path string, trustedRoot string) error {
for path != "/" {
path = filepath.Dir(path)
if path == trustedRoot {
return nil
}
}
return fmt.Errorf("path is outside of trusted root")
}
func verifyPath(path, basePath string) error {
c := filepath.Clean(filepath.Join(basePath, path))
return inTrustedRoot(c, basePath)
}
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
@ -110,7 +96,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
for _, file := range config.Files {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := verifyPath(file.Filename, basePath); err != nil {
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
return err
}
// Create file path
@ -196,7 +182,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
// Write prompt template contents to separate files
for _, template := range config.PromptTemplates {
if err := verifyPath(template.Name+".tmpl", basePath); err != nil {
if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil {
return err
}
// Create file path
@ -221,7 +207,7 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
name = nameOverride
}
if err := verifyPath(name+".yaml", basePath); err != nil {
if err := utils.VerifyPath(name+".yaml", basePath); err != nil {
return err
}

View file

@ -9,6 +9,7 @@ import (
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
"github.com/go-skynet/LocalAI/pkg/tts"
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
@ -39,6 +40,7 @@ const (
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
StableDiffusionBackend = "stablediffusion"
PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface"
)
@ -103,6 +105,12 @@ var stableDiffusion = func(assetDir string) (interface{}, error) {
return stablediffusion.New(assetDir)
}
func piperTTS(assetDir string) func(s string) (interface{}, error) {
return func(s string) (interface{}, error) {
return tts.New(assetDir)
}
}
var whisperModel = func(modelFile string) (interface{}, error) {
return whisper.New(modelFile)
}
@ -158,6 +166,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, replit)
case StableDiffusionBackend:
return ml.LoadModel(modelFile, stableDiffusion)
case PiperBackend:
return ml.LoadModel(modelFile, piperTTS(filepath.Join(assetDir, "backend-assets", "espeak-ng-data")))
case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:

12
pkg/tts/generate.go Normal file
View file

@ -0,0 +1,12 @@
//go:build tts
// +build tts
package tts
import (
piper "github.com/mudler/go-piper"
)
func tts(text, model, assetDir, arLib, dst string) error {
return piper.TextToWav(text, model, assetDir, arLib, dst)
}

View file

@ -0,0 +1,10 @@
//go:build !tts
// +build !tts
package tts
import "fmt"
func tts(text, model, assetDir, arLib, dst string) error {
return fmt.Errorf("this version of LocalAI was built without the tts tag")
}

20
pkg/tts/piper.go Normal file
View file

@ -0,0 +1,20 @@
package tts
import "os"
type Piper struct {
assetDir string
}
func New(assetDir string) (*Piper, error) {
if _, err := os.Stat(assetDir); err != nil {
return nil, err
}
return &Piper{
assetDir: assetDir,
}, nil
}
func (s *Piper) TTS(text, model, dst string) error {
return tts(text, model, s.assetDir, "", dst)
}

22
pkg/utils/path.go Normal file
View file

@ -0,0 +1,22 @@
package utils
import (
"fmt"
"path/filepath"
)
func inTrustedRoot(path string, trustedRoot string) error {
for path != "/" {
path = filepath.Dir(path)
if path == trustedRoot {
return nil
}
}
return fmt.Errorf("path is outside of trusted root")
}
// VerifyPath verifies that path is based in basePath.
func VerifyPath(path, basePath string) error {
c := filepath.Clean(filepath.Join(basePath, path))
return inTrustedRoot(c, basePath)
}