mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-30 15:35:01 +00:00
feat: add LangChainGo Huggingface backend (#446)
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
7282668da1
commit
3ba07a5928
13 changed files with 241 additions and 0 deletions
47
pkg/langchain/huggingface.go
Normal file
47
pkg/langchain/huggingface.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package langchain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/huggingface"
|
||||
)
|
||||
|
||||
type HuggingFace struct {
|
||||
modelPath string
|
||||
}
|
||||
|
||||
func NewHuggingFace(repoId string) (*HuggingFace, error) {
|
||||
return &HuggingFace{
|
||||
modelPath: repoId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *HuggingFace) PredictHuggingFace(text string, opts ...PredictOption) (*Predict, error) {
|
||||
po := NewPredictOptions(opts...)
|
||||
|
||||
// Init client
|
||||
llm, err := huggingface.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert from LocalAI to LangChainGo format of options
|
||||
co := []llms.CallOption{
|
||||
llms.WithModel(po.Model),
|
||||
llms.WithMaxTokens(po.MaxTokens),
|
||||
llms.WithTemperature(po.Temperature),
|
||||
llms.WithStopWords(po.StopWords),
|
||||
}
|
||||
|
||||
// Call Inference API
|
||||
ctx := context.Background()
|
||||
completion, err := llm.Call(ctx, text, co...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Predict{
|
||||
Completion: completion,
|
||||
}, nil
|
||||
}
|
57
pkg/langchain/langchain.go
Normal file
57
pkg/langchain/langchain.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package langchain
|
||||
|
||||
type PredictOptions struct {
|
||||
Model string `json:"model"`
|
||||
// MaxTokens is the maximum number of tokens to generate.
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
// Temperature is the temperature for sampling, between 0 and 1.
|
||||
Temperature float64 `json:"temperature"`
|
||||
// StopWords is a list of words to stop on.
|
||||
StopWords []string `json:"stop_words"`
|
||||
}
|
||||
|
||||
type PredictOption func(p *PredictOptions)
|
||||
|
||||
var DefaultOptions = PredictOptions{
|
||||
Model: "gpt2",
|
||||
MaxTokens: 200,
|
||||
Temperature: 0.96,
|
||||
StopWords: nil,
|
||||
}
|
||||
|
||||
type Predict struct {
|
||||
Completion string
|
||||
}
|
||||
|
||||
func SetModel(model string) PredictOption {
|
||||
return func(o *PredictOptions) {
|
||||
o.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func SetTemperature(temperature float64) PredictOption {
|
||||
return func(o *PredictOptions) {
|
||||
o.Temperature = temperature
|
||||
}
|
||||
}
|
||||
|
||||
func SetMaxTokens(maxTokens int) PredictOption {
|
||||
return func(o *PredictOptions) {
|
||||
o.MaxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func SetStopWords(stopWords []string) PredictOption {
|
||||
return func(o *PredictOptions) {
|
||||
o.StopWords = stopWords
|
||||
}
|
||||
}
|
||||
|
||||
// NewPredictOptions Create a new PredictOptions object with the given options.
|
||||
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
||||
p := DefaultOptions
|
||||
for _, opt := range opts {
|
||||
opt(&p)
|
||||
}
|
||||
return p
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
rwkv "github.com/donomii/go-rwkv.cpp"
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
|
@ -36,6 +37,7 @@ const (
|
|||
RwkvBackend = "rwkv"
|
||||
WhisperBackend = "whisper"
|
||||
StableDiffusionBackend = "stablediffusion"
|
||||
LCHuggingFaceBackend = "langchain-huggingface"
|
||||
)
|
||||
|
||||
var backends []string = []string{
|
||||
|
@ -100,6 +102,10 @@ var whisperModel = func(modelFile string) (interface{}, error) {
|
|||
return whisper.New(modelFile)
|
||||
}
|
||||
|
||||
var lcHuggingFace = func(repoId string) (interface{}, error) {
|
||||
return langchain.NewHuggingFace(repoId)
|
||||
}
|
||||
|
||||
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
|
||||
return func(s string) (interface{}, error) {
|
||||
return llama.New(s, opts...)
|
||||
|
@ -159,6 +165,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
|||
return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads))
|
||||
case WhisperBackend:
|
||||
return ml.LoadModel(modelFile, whisperModel)
|
||||
case LCHuggingFaceBackend:
|
||||
return ml.LoadModel(modelFile, lcHuggingFace)
|
||||
default:
|
||||
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue