mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-24 04:25:00 +00:00
feat: add bert.cpp embeddings (#222)
This commit is contained in:
parent
e6db14e2f1
commit
f8ee20991c
14 changed files with 104 additions and 53 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
|
@ -62,6 +63,14 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
|
|||
}
|
||||
return model.Embeddings(s, predictOptions...)
|
||||
}
|
||||
// bert embeddings
|
||||
case *bert.Bert:
|
||||
fn = func() ([]float32, error) {
|
||||
if len(tokens) > 0 {
|
||||
return nil, fmt.Errorf("embeddings endpoint for this model supports only string")
|
||||
}
|
||||
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
||||
}
|
||||
default:
|
||||
fn = func() ([]float32, error) {
|
||||
return nil, fmt.Errorf("embeddings not supported by the backend")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue