mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: add rwkv support (#158)
Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
parent
1ae7150810
commit
751b7eca62
7 changed files with 113 additions and 28 deletions
|
@ -79,7 +79,7 @@ var _ = Describe("API test", func() {
|
|||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:"))
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
|
||||
})
|
||||
|
||||
})
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
|
@ -13,6 +14,8 @@ import (
|
|||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex
|
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||
|
@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
|||
var loadedModels map[string]interface{} = map[string]interface{}{}
|
||||
var muModels sync.Mutex
|
||||
|
||||
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
|
||||
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||
switch strings.ToLower(backendString) {
|
||||
case "llama":
|
||||
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
|
||||
|
@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
|
|||
return loader.LoadGPT2Model(modelFile)
|
||||
case "gptj":
|
||||
return loader.LoadGPTJModel(modelFile)
|
||||
case "rwkv":
|
||||
return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
|
||||
default:
|
||||
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
||||
}
|
||||
}
|
||||
|
||||
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
|
||||
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||
updateModels := func(model interface{}) {
|
||||
muModels.Lock()
|
||||
defer muModels.Unlock()
|
||||
|
@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
|
|||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
|
||||
if modelerr == nil {
|
||||
updateModels(model)
|
||||
return model, nil
|
||||
} else {
|
||||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
|
||||
}
|
||||
|
||||
|
@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||
var inferenceModel interface{}
|
||||
var err error
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts)
|
||||
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads))
|
||||
} else {
|
||||
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts)
|
||||
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||
var fn func() (string, error)
|
||||
|
||||
switch model := inferenceModel.(type) {
|
||||
case *rwkv.RwkvState:
|
||||
supportStreams = true
|
||||
|
||||
fn = func() (string, error) {
|
||||
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
|
||||
stopWord := "\n"
|
||||
if len(c.StopWords) > 0 {
|
||||
stopWord = c.StopWords[0]
|
||||
}
|
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
case *gpt2.StableLM:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue