feat: move other backends to grpc

This finally makes everything more consistent

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-07-15 01:19:43 +02:00
parent 5dcfdbe51d
commit 1d0ed95a54
54 changed files with 3171 additions and 1712 deletions

View file

@ -1,34 +1,30 @@
package backend
import (
"context"
"regexp"
"strings"
"sync"
"github.com/donomii/go-rwkv.cpp"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/langchain"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/bloomz.cpp"
)
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
supportStreams := false
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel interface{}
var inferenceModel *grpc.Client
var err error
opts := []model.Option{
model.WithLoadGRPCOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
model.WithLoadGRPCLLMModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModelFile(modelFile),
model.WithContext(o.Context),
}
if c.Backend == "" {
@ -41,95 +37,37 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
return nil, err
}
var fn func() (string, error)
switch model := inferenceModel.(type) {
case *rwkv.RwkvState:
supportStreams = true
fn = func() (string, error) {
stopWord := "\n"
if len(c.StopWords) > 0 {
stopWord = c.StopWords[0]
}
if err := model.ProcessInput(s); err != nil {
return "", err
}
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
return response, nil
}
case *bloomz.Bloomz:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []bloomz.PredictOption{
bloomz.SetTemperature(c.Temperature),
bloomz.SetTopP(c.TopP),
bloomz.SetTopK(c.TopK),
bloomz.SetTokens(c.Maxtokens),
bloomz.SetThreads(c.Threads),
}
if c.Seed != 0 {
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *grpc.Client:
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
supportStreams = true
fn = func() (string, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
if tokenCallback != nil {
ss := ""
err := model.PredictStream(context.TODO(), opts, func(s string) {
tokenCallback(s)
ss += s
})
return ss, err
} else {
reply, err := model.Predict(context.TODO(), opts)
return reply.Message, err
}
}
case *langchain.HuggingFace:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []langchain.PredictOption{
langchain.SetModel(c.Model),
langchain.SetMaxTokens(c.Maxtokens),
langchain.SetTemperature(c.Temperature),
langchain.SetStopWords(c.StopWords),
}
pred, er := model.PredictHuggingFace(s, predictOptions...)
if er != nil {
return "", er
}
return pred.Completion, nil
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (string, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(o.Context, opts, func(s string) {
tokenCallback(s)
ss += s
})
return ss, err
} else {
reply, err := inferenceModel.Predict(o.Context, opts)
return reply.Message, err
}
}
return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
l := Lock(modelFile)
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
res, err := fn()
if tokenCallback != nil && !supportStreams {
tokenCallback(res)
}
return res, err
return fn()
}, nil
}