make use of new bindings for gpt4all (#232)

This commit is contained in:
Ettore Di Giacinto 2023-05-11 14:31:19 +02:00 committed by GitHub
parent 032dee256f
commit 59e3c02002
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 54 deletions

View file

@ -15,9 +15,9 @@ import (
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang"
"github.com/rs/zerolog/log"
)
@ -26,7 +26,7 @@ type ModelLoader struct {
mu sync.Mutex
// TODO: this needs generics
models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
gptmodels map[string]*gpt4all.Model
gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM
dollymodels map[string]*gpt2.Dolly
@ -42,7 +42,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{
ModelPath: modelPath,
gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ),
gptmodels: make(map[string]*gpt4all.Model),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
dollymodels: make(map[string]*gpt2.Dolly),
redpajama: make(map[string]*gpt2.RedPajama),
@ -328,7 +328,7 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
return model, err
}
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
func (ml *ModelLoader) LoadGPT4AllModel(modelName string, opts ...gpt4all.ModelOption) (*gpt4all.Model, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
@ -346,7 +346,7 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gptj.New(modelFile)
model, err := gpt4all.New(modelFile, opts...)
if err != nil {
return nil, err
}
@ -470,8 +470,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadRedPajama(modelFile)
case "gpt2":
return ml.LoadGPT2Model(modelFile)
case "gptj":
return ml.LoadGPTJModel(modelFile)
case "gpt4all-llama":
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType))
case "gpt4all-mpt":
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType))
case "gpt4all-j":
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType))
case "bert-embeddings":
return ml.LoadBERT(modelFile)
case "rwkv":
@ -514,7 +518,23 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadGPTJModel(modelFile)
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType))
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType))
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType))
if modelerr == nil {
updateModels(model)
return model, nil
@ -553,14 +573,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadBloomz(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
// Do not autoload bloomz
//model, modelerr = ml.LoadBloomz(modelFile)
//if modelerr == nil {
// updateModels(model)
// return model, nil
//} else {
// err = multierror.Append(err, modelerr)
//}
model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {