mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
make use of new bindings for gpt4all (#232)
This commit is contained in:
parent
032dee256f
commit
59e3c02002
6 changed files with 92 additions and 54 deletions
|
@ -15,9 +15,9 @@ import (
|
|||
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ type ModelLoader struct {
|
|||
mu sync.Mutex
|
||||
// TODO: this needs generics
|
||||
models map[string]*llama.LLama
|
||||
gptmodels map[string]*gptj.GPTJ
|
||||
gptmodels map[string]*gpt4all.Model
|
||||
gpt2models map[string]*gpt2.GPT2
|
||||
gptstablelmmodels map[string]*gpt2.StableLM
|
||||
dollymodels map[string]*gpt2.Dolly
|
||||
|
@ -42,7 +42,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||
return &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
gpt2models: make(map[string]*gpt2.GPT2),
|
||||
gptmodels: make(map[string]*gptj.GPTJ),
|
||||
gptmodels: make(map[string]*gpt4all.Model),
|
||||
gptstablelmmodels: make(map[string]*gpt2.StableLM),
|
||||
dollymodels: make(map[string]*gpt2.Dolly),
|
||||
redpajama: make(map[string]*gpt2.RedPajama),
|
||||
|
@ -328,7 +328,7 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
|
|||
return model, err
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
|
||||
func (ml *ModelLoader) LoadGPT4AllModel(modelName string, opts ...gpt4all.ModelOption) (*gpt4all.Model, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
|
@ -346,7 +346,7 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
|
|||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||
|
||||
model, err := gptj.New(modelFile)
|
||||
model, err := gpt4all.New(modelFile, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -470,8 +470,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
|||
return ml.LoadRedPajama(modelFile)
|
||||
case "gpt2":
|
||||
return ml.LoadGPT2Model(modelFile)
|
||||
case "gptj":
|
||||
return ml.LoadGPTJModel(modelFile)
|
||||
case "gpt4all-llama":
|
||||
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType))
|
||||
case "gpt4all-mpt":
|
||||
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType))
|
||||
case "gpt4all-j":
|
||||
return ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType))
|
||||
case "bert-embeddings":
|
||||
return ml.LoadBERT(modelFile)
|
||||
case "rwkv":
|
||||
|
@ -514,7 +518,23 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
|
|||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
model, modelerr = ml.LoadGPTJModel(modelFile)
|
||||
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType))
|
||||
if modelerr == nil {
|
||||
updateModels(model)
|
||||
return model, nil
|
||||
} else {
|
||||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType))
|
||||
if modelerr == nil {
|
||||
updateModels(model)
|
||||
return model, nil
|
||||
} else {
|
||||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
model, modelerr = ml.LoadGPT4AllModel(modelFile, gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType))
|
||||
if modelerr == nil {
|
||||
updateModels(model)
|
||||
return model, nil
|
||||
|
@ -553,14 +573,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
|
|||
} else {
|
||||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
|
||||
model, modelerr = ml.LoadBloomz(modelFile)
|
||||
if modelerr == nil {
|
||||
updateModels(model)
|
||||
return model, nil
|
||||
} else {
|
||||
err = multierror.Append(err, modelerr)
|
||||
}
|
||||
// Do not autoload bloomz
|
||||
//model, modelerr = ml.LoadBloomz(modelFile)
|
||||
//if modelerr == nil {
|
||||
// updateModels(model)
|
||||
// return model, nil
|
||||
//} else {
|
||||
// err = multierror.Append(err, modelerr)
|
||||
//}
|
||||
|
||||
model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
|
||||
if modelerr == nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue