feat: add rwkv support (#158)

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
Ettore Di Giacinto 2023-05-03 11:45:22 +02:00 committed by GitHub
parent 1ae7150810
commit 751b7eca62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 28 deletions

View file

@ -79,7 +79,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:"))
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
})
})

View file

@ -6,6 +6,7 @@ import (
"strings"
"sync"
"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
@ -13,6 +14,8 @@ import (
"github.com/hashicorp/go-multierror"
)
const tokenizerSuffix = ".tokenizer.json"
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
var loadedModels map[string]interface{} = map[string]interface{}{}
var muModels sync.Mutex
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
switch strings.ToLower(backendString) {
case "llama":
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
return loader.LoadGPT2Model(modelFile)
case "gptj":
return loader.LoadGPTJModel(modelFile)
case "rwkv":
return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}
}
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
updateModels := func(model interface{}) {
muModels.Lock()
defer muModels.Unlock()
@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
err = multierror.Append(err, modelerr)
}
model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
}
@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var inferenceModel interface{}
var err error
if c.Backend == "" {
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts)
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads))
} else {
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts)
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads))
}
if err != nil {
return nil, err
@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var fn func() (string, error)
switch model := inferenceModel.(type) {
case *rwkv.RwkvState:
supportStreams = true
fn = func() (string, error) {
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
stopWord := "\n"
if len(c.StopWords) > 0 {
stopWord = c.StopWords[0]
}
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
return response, nil
}
case *gpt2.StableLM:
fn = func() (string, error) {
// Generate the prediction using the language model