feat: add new gpt4all-j binding (#142)

This commit is contained in:
Ettore Di Giacinto 2023-05-01 20:00:15 +02:00 committed by GitHub
parent ac70252d70
commit 92452d46da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 7 additions and 11 deletions

View file

@ -37,7 +37,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
gptModel, gptjerr = loader.LoadGPTJModel(modelFile, gptj.SetThreads(c.Threads))
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
@ -108,17 +108,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
gptj.SetTopP(c.TopP),
gptj.SetTopK(c.TopK),
gptj.SetTokens(c.Maxtokens),
gptj.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed))
}
return gptModel.Predict(
s,
predictOptions...,