mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-25 13:04:59 +00:00
fix: drop racy code, refactor and group API schema (#931)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
28db83e17b
commit
cc060a283d
55 changed files with 239 additions and 317 deletions
33
pkg/backend/image/stablediffusion.go
Normal file
33
pkg/backend/image/stablediffusion.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package image
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||
)
|
||||
|
||||
type StableDiffusion struct {
|
||||
base.SingleThread
|
||||
stablediffusion *stablediffusion.StableDiffusion
|
||||
}
|
||||
|
||||
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error {
|
||||
var err error
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.stablediffusion, err = stablediffusion.New(opts.ModelFile)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
return sd.stablediffusion.GenerateImage(
|
||||
int(opts.Height),
|
||||
int(opts.Width),
|
||||
int(opts.Mode),
|
||||
int(opts.Step),
|
||||
int(opts.Seed),
|
||||
opts.PositivePrompt,
|
||||
opts.NegativePrompt,
|
||||
opts.Dst)
|
||||
}
|
34
pkg/backend/llm/bert/bert.go
Normal file
34
pkg/backend/llm/bert/bert.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package bert
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Embeddings struct {
|
||||
base.SingleThread
|
||||
bert *bert.Bert
|
||||
}
|
||||
|
||||
func (llm *Embeddings) Load(opts *pb.ModelOptions) error {
|
||||
model, err := bert.New(opts.ModelFile)
|
||||
llm.bert = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads)))
|
||||
}
|
||||
|
||||
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads)))
|
||||
}
|
59
pkg/backend/llm/bloomz/bloomz.go
Normal file
59
pkg/backend/llm/bloomz/bloomz.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package bloomz
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/go-skynet/bloomz.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
bloomz *bloomz.Bloomz
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
model, err := bloomz.New(opts.ModelFile)
|
||||
llm.bloomz = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption {
|
||||
predictOptions := []bloomz.PredictOption{
|
||||
bloomz.SetTemperature(float64(opts.Temperature)),
|
||||
bloomz.SetTopP(float64(opts.TopP)),
|
||||
bloomz.SetTopK(int(opts.TopK)),
|
||||
bloomz.SetTokens(int(opts.Tokens)),
|
||||
bloomz.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
145
pkg/backend/llm/falcon/falcon.go
Normal file
145
pkg/backend/llm/falcon/falcon.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package falcon
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
ggllm "github.com/mudler/go-ggllm.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
falcon *ggllm.Falcon
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
ggllmOpts := []ggllm.ModelOption{}
|
||||
if opts.ContextSize != 0 {
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
// F16 doesn't seem to produce good output at all!
|
||||
//if c.F16 {
|
||||
// llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
//}
|
||||
|
||||
if opts.NGPULayers != 0 {
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetMMap(opts.MMap))
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(opts.MainGPU))
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512))
|
||||
}
|
||||
|
||||
model, err := ggllm.New(opts.ModelFile, ggllmOpts...)
|
||||
llm.falcon = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption {
|
||||
predictOptions := []ggllm.PredictOption{
|
||||
ggllm.SetTemperature(float64(opts.Temperature)),
|
||||
ggllm.SetTopP(float64(opts.TopP)),
|
||||
ggllm.SetTopK(int(opts.TopK)),
|
||||
ggllm.SetTokens(int(opts.Tokens)),
|
||||
ggllm.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, ggllm.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetMirostatETA(float64(opts.MirostatETA)))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetMirostatTAU(float64(opts.MirostatTAU)))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, ggllm.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetPenalty(float64(opts.PresencePenalty)))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, ggllm.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, ggllm.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(float64(opts.FrequencyPenalty)))
|
||||
predictOptions = append(predictOptions, ggllm.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, ggllm.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ)))
|
||||
predictOptions = append(predictOptions, ggllm.SetTypicalP(float64(opts.TypicalP)))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
|
||||
if token == "<|endoftext|>" {
|
||||
return true
|
||||
}
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.falcon.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
62
pkg/backend/llm/gpt4all/gpt4all.go
Normal file
62
pkg/backend/llm/gpt4all/gpt4all.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package gpt4all
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
gpt4all *gpt4all.Model
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
model, err := gpt4all.New(opts.ModelFile,
|
||||
gpt4all.SetThreads(int(opts.Threads)),
|
||||
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath))
|
||||
llm.gpt4all = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption {
|
||||
predictOptions := []gpt4all.PredictOption{
|
||||
gpt4all.SetTemperature(float64(opts.Temperature)),
|
||||
gpt4all.SetTopP(float64(opts.TopP)),
|
||||
gpt4all.SetTopK(int(opts.TopK)),
|
||||
gpt4all.SetTokens(int(opts.Tokens)),
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
go func() {
|
||||
llm.gpt4all.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
})
|
||||
_, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
llm.gpt4all.SetTokenCallback(nil)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
58
pkg/backend/llm/langchain/langchain.go
Normal file
58
pkg/backend/llm/langchain/langchain.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package langchain
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.Base
|
||||
|
||||
langchain *langchain.HuggingFace
|
||||
model string
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
llm.langchain, _ = langchain.NewHuggingFace(opts.Model)
|
||||
llm.model = opts.Model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
o := []langchain.PredictOption{
|
||||
langchain.SetModel(llm.model),
|
||||
langchain.SetMaxTokens(int(opts.Tokens)),
|
||||
langchain.SetTemperature(float64(opts.Temperature)),
|
||||
langchain.SetStopWords(opts.StopPrompts),
|
||||
}
|
||||
pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pred.Completion, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
o := []langchain.PredictOption{
|
||||
langchain.SetModel(llm.model),
|
||||
langchain.SetMaxTokens(int(opts.Tokens)),
|
||||
langchain.SetTemperature(float64(opts.Temperature)),
|
||||
langchain.SetStopWords(opts.StopPrompts),
|
||||
}
|
||||
go func() {
|
||||
res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res.Completion
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
216
pkg/backend/llm/llama/llama.go
Normal file
216
pkg/backend/llm/llama/llama.go
Normal file
|
@ -0,0 +1,216 @@
|
|||
package llama
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
llama *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
|
||||
llamaOpts := []llama.ModelOption{
|
||||
llama.WithRopeFreqBase(ropeFreqBase),
|
||||
llama.WithRopeFreqScale(ropeFreqScale),
|
||||
}
|
||||
|
||||
if opts.NGQA != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.WithGQA(int(opts.NGQA)))
|
||||
}
|
||||
|
||||
if opts.RMSNormEps != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.WithRMSNormEPS(opts.RMSNormEps))
|
||||
}
|
||||
|
||||
if opts.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
if opts.F16Memory {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if opts.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if opts.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.ModelFile, llamaOpts...)
|
||||
llm.llama = model
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(opts.Temperature),
|
||||
llama.SetTopP(opts.TopP),
|
||||
llama.SetTopK(int(opts.TopK)),
|
||||
llama.SetTokens(int(opts.Tokens)),
|
||||
llama.SetThreads(int(opts.Threads)),
|
||||
llama.WithGrammar(opts.Grammar),
|
||||
llama.SetRopeFreqBase(ropeFreqBase),
|
||||
llama.SetRopeFreqScale(ropeFreqScale),
|
||||
llama.SetNegativePromptScale(opts.NegativePromptScale),
|
||||
llama.SetNegativePrompt(opts.NegativePrompt),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(opts.MirostatETA))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(opts.MirostatTAU))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(opts.PresencePenalty))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.F16KV {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(opts.TailFreeSamplingZ))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(opts.TypicalP))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
|
||||
}
|
||||
|
||||
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
return pb.TokenizationResponse{}, err
|
||||
}
|
||||
return pb.TokenizationResponse{
|
||||
Length: l,
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
}
|
70
pkg/backend/llm/rwkv/rwkv.go
Normal file
70
pkg/backend/llm/rwkv/rwkv.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package rwkv
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
rwkv *rwkv.RwkvState
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
modelPath := filepath.Dir(opts.ModelFile)
|
||||
modelFile := filepath.Base(opts.ModelFile)
|
||||
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("could not load model")
|
||||
}
|
||||
llm.rwkv = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
stopWord := "\n"
|
||||
if len(opts.StopPrompts) > 0 {
|
||||
stopWord = opts.StopPrompts[0]
|
||||
}
|
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
|
||||
stopWord := "\n"
|
||||
if len(opts.StopPrompts) > 0 {
|
||||
stopWord = opts.StopPrompts[0]
|
||||
}
|
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
|
||||
fmt.Println("Error processing input: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool {
|
||||
results <- s
|
||||
return true
|
||||
})
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
44
pkg/backend/llm/transformers/dolly.go
Normal file
44
pkg/backend/llm/transformers/dolly.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Dolly struct {
|
||||
base.SingleThread
|
||||
|
||||
dolly *transformers.Dolly
|
||||
}
|
||||
|
||||
func (llm *Dolly) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewDolly(opts.ModelFile)
|
||||
llm.dolly = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
|
||||
go func() {
|
||||
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
43
pkg/backend/llm/transformers/falcon.go
Normal file
43
pkg/backend/llm/transformers/falcon.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Falcon struct {
|
||||
base.SingleThread
|
||||
|
||||
falcon *transformers.Falcon
|
||||
}
|
||||
|
||||
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewFalcon(opts.ModelFile)
|
||||
llm.falcon = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
42
pkg/backend/llm/transformers/gpt2.go
Normal file
42
pkg/backend/llm/transformers/gpt2.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPT2 struct {
|
||||
base.SingleThread
|
||||
|
||||
gpt2 *transformers.GPT2
|
||||
}
|
||||
|
||||
func (llm *GPT2) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.New(opts.ModelFile)
|
||||
llm.gpt2 = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
42
pkg/backend/llm/transformers/gptj.go
Normal file
42
pkg/backend/llm/transformers/gptj.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTJ struct {
|
||||
base.SingleThread
|
||||
|
||||
gptj *transformers.GPTJ
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewGPTJ(opts.ModelFile)
|
||||
llm.gptj = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
42
pkg/backend/llm/transformers/gptneox.go
Normal file
42
pkg/backend/llm/transformers/gptneox.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTNeoX struct {
|
||||
base.SingleThread
|
||||
|
||||
gptneox *transformers.GPTNeoX
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewGPTNeoX(opts.ModelFile)
|
||||
llm.gptneox = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
42
pkg/backend/llm/transformers/mpt.go
Normal file
42
pkg/backend/llm/transformers/mpt.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type MPT struct {
|
||||
base.SingleThread
|
||||
|
||||
mpt *transformers.MPT
|
||||
}
|
||||
|
||||
func (llm *MPT) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewMPT(opts.ModelFile)
|
||||
llm.mpt = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
26
pkg/backend/llm/transformers/predict.go
Normal file
26
pkg/backend/llm/transformers/predict.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package transformers
|
||||
|
||||
import (
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption {
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(float64(opts.Temperature)),
|
||||
transformers.SetTopP(float64(opts.TopP)),
|
||||
transformers.SetTopK(int(opts.TopK)),
|
||||
transformers.SetTokens(int(opts.Tokens)),
|
||||
transformers.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
return predictOptions
|
||||
}
|
42
pkg/backend/llm/transformers/replit.go
Normal file
42
pkg/backend/llm/transformers/replit.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Replit struct {
|
||||
base.SingleThread
|
||||
|
||||
replit *transformers.Replit
|
||||
}
|
||||
|
||||
func (llm *Replit) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewReplit(opts.ModelFile)
|
||||
llm.replit = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
43
pkg/backend/llm/transformers/starcoder.go
Normal file
43
pkg/backend/llm/transformers/starcoder.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Starcoder struct {
|
||||
base.SingleThread
|
||||
|
||||
starcoder *transformers.Starcoder
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewStarcoder(opts.ModelFile)
|
||||
llm.starcoder = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
100
pkg/backend/transcribe/transcript.go
Normal file
100
pkg/backend/transcribe/transcript.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package transcribe
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
)
|
||||
|
||||
func sh(c string) (string, error) {
|
||||
cmd := exec.Command("/bin/sh", "-c", c)
|
||||
cmd.Env = os.Environ()
|
||||
o, err := cmd.CombinedOutput()
|
||||
return string(o), err
|
||||
}
|
||||
|
||||
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
|
||||
// TODO: use https://github.com/mccoyst/ogg?
|
||||
func audioToWav(src, dst string) error {
|
||||
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error: %w out: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) {
|
||||
res := schema.Result{}
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := audioToWav(audiopath, convertedPath); err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return res, err
|
||||
|
||||
}
|
||||
|
||||
context.SetThreads(threads)
|
||||
|
||||
if language != "" {
|
||||
context.SetLanguage(language)
|
||||
} else {
|
||||
context.SetLanguage("auto")
|
||||
}
|
||||
|
||||
if err := context.Process(data, nil, nil); err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
for {
|
||||
s, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var tokens []int
|
||||
for _, t := range s.Tokens {
|
||||
tokens = append(tokens, t.Id)
|
||||
}
|
||||
|
||||
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
||||
res.Segments = append(res.Segments, segment)
|
||||
|
||||
res.Text += s.Text
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
26
pkg/backend/transcribe/whisper.go
Normal file
26
pkg/backend/transcribe/whisper.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package transcribe
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
base.SingleThread
|
||||
whisper whisper.Model
|
||||
}
|
||||
|
||||
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
w, err := whisper.New(opts.ModelFile)
|
||||
sd.whisper = w
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) {
|
||||
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
||||
}
|
49
pkg/backend/tts/piper.go
Normal file
49
pkg/backend/tts/piper.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package tts
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
piper "github.com/mudler/go-piper"
|
||||
)
|
||||
|
||||
type Piper struct {
|
||||
base.SingleThread
|
||||
piper *PiperB
|
||||
}
|
||||
|
||||
func (sd *Piper) Load(opts *pb.ModelOptions) error {
|
||||
if filepath.Ext(opts.ModelFile) != ".onnx" {
|
||||
return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.ModelFile)
|
||||
}
|
||||
var err error
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.piper, err = New(opts.LibrarySearchPath)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Piper) TTS(opts *pb.TTSRequest) error {
|
||||
return sd.piper.TTS(opts.Text, opts.Model, opts.Dst)
|
||||
}
|
||||
|
||||
type PiperB struct {
|
||||
assetDir string
|
||||
}
|
||||
|
||||
func New(assetDir string) (*PiperB, error) {
|
||||
if _, err := os.Stat(assetDir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PiperB{
|
||||
assetDir: assetDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PiperB) TTS(text, model, dst string) error {
|
||||
return piper.TextToWav(text, model, s.assetDir, "", dst)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue