mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 06:25:00 +00:00
feat: move other backends to grpc
This finally makes everything more consistent Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
5dcfdbe51d
commit
1d0ed95a54
54 changed files with 3171 additions and 1712 deletions
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Dolly struct {
|
||||
base.Base
|
||||
|
||||
dolly *transformers.Dolly
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *Dolly) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *Dolly) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,6 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) {
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
43
pkg/grpc/llm/transformers/falcon.go
Normal file
43
pkg/grpc/llm/transformers/falcon.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package transformers
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Falcon struct {
|
||||
base.Base
|
||||
|
||||
falcon *transformers.Falcon
|
||||
}
|
||||
|
||||
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
|
||||
model, err := transformers.NewFalcon(opts.Model)
|
||||
llm.falcon = model
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPT2 struct {
|
||||
base.Base
|
||||
|
||||
gpt2 *transformers.GPT2
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *GPT2) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *GPT2) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,5 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) {
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTJ struct {
|
||||
base.Base
|
||||
|
||||
gptj *transformers.GPTJ
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,5 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) {
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type GPTNeoX struct {
|
||||
base.Base
|
||||
|
||||
gptneox *transformers.GPTNeoX
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,5 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string)
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type MPT struct {
|
||||
base.Base
|
||||
|
||||
mpt *transformers.MPT
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *MPT) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *MPT) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,5 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) {
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Replit struct {
|
||||
base.Base
|
||||
|
||||
replit *transformers.Replit
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *Replit) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *Replit) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,5 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) {
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ package transformers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
)
|
||||
|
||||
type Starcoder struct {
|
||||
base.Base
|
||||
|
||||
starcoder *transformers.Starcoder
|
||||
}
|
||||
|
||||
|
@ -20,16 +23,12 @@ func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
// fallback to Predict
|
||||
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
go func() {
|
||||
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
|
||||
|
@ -39,4 +38,6 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string
|
|||
results <- res
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue