mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-25 04:54:59 +00:00
feat: add new gpt4all-j binding (#142)
This commit is contained in:
parent
ac70252d70
commit
92452d46da
5 changed files with 7 additions and 11 deletions
|
@ -37,7 +37,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
|
|||
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
|
||||
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
|
||||
if llamaerr != nil {
|
||||
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
|
||||
gptModel, gptjerr = loader.LoadGPTJModel(modelFile, gptj.SetThreads(c.Threads))
|
||||
if gptjerr != nil {
|
||||
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
|
||||
if gpt2err != nil {
|
||||
|
@ -108,17 +108,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
|
|||
gptj.SetTopP(c.TopP),
|
||||
gptj.SetTopK(c.TopK),
|
||||
gptj.SetTokens(c.Maxtokens),
|
||||
gptj.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gptj.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return gptModel.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue