mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-27 14:05:00 +00:00
refactor: move backends into the backends directory (#1279)
* refactor: move backends into the backends directory Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor: move main close to implementation for every backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
55461188a4
commit
ad0e30bca5
102 changed files with 156 additions and 190 deletions
|
@ -1,33 +0,0 @@
|
|||
package image
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||
)
|
||||
|
||||
type StableDiffusion struct {
|
||||
base.SingleThread
|
||||
stablediffusion *stablediffusion.StableDiffusion
|
||||
}
|
||||
|
||||
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error {
|
||||
var err error
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.stablediffusion, err = stablediffusion.New(opts.ModelFile)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
return sd.stablediffusion.GenerateImage(
|
||||
int(opts.Height),
|
||||
int(opts.Width),
|
||||
int(opts.Mode),
|
||||
int(opts.Step),
|
||||
int(opts.Seed),
|
||||
opts.PositivePrompt,
|
||||
opts.NegativePrompt,
|
||||
opts.Dst)
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
package bert
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Embeddings struct {
|
||||
base.SingleThread
|
||||
bert *bert.Bert
|
||||
}
|
||||
|
||||
func (llm *Embeddings) Load(opts *pb.ModelOptions) error {
|
||||
model, err := bert.New(opts.ModelFile)
|
||||
llm.bert = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads)))
|
||||
}
|
||||
|
||||
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads)))
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
package gpt4all
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
gpt4all *gpt4all.Model
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
model, err := gpt4all.New(opts.ModelFile,
|
||||
gpt4all.SetThreads(int(opts.Threads)),
|
||||
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath))
|
||||
llm.gpt4all = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption {
|
||||
predictOptions := []gpt4all.PredictOption{
|
||||
gpt4all.SetTemperature(float64(opts.Temperature)),
|
||||
gpt4all.SetTopP(float64(opts.TopP)),
|
||||
gpt4all.SetTopK(int(opts.TopK)),
|
||||
gpt4all.SetTokens(int(opts.Tokens)),
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
go func() {
|
||||
llm.gpt4all.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
})
|
||||
_, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
llm.gpt4all.SetTokenCallback(nil)
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
package langchain
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.Base
|
||||
|
||||
langchain *langchain.HuggingFace
|
||||
model string
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
llm.langchain, _ = langchain.NewHuggingFace(opts.Model)
|
||||
llm.model = opts.Model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
o := []langchain.PredictOption{
|
||||
langchain.SetModel(llm.model),
|
||||
langchain.SetMaxTokens(int(opts.Tokens)),
|
||||
langchain.SetTemperature(float64(opts.Temperature)),
|
||||
langchain.SetStopWords(opts.StopPrompts),
|
||||
}
|
||||
pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pred.Completion, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
o := []langchain.PredictOption{
|
||||
langchain.SetModel(llm.model),
|
||||
langchain.SetMaxTokens(int(opts.Tokens)),
|
||||
langchain.SetTemperature(float64(opts.Temperature)),
|
||||
langchain.SetStopWords(opts.StopPrompts),
|
||||
}
|
||||
go func() {
|
||||
res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res.Completion
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,204 +0,0 @@
|
|||
package llama
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
llama *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
|
||||
llamaOpts := []llama.ModelOption{
|
||||
llama.WithRopeFreqBase(ropeFreqBase),
|
||||
llama.WithRopeFreqScale(ropeFreqScale),
|
||||
}
|
||||
|
||||
if opts.NGQA != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.WithGQA(int(opts.NGQA)))
|
||||
}
|
||||
|
||||
if opts.RMSNormEps != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.WithRMSNormEPS(opts.RMSNormEps))
|
||||
}
|
||||
|
||||
if opts.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
if opts.F16Memory {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if opts.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if opts.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.ModelFile, llamaOpts...)
|
||||
llm.llama = model
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(opts.Temperature),
|
||||
llama.SetTopP(opts.TopP),
|
||||
llama.SetTopK(int(opts.TopK)),
|
||||
llama.SetTokens(int(opts.Tokens)),
|
||||
llama.SetThreads(int(opts.Threads)),
|
||||
llama.WithGrammar(opts.Grammar),
|
||||
llama.SetRopeFreqBase(ropeFreqBase),
|
||||
llama.SetRopeFreqScale(ropeFreqScale),
|
||||
llama.SetNegativePromptScale(opts.NegativePromptScale),
|
||||
llama.SetNegativePrompt(opts.NegativePrompt),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(opts.MirostatETA))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(opts.MirostatTAU))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(opts.PresencePenalty))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.F16KV {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(opts.TailFreeSamplingZ))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(opts.TypicalP))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
|
||||
}
|
|
@ -1,257 +0,0 @@
|
|||
package llama
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
llama *llama.LLama
|
||||
draftModel *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
|
||||
llamaOpts := []llama.ModelOption{
|
||||
llama.WithRopeFreqBase(ropeFreqBase),
|
||||
llama.WithRopeFreqScale(ropeFreqScale),
|
||||
}
|
||||
|
||||
if opts.NoMulMatQ {
|
||||
llamaOpts = append(llamaOpts, llama.SetMulMatQ(false))
|
||||
}
|
||||
|
||||
// Get base path of opts.ModelFile and use the same for lora (assume the same path)
|
||||
basePath := filepath.Dir(opts.ModelFile)
|
||||
|
||||
if opts.LoraAdapter != "" {
|
||||
llamaOpts = append(llamaOpts, llama.SetLoraAdapter(filepath.Join(basePath, opts.LoraAdapter)))
|
||||
}
|
||||
|
||||
if opts.LoraBase != "" {
|
||||
llamaOpts = append(llamaOpts, llama.SetLoraBase(filepath.Join(basePath, opts.LoraBase)))
|
||||
}
|
||||
|
||||
if opts.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
if opts.F16Memory {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if opts.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if opts.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
if opts.DraftModel != "" {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40
|
||||
llamaOpts = append(llamaOpts, llama.SetPerplexity(true))
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.ModelFile, llamaOpts...)
|
||||
|
||||
if opts.DraftModel != "" {
|
||||
// opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile
|
||||
if !filepath.IsAbs(opts.DraftModel) {
|
||||
dir := filepath.Dir(opts.ModelFile)
|
||||
opts.DraftModel = filepath.Join(dir, opts.DraftModel)
|
||||
}
|
||||
|
||||
draftModel, err := llama.New(opts.DraftModel, llamaOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
llm.draftModel = draftModel
|
||||
}
|
||||
|
||||
llm.llama = model
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
ropeFreqBase := float32(10000)
|
||||
ropeFreqScale := float32(1)
|
||||
|
||||
if opts.RopeFreqBase != 0 {
|
||||
ropeFreqBase = opts.RopeFreqBase
|
||||
}
|
||||
if opts.RopeFreqScale != 0 {
|
||||
ropeFreqScale = opts.RopeFreqScale
|
||||
}
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(opts.Temperature),
|
||||
llama.SetTopP(opts.TopP),
|
||||
llama.SetTopK(int(opts.TopK)),
|
||||
llama.SetTokens(int(opts.Tokens)),
|
||||
llama.SetThreads(int(opts.Threads)),
|
||||
llama.WithGrammar(opts.Grammar),
|
||||
llama.SetRopeFreqBase(ropeFreqBase),
|
||||
llama.SetRopeFreqScale(ropeFreqScale),
|
||||
llama.SetNegativePromptScale(opts.NegativePromptScale),
|
||||
llama.SetNegativePrompt(opts.NegativePrompt),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(opts.MirostatETA))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(opts.MirostatTAU))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(opts.PresencePenalty))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.F16KV {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
if opts.NDraft != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft)))
|
||||
}
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(opts.TailFreeSamplingZ))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(opts.TypicalP))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
if llm.draftModel != nil {
|
||||
return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
if llm.draftModel != nil {
|
||||
_, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
|
||||
} else {
|
||||
_, err = llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
|
||||
}
|
||||
|
||||
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
return pb.TokenizationResponse{}, err
|
||||
}
|
||||
return pb.TokenizationResponse{
|
||||
Length: l,
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
package rwkv
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
rwkv *rwkv.RwkvState
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
tokenizerFile := opts.Tokenizer
|
||||
if tokenizerFile == "" {
|
||||
modelFile := filepath.Base(opts.ModelFile)
|
||||
tokenizerFile = modelFile + tokenizerSuffix
|
||||
}
|
||||
modelPath := filepath.Dir(opts.ModelFile)
|
||||
tokenizerPath := filepath.Join(modelPath, tokenizerFile)
|
||||
|
||||
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("could not load model")
|
||||
}
|
||||
llm.rwkv = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
stopWord := "\n"
|
||||
if len(opts.StopPrompts) > 0 {
|
||||
stopWord = opts.StopPrompts[0]
|
||||
}
|
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
|
||||
stopWord := "\n"
|
||||
if len(opts.StopPrompts) > 0 {
|
||||
stopWord = opts.StopPrompts[0]
|
||||
}
|
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
|
||||
fmt.Println("Error processing input: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool {
|
||||
results <- s
|
||||
return true
|
||||
})
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
tokens, err := llm.rwkv.Tokenizer.Encode(opts.Prompt)
|
||||
if err != nil {
|
||||
return pb.TokenizationResponse{}, err
|
||||
}
|
||||
|
||||
l := len(tokens)
|
||||
i32Tokens := make([]int32, l)
|
||||
|
||||
for i, t := range tokens {
|
||||
i32Tokens[i] = int32(t.ID)
|
||||
}
|
||||
|
||||
return pb.TokenizationResponse{
|
||||
Length: int32(l),
|
||||
Tokens: i32Tokens,
|
||||
}, nil
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Dolly struct {
|
||||
base.SingleThread
|
||||
|
||||
dolly *transformers.Dolly
|
||||
}
|
||||
|
||||
func (llm *Dolly) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewDolly(opts.ModelFile)
|
||||
llm.dolly = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
|
||||
go func() {
|
||||
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Falcon struct {
|
||||
base.SingleThread
|
||||
|
||||
falcon *transformers.Falcon
|
||||
}
|
||||
|
||||
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewFalcon(opts.ModelFile)
|
||||
llm.falcon = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPT2 struct {
|
||||
base.SingleThread
|
||||
|
||||
gpt2 *transformers.GPT2
|
||||
}
|
||||
|
||||
func (llm *GPT2) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.New(opts.ModelFile)
|
||||
llm.gpt2 = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTJ struct {
|
||||
base.SingleThread
|
||||
|
||||
gptj *transformers.GPTJ
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewGPTJ(opts.ModelFile)
|
||||
llm.gptj = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTNeoX struct {
|
||||
base.SingleThread
|
||||
|
||||
gptneox *transformers.GPTNeoX
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewGPTNeoX(opts.ModelFile)
|
||||
llm.gptneox = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type MPT struct {
|
||||
base.SingleThread
|
||||
|
||||
mpt *transformers.MPT
|
||||
}
|
||||
|
||||
func (llm *MPT) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewMPT(opts.ModelFile)
|
||||
llm.mpt = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package transformers
|
||||
|
||||
import (
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption {
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(float64(opts.Temperature)),
|
||||
transformers.SetTopP(float64(opts.TopP)),
|
||||
transformers.SetTopK(int(opts.TopK)),
|
||||
transformers.SetTokens(int(opts.Tokens)),
|
||||
transformers.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
return predictOptions
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Replit struct {
|
||||
base.SingleThread
|
||||
|
||||
replit *transformers.Replit
|
||||
}
|
||||
|
||||
func (llm *Replit) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewReplit(opts.ModelFile)
|
||||
llm.replit = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Starcoder struct {
|
||||
base.SingleThread
|
||||
|
||||
starcoder *transformers.Starcoder
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewStarcoder(opts.ModelFile)
|
||||
llm.starcoder = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package transcribe
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
)
|
||||
|
||||
func sh(c string) (string, error) {
|
||||
cmd := exec.Command("/bin/sh", "-c", c)
|
||||
cmd.Env = os.Environ()
|
||||
o, err := cmd.CombinedOutput()
|
||||
return string(o), err
|
||||
}
|
||||
|
||||
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
|
||||
// TODO: use https://github.com/mccoyst/ogg?
|
||||
func audioToWav(src, dst string) error {
|
||||
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error: %w out: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) {
|
||||
res := schema.Result{}
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
convertedPath := filepath.Join(dir, "converted.wav")
|
||||
|
||||
if err := audioToWav(audiopath, convertedPath); err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
return res, err
|
||||
|
||||
}
|
||||
|
||||
context.SetThreads(threads)
|
||||
|
||||
if language != "" {
|
||||
context.SetLanguage(language)
|
||||
} else {
|
||||
context.SetLanguage("auto")
|
||||
}
|
||||
|
||||
if err := context.Process(data, nil, nil); err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
for {
|
||||
s, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var tokens []int
|
||||
for _, t := range s.Tokens {
|
||||
tokens = append(tokens, t.Id)
|
||||
}
|
||||
|
||||
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
||||
res.Segments = append(res.Segments, segment)
|
||||
|
||||
res.Text += s.Text
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package transcribe
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
base.SingleThread
|
||||
whisper whisper.Model
|
||||
}
|
||||
|
||||
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
w, err := whisper.New(opts.ModelFile)
|
||||
sd.whisper = w
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) {
|
||||
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
package tts
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
piper "github.com/mudler/go-piper"
|
||||
)
|
||||
|
||||
type Piper struct {
|
||||
base.SingleThread
|
||||
piper *PiperB
|
||||
}
|
||||
|
||||
func (sd *Piper) Load(opts *pb.ModelOptions) error {
|
||||
if filepath.Ext(opts.ModelFile) != ".onnx" {
|
||||
return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.ModelFile)
|
||||
}
|
||||
var err error
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.piper, err = New(opts.LibrarySearchPath)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Piper) TTS(opts *pb.TTSRequest) error {
|
||||
return sd.piper.TTS(opts.Text, opts.Model, opts.Dst)
|
||||
}
|
||||
|
||||
type PiperB struct {
|
||||
assetDir string
|
||||
}
|
||||
|
||||
func New(assetDir string) (*PiperB, error) {
|
||||
if _, err := os.Stat(assetDir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PiperB{
|
||||
assetDir: assetDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PiperB) TTS(text, model, dst string) error {
|
||||
return piper.TextToWav(text, model, s.assetDir, "", dst)
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v4.23.4
|
||||
// source: pkg/grpc/proto/backend.proto
|
||||
// source: backend/backend.proto
|
||||
|
||||
package proto
|
||||
|
||||
|
|
|
@ -1,208 +0,0 @@
|
|||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "io.skynet.localai.backend";
|
||||
option java_outer_classname = "LocalAIBackend";
|
||||
|
||||
package backend;
|
||||
|
||||
service Backend {
|
||||
rpc Health(HealthMessage) returns (Reply) {}
|
||||
rpc Predict(PredictOptions) returns (Reply) {}
|
||||
rpc LoadModel(ModelOptions) returns (Result) {}
|
||||
rpc PredictStream(PredictOptions) returns (stream Reply) {}
|
||||
rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
|
||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
}
|
||||
|
||||
message HealthMessage {}
|
||||
|
||||
// The request message containing the user's name.
|
||||
message PredictOptions {
|
||||
string Prompt = 1;
|
||||
int32 Seed = 2;
|
||||
int32 Threads = 3;
|
||||
int32 Tokens = 4;
|
||||
int32 TopK = 5;
|
||||
int32 Repeat = 6;
|
||||
int32 Batch = 7;
|
||||
int32 NKeep = 8;
|
||||
float Temperature = 9;
|
||||
float Penalty = 10;
|
||||
bool F16KV = 11;
|
||||
bool DebugMode = 12;
|
||||
repeated string StopPrompts = 13;
|
||||
bool IgnoreEOS = 14;
|
||||
float TailFreeSamplingZ = 15;
|
||||
float TypicalP = 16;
|
||||
float FrequencyPenalty = 17;
|
||||
float PresencePenalty = 18;
|
||||
int32 Mirostat = 19;
|
||||
float MirostatETA = 20;
|
||||
float MirostatTAU = 21;
|
||||
bool PenalizeNL = 22;
|
||||
string LogitBias = 23;
|
||||
bool MLock = 25;
|
||||
bool MMap = 26;
|
||||
bool PromptCacheAll = 27;
|
||||
bool PromptCacheRO = 28;
|
||||
string Grammar = 29;
|
||||
string MainGPU = 30;
|
||||
string TensorSplit = 31;
|
||||
float TopP = 32;
|
||||
string PromptCachePath = 33;
|
||||
bool Debug = 34;
|
||||
repeated int32 EmbeddingTokens = 35;
|
||||
string Embeddings = 36;
|
||||
float RopeFreqBase = 37;
|
||||
float RopeFreqScale = 38;
|
||||
float NegativePromptScale = 39;
|
||||
string NegativePrompt = 40;
|
||||
int32 NDraft = 41;
|
||||
repeated string Images = 42;
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
message Reply {
|
||||
bytes message = 1;
|
||||
}
|
||||
|
||||
message ModelOptions {
|
||||
string Model = 1;
|
||||
int32 ContextSize = 2;
|
||||
int32 Seed = 3;
|
||||
int32 NBatch = 4;
|
||||
bool F16Memory = 5;
|
||||
bool MLock = 6;
|
||||
bool MMap = 7;
|
||||
bool VocabOnly = 8;
|
||||
bool LowVRAM = 9;
|
||||
bool Embeddings = 10;
|
||||
bool NUMA = 11;
|
||||
int32 NGPULayers = 12;
|
||||
string MainGPU = 13;
|
||||
string TensorSplit = 14;
|
||||
int32 Threads = 15;
|
||||
string LibrarySearchPath = 16;
|
||||
float RopeFreqBase = 17;
|
||||
float RopeFreqScale = 18;
|
||||
float RMSNormEps = 19;
|
||||
int32 NGQA = 20;
|
||||
string ModelFile = 21;
|
||||
|
||||
// AutoGPTQ
|
||||
string Device = 22;
|
||||
bool UseTriton = 23;
|
||||
string ModelBaseName = 24;
|
||||
bool UseFastTokenizer = 25;
|
||||
|
||||
// Diffusers
|
||||
string PipelineType = 26;
|
||||
string SchedulerType = 27;
|
||||
bool CUDA = 28;
|
||||
float CFGScale = 29;
|
||||
bool IMG2IMG = 30;
|
||||
string CLIPModel = 31;
|
||||
string CLIPSubfolder = 32;
|
||||
int32 CLIPSkip = 33;
|
||||
|
||||
// RWKV
|
||||
string Tokenizer = 34;
|
||||
|
||||
// LLM (llama.cpp)
|
||||
string LoraBase = 35;
|
||||
string LoraAdapter = 36;
|
||||
float LoraScale = 42;
|
||||
|
||||
bool NoMulMatQ = 37;
|
||||
string DraftModel = 39;
|
||||
|
||||
string AudioPath = 38;
|
||||
|
||||
// vllm
|
||||
string Quantization = 40;
|
||||
|
||||
string MMProj = 41;
|
||||
|
||||
string RopeScaling = 43;
|
||||
float YarnExtFactor = 44;
|
||||
float YarnAttnFactor = 45;
|
||||
float YarnBetaFast = 46;
|
||||
float YarnBetaSlow = 47;
|
||||
}
|
||||
|
||||
message Result {
|
||||
string message = 1;
|
||||
bool success = 2;
|
||||
}
|
||||
|
||||
message EmbeddingResult {
|
||||
repeated float embeddings = 1;
|
||||
}
|
||||
|
||||
message TranscriptRequest {
|
||||
string dst = 2;
|
||||
string language = 3;
|
||||
uint32 threads = 4;
|
||||
}
|
||||
|
||||
message TranscriptResult {
|
||||
repeated TranscriptSegment segments = 1;
|
||||
string text = 2;
|
||||
}
|
||||
|
||||
message TranscriptSegment {
|
||||
int32 id = 1;
|
||||
int64 start = 2;
|
||||
int64 end = 3;
|
||||
string text = 4;
|
||||
repeated int32 tokens = 5;
|
||||
}
|
||||
|
||||
message GenerateImageRequest {
|
||||
int32 height = 1;
|
||||
int32 width = 2;
|
||||
int32 mode = 3;
|
||||
int32 step = 4;
|
||||
int32 seed = 5;
|
||||
string positive_prompt = 6;
|
||||
string negative_prompt = 7;
|
||||
string dst = 8;
|
||||
string src = 9;
|
||||
|
||||
// Diffusers
|
||||
string EnableParameters = 10;
|
||||
int32 CLIPSkip = 11;
|
||||
}
|
||||
|
||||
message TTSRequest {
|
||||
string text = 1;
|
||||
string model = 2;
|
||||
string dst = 3;
|
||||
}
|
||||
|
||||
message TokenizationResponse {
|
||||
int32 length = 1;
|
||||
repeated int32 tokens = 2;
|
||||
}
|
||||
|
||||
message MemoryUsageData {
|
||||
uint64 total = 1;
|
||||
map<string, uint64> breakdown = 2;
|
||||
}
|
||||
|
||||
message StatusResponse {
|
||||
enum State {
|
||||
UNINITIALIZED = 0;
|
||||
BUSY = 1;
|
||||
READY = 2;
|
||||
ERROR = -1;
|
||||
}
|
||||
State state = 1;
|
||||
MemoryUsageData memory = 2;
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v4.23.4
|
||||
// source: pkg/grpc/proto/backend.proto
|
||||
// source: backend/backend.proto
|
||||
|
||||
package proto
|
||||
|
||||
|
@ -453,5 +453,5 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
|
|||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "pkg/grpc/proto/backend.proto",
|
||||
Metadata: "backend/backend.proto",
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue