feat: allow to specify default backend for model (#156)

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
Ettore Di Giacinto 2023-05-03 00:31:28 +02:00 committed by GitHub
parent 70caf9bf8c
commit 1ae7150810
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 77 deletions

View file

@ -10,22 +10,86 @@ import (
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
)
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
var loadedModels map[string]interface{} = map[string]interface{}{}
var muModels sync.Mutex
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
switch strings.ToLower(backendString) {
case "llama":
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
case "stablelm":
return loader.LoadStableLMModel(modelFile)
case "gpt2":
return loader.LoadGPT2Model(modelFile)
case "gptj":
return loader.LoadGPTJModel(modelFile)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}
}
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
updateModels := func(model interface{}) {
muModels.Lock()
defer muModels.Unlock()
loadedModels[modelFile] = model
}
muModels.Lock()
m, exists := loadedModels[modelFile]
if exists {
muModels.Unlock()
return m, nil
}
muModels.Unlock()
model, modelerr := loader.LoadLLaMAModel(modelFile, llamaOpts...)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = loader.LoadGPTJModel(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = loader.LoadGPT2Model(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = loader.LoadStableLMModel(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
}
func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) {
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
supportStreams := false
modelFile := c.Model
// Try to load the model
var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{}
if c.ContextSize != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
@ -34,25 +98,21 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return nil, fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
}
}
var inferenceModel interface{}
var err error
if c.Backend == "" {
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts)
} else {
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts)
}
if err != nil {
return nil, err
}
var fn func() (string, error)
switch {
case stableLMModel != nil:
switch model := inferenceModel.(type) {
case *gpt2.StableLM:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
@ -71,12 +131,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return stableLMModel.Predict(
return model.Predict(
s,
predictOptions...,
)
}
case gpt2Model != nil:
case *gpt2.GPT2:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
@ -95,12 +155,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return gpt2Model.Predict(
return model.Predict(
s,
predictOptions...,
)
}
case gptModel != nil:
case *gptj.GPTJ:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
@ -119,12 +179,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed))
}
return gptModel.Predict(
return model.Predict(
s,
predictOptions...,
)
}
case model != nil:
case *llama.LLama:
supportStreams = true
fn = func() (string, error) {