deps: update gpt4all bindings, fix search path on new versions (#592)

This commit is contained in:
Ettore Di Giacinto 2023-06-14 13:24:53 +02:00 committed by GitHub
parent 467e88d305
commit e37361985c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 33 additions and 52 deletions

View file

@ -135,7 +135,7 @@ func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error)
}
}
func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) {
log.Debug().Msgf("Loading model %s from %s", backendString, modelFile)
switch strings.ToLower(backendString) {
case LlamaBackend:
@ -161,7 +161,7 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads))))
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetLibrarySearchPath(filepath.Join(assetDir, "backend-assets", "gpt4all"))))
case BertEmbeddingsBackend:
return ml.LoadModel(modelFile, bertEmbeddings)
case RwkvBackend:
@ -175,7 +175,7 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
}
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) {
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (interface{}, error) {
log.Debug().Msgf("Loading model '%s' greedly", modelFile)
ml.mu.Lock()
@ -193,7 +193,7 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
continue
}
log.Debug().Msgf("[%s] Attempting to load", b)
model, modelerr := ml.BackendLoader(b, modelFile, llamaOpts, threads)
model, modelerr := ml.BackendLoader(b, modelFile, llamaOpts, threads, assetDir)
if modelerr == nil && model != nil {
log.Debug().Msgf("[%s] Loads OK", b)
return model, nil