Usage Features (#863)

This commit is contained in:
Dave 2023-08-18 15:23:14 -04:00 committed by GitHub
parent 2bacd0180d
commit 8cb1061c11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 1222 additions and 317 deletions

View file

@ -4,17 +4,39 @@ package base
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"os"
"sync"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type Base struct {
backendBusy sync.Mutex
State pb.StatusResponse_State
}
func (llm *Base) Busy() bool {
r := llm.backendBusy.TryLock()
if r {
llm.backendBusy.Unlock()
}
return r
}
func (llm *Base) Lock() {
llm.backendBusy.Lock()
llm.State = pb.StatusResponse_BUSY
}
func (llm *Base) Unlock() {
llm.State = pb.StatusResponse_READY
llm.backendBusy.Unlock()
}
func (llm *Base) Load(opts *pb.ModelOptions) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) {
@ -40,3 +62,32 @@ func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) {
func (llm *Base) TTS(*pb.TTSRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
return pb.TokenizationResponse{}, fmt.Errorf("unimplemented")
}
// backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details?
func (llm *Base) Status() (pb.StatusResponse, error) {
mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64),
}
pid := int32(os.Getpid())
backendProcess, err := gopsutil.NewProcess(pid)
if err == nil {
memInfo, err := backendProcess.MemoryInfo()
if err == nil {
mud.Total = memInfo.VMS // TEST, but rss seems reasonable first guess. Does include swap, but we might care about that.
mud.Breakdown["gopsutil-RSS"] = memInfo.RSS
}
}
return pb.StatusResponse{
State: llm.State,
Memory: &mud,
}, nil
}

View file

@ -158,3 +158,29 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
tresult.Text = res.Text
return tresult, err
}
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
res, err := client.TokenizeString(ctx, in, opts...)
if err != nil {
return nil, err
}
return res, nil
}
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.Status(ctx, &pb.HealthMessage{})
}

View file

@ -6,6 +6,7 @@ import (
)
type LLM interface {
Busy() bool
Predict(*pb.PredictOptions) (string, error)
PredictStream(*pb.PredictOptions, chan string) error
Load(*pb.ModelOptions) error
@ -13,6 +14,8 @@ type LLM interface {
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (api.Result, error)
TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)
}
func newReply(s string) *pb.Reply {

View file

@ -4,6 +4,7 @@ package bert
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
bert "github.com/go-skynet/go-bert.cpp"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -15,12 +16,21 @@ type Embeddings struct {
}
func (llm *Embeddings) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("bert backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := bert.New(opts.ModelFile)
llm.bert = model
return err
}
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
if len(opts.EmbeddingTokens) > 0 {
tokens := []int{}
for _, t := range opts.EmbeddingTokens {

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
"github.com/go-skynet/bloomz.cpp"
)
@ -18,6 +19,12 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("bloomz backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := bloomz.New(opts.ModelFile)
llm.bloomz = model
return err
@ -40,11 +47,16 @@ func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -53,6 +65,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
ggllm "github.com/mudler/go-ggllm.cpp"
)
@ -18,6 +19,13 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
ggllmOpts := []ggllm.ModelOption{}
if opts.ContextSize != 0 {
ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize)))
@ -118,10 +126,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
@ -138,6 +150,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
fmt.Println("err: ", err)
}
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -17,6 +18,13 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gpt4all backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := gpt4all.New(opts.ModelFile,
gpt4all.SetThreads(int(opts.Threads)),
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath))
@ -39,10 +47,15 @@ func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
go func() {
@ -56,6 +69,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
llm.gpt4all.SetTokenCallback(nil)
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -18,12 +19,21 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("langchain backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
llm.langchain, _ = langchain.NewHuggingFace(opts.Model)
llm.model = opts.Model
return nil
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
@ -38,6 +48,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
@ -52,6 +63,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res.Completion
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -8,6 +8,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/go-llama.cpp"
"github.com/rs/zerolog/log"
)
type LLM struct {
@ -18,6 +19,13 @@ type LLM struct {
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("llama backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
ropeFreqBase := float32(10000)
ropeFreqScale := float32(1)
@ -73,6 +81,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
model, err := llama.New(opts.ModelFile, llamaOpts...)
llm.llama = model
return err
}
@ -167,10 +176,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
@ -184,12 +197,16 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
fmt.Println("err: ", err)
}
close(results)
llm.Base.Unlock()
}()
return nil
}
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
predictOptions := buildPredictOptions(opts)
if len(opts.EmbeddingTokens) > 0 {
@ -202,3 +219,18 @@ func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
}
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
predictOptions := buildPredictOptions(opts)
l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...)
if err != nil {
return pb.TokenizationResponse{}, err
}
return pb.TokenizationResponse{
Length: l,
Tokens: tokens,
}, nil
}

View file

@ -9,6 +9,7 @@ import (
"github.com/donomii/go-rwkv.cpp"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
)
const tokenizerSuffix = ".tokenizer.json"
@ -20,6 +21,12 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("rwkv backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
modelPath := filepath.Dir(opts.ModelFile)
modelFile := filepath.Base(opts.ModelFile)
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
@ -32,6 +39,8 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
stopWord := "\n"
if len(opts.StopPrompts) > 0 {
@ -48,6 +57,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
stopWord := "\n"
@ -65,6 +75,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
return true
})
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,27 @@ type Dolly struct {
}
func (llm *Dolly) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("dolly backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewDolly(opts.ModelFile)
llm.dolly = model
return err
}
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +48,7 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) er
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Falcon struct {
}
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("transformers-falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewFalcon(opts.ModelFile)
llm.falcon = model
return err
}
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) e
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPT2 struct {
}
func (llm *GPT2) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gpt2 backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.New(opts.ModelFile)
llm.gpt2 = model
return err
}
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) err
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPTJ struct {
}
func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gptj backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewGPTJ(opts.ModelFile)
llm.gptj = model
return err
}
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) err
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type GPTNeoX struct {
}
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("gptneox backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewGPTNeoX(opts.ModelFile)
llm.gptneox = model
return err
}
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string)
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,27 @@ type MPT struct {
}
func (llm *MPT) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("mpt backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewMPT(opts.ModelFile)
llm.mpt = model
return err
}
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +48,7 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) erro
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Replit struct {
}
func (llm *Replit) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("replit backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewReplit(opts.ModelFile)
llm.replit = model
return err
}
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) e
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
@ -18,17 +19,26 @@ type Starcoder struct {
}
func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
if llm.Base.State != pb.StatusResponse_UNINITIALIZED {
log.Warn().Msgf("starcoder backend loading %s while already in state %s!", opts.Model, llm.Base.State.String())
}
llm.Base.Lock()
defer llm.Base.Unlock()
model, err := transformers.NewStarcoder(opts.ModelFile)
llm.starcoder = model
return err
}
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
llm.Base.Lock()
defer llm.Base.Unlock()
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
llm.Base.Lock()
go func() {
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -37,6 +47,7 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string
}
results <- res
close(results)
llm.Base.Unlock()
}()
return nil

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.15.8
// protoc-gen-go v1.27.1
// protoc v3.12.4
// source: pkg/grpc/proto/backend.proto
package proto
@ -20,6 +20,58 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type StatusResponse_State int32
const (
StatusResponse_UNINITIALIZED StatusResponse_State = 0
StatusResponse_BUSY StatusResponse_State = 1
StatusResponse_READY StatusResponse_State = 2
StatusResponse_ERROR StatusResponse_State = -1
)
// Enum value maps for StatusResponse_State.
var (
StatusResponse_State_name = map[int32]string{
0: "UNINITIALIZED",
1: "BUSY",
2: "READY",
-1: "ERROR",
}
StatusResponse_State_value = map[string]int32{
"UNINITIALIZED": 0,
"BUSY": 1,
"READY": 2,
"ERROR": -1,
}
)
func (x StatusResponse_State) Enum() *StatusResponse_State {
p := new(StatusResponse_State)
*p = x
return p
}
func (x StatusResponse_State) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (StatusResponse_State) Descriptor() protoreflect.EnumDescriptor {
return file_pkg_grpc_proto_backend_proto_enumTypes[0].Descriptor()
}
func (StatusResponse_State) Type() protoreflect.EnumType {
return &file_pkg_grpc_proto_backend_proto_enumTypes[0]
}
func (x StatusResponse_State) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use StatusResponse_State.Descriptor instead.
func (StatusResponse_State) EnumDescriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13, 0}
}
type HealthMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -1253,6 +1305,171 @@ func (x *TTSRequest) GetDst() string {
return ""
}
type TokenizationResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Length int32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"`
Tokens []int32 `protobuf:"varint,2,rep,packed,name=tokens,proto3" json:"tokens,omitempty"`
}
func (x *TokenizationResponse) Reset() {
*x = TokenizationResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TokenizationResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TokenizationResponse) ProtoMessage() {}
func (x *TokenizationResponse) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TokenizationResponse.ProtoReflect.Descriptor instead.
func (*TokenizationResponse) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{11}
}
func (x *TokenizationResponse) GetLength() int32 {
if x != nil {
return x.Length
}
return 0
}
func (x *TokenizationResponse) GetTokens() []int32 {
if x != nil {
return x.Tokens
}
return nil
}
type MemoryUsageData struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"`
Breakdown map[string]uint64 `protobuf:"bytes,2,rep,name=breakdown,proto3" json:"breakdown,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"`
}
func (x *MemoryUsageData) Reset() {
*x = MemoryUsageData{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MemoryUsageData) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MemoryUsageData) ProtoMessage() {}
func (x *MemoryUsageData) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MemoryUsageData.ProtoReflect.Descriptor instead.
func (*MemoryUsageData) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{12}
}
func (x *MemoryUsageData) GetTotal() uint64 {
if x != nil {
return x.Total
}
return 0
}
func (x *MemoryUsageData) GetBreakdown() map[string]uint64 {
if x != nil {
return x.Breakdown
}
return nil
}
type StatusResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
State StatusResponse_State `protobuf:"varint,1,opt,name=state,proto3,enum=backend.StatusResponse_State" json:"state,omitempty"`
Memory *MemoryUsageData `protobuf:"bytes,2,opt,name=memory,proto3" json:"memory,omitempty"`
}
func (x *StatusResponse) Reset() {
*x = StatusResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StatusResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StatusResponse) ProtoMessage() {}
func (x *StatusResponse) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead.
func (*StatusResponse) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13}
}
func (x *StatusResponse) GetState() StatusResponse_State {
if x != nil {
return x.State
}
return StatusResponse_UNINITIALIZED
}
func (x *StatusResponse) GetMemory() *MemoryUsageData {
if x != nil {
return x.Memory
}
return nil
}
var File_pkg_grpc_proto_backend_proto protoreflect.FileDescriptor
var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{
@ -1451,44 +1668,80 @@ var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{
0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78,
0x74, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x32, 0xeb, 0x03, 0x0a, 0x07, 0x42, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12,
0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65,
0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50,
0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x22, 0x46, 0x0a, 0x14, 0x54, 0x6f, 0x6b,
0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28,
0x05, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, 0x6b,
0x65, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x73, 0x22, 0xac, 0x01, 0x0a, 0x0f, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67,
0x65, 0x44, 0x61, 0x74, 0x61, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x01,
0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x45, 0x0a, 0x09, 0x62,
0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x27,
0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55,
0x73, 0x61, 0x67, 0x65, 0x44, 0x61, 0x74, 0x61, 0x2e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f,
0x77, 0x6e, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x62, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f,
0x77, 0x6e, 0x1a, 0x3c, 0x0a, 0x0e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x45,
0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18,
0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01,
0x22, 0xbc, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x12, 0x33, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01,
0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x12, 0x30, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f,
0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65,
0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67, 0x65, 0x44, 0x61,
0x74, 0x61, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x22, 0x43, 0x0a, 0x05, 0x53, 0x74,
0x61, 0x74, 0x65, 0x12, 0x11, 0x0a, 0x0d, 0x55, 0x4e, 0x49, 0x4e, 0x49, 0x54, 0x49, 0x41, 0x4c,
0x49, 0x5a, 0x45, 0x44, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x42, 0x55, 0x53, 0x59, 0x10, 0x01,
0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x41, 0x44, 0x59, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x05, 0x45,
0x52, 0x52, 0x4f, 0x52, 0x10, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x32,
0xf4, 0x04, 0x0a, 0x07, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48,
0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12,
0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62,
0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63,
0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65,
0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64,
0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64,
0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d,
0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f,
0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64,
0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d,
0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73,
0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e,
0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64,
0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61,
0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e,
0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64,
0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64,
0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63,
0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12,
0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x54, 0x53, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52,
0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b,
0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63,
0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64,
0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d,
0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e,
0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65,
0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62,
0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12,
0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69,
0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e,
0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d,
0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e,
0x54, 0x54, 0x53, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63,
0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4a, 0x0a,
0x0e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x65, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12,
0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63,
0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65,
0x6e, 0x64, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x62, 0x61,
0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79,
0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, 0x6b,
0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61,
0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -1503,43 +1756,56 @@ func file_pkg_grpc_proto_backend_proto_rawDescGZIP() []byte {
return file_pkg_grpc_proto_backend_proto_rawDescData
}
var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 11)
var file_pkg_grpc_proto_backend_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 15)
var file_pkg_grpc_proto_backend_proto_goTypes = []interface{}{
(*HealthMessage)(nil), // 0: backend.HealthMessage
(*PredictOptions)(nil), // 1: backend.PredictOptions
(*Reply)(nil), // 2: backend.Reply
(*ModelOptions)(nil), // 3: backend.ModelOptions
(*Result)(nil), // 4: backend.Result
(*EmbeddingResult)(nil), // 5: backend.EmbeddingResult
(*TranscriptRequest)(nil), // 6: backend.TranscriptRequest
(*TranscriptResult)(nil), // 7: backend.TranscriptResult
(*TranscriptSegment)(nil), // 8: backend.TranscriptSegment
(*GenerateImageRequest)(nil), // 9: backend.GenerateImageRequest
(*TTSRequest)(nil), // 10: backend.TTSRequest
(StatusResponse_State)(0), // 0: backend.StatusResponse.State
(*HealthMessage)(nil), // 1: backend.HealthMessage
(*PredictOptions)(nil), // 2: backend.PredictOptions
(*Reply)(nil), // 3: backend.Reply
(*ModelOptions)(nil), // 4: backend.ModelOptions
(*Result)(nil), // 5: backend.Result
(*EmbeddingResult)(nil), // 6: backend.EmbeddingResult
(*TranscriptRequest)(nil), // 7: backend.TranscriptRequest
(*TranscriptResult)(nil), // 8: backend.TranscriptResult
(*TranscriptSegment)(nil), // 9: backend.TranscriptSegment
(*GenerateImageRequest)(nil), // 10: backend.GenerateImageRequest
(*TTSRequest)(nil), // 11: backend.TTSRequest
(*TokenizationResponse)(nil), // 12: backend.TokenizationResponse
(*MemoryUsageData)(nil), // 13: backend.MemoryUsageData
(*StatusResponse)(nil), // 14: backend.StatusResponse
nil, // 15: backend.MemoryUsageData.BreakdownEntry
}
var file_pkg_grpc_proto_backend_proto_depIdxs = []int32{
8, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment
0, // 1: backend.Backend.Health:input_type -> backend.HealthMessage
1, // 2: backend.Backend.Predict:input_type -> backend.PredictOptions
3, // 3: backend.Backend.LoadModel:input_type -> backend.ModelOptions
1, // 4: backend.Backend.PredictStream:input_type -> backend.PredictOptions
1, // 5: backend.Backend.Embedding:input_type -> backend.PredictOptions
9, // 6: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest
6, // 7: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest
10, // 8: backend.Backend.TTS:input_type -> backend.TTSRequest
2, // 9: backend.Backend.Health:output_type -> backend.Reply
2, // 10: backend.Backend.Predict:output_type -> backend.Reply
4, // 11: backend.Backend.LoadModel:output_type -> backend.Result
2, // 12: backend.Backend.PredictStream:output_type -> backend.Reply
5, // 13: backend.Backend.Embedding:output_type -> backend.EmbeddingResult
4, // 14: backend.Backend.GenerateImage:output_type -> backend.Result
7, // 15: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult
4, // 16: backend.Backend.TTS:output_type -> backend.Result
9, // [9:17] is the sub-list for method output_type
1, // [1:9] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
9, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment
15, // 1: backend.MemoryUsageData.breakdown:type_name -> backend.MemoryUsageData.BreakdownEntry
0, // 2: backend.StatusResponse.state:type_name -> backend.StatusResponse.State
13, // 3: backend.StatusResponse.memory:type_name -> backend.MemoryUsageData
1, // 4: backend.Backend.Health:input_type -> backend.HealthMessage
2, // 5: backend.Backend.Predict:input_type -> backend.PredictOptions
4, // 6: backend.Backend.LoadModel:input_type -> backend.ModelOptions
2, // 7: backend.Backend.PredictStream:input_type -> backend.PredictOptions
2, // 8: backend.Backend.Embedding:input_type -> backend.PredictOptions
10, // 9: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest
7, // 10: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest
11, // 11: backend.Backend.TTS:input_type -> backend.TTSRequest
2, // 12: backend.Backend.TokenizeString:input_type -> backend.PredictOptions
1, // 13: backend.Backend.Status:input_type -> backend.HealthMessage
3, // 14: backend.Backend.Health:output_type -> backend.Reply
3, // 15: backend.Backend.Predict:output_type -> backend.Reply
5, // 16: backend.Backend.LoadModel:output_type -> backend.Result
3, // 17: backend.Backend.PredictStream:output_type -> backend.Reply
6, // 18: backend.Backend.Embedding:output_type -> backend.EmbeddingResult
5, // 19: backend.Backend.GenerateImage:output_type -> backend.Result
8, // 20: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult
5, // 21: backend.Backend.TTS:output_type -> backend.Result
12, // 22: backend.Backend.TokenizeString:output_type -> backend.TokenizationResponse
14, // 23: backend.Backend.Status:output_type -> backend.StatusResponse
14, // [14:24] is the sub-list for method output_type
4, // [4:14] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() { file_pkg_grpc_proto_backend_proto_init() }
@ -1680,19 +1946,56 @@ func file_pkg_grpc_proto_backend_proto_init() {
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TokenizationResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MemoryUsageData); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_backend_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StatusResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_proto_backend_proto_rawDesc,
NumEnums: 0,
NumMessages: 11,
NumEnums: 1,
NumMessages: 15,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_proto_backend_proto_goTypes,
DependencyIndexes: file_pkg_grpc_proto_backend_proto_depIdxs,
EnumInfos: file_pkg_grpc_proto_backend_proto_enumTypes,
MessageInfos: file_pkg_grpc_proto_backend_proto_msgTypes,
}.Build()
File_pkg_grpc_proto_backend_proto = out.File

View file

@ -16,6 +16,8 @@ service Backend {
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {}
}
message HealthMessage {}
@ -157,3 +159,24 @@ message TTSRequest {
string model = 2;
string dst = 3;
}
message TokenizationResponse {
int32 length = 1;
repeated int32 tokens = 2;
}
message MemoryUsageData {
uint64 total = 1;
map<string, uint64> breakdown = 2;
}
message StatusResponse {
enum State {
UNINITIALIZED = 0;
BUSY = 1;
READY = 2;
ERROR = -1;
}
State state = 1;
MemoryUsageData memory = 2;
}

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.15.8
// - protoc-gen-go-grpc v1.3.0
// - protoc v3.12.4
// source: pkg/grpc/proto/backend.proto
package proto
@ -18,6 +18,19 @@ import (
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
)
// BackendClient is the client API for Backend service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -30,6 +43,8 @@ type BackendClient interface {
GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error)
AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error)
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error)
TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error)
Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error)
}
type backendClient struct {
@ -42,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -51,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -60,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -68,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
}
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
if err != nil {
return nil, err
}
@ -101,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -110,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -119,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -128,7 +143,25 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse)
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
@ -147,6 +180,8 @@ type BackendServer interface {
GenerateImage(context.Context, *GenerateImageRequest) (*Result, error)
AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error)
TTS(context.Context, *TTSRequest) (*Result, error)
TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error)
Status(context.Context, *HealthMessage) (*StatusResponse, error)
mustEmbedUnimplementedBackendServer()
}
@ -178,6 +213,12 @@ func (UnimplementedBackendServer) AudioTranscription(context.Context, *Transcrip
func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented")
}
func (UnimplementedBackendServer) TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TokenizeString not implemented")
}
func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Status not implemented")
}
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {}
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
@ -201,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Health",
FullMethod: Backend_Health_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -219,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Predict",
FullMethod: Backend_Predict_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -237,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/LoadModel",
FullMethod: Backend_LoadModel_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -276,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Embedding",
FullMethod: Backend_Embedding_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -294,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/GenerateImage",
FullMethod: Backend_GenerateImage_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -312,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/AudioTranscription",
FullMethod: Backend_AudioTranscription_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -330,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/TTS",
FullMethod: Backend_TTS_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -338,6 +379,42 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
return interceptor(ctx, in, info, handler)
}
func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).TokenizeString(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_TokenizeString_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).Status(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Status_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
}
return interceptor(ctx, in, info, handler)
}
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -373,6 +450,14 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
MethodName: "TTS",
Handler: _Backend_TTS_Handler,
},
{
MethodName: "TokenizeString",
Handler: _Backend_TokenizeString_Handler,
},
{
MethodName: "Status",
Handler: _Backend_Status_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View file

@ -110,6 +110,32 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS
return nil
}
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
res, err := s.llm.TokenizeString(in)
if err != nil {
return nil, err
}
castTokens := make([]int32, len(res.Tokens))
for i, v := range res.Tokens {
castTokens[i] = int32(v)
}
return &pb.TokenizationResponse{
Length: int32(res.Length),
Tokens: castTokens,
}, err
}
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) {
res, err := s.llm.Status()
if err != nil {
return nil, err
}
return &res, nil
}
func StartServer(address string, model LLM) error {
lis, err := net.Listen("tcp", address)
if err != nil {

View file

@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
@ -64,10 +65,33 @@ var AutoLoadBackends []string = []string{
PiperBackend,
}
func (ml *ModelLoader) StopGRPC() {
for _, p := range ml.grpcProcesses {
p.Stop()
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
p, exists := ml.grpcProcesses[id]
if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id)
}
return strconv.Atoi(p.PID)
}
type GRPCProcessFilter = func(p *process.Process) bool
func includeAllProcesses(_ *process.Process) bool {
return true
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
for _, p := range ml.grpcProcesses {
if filter(p) {
p.Stop()
}
}
}
func (ml *ModelLoader) StopAllGRPC() {
ml.StopGRPC(includeAllProcesses)
// for _, p := range ml.grpcProcesses {
// p.Stop()
// }
}
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
@ -252,7 +276,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
// Is this really needed? BackendLoader already does this
ml.mu.Lock()
if m := ml.checkIsLoaded(o.model); m != nil {
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()
return m, nil

View file

@ -103,7 +103,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
defer ml.mu.Unlock()
// Check if we already have a loaded model
if model := ml.checkIsLoaded(modelName); model != nil {
if model := ml.CheckIsLoaded(modelName); model != nil {
return model, nil
}
@ -128,7 +128,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return model, nil
}
func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client {
if m, ok := ml.models[s]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", s)